#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Private reader component for LeRobotDataset. Handles random-access reading (HF dataset, delta indices, video decoding).""" from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor from pathlib import Path import datasets import torch from .dataset_metadata import LeRobotDatasetMetadata from .feature_utils import ( check_delta_timestamps, get_delta_indices, get_hf_features_from_features, ) from .io_utils import ( hf_transform_to_torch, load_nested_dataset, ) from .video_utils import decode_video_frames class DatasetReader: """Encapsulates read-side state and methods for LeRobotDataset. Owns: hf_dataset, _absolute_to_relative_idx, delta_indices. """ def __init__( self, meta: LeRobotDatasetMetadata, root: Path, episodes: list[int] | None, tolerance_s: float, video_backend: str, delta_timestamps: dict[str, list[float]] | None, image_transforms: Callable | None, return_uint8: bool = False, ): """Initialize the reader with metadata, filtering, and transform config. The HF dataset is not loaded here — call :meth:`try_load` or :meth:`load_and_activate` afterward. Args: meta: Dataset metadata instance. root: Local dataset root directory. episodes: Optional list of episode indices to select. ``None`` means all episodes. tolerance_s: Timestamp synchronization tolerance in seconds. video_backend: Video decoding backend identifier. delta_timestamps: Optional dict mapping feature keys to lists of relative timestamp offsets for temporal context windows. image_transforms: Optional torchvision v2 transform applied to visual features. """ self._meta = meta self.root = root self.episodes = episodes self._tolerance_s = tolerance_s self._video_backend = video_backend self._image_transforms = image_transforms self._return_uint8 = return_uint8 self.hf_dataset: datasets.Dataset | None = None self._absolute_to_relative_idx: dict[int, int] | None = None # Setup delta_indices (doesn't depend on hf_dataset) self.delta_indices = None if delta_timestamps is not None: check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s) self.delta_indices = get_delta_indices(delta_timestamps, meta.fps) def try_load(self) -> bool: """Attempt to load from local cache. Returns True if data is sufficient.""" try: self.hf_dataset = self._load_hf_dataset() except (FileNotFoundError, NotADirectoryError): self.hf_dataset = None return False if not self._check_cached_episodes_sufficient(): self.hf_dataset = None return False self._build_index_mapping() return True def load_and_activate(self) -> None: """Load HF dataset from disk and build index mapping. Call after data is on disk.""" self.hf_dataset = self._load_hf_dataset() self._build_index_mapping() def _build_index_mapping(self) -> None: """Build absolute-to-relative index mapping from loaded hf_dataset.""" self._absolute_to_relative_idx = None if self.episodes is not None and self.hf_dataset is not None: indices = self.hf_dataset.data.column("index").to_numpy() self._absolute_to_relative_idx = dict(zip(indices.tolist(), range(len(indices)), strict=True)) @property def num_frames(self) -> int: """Number of frames in selected episodes.""" if self.episodes is not None and self.hf_dataset is not None: return len(self.hf_dataset) return self._meta.total_frames @property def num_episodes(self) -> int: """Number of episodes selected.""" return len(self.episodes) if self.episodes is not None else self._meta.total_episodes def _load_hf_dataset(self) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" features = get_hf_features_from_features(self._meta.features) # Datasets annotated with the PR1 language columns may have been # written without registering those columns in ``meta/info.json`` # (e.g. they predate ``CODEBASE_VERSION="v3.1"`` and were # back-filled by ``lerobot-annotate``). Probe a single parquet # shard and graft the column features on so the strict # ``Dataset.from_parquet`` cast doesn't fail with # ``column names don't match``. features = self._extend_features_with_language_columns(features) hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes) hf_dataset.set_transform(hf_transform_to_torch) return hf_dataset def _extend_features_with_language_columns( self, features: datasets.Features ) -> datasets.Features: """Add ``language_persistent`` / ``language_events`` to ``features`` when the underlying parquet shards declare them but the metadata doesn't. No-op when neither column is present or both are already registered. """ # Find any one parquet to peek at; bail if there are none yet # (the dataset will fail later for an unrelated reason and we # want that error to surface as-is). try: sample = next((self.root / "data").glob("*/*.parquet")) except StopIteration: return features from pyarrow import parquet as _pq # noqa: PLC0415 schema_names = set(_pq.read_schema(sample).names) from .language import ( # noqa: PLC0415 LANGUAGE_EVENTS, LANGUAGE_PERSISTENT, language_events_column_feature, language_persistent_column_feature, ) extra: dict[str, object] = {} if LANGUAGE_PERSISTENT in schema_names and LANGUAGE_PERSISTENT not in features: extra[LANGUAGE_PERSISTENT] = language_persistent_column_feature() if LANGUAGE_EVENTS in schema_names and LANGUAGE_EVENTS not in features: extra[LANGUAGE_EVENTS] = language_events_column_feature() if not extra: return features return datasets.Features({**features, **extra}) def _check_cached_episodes_sufficient(self) -> bool: """Check if the cached dataset contains all requested episodes and their video files.""" if self.hf_dataset is None or len(self.hf_dataset) == 0: return False available_episodes = { ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx for ep_idx in self.hf_dataset.unique("episode_index") } if self.episodes is None: requested_episodes = set(range(self._meta.total_episodes)) else: requested_episodes = set(self.episodes) if not requested_episodes.issubset(available_episodes): return False if len(self._meta.video_keys) > 0: for ep_idx in requested_episodes: for vid_key in self._meta.video_keys: video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key) if not video_path.exists(): return False return True def get_episodes_file_paths(self) -> list[Path]: """Return deduplicated file paths (data + video) for selected episodes. Used to build the ``allow_patterns`` list for ``snapshot_download``. """ episodes = self.episodes if self.episodes is not None else list(range(self._meta.total_episodes)) fpaths = [str(self._meta.get_data_file_path(ep_idx)) for ep_idx in episodes] if len(self._meta.video_keys) > 0: video_files = [ str(self._meta.get_video_file_path(ep_idx, vid_key)) for vid_key in self._meta.video_keys for ep_idx in episodes ] fpaths += video_files # episodes are stored in the same files, so we return unique paths only fpaths = list(set(fpaths)) return fpaths def _get_query_indices( self, abs_idx: int, ep_idx: int ) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]: """Compute query indices for delta timestamps.""" ep = self._meta.episodes[ep_idx] ep_start = ep["dataset_from_index"] ep_end = ep["dataset_to_index"] query_indices = { key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx] for key, delta_idx in self.delta_indices.items() } padding = { f"{key}_is_pad": torch.BoolTensor( [(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx] ) for key, delta_idx in self.delta_indices.items() } return query_indices, padding def _get_query_timestamps( self, current_ts: float, query_indices: dict[str, list[int]] | None = None, ) -> dict[str, list[float]]: query_timestamps = {} for key in self._meta.video_keys: if query_indices is not None and key in query_indices: if self._absolute_to_relative_idx is not None: relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]] timestamps = self.hf_dataset[relative_indices]["timestamp"] else: timestamps = self.hf_dataset[query_indices[key]]["timestamp"] query_timestamps[key] = torch.stack(timestamps).tolist() else: query_timestamps[key] = [current_ts] return query_timestamps def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: """Query dataset for indices across keys, skipping video keys.""" result: dict = {} for key, q_idx in query_indices.items(): if key in self._meta.video_keys: continue relative_indices = ( q_idx if self._absolute_to_relative_idx is None else [self._absolute_to_relative_idx[idx] for idx in q_idx] ) try: result[key] = torch.stack(self.hf_dataset[key][relative_indices]) except (KeyError, TypeError, IndexError): result[key] = torch.stack(self.hf_dataset[relative_indices][key]) return result def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault. """ ep = self._meta.episodes[ep_idx] def _decode_single(vid_key: str, query_ts: list[float]) -> tuple[str, torch.Tensor]: from_timestamp = ep[f"videos/{vid_key}/from_timestamp"] shifted_query_ts = [from_timestamp + ts for ts in query_ts] video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key) frames = decode_video_frames( video_path, shifted_query_ts, self._tolerance_s, self._video_backend, return_uint8=self._return_uint8, ) return vid_key, frames.squeeze(0) items = list(query_timestamps.items()) # Single camera: no threading overhead if len(items) <= 1: return {vid_key: _decode_single(vid_key, query_ts)[1] for vid_key, query_ts in items} # Multi-camera: decode in parallel (video decoding releases the GIL) with ThreadPoolExecutor(max_workers=len(items)) as pool: futures = [pool.submit(_decode_single, k, ts) for k, ts in items] return dict(f.result() for f in futures) def get_item(self, idx) -> dict: """Core __getitem__ logic. Assumes hf_dataset is loaded. ``idx`` is a *relative* index into the (possibly episode-filtered) HF dataset, **not** the absolute frame index stored in the ``index`` column. The absolute index is retrieved from the row itself. """ item = self.hf_dataset[idx] ep_idx = item["episode_index"].item() abs_idx = item["index"].item() query_indices = None if self.delta_indices is not None: query_indices, padding = self._get_query_indices(abs_idx, ep_idx) query_result = self._query_hf_dataset(query_indices) item = {**item, **padding} for key, val in query_result.items(): item[key] = val if len(self._meta.video_keys) > 0: current_ts = item["timestamp"].item() query_timestamps = self._get_query_timestamps(current_ts, query_indices) video_frames = self._query_videos(query_timestamps, ep_idx) item = {**video_frames, **item} if self._image_transforms is not None: image_keys = self._meta.camera_keys for cam in image_keys: item[cam] = self._image_transforms(item[cam]) # Add task as a string task_idx = item["task_index"].item() item["task"] = self._meta.tasks.iloc[task_idx].name return item