mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
``_load_hf_dataset`` was building the strict cast schema only from ``meta/info.json["features"]``. Datasets annotated by ``lerobot-annotate`` but still tagged at the older codebase version (no ``language_persistent`` / ``language_events`` entry in ``info.json``) carry both columns in the parquet itself but not in the features dict, so ``Dataset.from_parquet`` blew up with ``CastError: column names don't match`` when trying to project a 9-column parquet onto a 7-column schema. Probe one parquet shard's actual schema; if either language column is present in the parquet but missing from ``features``, graft it on using PR 1's ``language_persistent_column_feature`` / ``language_events_column_feature`` helpers. No-op when neither column is present (fully backwards-compatible with v3.0 datasets), no-op when both are already registered (fully forwards-compatible with future v3.1 ``info.json`` writes). This unblocks dry-run inference on PR 2-annotated datasets that weren't re-tagged to v3.1 — including the ones in the field today. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
342 lines
14 KiB
Python
342 lines
14 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Private reader component for LeRobotDataset. Handles random-access reading (HF dataset, delta indices, video decoding)."""
|
|
|
|
from collections.abc import Callable
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from pathlib import Path
|
|
|
|
import datasets
|
|
import torch
|
|
|
|
from .dataset_metadata import LeRobotDatasetMetadata
|
|
from .feature_utils import (
|
|
check_delta_timestamps,
|
|
get_delta_indices,
|
|
get_hf_features_from_features,
|
|
)
|
|
from .io_utils import (
|
|
hf_transform_to_torch,
|
|
load_nested_dataset,
|
|
)
|
|
from .video_utils import decode_video_frames
|
|
|
|
|
|
class DatasetReader:
|
|
"""Encapsulates read-side state and methods for LeRobotDataset.
|
|
|
|
Owns: hf_dataset, _absolute_to_relative_idx, delta_indices.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
meta: LeRobotDatasetMetadata,
|
|
root: Path,
|
|
episodes: list[int] | None,
|
|
tolerance_s: float,
|
|
video_backend: str,
|
|
delta_timestamps: dict[str, list[float]] | None,
|
|
image_transforms: Callable | None,
|
|
return_uint8: bool = False,
|
|
):
|
|
"""Initialize the reader with metadata, filtering, and transform config.
|
|
|
|
The HF dataset is not loaded here — call :meth:`try_load` or
|
|
:meth:`load_and_activate` afterward.
|
|
|
|
Args:
|
|
meta: Dataset metadata instance.
|
|
root: Local dataset root directory.
|
|
episodes: Optional list of episode indices to select. ``None``
|
|
means all episodes.
|
|
tolerance_s: Timestamp synchronization tolerance in seconds.
|
|
video_backend: Video decoding backend identifier.
|
|
delta_timestamps: Optional dict mapping feature keys to lists of
|
|
relative timestamp offsets for temporal context windows.
|
|
image_transforms: Optional torchvision v2 transform applied to
|
|
visual features.
|
|
"""
|
|
self._meta = meta
|
|
self.root = root
|
|
self.episodes = episodes
|
|
self._tolerance_s = tolerance_s
|
|
self._video_backend = video_backend
|
|
self._image_transforms = image_transforms
|
|
self._return_uint8 = return_uint8
|
|
|
|
self.hf_dataset: datasets.Dataset | None = None
|
|
self._absolute_to_relative_idx: dict[int, int] | None = None
|
|
|
|
# Setup delta_indices (doesn't depend on hf_dataset)
|
|
self.delta_indices = None
|
|
if delta_timestamps is not None:
|
|
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
|
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
|
|
|
def try_load(self) -> bool:
|
|
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
|
try:
|
|
self.hf_dataset = self._load_hf_dataset()
|
|
except (FileNotFoundError, NotADirectoryError):
|
|
self.hf_dataset = None
|
|
return False
|
|
if not self._check_cached_episodes_sufficient():
|
|
self.hf_dataset = None
|
|
return False
|
|
self._build_index_mapping()
|
|
return True
|
|
|
|
def load_and_activate(self) -> None:
|
|
"""Load HF dataset from disk and build index mapping. Call after data is on disk."""
|
|
self.hf_dataset = self._load_hf_dataset()
|
|
self._build_index_mapping()
|
|
|
|
def _build_index_mapping(self) -> None:
|
|
"""Build absolute-to-relative index mapping from loaded hf_dataset."""
|
|
self._absolute_to_relative_idx = None
|
|
if self.episodes is not None and self.hf_dataset is not None:
|
|
indices = self.hf_dataset.data.column("index").to_numpy()
|
|
self._absolute_to_relative_idx = dict(zip(indices.tolist(), range(len(indices)), strict=True))
|
|
|
|
@property
|
|
def num_frames(self) -> int:
|
|
"""Number of frames in selected episodes."""
|
|
if self.episodes is not None and self.hf_dataset is not None:
|
|
return len(self.hf_dataset)
|
|
return self._meta.total_frames
|
|
|
|
@property
|
|
def num_episodes(self) -> int:
|
|
"""Number of episodes selected."""
|
|
return len(self.episodes) if self.episodes is not None else self._meta.total_episodes
|
|
|
|
def _load_hf_dataset(self) -> datasets.Dataset:
|
|
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
|
features = get_hf_features_from_features(self._meta.features)
|
|
# Datasets annotated with the PR1 language columns may have been
|
|
# written without registering those columns in ``meta/info.json``
|
|
# (e.g. they predate ``CODEBASE_VERSION="v3.1"`` and were
|
|
# back-filled by ``lerobot-annotate``). Probe a single parquet
|
|
# shard and graft the column features on so the strict
|
|
# ``Dataset.from_parquet`` cast doesn't fail with
|
|
# ``column names don't match``.
|
|
features = self._extend_features_with_language_columns(features)
|
|
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
|
|
hf_dataset.set_transform(hf_transform_to_torch)
|
|
return hf_dataset
|
|
|
|
def _extend_features_with_language_columns(
|
|
self, features: datasets.Features
|
|
) -> datasets.Features:
|
|
"""Add ``language_persistent`` / ``language_events`` to ``features``
|
|
when the underlying parquet shards declare them but the metadata
|
|
doesn't. No-op when neither column is present or both are
|
|
already registered.
|
|
"""
|
|
# Find any one parquet to peek at; bail if there are none yet
|
|
# (the dataset will fail later for an unrelated reason and we
|
|
# want that error to surface as-is).
|
|
try:
|
|
sample = next((self.root / "data").glob("*/*.parquet"))
|
|
except StopIteration:
|
|
return features
|
|
|
|
from pyarrow import parquet as _pq # noqa: PLC0415
|
|
|
|
schema_names = set(_pq.read_schema(sample).names)
|
|
from .language import ( # noqa: PLC0415
|
|
LANGUAGE_EVENTS,
|
|
LANGUAGE_PERSISTENT,
|
|
language_events_column_feature,
|
|
language_persistent_column_feature,
|
|
)
|
|
|
|
extra: dict[str, object] = {}
|
|
if LANGUAGE_PERSISTENT in schema_names and LANGUAGE_PERSISTENT not in features:
|
|
extra[LANGUAGE_PERSISTENT] = language_persistent_column_feature()
|
|
if LANGUAGE_EVENTS in schema_names and LANGUAGE_EVENTS not in features:
|
|
extra[LANGUAGE_EVENTS] = language_events_column_feature()
|
|
if not extra:
|
|
return features
|
|
return datasets.Features({**features, **extra})
|
|
|
|
def _check_cached_episodes_sufficient(self) -> bool:
|
|
"""Check if the cached dataset contains all requested episodes and their video files."""
|
|
if self.hf_dataset is None or len(self.hf_dataset) == 0:
|
|
return False
|
|
|
|
available_episodes = {
|
|
ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx
|
|
for ep_idx in self.hf_dataset.unique("episode_index")
|
|
}
|
|
|
|
if self.episodes is None:
|
|
requested_episodes = set(range(self._meta.total_episodes))
|
|
else:
|
|
requested_episodes = set(self.episodes)
|
|
|
|
if not requested_episodes.issubset(available_episodes):
|
|
return False
|
|
|
|
if len(self._meta.video_keys) > 0:
|
|
for ep_idx in requested_episodes:
|
|
for vid_key in self._meta.video_keys:
|
|
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
|
|
if not video_path.exists():
|
|
return False
|
|
|
|
return True
|
|
|
|
def get_episodes_file_paths(self) -> list[Path]:
|
|
"""Return deduplicated file paths (data + video) for selected episodes.
|
|
|
|
Used to build the ``allow_patterns`` list for ``snapshot_download``.
|
|
"""
|
|
episodes = self.episodes if self.episodes is not None else list(range(self._meta.total_episodes))
|
|
fpaths = [str(self._meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
|
|
if len(self._meta.video_keys) > 0:
|
|
video_files = [
|
|
str(self._meta.get_video_file_path(ep_idx, vid_key))
|
|
for vid_key in self._meta.video_keys
|
|
for ep_idx in episodes
|
|
]
|
|
fpaths += video_files
|
|
# episodes are stored in the same files, so we return unique paths only
|
|
fpaths = list(set(fpaths))
|
|
return fpaths
|
|
|
|
def _get_query_indices(
|
|
self, abs_idx: int, ep_idx: int
|
|
) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]:
|
|
"""Compute query indices for delta timestamps."""
|
|
ep = self._meta.episodes[ep_idx]
|
|
ep_start = ep["dataset_from_index"]
|
|
ep_end = ep["dataset_to_index"]
|
|
query_indices = {
|
|
key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx]
|
|
for key, delta_idx in self.delta_indices.items()
|
|
}
|
|
padding = {
|
|
f"{key}_is_pad": torch.BoolTensor(
|
|
[(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx]
|
|
)
|
|
for key, delta_idx in self.delta_indices.items()
|
|
}
|
|
return query_indices, padding
|
|
|
|
def _get_query_timestamps(
|
|
self,
|
|
current_ts: float,
|
|
query_indices: dict[str, list[int]] | None = None,
|
|
) -> dict[str, list[float]]:
|
|
query_timestamps = {}
|
|
for key in self._meta.video_keys:
|
|
if query_indices is not None and key in query_indices:
|
|
if self._absolute_to_relative_idx is not None:
|
|
relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]]
|
|
timestamps = self.hf_dataset[relative_indices]["timestamp"]
|
|
else:
|
|
timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
|
|
query_timestamps[key] = torch.stack(timestamps).tolist()
|
|
else:
|
|
query_timestamps[key] = [current_ts]
|
|
|
|
return query_timestamps
|
|
|
|
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
|
"""Query dataset for indices across keys, skipping video keys."""
|
|
result: dict = {}
|
|
for key, q_idx in query_indices.items():
|
|
if key in self._meta.video_keys:
|
|
continue
|
|
relative_indices = (
|
|
q_idx
|
|
if self._absolute_to_relative_idx is None
|
|
else [self._absolute_to_relative_idx[idx] for idx in q_idx]
|
|
)
|
|
try:
|
|
result[key] = torch.stack(self.hf_dataset[key][relative_indices])
|
|
except (KeyError, TypeError, IndexError):
|
|
result[key] = torch.stack(self.hf_dataset[relative_indices][key])
|
|
return result
|
|
|
|
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
|
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
|
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
|
Segmentation Fault.
|
|
"""
|
|
ep = self._meta.episodes[ep_idx]
|
|
|
|
def _decode_single(vid_key: str, query_ts: list[float]) -> tuple[str, torch.Tensor]:
|
|
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
|
|
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
|
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
|
|
frames = decode_video_frames(
|
|
video_path,
|
|
shifted_query_ts,
|
|
self._tolerance_s,
|
|
self._video_backend,
|
|
return_uint8=self._return_uint8,
|
|
)
|
|
return vid_key, frames.squeeze(0)
|
|
|
|
items = list(query_timestamps.items())
|
|
|
|
# Single camera: no threading overhead
|
|
if len(items) <= 1:
|
|
return {vid_key: _decode_single(vid_key, query_ts)[1] for vid_key, query_ts in items}
|
|
|
|
# Multi-camera: decode in parallel (video decoding releases the GIL)
|
|
with ThreadPoolExecutor(max_workers=len(items)) as pool:
|
|
futures = [pool.submit(_decode_single, k, ts) for k, ts in items]
|
|
return dict(f.result() for f in futures)
|
|
|
|
def get_item(self, idx) -> dict:
|
|
"""Core __getitem__ logic. Assumes hf_dataset is loaded.
|
|
|
|
``idx`` is a *relative* index into the (possibly episode-filtered)
|
|
HF dataset, **not** the absolute frame index stored in the ``index``
|
|
column. The absolute index is retrieved from the row itself.
|
|
"""
|
|
item = self.hf_dataset[idx]
|
|
ep_idx = item["episode_index"].item()
|
|
abs_idx = item["index"].item()
|
|
|
|
query_indices = None
|
|
if self.delta_indices is not None:
|
|
query_indices, padding = self._get_query_indices(abs_idx, ep_idx)
|
|
query_result = self._query_hf_dataset(query_indices)
|
|
item = {**item, **padding}
|
|
for key, val in query_result.items():
|
|
item[key] = val
|
|
|
|
if len(self._meta.video_keys) > 0:
|
|
current_ts = item["timestamp"].item()
|
|
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
|
video_frames = self._query_videos(query_timestamps, ep_idx)
|
|
item = {**video_frames, **item}
|
|
|
|
if self._image_transforms is not None:
|
|
image_keys = self._meta.camera_keys
|
|
for cam in image_keys:
|
|
item[cam] = self._image_transforms(item[cam])
|
|
|
|
# Add task as a string
|
|
task_idx = item["task_index"].item()
|
|
item["task"] = self._meta.tasks.iloc[task_idx].name
|
|
|
|
return item
|