diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 245634382..8e50a2aec 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -424,7 +424,7 @@ robot = SO100Follower(robot_config) robot.connect() dataset = LeRobotDataset("/", episodes=[episode_idx]) -actions = dataset.hf_dataset.select_columns("action") +actions = dataset.select_columns("action") log_say(f"Replaying episode {episode_idx}") for idx in range(dataset.num_frames): diff --git a/examples/backward_compatibility/replay.py b/examples/backward_compatibility/replay.py index 13fdfd5f5..e999b5913 100644 --- a/examples/backward_compatibility/replay.py +++ b/examples/backward_compatibility/replay.py @@ -78,7 +78,7 @@ def replay(cfg: ReplayConfig): robot = make_robot_from_config(cfg.robot) dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode]) - actions = dataset.hf_dataset.select_columns(ACTION) + actions = dataset.select_columns(ACTION) robot.connect() try: diff --git a/examples/dataset/load_lerobot_dataset.py b/examples/dataset/load_lerobot_dataset.py index ea3516710..44ae21a11 100644 --- a/examples/dataset/load_lerobot_dataset.py +++ b/examples/dataset/load_lerobot_dataset.py @@ -88,9 +88,8 @@ def main(): # The previous metadata class is contained in the 'meta' attribute of the dataset: print(dataset.meta) - # LeRobotDataset actually wraps an underlying Hugging Face dataset - # (see https://huggingface.co/docs/datasets for more information). - print(dataset.hf_dataset) + # You can inspect the dataset using its repr: + print(dataset) # LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working # with the latter, like iterating through the dataset. diff --git a/examples/lekiwi/replay.py b/examples/lekiwi/replay.py index cf89aea16..0cfd4811c 100644 --- a/examples/lekiwi/replay.py +++ b/examples/lekiwi/replay.py @@ -35,9 +35,7 @@ def main(): # Fetch the dataset to replay dataset = LeRobotDataset("/", episodes=[EPISODE_IDX]) - # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 - episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) - actions = episode_frames.select_columns(ACTION) + actions = dataset.select_columns(ACTION) # Connect to the robot robot.connect() @@ -48,7 +46,7 @@ def main(): print("Starting replay loop...") log_say(f"Replaying episode {EPISODE_IDX}") - for idx in range(len(episode_frames)): + for idx in range(dataset.num_frames): t0 = time.perf_counter() # Get recorded action from dataset diff --git a/examples/phone_to_so100/replay.py b/examples/phone_to_so100/replay.py index 7b955cdb7..c544614a7 100644 --- a/examples/phone_to_so100/replay.py +++ b/examples/phone_to_so100/replay.py @@ -67,9 +67,7 @@ def main(): # Fetch the dataset to replay dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX]) - # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 - episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) - actions = episode_frames.select_columns(ACTION) + actions = dataset.select_columns(ACTION) # Connect to the robot robot.connect() @@ -80,7 +78,7 @@ def main(): print("Starting replay loop...") log_say(f"Replaying episode {EPISODE_IDX}") - for idx in range(len(episode_frames)): + for idx in range(dataset.num_frames): t0 = time.perf_counter() # Get recorded action from dataset diff --git a/examples/so100_to_so100_EE/replay.py b/examples/so100_to_so100_EE/replay.py index b042e02dd..7caa560f0 100644 --- a/examples/so100_to_so100_EE/replay.py +++ b/examples/so100_to_so100_EE/replay.py @@ -68,9 +68,7 @@ def main(): # Fetch the dataset to replay dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX]) - # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 - episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) - actions = episode_frames.select_columns(ACTION) + actions = dataset.select_columns(ACTION) # Connect to the robot robot.connect() @@ -81,7 +79,7 @@ def main(): print("Starting replay loop...") log_say(f"Replaying episode {EPISODE_IDX}") - for idx in range(len(episode_frames)): + for idx in range(dataset.num_frames): t0 = time.perf_counter() # Get recorded action from dataset diff --git a/src/lerobot/datasets/__init__.py b/src/lerobot/datasets/__init__.py new file mode 100644 index 000000000..42c4ab810 --- /dev/null +++ b/src/lerobot/datasets/__init__.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.multi_dataset import MultiLeRobotDataset +from lerobot.datasets.sampler import EpisodeAwareSampler +from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset +from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig + +__all__ = [ + "EpisodeAwareSampler", + "ImageTransforms", + "ImageTransformsConfig", + "LeRobotDataset", + "LeRobotDatasetMetadata", + "MultiLeRobotDataset", + "StreamingLeRobotDataset", +] diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py index 560a90a6e..a43ba07b4 100644 --- a/src/lerobot/datasets/dataset_metadata.py +++ b/src/lerobot/datasets/dataset_metadata.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib from pathlib import Path import numpy as np @@ -53,6 +54,13 @@ CODEBASE_VERSION = "v3.0" class LeRobotDatasetMetadata: + """Metadata container for a LeRobot dataset. + + Manages the ``info.json``, ``stats.json``, ``tasks.parquet``, and + ``episodes/`` parquet files that describe a dataset's structure, content, + and statistics. + """ + def __init__( self, repo_id: str, @@ -61,33 +69,51 @@ class LeRobotDatasetMetadata: force_cache_sync: bool = False, metadata_buffer_size: int = 10, ): + """Load or download metadata for an existing LeRobot dataset. + + Attempts to load metadata from local disk. If files are missing or + ``force_cache_sync`` is ``True``, downloads the ``meta/`` directory from + the Hub. + + Args: + repo_id: Repository identifier (e.g. ``'lerobot/aloha_sim'``). + root: Local directory for the dataset. Defaults to + ``$HF_LEROBOT_HOME/{repo_id}``. + revision: Git revision (branch, tag, or commit hash). Defaults to + the current codebase version. + force_cache_sync: If ``True``, re-download metadata from the Hub + even when local files exist. + metadata_buffer_size: Number of episode metadata records to buffer + in memory before flushing to parquet. + """ self.repo_id = repo_id self.revision = revision if revision else CODEBASE_VERSION self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id - self.writer = None + self._pq_writer = None self.latest_episode = None - self.metadata_buffer: list[dict] = [] - self.metadata_buffer_size = metadata_buffer_size + self._metadata_buffer: list[dict] = [] + self._metadata_buffer_size = metadata_buffer_size + self._finalized = False try: if force_cache_sync: raise FileNotFoundError - self.load_metadata() + self._load_metadata() except (FileNotFoundError, NotADirectoryError): if is_valid_version(self.revision): self.revision = get_safe_version(self.repo_id, self.revision) (self.root / "meta").mkdir(exist_ok=True, parents=True) - self.pull_from_repo(allow_patterns="meta/") - self.load_metadata() + self._pull_from_repo(allow_patterns="meta/") + self._load_metadata() def _flush_metadata_buffer(self) -> None: """Write all buffered episode metadata to parquet file.""" - if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0: + if not hasattr(self, "_metadata_buffer") or len(self._metadata_buffer) == 0: return combined_dict = {} - for episode_dict in self.metadata_buffer: + for episode_dict in self._metadata_buffer: for key, value in episode_dict.items(): if key not in combined_dict: combined_dict[key] = [] @@ -96,40 +122,50 @@ class LeRobotDatasetMetadata: val = value[0] if isinstance(value, list) else value combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val) - first_ep = self.metadata_buffer[0] + first_ep = self._metadata_buffer[0] chunk_idx = first_ep["meta/episodes/chunk_index"][0] file_idx = first_ep["meta/episodes/file_index"][0] table = pa.Table.from_pydict(combined_dict) - if not self.writer: + if not self._pq_writer: path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)) path.parent.mkdir(parents=True, exist_ok=True) - self.writer = pq.ParquetWriter( + self._pq_writer = pq.ParquetWriter( path, schema=table.schema, compression="snappy", use_dictionary=True ) - self.writer.write_table(table) + self._pq_writer.write_table(table) - self.latest_episode = self.metadata_buffer[-1] - self.metadata_buffer.clear() + self.latest_episode = self._metadata_buffer[-1] + self._metadata_buffer.clear() def _close_writer(self) -> None: """Close and cleanup the parquet writer if it exists.""" self._flush_metadata_buffer() - writer = getattr(self, "writer", None) + writer = getattr(self, "_pq_writer", None) if writer is not None: writer.close() - self.writer = None + self._pq_writer = None + + def finalize(self) -> None: + """Flush metadata buffer and close the parquet writer. + + Idempotent — safe to call multiple times. + """ + if getattr(self, "_finalized", False): + return + self._close_writer() + self._finalized = True def __del__(self): - """ - Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor - """ - self._close_writer() + """Safety net: flush and close parquet writer on garbage collection.""" + # During interpreter shutdown, referenced objects may already be collected. + with contextlib.suppress(Exception): + self.finalize() - def load_metadata(self): + def _load_metadata(self): self.info = load_info(self.root) check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) self.tasks = load_tasks(self.root) @@ -137,7 +173,7 @@ class LeRobotDatasetMetadata: self.episodes = load_episodes(self.root) self.stats = load_stats(self.root) - def pull_from_repo( + def _pull_from_repo( self, allow_patterns: list[str] | str | None = None, ignore_patterns: list[str] | str | None = None, @@ -153,6 +189,7 @@ class LeRobotDatasetMetadata: @property def url_root(self) -> str: + """Hugging Face Hub URL root for this dataset.""" return f"hf://datasets/{self.repo_id}" @property @@ -161,6 +198,17 @@ class LeRobotDatasetMetadata: return packaging.version.parse(self.info["codebase_version"]) def get_data_file_path(self, ep_index: int) -> Path: + """Return the relative parquet file path for the given episode index. + + Args: + ep_index: Zero-based episode index. + + Returns: + Path to the parquet file containing this episode's data. + + Raises: + IndexError: If ``ep_index`` is out of range. + """ if self.episodes is None: self.episodes = load_episodes(self.root) if ep_index >= len(self.episodes): @@ -174,6 +222,19 @@ class LeRobotDatasetMetadata: return Path(fpath) def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: + """Return the relative video file path for the given episode and video key. + + Args: + ep_index: Zero-based episode index. + vid_key: Feature key identifying the video stream + (e.g. ``'observation.images.laptop'``). + + Returns: + Path to the video file containing this episode's frames. + + Raises: + IndexError: If ``ep_index`` is out of range. + """ if self.episodes is None: self.episodes = load_episodes(self.root) if ep_index >= len(self.episodes): @@ -277,6 +338,17 @@ class LeRobotDatasetMetadata: return None def save_episode_tasks(self, tasks: list[str]): + """Register tasks for the current episode and persist to disk. + + New tasks that do not already exist in the dataset are assigned + sequential task indices and appended to the tasks parquet file. + + Args: + tasks: List of unique task descriptions in natural language. + + Raises: + ValueError: If ``tasks`` contains duplicates. + """ if len(set(tasks)) != len(tasks): raise ValueError(f"Tasks are not unique: {tasks}") @@ -336,8 +408,8 @@ class LeRobotDatasetMetadata: latest_path = ( self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) - if self.writer is None - else self.writer.where + if self._pq_writer is None + else self._pq_writer.where ) if Path(latest_path).exists(): @@ -359,10 +431,10 @@ class LeRobotDatasetMetadata: episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames] # Add to buffer - self.metadata_buffer.append(episode_dict) + self._metadata_buffer.append(episode_dict) self.latest_episode = episode_dict - if len(self.metadata_buffer) >= self.metadata_buffer_size: + if len(self._metadata_buffer) >= self._metadata_buffer_size: self._flush_metadata_buffer() def save_episode( @@ -373,6 +445,20 @@ class LeRobotDatasetMetadata: episode_stats: dict[str, dict], episode_metadata: dict, ) -> None: + """Persist episode metadata, update dataset info, and aggregate stats. + + Writes the episode's metadata to the buffered parquet writer, increments + the total episode/frame counters in ``info.json``, and merges the + episode's statistics into the running dataset statistics. + + Args: + episode_index: Zero-based index of the episode being saved. + episode_length: Number of frames in this episode. + episode_tasks: List of task descriptions for this episode. + episode_stats: Per-feature statistics for this episode. + episode_metadata: Additional metadata (chunk/file indices, frame + ranges, video timestamps, etc.). + """ episode_dict = { "episode_index": episode_index, "tasks": episode_tasks, @@ -479,7 +565,32 @@ class LeRobotDatasetMetadata: data_files_size_in_mb: int | None = None, video_files_size_in_mb: int | None = None, ) -> "LeRobotDatasetMetadata": - """Creates metadata for a LeRobotDataset.""" + """Create metadata for a new LeRobot dataset from scratch. + + Initializes the ``info.json`` file on disk with the provided feature + schema and dataset settings. No episode data is written yet. + + Args: + repo_id: Repository identifier (e.g. ``'user/my_dataset'``). + fps: Frames per second used during data collection. + features: Feature specification dict mapping feature names to their + type/shape metadata. + robot_type: Optional robot type string stored in metadata. + root: Local directory for the dataset. Defaults to + ``$HF_LEROBOT_HOME/{repo_id}``. Must not already exist. + use_videos: If ``True``, visual modalities are encoded as MP4 videos. + metadata_buffer_size: Number of episode metadata records to buffer + before flushing to parquet. + chunks_size: Max number of files per chunk directory. ``None`` uses + the default. + data_files_size_in_mb: Max parquet file size in MB. ``None`` uses the + default. + video_files_size_in_mb: Max video file size in MB. ``None`` uses the + default. + + Returns: + A new :class:`LeRobotDatasetMetadata` instance. + """ obj = cls.__new__(cls) obj.repo_id = repo_id obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id @@ -510,8 +621,9 @@ class LeRobotDatasetMetadata: ) write_json(obj.info, obj.root / INFO_PATH) obj.revision = None - obj.writer = None + obj._pq_writer = None obj.latest_episode = None - obj.metadata_buffer = [] - obj.metadata_buffer_size = metadata_buffer_size + obj._metadata_buffer = [] + obj._metadata_buffer_size = metadata_buffer_size + obj._finalized = False return obj diff --git a/src/lerobot/datasets/dataset_reader.py b/src/lerobot/datasets/dataset_reader.py new file mode 100644 index 000000000..0233a3cf6 --- /dev/null +++ b/src/lerobot/datasets/dataset_reader.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Private reader component for LeRobotDataset. Handles random-access reading (HF dataset, delta indices, video decoding).""" + +from collections.abc import Callable +from pathlib import Path + +import datasets +import torch + +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import ( + check_delta_timestamps, + get_delta_indices, + get_hf_features_from_features, +) +from lerobot.datasets.io_utils import ( + hf_transform_to_torch, + load_nested_dataset, +) +from lerobot.datasets.video_utils import decode_video_frames + + +class DatasetReader: + """Encapsulates read-side state and methods for LeRobotDataset. + + Owns: hf_dataset, _absolute_to_relative_idx, delta_indices. + """ + + def __init__( + self, + meta: LeRobotDatasetMetadata, + root: Path, + episodes: list[int] | None, + tolerance_s: float, + video_backend: str, + delta_timestamps: dict[str, list[float]] | None, + image_transforms: Callable | None, + ): + """Initialize the reader with metadata, filtering, and transform config. + + The HF dataset is not loaded here — call :meth:`try_load` or + :meth:`load_and_activate` afterward. + + Args: + meta: Dataset metadata instance. + root: Local dataset root directory. + episodes: Optional list of episode indices to select. ``None`` + means all episodes. + tolerance_s: Timestamp synchronization tolerance in seconds. + video_backend: Video decoding backend identifier. + delta_timestamps: Optional dict mapping feature keys to lists of + relative timestamp offsets for temporal context windows. + image_transforms: Optional torchvision v2 transform applied to + visual features. + """ + self._meta = meta + self._root = root + self.episodes = episodes + self._tolerance_s = tolerance_s + self._video_backend = video_backend + self._image_transforms = image_transforms + + self.hf_dataset: datasets.Dataset | None = None + self._absolute_to_relative_idx: dict[int, int] | None = None + + # Setup delta_indices (doesn't depend on hf_dataset) + self.delta_indices = None + if delta_timestamps is not None: + check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s) + self.delta_indices = get_delta_indices(delta_timestamps, meta.fps) + + def try_load(self) -> bool: + """Attempt to load from local cache. Returns True if data is sufficient.""" + try: + self.hf_dataset = self._load_hf_dataset() + except (FileNotFoundError, NotADirectoryError): + self.hf_dataset = None + return False + if not self._check_cached_episodes_sufficient(): + self.hf_dataset = None + return False + self._build_index_mapping() + return True + + def load_and_activate(self) -> None: + """Load HF dataset from disk and build index mapping. Call after data is on disk.""" + self.hf_dataset = self._load_hf_dataset() + self._build_index_mapping() + + def _build_index_mapping(self) -> None: + """Build absolute-to-relative index mapping from loaded hf_dataset.""" + self._absolute_to_relative_idx = None + if self.episodes is not None and self.hf_dataset is not None: + self._absolute_to_relative_idx = { + abs_idx.item() if isinstance(abs_idx, torch.Tensor) else abs_idx: rel_idx + for rel_idx, abs_idx in enumerate(self.hf_dataset["index"]) + } + + @property + def num_frames(self) -> int: + """Number of frames in selected episodes.""" + if self.episodes is not None and self.hf_dataset is not None: + return len(self.hf_dataset) + return self._meta.total_frames + + @property + def num_episodes(self) -> int: + """Number of episodes selected.""" + return len(self.episodes) if self.episodes is not None else self._meta.total_episodes + + def _load_hf_dataset(self) -> datasets.Dataset: + """hf_dataset contains all the observations, states, actions, rewards, etc.""" + features = get_hf_features_from_features(self._meta.features) + hf_dataset = load_nested_dataset(self._root / "data", features=features, episodes=self.episodes) + hf_dataset.set_transform(hf_transform_to_torch) + return hf_dataset + + def _check_cached_episodes_sufficient(self) -> bool: + """Check if the cached dataset contains all requested episodes and their video files.""" + if self.hf_dataset is None or len(self.hf_dataset) == 0: + return False + + available_episodes = { + ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx + for ep_idx in self.hf_dataset.unique("episode_index") + } + + if self.episodes is None: + requested_episodes = set(range(self._meta.total_episodes)) + else: + requested_episodes = set(self.episodes) + + if not requested_episodes.issubset(available_episodes): + return False + + if len(self._meta.video_keys) > 0: + for ep_idx in requested_episodes: + for vid_key in self._meta.video_keys: + video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key) + if not video_path.exists(): + return False + + return True + + def get_episodes_file_paths(self) -> list[Path]: + """Return deduplicated file paths (data + video) for selected episodes. + + Used to build the ``allow_patterns`` list for ``snapshot_download``. + """ + episodes = self.episodes if self.episodes is not None else list(range(self._meta.total_episodes)) + fpaths = [str(self._meta.get_data_file_path(ep_idx)) for ep_idx in episodes] + if len(self._meta.video_keys) > 0: + video_files = [ + str(self._meta.get_video_file_path(ep_idx, vid_key)) + for vid_key in self._meta.video_keys + for ep_idx in episodes + ] + fpaths += video_files + # episodes are stored in the same files, so we return unique paths only + fpaths = list(set(fpaths)) + return fpaths + + def _get_query_indices( + self, abs_idx: int, ep_idx: int + ) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]: + """Compute query indices for delta timestamps.""" + ep = self._meta.episodes[ep_idx] + ep_start = ep["dataset_from_index"] + ep_end = ep["dataset_to_index"] + query_indices = { + key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx] + for key, delta_idx in self.delta_indices.items() + } + padding = { + f"{key}_is_pad": torch.BoolTensor( + [(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx] + ) + for key, delta_idx in self.delta_indices.items() + } + return query_indices, padding + + def _get_query_timestamps( + self, + current_ts: float, + query_indices: dict[str, list[int]] | None = None, + ) -> dict[str, list[float]]: + query_timestamps = {} + for key in self._meta.video_keys: + if query_indices is not None and key in query_indices: + if self._absolute_to_relative_idx is not None: + relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]] + timestamps = self.hf_dataset[relative_indices]["timestamp"] + else: + timestamps = self.hf_dataset[query_indices[key]]["timestamp"] + query_timestamps[key] = torch.stack(timestamps).tolist() + else: + query_timestamps[key] = [current_ts] + + return query_timestamps + + def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: + """Query dataset for indices across keys, skipping video keys.""" + result: dict = {} + for key, q_idx in query_indices.items(): + if key in self._meta.video_keys: + continue + relative_indices = ( + q_idx + if self._absolute_to_relative_idx is None + else [self._absolute_to_relative_idx[idx] for idx in q_idx] + ) + try: + result[key] = torch.stack(self.hf_dataset[key][relative_indices]) + except (KeyError, TypeError, IndexError): + result[key] = torch.stack(self.hf_dataset[relative_indices][key]) + return result + + def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: + """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function + in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a + Segmentation Fault. + """ + ep = self._meta.episodes[ep_idx] + item = {} + for vid_key, query_ts in query_timestamps.items(): + from_timestamp = ep[f"videos/{vid_key}/from_timestamp"] + shifted_query_ts = [from_timestamp + ts for ts in query_ts] + + video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key) + frames = decode_video_frames(video_path, shifted_query_ts, self._tolerance_s, self._video_backend) + item[vid_key] = frames.squeeze(0) + + return item + + def get_item(self, idx) -> dict: + """Core __getitem__ logic. Assumes hf_dataset is loaded. + + ``idx`` is a *relative* index into the (possibly episode-filtered) + HF dataset, **not** the absolute frame index stored in the ``index`` + column. The absolute index is retrieved from the row itself. + """ + item = self.hf_dataset[idx] + ep_idx = item["episode_index"].item() + abs_idx = item["index"].item() + + query_indices = None + if self.delta_indices is not None: + query_indices, padding = self._get_query_indices(abs_idx, ep_idx) + query_result = self._query_hf_dataset(query_indices) + item = {**item, **padding} + for key, val in query_result.items(): + item[key] = val + + if len(self._meta.video_keys) > 0: + current_ts = item["timestamp"].item() + query_timestamps = self._get_query_timestamps(current_ts, query_indices) + video_frames = self._query_videos(query_timestamps, ep_idx) + item = {**video_frames, **item} + + if self._image_transforms is not None: + image_keys = self._meta.camera_keys + for cam in image_keys: + item[cam] = self._image_transforms(item[cam]) + + # Add task as a string + task_idx = item["task_index"].item() + item["task"] = self._meta.tasks.iloc[task_idx].name + + # add subtask information if available + if "subtask_index" in self._meta.features and self._meta.subtasks is not None: + subtask_idx = item["subtask_index"].item() + item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name + + return item diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index 87cdc18e5..cd2b9fc7c 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -891,7 +891,7 @@ def _copy_and_reindex_episodes_metadata( total_frames += src_episode["length"] - dst_meta._close_writer() + dst_meta.finalize() dst_meta.info.update( { diff --git a/src/lerobot/datasets/dataset_writer.py b/src/lerobot/datasets/dataset_writer.py new file mode 100644 index 000000000..b74b18e0c --- /dev/null +++ b/src/lerobot/datasets/dataset_writer.py @@ -0,0 +1,625 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Private writer component for LeRobotDataset. Handles sequential recording (episode buffer, ParquetWriter, image writer, video encoding).""" + +from __future__ import annotations + +import concurrent.futures +import contextlib +import logging +import shutil +import tempfile +from pathlib import Path + +import datasets +import numpy as np +import pandas as pd +import PIL.Image +import pyarrow.parquet as pq +import torch + +from lerobot.datasets.compute_stats import compute_episode_stats +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.feature_utils import ( + get_hf_features_from_features, + validate_episode_buffer, + validate_frame, +) +from lerobot.datasets.image_writer import AsyncImageWriter, write_image +from lerobot.datasets.io_utils import ( + embed_images, + get_file_size_in_mb, + load_episodes, + write_info, +) +from lerobot.datasets.utils import ( + DEFAULT_EPISODES_PATH, + DEFAULT_IMAGE_PATH, + update_chunk_file_indices, +) +from lerobot.datasets.video_utils import ( + StreamingVideoEncoder, + concatenate_video_files, + encode_video_frames, + get_video_duration_in_s, +) + +logger = logging.getLogger(__name__) + + +def _encode_video_worker( + video_key: str, + episode_index: int, + root: Path, + fps: int, + vcodec: str = "libsvtav1", + encoder_threads: int | None = None, +) -> Path: + temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4" + fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0) + img_dir = (root / fpath).parent + encode_video_frames( + img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads + ) + shutil.rmtree(img_dir) + return temp_path + + +class DatasetWriter: + """Encapsulates write-side state and methods for LeRobotDataset. + + Owns: episode_buffer, image_writer, _pq_writer (ParquetWriter), _latest_episode, + _current_file_start_frame, _streaming_encoder, _episodes_since_last_encoding, _recorded_frames. + """ + + def __init__( + self, + meta: LeRobotDatasetMetadata, + root: Path, + vcodec: str, + encoder_threads: int | None, + batch_encoding_size: int, + streaming_encoder: StreamingVideoEncoder | None = None, + initial_frames: int = 0, + ): + """Initialize the writer with metadata, codec, and encoding config. + + Args: + meta: Dataset metadata instance (used for feature schema, chunk + settings, and episode persistence). + root: Local dataset root directory. + vcodec: Video codec for encoding (e.g. ``'libsvtav1'``, ``'h264'``). + encoder_threads: Threads per encoder instance. ``None`` for auto. + batch_encoding_size: Number of episodes to accumulate before + batch-encoding videos. + streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder` + for real-time encoding. ``None`` disables streaming mode. + initial_frames: Starting frame count (non-zero when resuming). + """ + self._meta = meta + self._root = root + self._vcodec = vcodec + self._encoder_threads = encoder_threads + self._batch_encoding_size = batch_encoding_size + self._streaming_encoder = streaming_encoder + + # Writer state + self.image_writer: AsyncImageWriter | None = None + self.episode_buffer: dict = self._create_episode_buffer() + self._pq_writer: pq.ParquetWriter | None = None + self._latest_episode: dict | None = None + self._current_file_start_frame: int | None = None + self._episodes_since_last_encoding: int = 0 + self._recorded_frames: int = initial_frames + self._finalized = False + + def _create_episode_buffer(self, episode_index: int | None = None) -> dict: + current_ep_idx = self._meta.total_episodes if episode_index is None else episode_index + ep_buffer = {} + ep_buffer["size"] = 0 + ep_buffer["task"] = [] + for key in self._meta.features: + ep_buffer[key] = current_ep_idx if key == "episode_index" else [] + return ep_buffer + + def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: + fpath = DEFAULT_IMAGE_PATH.format( + image_key=image_key, episode_index=episode_index, frame_index=frame_index + ) + return self._root / fpath + + def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path: + return self._get_image_file_path(episode_index, image_key, frame_index=0).parent + + def _save_image( + self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1 + ) -> None: + if self.image_writer is None: + if isinstance(image, torch.Tensor): + image = image.cpu().numpy() + write_image(image, fpath, compress_level=compress_level) + else: + self.image_writer.save_image(image=image, fpath=fpath, compress_level=compress_level) + + def add_frame(self, frame: dict) -> None: + """Add a frame to the episode_buffer. Images are written to a temporary directory.""" + # Convert torch to numpy if needed + for name in frame: + if isinstance(frame[name], torch.Tensor): + frame[name] = frame[name].numpy() + + validate_frame(frame, self._meta.features) + + if self.episode_buffer is None: + self.episode_buffer = self._create_episode_buffer() + + # Automatically add frame_index and timestamp to episode buffer + frame_index = self.episode_buffer["size"] + timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self._meta.fps + self.episode_buffer["frame_index"].append(frame_index) + self.episode_buffer["timestamp"].append(timestamp) + self.episode_buffer["task"].append(frame.pop("task")) + + # Start streaming encoder on first frame of episode + if frame_index == 0 and self._streaming_encoder is not None: + self._streaming_encoder.start_episode( + video_keys=list(self._meta.video_keys), + temp_dir=self._root, + ) + + # Add frame features to episode_buffer + for key in frame: + if key not in self._meta.features: + raise ValueError( + f"An element of the frame is not in the features. '{key}' not in '{self._meta.features.keys()}'." + ) + + if self._meta.features[key]["dtype"] == "video" and self._streaming_encoder is not None: + self._streaming_encoder.feed_frame(key, frame[key]) + self.episode_buffer[key].append(None) + elif self._meta.features[key]["dtype"] in ["image", "video"]: + img_path = self._get_image_file_path( + episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index + ) + if frame_index == 0: + img_path.parent.mkdir(parents=True, exist_ok=True) + compress_level = 1 if self._meta.features[key]["dtype"] == "video" else 6 + self._save_image(frame[key], img_path, compress_level) + self.episode_buffer[key].append(str(img_path)) + else: + self.episode_buffer[key].append(frame[key]) + + self.episode_buffer["size"] += 1 + + def save_episode( + self, + episode_data: dict | None = None, + parallel_encoding: bool = True, + ) -> None: + """Save the current episode in self.episode_buffer to disk.""" + episode_buffer = episode_data if episode_data is not None else self.episode_buffer + + validate_episode_buffer(episode_buffer, self._meta.total_episodes, self._meta.features) + + # size and task are special cases that won't be added to hf_dataset + episode_length = episode_buffer.pop("size") + tasks = episode_buffer.pop("task") + episode_tasks = list(set(tasks)) + episode_index = episode_buffer["episode_index"] + + episode_buffer["index"] = np.arange(self._meta.total_frames, self._meta.total_frames + episode_length) + episode_buffer["episode_index"] = np.full((episode_length,), episode_index) + + # Update tasks and task indices with new tasks if any + self._meta.save_episode_tasks(episode_tasks) + + # Given tasks in natural language, find their corresponding task indices + episode_buffer["task_index"] = np.array([self._meta.get_task_index(task) for task in tasks]) + + for key, ft in self._meta.features.items(): + if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: + continue + episode_buffer[key] = np.stack(episode_buffer[key]) + + # Wait for image writer to end, so that episode stats over images can be computed + self._wait_image_writer() + + has_video_keys = len(self._meta.video_keys) > 0 + use_streaming = self._streaming_encoder is not None and has_video_keys + use_batched_encoding = self._batch_encoding_size > 1 + + if use_streaming: + non_video_buffer = { + k: v + for k, v in episode_buffer.items() + if self._meta.features.get(k, {}).get("dtype") not in ("video",) + } + non_video_features = {k: v for k, v in self._meta.features.items() if v["dtype"] != "video"} + ep_stats = compute_episode_stats(non_video_buffer, non_video_features) + else: + ep_stats = compute_episode_stats(episode_buffer, self._meta.features) + + ep_metadata = self._save_episode_data(episode_buffer) + + if use_streaming: + streaming_results = self._streaming_encoder.finish_episode() + for video_key in self._meta.video_keys: + temp_path, video_stats = streaming_results[video_key] + if video_stats is not None: + ep_stats[video_key] = { + k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0) + for k, v in video_stats.items() + } + ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path)) + elif has_video_keys and not use_batched_encoding: + num_cameras = len(self._meta.video_keys) + if parallel_encoding and num_cameras > 1: + with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor: + future_to_key = { + executor.submit( + _encode_video_worker, + video_key, + episode_index, + self._root, + self._meta.fps, + self._vcodec, + self._encoder_threads, + ): video_key + for video_key in self._meta.video_keys + } + + results = {} + for future in concurrent.futures.as_completed(future_to_key): + video_key = future_to_key[future] + try: + temp_path = future.result() + results[video_key] = temp_path + except Exception as exc: + logger.error(f"Video encoding failed for {video_key}: {exc}") + raise exc + + for video_key in self._meta.video_keys: + temp_path = results[video_key] + ep_metadata.update( + self._save_episode_video(video_key, episode_index, temp_path=temp_path) + ) + else: + for video_key in self._meta.video_keys: + ep_metadata.update(self._save_episode_video(video_key, episode_index)) + + # `meta.save_episode` need to be executed after encoding the videos + self._meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata) + + if has_video_keys and use_batched_encoding: + self._episodes_since_last_encoding += 1 + if self._episodes_since_last_encoding == self._batch_encoding_size: + start_ep = self._meta.total_episodes - self._batch_encoding_size + end_ep = self._meta.total_episodes + self._batch_save_episode_video(start_ep, end_ep) + self._episodes_since_last_encoding = 0 + + if episode_data is None: + self.clear_episode_buffer(delete_images=len(self._meta.image_keys) > 0) + + def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None: + """Batch save videos for multiple episodes.""" + if end_episode is None: + end_episode = self._meta.total_episodes + + logger.info( + f"Batch encoding {self._batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}" + ) + + chunk_idx = self._meta.episodes[start_episode]["data/chunk_index"] + file_idx = self._meta.episodes[start_episode]["data/file_index"] + episode_df_path = self._root / DEFAULT_EPISODES_PATH.format( + chunk_index=chunk_idx, file_index=file_idx + ) + episode_df = pd.read_parquet(episode_df_path) + + for ep_idx in range(start_episode, end_episode): + logger.info(f"Encoding videos for episode {ep_idx}") + + if ( + self._meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx + or self._meta.episodes[ep_idx]["data/file_index"] != file_idx + ): + episode_df.to_parquet(episode_df_path) + self._meta.episodes = load_episodes(self._root) + + chunk_idx = self._meta.episodes[ep_idx]["data/chunk_index"] + file_idx = self._meta.episodes[ep_idx]["data/file_index"] + episode_df_path = self._root / DEFAULT_EPISODES_PATH.format( + chunk_index=chunk_idx, file_index=file_idx + ) + episode_df = pd.read_parquet(episode_df_path) + + video_ep_metadata = {} + for video_key in self._meta.video_keys: + video_ep_metadata.update(self._save_episode_video(video_key, ep_idx)) + video_ep_metadata.pop("episode_index") + video_ep_df = pd.DataFrame(video_ep_metadata, index=[ep_idx]).convert_dtypes( + dtype_backend="pyarrow" + ) + + episode_df = episode_df.combine_first(video_ep_df) + episode_df.to_parquet(episode_df_path) + self._meta.episodes = load_episodes(self._root) + + def _save_episode_data(self, episode_buffer: dict) -> dict: + """Save episode data to a parquet file.""" + # Use metadata features as the authoritative schema + hf_features = get_hf_features_from_features(self._meta.features) + ep_dict = {key: episode_buffer[key] for key in hf_features} + ep_dataset = datasets.Dataset.from_dict(ep_dict, features=hf_features, split="train") + ep_dataset = embed_images(ep_dataset) + ep_num_frames = len(ep_dataset) + + if self._latest_episode is None: + chunk_idx, file_idx = 0, 0 + global_frame_index = 0 + self._current_file_start_frame = 0 + if self._meta.episodes is not None and len(self._meta.episodes) > 0: + latest_ep = self._meta.episodes[-1] + global_frame_index = latest_ep["dataset_to_index"] + chunk_idx = latest_ep["data/chunk_index"] + file_idx = latest_ep["data/file_index"] + + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size) + self._current_file_start_frame = global_frame_index + else: + latest_ep = self._latest_episode + chunk_idx = latest_ep["data/chunk_index"] + file_idx = latest_ep["data/file_index"] + global_frame_index = latest_ep["index"][-1] + 1 + + latest_path = self._root / self._meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) + latest_size_in_mb = get_file_size_in_mb(latest_path) + + frames_in_current_file = global_frame_index - self._current_file_start_frame + av_size_per_frame = ( + latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0 + ) + + if latest_size_in_mb + av_size_per_frame * ep_num_frames >= self._meta.data_files_size_in_mb: + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size) + self.close_writer() + self._current_file_start_frame = global_frame_index + + ep_dict["data/chunk_index"] = chunk_idx + ep_dict["data/file_index"] = file_idx + + path = self._root / self._meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) + path.parent.mkdir(parents=True, exist_ok=True) + + table = ep_dataset.with_format("arrow")[:] + if not self._pq_writer: + self._pq_writer = pq.ParquetWriter( + path, schema=table.schema, compression="snappy", use_dictionary=True + ) + self._pq_writer.write_table(table) + + metadata = { + "data/chunk_index": chunk_idx, + "data/file_index": file_idx, + "dataset_from_index": global_frame_index, + "dataset_to_index": global_frame_index + ep_num_frames, + } + + self._latest_episode = {**ep_dict, **metadata} + self._recorded_frames += ep_num_frames + + return metadata + + def _save_episode_video( + self, + video_key: str, + episode_index: int, + temp_path: Path | None = None, + ) -> dict: + if temp_path is None: + ep_path = self._encode_temporary_episode_video(video_key, episode_index) + else: + ep_path = temp_path + + ep_size_in_mb = get_file_size_in_mb(ep_path) + ep_duration_in_s = get_video_duration_in_s(ep_path) + + if ( + episode_index == 0 + or self._meta.latest_episode is None + or f"videos/{video_key}/chunk_index" not in self._meta.latest_episode + ): + chunk_idx, file_idx = 0, 0 + if self._meta.episodes is not None and len(self._meta.episodes) > 0: + old_chunk_idx = self._meta.episodes[-1][f"videos/{video_key}/chunk_index"] + old_file_idx = self._meta.episodes[-1][f"videos/{video_key}/file_index"] + chunk_idx, file_idx = update_chunk_file_indices( + old_chunk_idx, old_file_idx, self._meta.chunks_size + ) + latest_duration_in_s = 0.0 + new_path = self._root / self._meta.video_path.format( + video_key=video_key, chunk_index=chunk_idx, file_index=file_idx + ) + new_path.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(ep_path), str(new_path)) + else: + latest_ep = self._meta.latest_episode + chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"][0] + file_idx = latest_ep[f"videos/{video_key}/file_index"][0] + + latest_path = self._root / self._meta.video_path.format( + video_key=video_key, chunk_index=chunk_idx, file_index=file_idx + ) + latest_size_in_mb = get_file_size_in_mb(latest_path) + latest_duration_in_s = latest_ep[f"videos/{video_key}/to_timestamp"][0] + + if latest_size_in_mb + ep_size_in_mb >= self._meta.video_files_size_in_mb: + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size) + new_path = self._root / self._meta.video_path.format( + video_key=video_key, chunk_index=chunk_idx, file_index=file_idx + ) + new_path.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(ep_path), str(new_path)) + latest_duration_in_s = 0.0 + else: + concatenate_video_files( + [latest_path, ep_path], + latest_path, + ) + + # Remove temporary directory + shutil.rmtree(str(ep_path.parent)) + + # Update video info (only needed when first episode is encoded) + if episode_index == 0: + self._meta.update_video_info(video_key) + write_info(self._meta.info, self._meta.root) + + metadata = { + "episode_index": episode_index, + f"videos/{video_key}/chunk_index": chunk_idx, + f"videos/{video_key}/file_index": file_idx, + f"videos/{video_key}/from_timestamp": latest_duration_in_s, + f"videos/{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s, + } + return metadata + + def clear_episode_buffer(self, delete_images: bool = True) -> None: + """Discard the current episode buffer and optionally delete temp images. + + Args: + delete_images: If ``True``, remove temporary image directories + written for the current episode. + """ + # Cancel streaming encoder if active + if self._streaming_encoder is not None: + self._streaming_encoder.cancel_episode() + + if delete_images: + if self.image_writer is not None: + self._wait_image_writer() + episode_index = self.episode_buffer["episode_index"] + # episode_index is `int` when freshly created, but becomes `np.ndarray` after + # save_episode() mutates the buffer. Handle both types here. + if isinstance(episode_index, np.ndarray): + episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0] + for cam_key in self._meta.image_keys: + img_dir = self._get_image_file_dir(episode_index, cam_key) + if img_dir.is_dir(): + shutil.rmtree(img_dir) + + self.episode_buffer = self._create_episode_buffer() + + def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None: + """Start an :class:`AsyncImageWriter` for background image persistence. + + Args: + num_processes: Number of subprocesses. ``0`` means threads only. + num_threads: Number of threads per process. + """ + if isinstance(self.image_writer, AsyncImageWriter): + logger.warning( + "You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset." + ) + + self.image_writer = AsyncImageWriter( + num_processes=num_processes, + num_threads=num_threads, + ) + + def stop_image_writer(self) -> None: + """Stop the image writer (needed before pickling the dataset for DataLoader).""" + if self.image_writer is not None: + self.image_writer.stop() + self.image_writer = None + + def _wait_image_writer(self) -> None: + """Wait for asynchronous image writer to finish.""" + if self.image_writer is not None: + self.image_writer.wait_until_done() + + def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path: + """Use ffmpeg to convert frames stored as png into mp4 videos.""" + return _encode_video_worker( + video_key, episode_index, self._root, self._meta.fps, self._vcodec, self._encoder_threads + ) + + def close_writer(self) -> None: + """Close and cleanup the parquet writer if it exists.""" + if self._pq_writer is not None: + self._pq_writer.close() + self._pq_writer = None + + def flush_pending_videos(self) -> None: + """Flush any pending video encoding (streaming or batch). + + For streaming encoding: closes the encoder. + For batch encoding: encodes any remaining episodes that haven't been batch-encoded yet. + """ + if self._streaming_encoder is not None: + self._streaming_encoder.close() + elif self._episodes_since_last_encoding > 0: + start_ep = self._meta.total_episodes - self._episodes_since_last_encoding + end_ep = self._meta.total_episodes + logger.info( + f"Encoding remaining {self._episodes_since_last_encoding} episodes, " + f"from episode {start_ep} to {end_ep - 1}" + ) + self._batch_save_episode_video(start_ep, end_ep) + + def cancel_pending_videos(self) -> None: + """Cancel any in-progress streaming encoding without flushing.""" + if self._streaming_encoder is not None: + self._streaming_encoder.cancel_episode() + + def cleanup_interrupted_episode(self, episode_index: int) -> None: + """Remove temporary image directories for an interrupted episode.""" + for key in self._meta.video_keys: + img_dir = self._get_image_file_path( + episode_index=episode_index, image_key=key, frame_index=0 + ).parent + if img_dir.exists(): + logger.debug( + f"Cleaning up interrupted episode images for episode {episode_index}, camera {key}" + ) + shutil.rmtree(img_dir) + + def finalize(self) -> None: + """Flush all pending work and release all resources. + + Idempotent — safe to call multiple times. + """ + if getattr(self, "_finalized", False): + return + # 1. Wait for async image writes to complete, then stop + if self.image_writer is not None: + self.image_writer.wait_until_done() + self.image_writer.stop() + self.image_writer = None + # 2. Flush pending video encoding (streaming or batch) + self.flush_pending_videos() + # 3. Close own parquet writer + self.close_writer() + # 4. Finalize metadata (idempotent) + self._meta.finalize() + self._finalized = True + + def __del__(self): + """Safety net: release resources on garbage collection.""" + # During interpreter shutdown, referenced objects may already be collected. + with contextlib.suppress(Exception): + self.finalize() diff --git a/src/lerobot/datasets/image_writer.py b/src/lerobot/datasets/image_writer.py index 9f40394de..603067757 100644 --- a/src/lerobot/datasets/image_writer.py +++ b/src/lerobot/datasets/image_writer.py @@ -32,10 +32,10 @@ def safe_stop_image_writer(func): return func(*args, **kwargs) except Exception as e: dataset = kwargs.get("dataset") - image_writer = getattr(dataset, "image_writer", None) if dataset else None - if image_writer is not None: + writer = getattr(dataset, "writer", None) if dataset else None + if writer is not None and writer.image_writer is not None: logger.warning("Waiting for image writer to terminate...") - image_writer.stop() + writer.image_writer.stop() raise e return wrapper diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 8f0600ba8..cba0c1cba 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -13,57 +13,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import concurrent.futures import contextlib import logging -import shutil -import tempfile from collections.abc import Callable from pathlib import Path import datasets -import numpy as np -import pandas as pd -import PIL.Image -import pyarrow.parquet as pq import torch import torch.utils from huggingface_hub import HfApi, snapshot_download from huggingface_hub.errors import RevisionNotFoundError -from lerobot.datasets.compute_stats import compute_episode_stats from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata -from lerobot.datasets.feature_utils import ( - check_delta_timestamps, - get_delta_indices, - get_hf_features_from_features, - validate_episode_buffer, - validate_frame, -) -from lerobot.datasets.image_writer import AsyncImageWriter, write_image -from lerobot.datasets.io_utils import ( - embed_images, - get_file_size_in_mb, - hf_transform_to_torch, - load_episodes, - load_nested_dataset, - write_info, -) +from lerobot.datasets.dataset_reader import DatasetReader +from lerobot.datasets.dataset_writer import DatasetWriter from lerobot.datasets.utils import ( - DEFAULT_EPISODES_PATH, - DEFAULT_IMAGE_PATH, create_lerobot_dataset_card, get_safe_version, is_valid_version, - update_chunk_file_indices, ) from lerobot.datasets.video_utils import ( StreamingVideoEncoder, - concatenate_video_files, - decode_video_frames, - encode_video_frames, get_safe_default_codec, - get_video_duration_in_s, resolve_vcodec, ) from lerobot.utils.constants import HF_LEROBOT_HOME @@ -71,24 +42,6 @@ from lerobot.utils.constants import HF_LEROBOT_HOME logger = logging.getLogger(__name__) -def _encode_video_worker( - video_key: str, - episode_index: int, - root: Path, - fps: int, - vcodec: str = "libsvtav1", - encoder_threads: int | None = None, -) -> Path: - temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4" - fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0) - img_dir = (root / fpath).parent - encode_video_frames( - img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads - ) - shutil.rmtree(img_dir) - return temp_path - - class LeRobotDataset(torch.utils.data.Dataset): def __init__( self, @@ -136,7 +89,7 @@ class LeRobotDataset(torch.utils.data.Dataset): - stats stores the dataset statistics of the different modalities for normalization - tasks contains the prompts for each task of the dataset, which can be used for task-conditioned training. - - hf_dataset (from datasets.Dataset), which will read any values from parquet files. + - data (backed by datasets.Dataset), which reads values from parquet files. - videos (optional) from which frames are loaded to be synchronous with data from parquet files. A typical LeRobotDataset looks like this from its root path: @@ -229,6 +182,11 @@ class LeRobotDataset(torch.utils.data.Dataset): encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc. + + Note: + Write-mode parameters (``streaming_encoding``, ``batch_encoding_size``) passed to + ``__init__`` are deprecated. Use :meth:`create` for new datasets or :meth:`resume` + to append to existing ones. """ super().__init__() self.repo_id = repo_id @@ -238,21 +196,11 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episodes = episodes self.tolerance_s = tolerance_s self.revision = revision if revision else CODEBASE_VERSION - self.video_backend = video_backend if video_backend else get_safe_default_codec() - self.delta_indices = None - self.batch_encoding_size = batch_encoding_size - self.episodes_since_last_encoding = 0 - self.vcodec = resolve_vcodec(vcodec) + self._video_backend = video_backend if video_backend else get_safe_default_codec() + self._batch_encoding_size = batch_encoding_size + self._vcodec = resolve_vcodec(vcodec) self._encoder_threads = encoder_threads - # Unused attributes - self.image_writer = None - self.episode_buffer = None - self.writer = None - self.latest_episode = None - self._current_file_start_frame = None # Track the starting frame index of the current parquet file - self._streaming_encoder = None - self.root.mkdir(exist_ok=True, parents=True) # Load metadata @@ -260,64 +208,270 @@ class LeRobotDataset(torch.utils.data.Dataset): self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync ) - # Track dataset state for efficient incremental writing - self._lazy_loading = False - self._recorded_frames = self.meta.total_frames - self._writer_closed_for_reading = False + # Create reader (hf_dataset loaded below) + self.reader = DatasetReader( + meta=self.meta, + root=self.root, + episodes=episodes, + tolerance_s=tolerance_s, + video_backend=self._video_backend, + delta_timestamps=delta_timestamps, + image_transforms=image_transforms, + ) # Load actual data - try: - if force_cache_sync: - raise FileNotFoundError - self.hf_dataset = self.load_hf_dataset() - # Check if cached dataset contains all requested episodes - if not self._check_cached_episodes_sufficient(): - raise FileNotFoundError("Cached dataset doesn't contain all requested episodes") - except (FileNotFoundError, NotADirectoryError): + if force_cache_sync or not self.reader.try_load(): if is_valid_version(self.revision): self.revision = get_safe_version(self.repo_id, self.revision) - self.download(download_videos) - self.hf_dataset = self.load_hf_dataset() + self._download(download_videos) + self.reader.load_and_activate() - # Create mapping from absolute indices to relative indices when only a subset of the episodes are loaded - # Build a mapping: absolute_index -> relative_index_in_filtered_dataset - self._absolute_to_relative_idx = None - if self.episodes is not None: - self._absolute_to_relative_idx = { - abs_idx.item() if isinstance(abs_idx, torch.Tensor) else abs_idx: rel_idx - for rel_idx, abs_idx in enumerate(self.hf_dataset["index"]) - } + # Detect write-mode params for backward compatibility + _has_write_params = streaming_encoding or batch_encoding_size != 1 + if _has_write_params: + import warnings - # Setup delta_indices - if self.delta_timestamps is not None: - check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) - self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) - - # Initialize streaming encoder for resumed recording - if streaming_encoding and len(self.meta.video_keys) > 0: - self._streaming_encoder = StreamingVideoEncoder( - fps=self.meta.fps, - vcodec=self.vcodec, - pix_fmt="yuv420p", - g=2, - crf=30, - preset=None, - queue_maxsize=encoder_queue_maxsize, - encoder_threads=encoder_threads, + warnings.warn( + "Passing write-mode parameters (streaming_encoding, batch_encoding_size) to " + "LeRobotDataset.__init__() is deprecated. Use LeRobotDataset.resume() instead.", + DeprecationWarning, + stacklevel=2, ) - - def _close_writer(self) -> None: - """Close and cleanup the parquet writer if it exists.""" - writer = getattr(self, "writer", None) - if writer is not None: - writer.close() + streaming_enc = None + if streaming_encoding and len(self.meta.video_keys) > 0: + streaming_enc = self._build_streaming_encoder( + self.meta.fps, self._vcodec, encoder_queue_maxsize, encoder_threads + ) + self.writer = DatasetWriter( + meta=self.meta, + root=self.root, + vcodec=self._vcodec, + encoder_threads=encoder_threads, + batch_encoding_size=batch_encoding_size, + streaming_encoder=streaming_enc, + initial_frames=self.meta.total_frames, + ) + else: self.writer = None - def __del__(self): + self._is_finalized = False + + # ── Writer guard ────────────────────────────────────────────────── + + def _require_writer(self, method_name: str) -> None: + if self.writer is None: + raise RuntimeError( + f"Cannot call '{method_name}()' on a read-only dataset. " + f"Use LeRobotDataset.create() for new recording or " + f"LeRobotDataset.resume() for resume recording." + ) + + # ── Reader guard ────────────────────────────────────────────────── + + def _ensure_reader(self) -> DatasetReader: + """Lazily create the reader on first access.""" + if self.reader is None: + self.reader = DatasetReader( + meta=self.meta, + root=self.root, + episodes=self.episodes, + tolerance_s=self.tolerance_s, + video_backend=self._video_backend, + delta_timestamps=self.delta_timestamps, + image_transforms=self.image_transforms, + ) + return self.reader + + @staticmethod + def _build_streaming_encoder( + fps: int, + vcodec: str, + encoder_queue_maxsize: int, + encoder_threads: int | None, + ) -> StreamingVideoEncoder: + return StreamingVideoEncoder( + fps=fps, + vcodec=vcodec, + pix_fmt="yuv420p", + g=2, + crf=30, + preset=None, + queue_maxsize=encoder_queue_maxsize, + encoder_threads=encoder_threads, + ) + + # ── Metadata properties ─────────────────────────────────────────── + + @property + def fps(self) -> int: + """Frames per second used during data collection.""" + return self.meta.fps + + @property + def num_frames(self) -> int: + """Number of frames in selected episodes.""" + # Check directly instead of using _ensure_reader(): in write-only mode + # (create/resume) we rely on metadata rather than initializing a reader. + if self.reader is None: + return self.meta.total_frames + return self.reader.num_frames + + @property + def num_episodes(self) -> int: + """Number of episodes selected.""" + # Check directly instead of using _ensure_reader(): in write-only mode + # (create/resume) we rely on metadata rather than initializing a reader. + if self.reader is None: + return self.meta.total_episodes + return self.reader.num_episodes + + @property + def features(self) -> dict[str, dict]: + """Feature specification dict mapping feature names to their type/shape metadata.""" + return self.meta.features + + @property + def hf_dataset(self) -> datasets.Dataset: + """The underlying Hugging Face Dataset object""" + self.reader = self._ensure_reader() + if self.reader.hf_dataset is None: + self.reader.load_and_activate() + return self.reader.hf_dataset + + # ── Writer-delegated methods ────────────────────────────────────── + + def add_frame(self, frame: dict) -> None: + """Add a single frame to the current episode buffer. + + Delegates to :meth:`DatasetWriter.add_frame`. The dataset must be in + write mode (created via :meth:`create` or :meth:`resume`). + + Args: + frame: Dict mapping feature names to their values for this frame. + Must include a ``'task'`` key. Torch tensors are converted to numpy. + + Raises: + RuntimeError: If the dataset is read-only (no writer). """ - Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor + self._require_writer("add_frame") + self.writer.add_frame(frame) + + def save_episode(self, episode_data: dict | None = None, parallel_encoding: bool = True) -> None: + """Save the current episode buffer to disk. + + Delegates to :meth:`DatasetWriter.save_episode`. Encodes videos, writes + parquet data, and updates metadata. The episode buffer is reset afterward. + + Args: + episode_data: Optional pre-built episode dict. If ``None``, uses the + internal episode buffer populated by :meth:`add_frame`. + parallel_encoding: If ``True`` and multiple cameras exist, encode + videos in parallel using a process pool. + + Raises: + RuntimeError: If the dataset is read-only (no writer). """ - self._close_writer() + self._require_writer("save_episode") + self.writer.save_episode(episode_data, parallel_encoding) + + def clear_episode_buffer(self, delete_images: bool = True) -> None: + """Discard the current episode buffer without saving. + + Delegates to :meth:`DatasetWriter.clear_episode_buffer`. Useful for + discarding a failed or interrupted recording episode. + + Args: + delete_images: If ``True``, also remove temporary image files written + to disk for the current episode. + + Raises: + RuntimeError: If the dataset is read-only (no writer). + """ + self._require_writer("clear_episode_buffer") + self.writer.clear_episode_buffer(delete_images) + + def has_pending_frames(self) -> bool: + """Check if there are unsaved frames in the episode buffer.""" + if self.writer is None: + return False + return self.writer.episode_buffer is not None and self.writer.episode_buffer["size"] > 0 + + def finalize(self): + """Flush all pending work and close writers. + + Must be called after data collection/conversion, otherwise footer metadata + won't be written to the parquet files and the dataset will be invalid. + + Idempotent — safe to call multiple times. DatasetWriter.__del__ acts as a + safety net if this is never called explicitly. + """ + if self._is_finalized: + return + if self.writer is not None: + self.writer.finalize() + self._is_finalized = True + + # ── Core Dataset methods ────────────────────────────────────────── + + def __len__(self): + """Return the number of frames in the selected episodes.""" + return self.num_frames + + def __getitem__(self, idx) -> dict: + """Return a single frame by index, with all transforms applied. + + Loads the frame from the underlying HF dataset, expands delta-timestamp + windows, decodes video frames, and applies image transforms. Delegates + the core logic to :meth:`DatasetReader.get_item`. + + Args: + idx: Index into the (possibly episode-filtered) dataset. + + Returns: + Dict mapping feature names to their tensor values for this frame. + + Raises: + RuntimeError: If the dataset is currently being recorded and + :meth:`finalize` has not been called yet. + """ + if self.writer is not None and not self._is_finalized: + raise RuntimeError( + "Cannot read from a dataset that is being recorded. Call finalize() first, then access items." + ) + reader = self._ensure_reader() + if reader.hf_dataset is None: + # One-shot load after finalize() + reader.load_and_activate() + return reader.get_item(idx) + + def select_columns(self, column_names: str | list[str]): + """Select specific columns from the underlying dataset. + + Useful for extracting action sequences during replay without loading all features. + Returns a ``datasets.Dataset`` containing only the requested columns. + """ + return self.hf_dataset.select_columns(column_names) + + def get_raw_item(self, idx) -> dict: + """Get a raw frame without image transforms applied. + + Unlike ``__getitem__``, this returns the raw HF dataset row at the given + index with no delta-timestamp expansion, video decoding, or image transforms. + """ + return self.hf_dataset[idx] + + def __repr__(self): + feature_keys = list(self.features) + return ( + f"{self.__class__.__name__}({{\n" + f" Repository ID: '{self.repo_id}',\n" + f" Number of selected episodes: '{self.num_episodes}',\n" + f" Number of selected samples: '{self.num_frames}',\n" + f" Features: '{feature_keys}',\n" + f"}})" + ) + + # ── Hub methods (stay on facade) ────────────────────────────────── def push_to_hub( self, @@ -331,6 +485,27 @@ class LeRobotDataset(torch.utils.data.Dataset): upload_large_folder: bool = False, **card_kwargs, ) -> None: + """Upload the dataset to the Hugging Face Hub. + + Creates the repository if it does not exist, uploads all dataset files + (optionally excluding videos), generates a dataset card, and tags the + revision with the current codebase version. + + Args: + branch: Optional branch to push to. Created from the current + revision if it does not exist. + tags: Optional list of tags for the dataset card. + license: License identifier for the dataset card. + tag_version: If ``True``, create a Git tag for the current codebase + version. + push_videos: If ``False``, skip uploading the ``videos/`` directory. + private: If ``True``, create a private repository. + allow_patterns: Glob pattern(s) restricting which files to upload. + upload_large_folder: If ``True``, use ``upload_large_folder`` instead + of ``upload_folder`` for very large datasets. + **card_kwargs: Additional keyword arguments forwarded to dataset card + creation. + """ ignore_patterns = ["images/"] if not push_videos: ignore_patterns.append("videos/") @@ -374,795 +549,23 @@ class LeRobotDataset(torch.utils.data.Dataset): hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset") hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") - def pull_from_repo( - self, - allow_patterns: list[str] | str | None = None, - ignore_patterns: list[str] | str | None = None, - ) -> None: + def _download(self, download_videos: bool = True) -> None: + """Downloads the dataset from the given 'repo_id' at the provided version.""" + ignore_patterns = None if download_videos else "videos/" + files = None + if self.episodes is not None: + # Reader is guaranteed to exist here (created in __init__ before _download) + files = self.reader.get_episodes_file_paths() snapshot_download( self.repo_id, repo_type="dataset", revision=self.revision, local_dir=self.root, - allow_patterns=allow_patterns, + allow_patterns=files, ignore_patterns=ignore_patterns, ) - def download(self, download_videos: bool = True) -> None: - """Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this - will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole - dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present - in 'local_dir', they won't be downloaded again. - """ - # TODO(rcadene, aliberts): implement faster transfer - # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads - ignore_patterns = None if download_videos else "videos/" - files = None - if self.episodes is not None: - files = self.get_episodes_file_paths() - self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns) - - def get_episodes_file_paths(self) -> list[Path]: - episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes)) - fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes] - if len(self.meta.video_keys) > 0: - video_files = [ - str(self.meta.get_video_file_path(ep_idx, vid_key)) - for vid_key in self.meta.video_keys - for ep_idx in episodes - ] - fpaths += video_files - # episodes are stored in the same files, so we return unique paths only - fpaths = list(set(fpaths)) - return fpaths - - def load_hf_dataset(self) -> datasets.Dataset: - """hf_dataset contains all the observations, states, actions, rewards, etc.""" - features = get_hf_features_from_features(self.features) - hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes) - hf_dataset.set_transform(hf_transform_to_torch) - return hf_dataset - - def _check_cached_episodes_sufficient(self) -> bool: - """Check if the cached dataset contains all requested episodes and their video files.""" - if self.hf_dataset is None or len(self.hf_dataset) == 0: - return False - - # Get available episode indices from cached dataset - available_episodes = { - ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx - for ep_idx in self.hf_dataset.unique("episode_index") - } - - # Determine requested episodes - if self.episodes is None: - requested_episodes = set(range(self.meta.total_episodes)) - else: - requested_episodes = set(self.episodes) - - # Check if all requested episodes are available in cached data - if not requested_episodes.issubset(available_episodes): - return False - - # Check if all required video files exist - if len(self.meta.video_keys) > 0: - for ep_idx in requested_episodes: - for vid_key in self.meta.video_keys: - video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) - if not video_path.exists(): - return False - - return True - - def create_hf_dataset(self) -> datasets.Dataset: - features = get_hf_features_from_features(self.features) - ft_dict = {col: [] for col in features} - hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train") - hf_dataset.set_transform(hf_transform_to_torch) - return hf_dataset - - @property - def fps(self) -> int: - """Frames per second used during data collection.""" - return self.meta.fps - - @property - def num_frames(self) -> int: - """Number of frames in selected episodes. - - Note: When episodes a subset of the full dataset is requested, we must return the - actual loaded data length (len(self.hf_dataset)) rather than metadata total_frames. - self.meta.total_frames is the total number of frames in the full dataset. - """ - if self.episodes is not None and self.hf_dataset is not None: - return len(self.hf_dataset) - return self.meta.total_frames - - @property - def num_episodes(self) -> int: - """Number of episodes selected.""" - return len(self.episodes) if self.episodes is not None else self.meta.total_episodes - - @property - def features(self) -> dict[str, dict]: - return self.meta.features - - @property - def hf_features(self) -> datasets.Features: - """Features of the hf_dataset.""" - if self.hf_dataset is not None: - return self.hf_dataset.features - else: - return get_hf_features_from_features(self.features) - - def _get_query_indices( - self, abs_idx: int, ep_idx: int - ) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]: - """Compute query indices for delta timestamps. - - Args: - abs_idx: The absolute index in the full dataset (not the relative index in filtered episodes). - ep_idx: The episode index. - - Returns: - A tuple of (query_indices, padding) where: - - query_indices: Dict mapping keys to lists of absolute indices to query - - padding: Dict mapping "{key}_is_pad" to boolean tensors indicating padded positions - """ - ep = self.meta.episodes[ep_idx] - ep_start = ep["dataset_from_index"] - ep_end = ep["dataset_to_index"] - query_indices = { - key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx] - for key, delta_idx in self.delta_indices.items() - } - padding = { # Pad values outside of current episode range - f"{key}_is_pad": torch.BoolTensor( - [(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx] - ) - for key, delta_idx in self.delta_indices.items() - } - return query_indices, padding - - def _get_query_timestamps( - self, - current_ts: float, - query_indices: dict[str, list[int]] | None = None, - ) -> dict[str, list[float]]: - query_timestamps = {} - for key in self.meta.video_keys: - if query_indices is not None and key in query_indices: - if self._absolute_to_relative_idx is not None: - relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]] - timestamps = self.hf_dataset[relative_indices]["timestamp"] - else: - timestamps = self.hf_dataset[query_indices[key]]["timestamp"] - query_timestamps[key] = torch.stack(timestamps).tolist() - else: - query_timestamps[key] = [current_ts] - - return query_timestamps - - def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: - """ - Query dataset for indices across keys, skipping video keys. - - Tries column-first [key][indices] for speed, falls back to row-first. - - Args: - query_indices: Dict mapping keys to index lists to retrieve - - Returns: - Dict with stacked tensors of queried data (video keys excluded) - """ - result: dict = {} - for key, q_idx in query_indices.items(): - if key in self.meta.video_keys: - continue - # Map absolute indices to relative indices if needed - relative_indices = ( - q_idx - if self._absolute_to_relative_idx is None - else [self._absolute_to_relative_idx[idx] for idx in q_idx] - ) - try: - result[key] = torch.stack(self.hf_dataset[key][relative_indices]) - except (KeyError, TypeError, IndexError): - result[key] = torch.stack(self.hf_dataset[relative_indices][key]) - return result - - def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: - """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function - in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a - Segmentation Fault. This probably happens because a memory reference to the video loader is created in - the main process and a subprocess fails to access it. - """ - ep = self.meta.episodes[ep_idx] - item = {} - for vid_key, query_ts in query_timestamps.items(): - # Episodes are stored sequentially on a single mp4 to reduce the number of files. - # Thus we load the start timestamp of the episode on this mp4 and, - # shift the query timestamp accordingly. - from_timestamp = ep[f"videos/{vid_key}/from_timestamp"] - shifted_query_ts = [from_timestamp + ts for ts in query_ts] - - video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) - frames = decode_video_frames(video_path, shifted_query_ts, self.tolerance_s, self.video_backend) - item[vid_key] = frames.squeeze(0) - - return item - - def _ensure_hf_dataset_loaded(self): - """Lazy load the HF dataset only when needed for reading.""" - if self._lazy_loading or self.hf_dataset is None: - # Close the writer before loading to ensure parquet file is properly finalized - if self.writer is not None: - self._close_writer() - self._writer_closed_for_reading = True - self.hf_dataset = self.load_hf_dataset() - self._lazy_loading = False - - def __len__(self): - return self.num_frames - - def __getitem__(self, idx) -> dict: - # Ensure dataset is loaded when we actually need to read from it - self._ensure_hf_dataset_loaded() - item = self.hf_dataset[idx] - ep_idx = item["episode_index"].item() - # Use the absolute index from the dataset for delta timestamp calculations - abs_idx = item["index"].item() - - query_indices = None - if self.delta_indices is not None: - query_indices, padding = self._get_query_indices(abs_idx, ep_idx) - query_result = self._query_hf_dataset(query_indices) - item = {**item, **padding} - for key, val in query_result.items(): - item[key] = val - - if len(self.meta.video_keys) > 0: - current_ts = item["timestamp"].item() - query_timestamps = self._get_query_timestamps(current_ts, query_indices) - video_frames = self._query_videos(query_timestamps, ep_idx) - item = {**video_frames, **item} - - if self.image_transforms is not None: - image_keys = self.meta.camera_keys - for cam in image_keys: - item[cam] = self.image_transforms(item[cam]) - - # Add task as a string - task_idx = item["task_index"].item() - item["task"] = self.meta.tasks.iloc[task_idx].name - - # add subtask information if available - if "subtask_index" in self.features and self.meta.subtasks is not None: - subtask_idx = item["subtask_index"].item() - item["subtask"] = self.meta.subtasks.iloc[subtask_idx].name - - return item - - def __repr__(self): - feature_keys = list(self.features) - return ( - f"{self.__class__.__name__}({{\n" - f" Repository ID: '{self.repo_id}',\n" - f" Number of selected episodes: '{self.num_episodes}',\n" - f" Number of selected samples: '{self.num_frames}',\n" - f" Features: '{feature_keys}',\n" - "})',\n" - ) - - def finalize(self): - """ - Close the parquet writers. This function needs to be called after data collection/conversion, else footer metadata won't be written to the parquet files. - The dataset won't be valid and can't be loaded as ds = LeRobotDataset(repo_id=repo, root=HF_LEROBOT_HOME.joinpath(repo)) - """ - self._close_writer() - self.meta._close_writer() - if self._streaming_encoder is not None: - self._streaming_encoder.close() - - def create_episode_buffer(self, episode_index: int | None = None) -> dict: - current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index - ep_buffer = {} - # size and task are special cases that are not in self.features - ep_buffer["size"] = 0 - ep_buffer["task"] = [] - for key in self.features: - ep_buffer[key] = current_ep_idx if key == "episode_index" else [] - return ep_buffer - - # TODO(Steven): consider move this to utils - def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: - fpath = DEFAULT_IMAGE_PATH.format( - image_key=image_key, episode_index=episode_index, frame_index=frame_index - ) - return self.root / fpath - - def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path: - return self._get_image_file_path(episode_index, image_key, frame_index=0).parent - - def _save_image( - self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1 - ) -> None: - if self.image_writer is None: - if isinstance(image, torch.Tensor): - image = image.cpu().numpy() - write_image(image, fpath, compress_level=compress_level) - else: - self.image_writer.save_image(image=image, fpath=fpath, compress_level=compress_level) - - def add_frame(self, frame: dict) -> None: - """ - This function only adds the frame to the episode_buffer. Apart from images — which are written in a - temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method - then needs to be called. - """ - # Convert torch to numpy if needed - for name in frame: - if isinstance(frame[name], torch.Tensor): - frame[name] = frame[name].numpy() - - validate_frame(frame, self.features) - - if self.episode_buffer is None: - self.episode_buffer = self.create_episode_buffer() - - # Automatically add frame_index and timestamp to episode buffer - frame_index = self.episode_buffer["size"] - timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps - self.episode_buffer["frame_index"].append(frame_index) - self.episode_buffer["timestamp"].append(timestamp) - self.episode_buffer["task"].append(frame.pop("task")) # Remove task from frame after processing - - # Start streaming encoder on first frame of episode (once, before iterating keys) - if frame_index == 0 and self._streaming_encoder is not None: - self._streaming_encoder.start_episode( - video_keys=list(self.meta.video_keys), - temp_dir=self.root, - ) - - # Add frame features to episode_buffer - for key in frame: - if key not in self.features: - raise ValueError( - f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'." - ) - - if self.features[key]["dtype"] == "video" and self._streaming_encoder is not None: - self._streaming_encoder.feed_frame(key, frame[key]) - self.episode_buffer[key].append(None) # Placeholder (video keys are skipped in parquet) - elif self.features[key]["dtype"] in ["image", "video"]: - img_path = self._get_image_file_path( - episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index - ) - if frame_index == 0: - img_path.parent.mkdir(parents=True, exist_ok=True) - compress_level = 1 if self.features[key]["dtype"] == "video" else 6 - self._save_image(frame[key], img_path, compress_level) - self.episode_buffer[key].append(str(img_path)) - else: - self.episode_buffer[key].append(frame[key]) - - self.episode_buffer["size"] += 1 - - def save_episode( - self, - episode_data: dict | None = None, - parallel_encoding: bool = True, - ) -> None: - """ - This will save to disk the current episode in self.episode_buffer. - - Video encoding is handled automatically based on batch_encoding_size: - - If batch_encoding_size == 1: Videos are encoded immediately after each episode - - If batch_encoding_size > 1: Videos are encoded in batches. - - Args: - episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will - save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to - None. - parallel_encoding (bool, optional): If True, encode videos in parallel using ProcessPoolExecutor. - Defaults to True on Linux, False on macOS as it tends to use all the CPU available already. - """ - episode_buffer = episode_data if episode_data is not None else self.episode_buffer - - validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features) - - # size and task are special cases that won't be added to hf_dataset - episode_length = episode_buffer.pop("size") - tasks = episode_buffer.pop("task") - episode_tasks = list(set(tasks)) - episode_index = episode_buffer["episode_index"] - - episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length) - episode_buffer["episode_index"] = np.full((episode_length,), episode_index) - - # Update tasks and task indices with new tasks if any - self.meta.save_episode_tasks(episode_tasks) - - # Given tasks in natural language, find their corresponding task indices - episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks]) - - for key, ft in self.features.items(): - # index, episode_index, task_index are already processed above, and image and video - # are processed separately by storing image path and frame info as meta data - if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: - continue - episode_buffer[key] = np.stack(episode_buffer[key]) - - # Wait for image writer to end, so that episode stats over images can be computed - self._wait_image_writer() - - has_video_keys = len(self.meta.video_keys) > 0 - use_streaming = self._streaming_encoder is not None and has_video_keys - use_batched_encoding = self.batch_encoding_size > 1 - - if use_streaming: - # Compute stats for non-video features only (video stats come from encoder) - non_video_buffer = { - k: v - for k, v in episode_buffer.items() - if self.features.get(k, {}).get("dtype") not in ("video",) - } - non_video_features = {k: v for k, v in self.features.items() if v["dtype"] != "video"} - ep_stats = compute_episode_stats(non_video_buffer, non_video_features) - else: - ep_stats = compute_episode_stats(episode_buffer, self.features) - - ep_metadata = self._save_episode_data(episode_buffer) - - if use_streaming: - # Finish streaming encoding and collect results - streaming_results = self._streaming_encoder.finish_episode() - for video_key in self.meta.video_keys: - temp_path, video_stats = streaming_results[video_key] - if video_stats is not None: - # Format stats same as compute_episode_stats: normalize to [0,1], reshape to (C,1,1) - ep_stats[video_key] = { - k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0) - for k, v in video_stats.items() - } - ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path)) - elif has_video_keys and not use_batched_encoding: - num_cameras = len(self.meta.video_keys) - if parallel_encoding and num_cameras > 1: - # TODO(Steven): Ideally we would like to control the number of threads per encoding such that: - # num_cameras * num_threads = (total_cpu -1) - with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor: - future_to_key = { - executor.submit( - _encode_video_worker, - video_key, - episode_index, - self.root, - self.fps, - self.vcodec, - self._encoder_threads, - ): video_key - for video_key in self.meta.video_keys - } - - results = {} - for future in concurrent.futures.as_completed(future_to_key): - video_key = future_to_key[future] - try: - temp_path = future.result() - results[video_key] = temp_path - except Exception as exc: - logger.error(f"Video encoding failed for {video_key}: {exc}") - raise exc - - for video_key in self.meta.video_keys: - temp_path = results[video_key] - ep_metadata.update( - self._save_episode_video(video_key, episode_index, temp_path=temp_path) - ) - else: - for video_key in self.meta.video_keys: - ep_metadata.update(self._save_episode_video(video_key, episode_index)) - - # `meta.save_episode` need to be executed after encoding the videos - self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata) - - if has_video_keys and use_batched_encoding: - # Check if we should trigger batch encoding - self.episodes_since_last_encoding += 1 - if self.episodes_since_last_encoding == self.batch_encoding_size: - start_ep = self.num_episodes - self.batch_encoding_size - end_ep = self.num_episodes - self._batch_save_episode_video(start_ep, end_ep) - self.episodes_since_last_encoding = 0 - - if not episode_data: - # Reset episode buffer and clean up temporary images (if not already deleted during video encoding) - self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0) - - def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None: - """ - Batch save videos for multiple episodes. - - Args: - start_episode: Starting episode index (inclusive) - end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode to the current episode. - """ - if end_episode is None: - end_episode = self.num_episodes - - logger.info( - f"Batch encoding {self.batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}" - ) - - chunk_idx = self.meta.episodes[start_episode]["data/chunk_index"] - file_idx = self.meta.episodes[start_episode]["data/file_index"] - episode_df_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) - episode_df = pd.read_parquet(episode_df_path) - - for ep_idx in range(start_episode, end_episode): - logger.info(f"Encoding videos for episode {ep_idx}") - - if ( - self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx - or self.meta.episodes[ep_idx]["data/file_index"] != file_idx - ): - # The current episode is in a new chunk or file. - # Save previous episode dataframe and update the Hugging Face dataset by reloading it. - episode_df.to_parquet(episode_df_path) - self.meta.episodes = load_episodes(self.root) - - # Load new episode dataframe - chunk_idx = self.meta.episodes[ep_idx]["data/chunk_index"] - file_idx = self.meta.episodes[ep_idx]["data/file_index"] - episode_df_path = self.root / DEFAULT_EPISODES_PATH.format( - chunk_index=chunk_idx, file_index=file_idx - ) - episode_df = pd.read_parquet(episode_df_path) - - # Save the current episode's video metadata to the dataframe - video_ep_metadata = {} - for video_key in self.meta.video_keys: - video_ep_metadata.update(self._save_episode_video(video_key, ep_idx)) - video_ep_metadata.pop("episode_index") - video_ep_df = pd.DataFrame(video_ep_metadata, index=[ep_idx]).convert_dtypes( - dtype_backend="pyarrow" - ) # allows NaN values along with integers - - episode_df = episode_df.combine_first(video_ep_df) - episode_df.to_parquet(episode_df_path) - self.meta.episodes = load_episodes(self.root) - - def _save_episode_data(self, episode_buffer: dict) -> dict: - """Save episode data to a parquet file and update the Hugging Face dataset of frames data. - - This function processes episodes data from a buffer, converts it into a Hugging Face dataset, - and saves it as a parquet file. It handles both the creation of new parquet files and the - updating of existing ones based on size constraints. After saving the data, it reloads - the Hugging Face dataset to ensure it is up-to-date. - - Notes: We both need to update parquet files and HF dataset: - - `pandas` loads parquet file in RAM - - `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk, - or loads directly from pyarrow cache. - """ - # Convert buffer into HF Dataset - ep_dict = {key: episode_buffer[key] for key in self.hf_features} - ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train") - ep_dataset = embed_images(ep_dataset) - ep_num_frames = len(ep_dataset) - - if self.latest_episode is None: - # Initialize indices and frame count for a new dataset made of the first episode data - chunk_idx, file_idx = 0, 0 - global_frame_index = 0 - self._current_file_start_frame = 0 - # However, if the episodes already exists - # It means we are resuming recording, so we need to load the latest episode - # Update the indices to avoid overwriting the latest episode - if self.meta.episodes is not None and len(self.meta.episodes) > 0: - latest_ep = self.meta.episodes[-1] - global_frame_index = latest_ep["dataset_to_index"] - chunk_idx = latest_ep["data/chunk_index"] - file_idx = latest_ep["data/file_index"] - - # When resuming, move to the next file - chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) - self._current_file_start_frame = global_frame_index - else: - # Retrieve information from the latest parquet file - latest_ep = self.latest_episode - chunk_idx = latest_ep["data/chunk_index"] - file_idx = latest_ep["data/file_index"] - global_frame_index = latest_ep["index"][-1] + 1 - - latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) - latest_size_in_mb = get_file_size_in_mb(latest_path) - - frames_in_current_file = global_frame_index - self._current_file_start_frame - av_size_per_frame = ( - latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0 - ) - - # Determine if a new parquet file is needed - if ( - latest_size_in_mb + av_size_per_frame * ep_num_frames >= self.meta.data_files_size_in_mb - or self._writer_closed_for_reading - ): - # Size limit is reached or writer was closed for reading, prepare new parquet file - chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) - self._close_writer() - self._writer_closed_for_reading = False - self._current_file_start_frame = global_frame_index - - ep_dict["data/chunk_index"] = chunk_idx - ep_dict["data/file_index"] = file_idx - - # Write the resulting dataframe from RAM to disk - path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) - path.parent.mkdir(parents=True, exist_ok=True) - - table = ep_dataset.with_format("arrow")[:] - if not self.writer: - self.writer = pq.ParquetWriter( - path, schema=table.schema, compression="snappy", use_dictionary=True - ) - self.writer.write_table(table) - - metadata = { - "data/chunk_index": chunk_idx, - "data/file_index": file_idx, - "dataset_from_index": global_frame_index, - "dataset_to_index": global_frame_index + ep_num_frames, - } - - # Store metadata with episode data for next episode - self.latest_episode = {**ep_dict, **metadata} - - # Mark that the HF dataset needs reloading (lazy loading approach) - # This avoids expensive reloading during sequential recording - self._lazy_loading = True - # Update recorded frames count for efficient length tracking - self._recorded_frames += ep_num_frames - - return metadata - - def _save_episode_video( - self, - video_key: str, - episode_index: int, - temp_path: Path | None = None, - ) -> dict: - # Encode episode frames into a temporary video - if temp_path is None: - ep_path = self._encode_temporary_episode_video(video_key, episode_index) - else: - ep_path = temp_path - - ep_size_in_mb = get_file_size_in_mb(ep_path) - ep_duration_in_s = get_video_duration_in_s(ep_path) - - if ( - episode_index == 0 - or self.meta.latest_episode is None - or f"videos/{video_key}/chunk_index" not in self.meta.latest_episode - ): - # Initialize indices for a new dataset made of the first episode data - chunk_idx, file_idx = 0, 0 - if self.meta.episodes is not None and len(self.meta.episodes) > 0: - # It means we are resuming recording, so we need to load the latest episode - # Update the indices to avoid overwriting the latest episode - old_chunk_idx = self.meta.episodes[-1][f"videos/{video_key}/chunk_index"] - old_file_idx = self.meta.episodes[-1][f"videos/{video_key}/file_index"] - chunk_idx, file_idx = update_chunk_file_indices( - old_chunk_idx, old_file_idx, self.meta.chunks_size - ) - latest_duration_in_s = 0.0 - new_path = self.root / self.meta.video_path.format( - video_key=video_key, chunk_index=chunk_idx, file_index=file_idx - ) - new_path.parent.mkdir(parents=True, exist_ok=True) - shutil.move(str(ep_path), str(new_path)) - else: - # Retrieve information from the latest updated video file using latest_episode - latest_ep = self.meta.latest_episode - chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"][0] - file_idx = latest_ep[f"videos/{video_key}/file_index"][0] - - latest_path = self.root / self.meta.video_path.format( - video_key=video_key, chunk_index=chunk_idx, file_index=file_idx - ) - latest_size_in_mb = get_file_size_in_mb(latest_path) - latest_duration_in_s = latest_ep[f"videos/{video_key}/to_timestamp"][0] - - if latest_size_in_mb + ep_size_in_mb >= self.meta.video_files_size_in_mb: - # Move temporary episode video to a new video file in the dataset - chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) - new_path = self.root / self.meta.video_path.format( - video_key=video_key, chunk_index=chunk_idx, file_index=file_idx - ) - new_path.parent.mkdir(parents=True, exist_ok=True) - shutil.move(str(ep_path), str(new_path)) - latest_duration_in_s = 0.0 - else: - # Update latest video file - concatenate_video_files( - [latest_path, ep_path], - latest_path, - ) - - # Remove temporary directory - shutil.rmtree(str(ep_path.parent)) - - # Update video info (only needed when first episode is encoded since it reads from episode 0) - if episode_index == 0: - self.meta.update_video_info(video_key) - write_info(self.meta.info, self.meta.root) # ensure video info always written properly - - metadata = { - "episode_index": episode_index, - f"videos/{video_key}/chunk_index": chunk_idx, - f"videos/{video_key}/file_index": file_idx, - f"videos/{video_key}/from_timestamp": latest_duration_in_s, - f"videos/{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s, - } - return metadata - - def clear_episode_buffer(self, delete_images: bool = True) -> None: - # Cancel streaming encoder if active - if self._streaming_encoder is not None: - self._streaming_encoder.cancel_episode() - - # Clean up image files for the current episode buffer - if delete_images: - # Wait for the async image writer to finish - if self.image_writer is not None: - self._wait_image_writer() - episode_index = self.episode_buffer["episode_index"] - if isinstance(episode_index, np.ndarray): - episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0] - for cam_key in self.meta.image_keys: - img_dir = self._get_image_file_dir(episode_index, cam_key) - if img_dir.is_dir(): - shutil.rmtree(img_dir) - - # Reset the buffer - self.episode_buffer = self.create_episode_buffer() - - def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None: - if isinstance(self.image_writer, AsyncImageWriter): - logger.warning( - "You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset." - ) - - self.image_writer = AsyncImageWriter( - num_processes=num_processes, - num_threads=num_threads, - ) - - def stop_image_writer(self) -> None: - """ - Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to - remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized. - """ - if self.image_writer is not None: - self.image_writer.stop() - self.image_writer = None - - def _wait_image_writer(self) -> None: - """Wait for asynchronous image writer to finish.""" - if self.image_writer is not None: - self.image_writer.wait_until_done() - - def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path: - """ - Use ffmpeg to convert frames stored as png into mp4 videos. - Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, - since video encoding with ffmpeg is already using multithreading. - """ - return _encode_video_worker( - video_key, episode_index, self.root, self.fps, self.vcodec, self._encoder_threads - ) + # ── Class constructors ──────────────────────────────────────────── @classmethod def create( @@ -1184,7 +587,42 @@ class LeRobotDataset(torch.utils.data.Dataset): encoder_queue_maxsize: int = 30, encoder_threads: int | None = None, ) -> "LeRobotDataset": - """Create a LeRobot Dataset from scratch in order to record data.""" + """Create a new LeRobotDataset from scratch for recording data. + + Returns a write-mode dataset with an active :class:`DatasetWriter`. Use + :meth:`add_frame` / :meth:`save_episode` to populate it, then + :meth:`finalize` when done. + + Args: + repo_id: Repository identifier, typically ``'{hf_user}/{dataset_name}'``. + fps: Frames per second used during data collection. + features: Feature specification dict mapping feature names to their + type/shape metadata. + root: Local directory for dataset storage. Defaults to + ``$HF_LEROBOT_HOME/{repo_id}``. + robot_type: Optional robot type string stored in metadata. + use_videos: If ``True``, visual modalities are stored as MP4 videos. + If ``False``, they are stored as images. + tolerance_s: Timestamp synchronization tolerance in seconds. + image_writer_processes: Number of subprocesses for async image + writing. ``0`` means use threads only. + image_writer_threads: Number of threads for async image writing. + video_backend: Video decoding backend (used when reading back). + batch_encoding_size: Number of episodes to accumulate before + batch-encoding videos. ``1`` means encode immediately. + vcodec: Video codec for encoding. Options include ``'libsvtav1'``, + ``'h264'``, ``'hevc'``, ``'auto'``. + metadata_buffer_size: Number of episode metadata records to buffer + before flushing to parquet. + streaming_encoding: If ``True``, encode video frames in real-time + during capture instead of writing images first. + encoder_queue_maxsize: Max buffered frames per camera when using + streaming encoding. + encoder_threads: Threads per encoder instance. ``None`` for auto. + + Returns: + A new :class:`LeRobotDataset` in write mode. + """ vcodec = resolve_vcodec(vcodec) obj = cls.__new__(cls) obj.meta = LeRobotDatasetMetadata.create( @@ -1200,45 +638,126 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.root = obj.meta.root obj.revision = None obj.tolerance_s = tolerance_s - obj.image_writer = None - obj.batch_encoding_size = batch_encoding_size - obj.episodes_since_last_encoding = 0 - obj.vcodec = vcodec - obj._encoder_threads = encoder_threads - - if image_writer_processes or image_writer_threads: - obj.start_image_writer(image_writer_processes, image_writer_threads) - - obj.episode_buffer = obj.create_episode_buffer() - - obj.episodes = None - obj.hf_dataset = obj.create_hf_dataset() obj.image_transforms = None obj.delta_timestamps = None - obj.delta_indices = None - obj._absolute_to_relative_idx = None - obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() - obj.writer = None - obj.latest_episode = None - obj._current_file_start_frame = None - # Initialize tracking for incremental recording - obj._lazy_loading = False - obj._recorded_frames = 0 - obj._writer_closed_for_reading = False + obj.episodes = None + obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec() + obj._batch_encoding_size = batch_encoding_size + obj._vcodec = vcodec + obj._encoder_threads = encoder_threads - # Initialize streaming encoder + # Reader is lazily created on first access (write-only mode) + obj.reader = None + + # Create writer + streaming_enc = None if streaming_encoding and len(obj.meta.video_keys) > 0: - obj._streaming_encoder = StreamingVideoEncoder( - fps=fps, - vcodec=vcodec, - pix_fmt="yuv420p", - g=2, - crf=30, - preset=None, - queue_maxsize=encoder_queue_maxsize, - encoder_threads=encoder_threads, - ) - else: - obj._streaming_encoder = None + streaming_enc = cls._build_streaming_encoder(fps, vcodec, encoder_queue_maxsize, encoder_threads) + obj.writer = DatasetWriter( + meta=obj.meta, + root=obj.root, + vcodec=vcodec, + encoder_threads=encoder_threads, + batch_encoding_size=batch_encoding_size, + streaming_encoder=streaming_enc, + ) + + if image_writer_processes or image_writer_threads: + obj.writer.start_image_writer(image_writer_processes, image_writer_threads) + + obj._is_finalized = False + + return obj + + @classmethod + def resume( + cls, + repo_id: str, + root: str | Path | None = None, + tolerance_s: float = 1e-4, + revision: str | None = None, + force_cache_sync: bool = False, + video_backend: str | None = None, + batch_encoding_size: int = 1, + vcodec: str = "libsvtav1", + image_writer_processes: int = 0, + image_writer_threads: int = 0, + streaming_encoding: bool = False, + encoder_queue_maxsize: int = 30, + encoder_threads: int | None = None, + ) -> "LeRobotDataset": + """Resume recording on an existing dataset. + + Loads metadata from an existing dataset (local or Hub) and creates a + :class:`DatasetWriter` for appending new episodes. The underlying HF + dataset is not loaded until :meth:`finalize` is called and data is + subsequently read. + + Args: + repo_id: Repository identifier of the existing dataset. + root: Local directory of the dataset. Defaults to + ``$HF_LEROBOT_HOME/{repo_id}``. + tolerance_s: Timestamp synchronization tolerance in seconds. + revision: Git revision (branch, tag, or commit hash). Defaults to + current codebase version tag. + force_cache_sync: If ``True``, re-download metadata from the Hub even + if a local cache exists. + video_backend: Video decoding backend for reading back data. + batch_encoding_size: Number of episodes to accumulate before + batch-encoding videos. + vcodec: Video codec for encoding. + image_writer_processes: Subprocesses for async image writing. + image_writer_threads: Threads for async image writing. + streaming_encoding: If ``True``, encode video in real-time during + capture. + encoder_queue_maxsize: Max buffered frames per camera for streaming. + encoder_threads: Threads per encoder instance. ``None`` for auto. + + Returns: + A :class:`LeRobotDataset` in write mode, ready to append episodes. + """ + vcodec = resolve_vcodec(vcodec) + obj = cls.__new__(cls) + obj.repo_id = repo_id + obj.root = Path(root) if root else HF_LEROBOT_HOME / repo_id + obj.root.mkdir(exist_ok=True, parents=True) + obj.revision = revision if revision else CODEBASE_VERSION + obj.tolerance_s = tolerance_s + obj.image_transforms = None + obj.delta_timestamps = None + obj.episodes = None + obj._video_backend = video_backend if video_backend else get_safe_default_codec() + obj._batch_encoding_size = batch_encoding_size + obj._vcodec = vcodec + obj._encoder_threads = encoder_threads + + # Load metadata + obj.meta = LeRobotDatasetMetadata( + obj.repo_id, obj.root, obj.revision, force_cache_sync=force_cache_sync + ) + + # Reader is lazily created on first access (write-only mode) + obj.reader = None + + # Create writer for appending + streaming_enc = None + if streaming_encoding and len(obj.meta.video_keys) > 0: + streaming_enc = cls._build_streaming_encoder( + obj.meta.fps, vcodec, encoder_queue_maxsize, encoder_threads + ) + obj.writer = DatasetWriter( + meta=obj.meta, + root=obj.root, + vcodec=vcodec, + encoder_threads=encoder_threads, + batch_encoding_size=batch_encoding_size, + streaming_encoder=streaming_enc, + initial_frames=obj.meta.total_frames, + ) + + if image_writer_processes or image_writer_threads: + obj.writer.start_image_writer(image_writer_processes, image_writer_threads) + + obj._is_finalized = False return obj diff --git a/src/lerobot/datasets/multi_dataset.py b/src/lerobot/datasets/multi_dataset.py index 917d5c5eb..d16c5bb07 100644 --- a/src/lerobot/datasets/multi_dataset.py +++ b/src/lerobot/datasets/multi_dataset.py @@ -22,6 +22,7 @@ import torch import torch.utils from lerobot.datasets.compute_stats import aggregate_stats +from lerobot.datasets.feature_utils import get_hf_features_from_features from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.video_utils import VideoFrame from lerobot.utils.constants import HF_LEROBOT_HOME @@ -125,7 +126,13 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): def features(self) -> datasets.Features: features = {} for dataset in self._datasets: - features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features}) + features.update( + { + k: v + for k, v in get_hf_features_from_features(dataset.features).items() + if k not in self.disabled_features + } + ) return features @property diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index e465b79b4..59c8c7d3e 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -741,6 +741,7 @@ class StreamingVideoEncoder: self._video_paths: dict[str, Path] = {} self._dropped_frames: dict[str, int] = {} self._episode_active = False + self._closed = False def start_episode(self, video_keys: list[str], temp_dir: Path) -> None: """Start encoder threads for a new episode. @@ -895,8 +896,11 @@ class StreamingVideoEncoder: def close(self) -> None: """Close the encoder, canceling any in-progress episode.""" + if self._closed: + return if self._episode_active: self.cancel_episode() + self._closed = True def _cleanup(self) -> None: """Clean up queues and thread tracking dicts.""" @@ -1063,43 +1067,19 @@ class VideoEncodingManager: return self def __exit__(self, exc_type, exc_val, exc_tb): - streaming_encoder = getattr(self.dataset, "_streaming_encoder", None) + writer = self.dataset.writer + if writer is not None: + if exc_type is not None and writer._streaming_encoder is not None: + writer.cancel_pending_videos() - if streaming_encoder is not None: - # Handle streaming encoder cleanup - if exc_type is not None: - streaming_encoder.cancel_episode() - streaming_encoder.close() - elif self.dataset.episodes_since_last_encoding > 0: - # Handle any remaining episodes that haven't been batch encoded - if exc_type is not None: - logger.info("Exception occurred. Encoding remaining episodes before exit...") - else: - logger.info("Recording stopped. Encoding remaining episodes...") + # finalize() handles flush_pending_videos + parquet + metadata + self.dataset.finalize() - start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding - end_ep = self.dataset.num_episodes - logger.info( - f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, " - f"from episode {start_ep} to {end_ep - 1}" - ) - self.dataset._batch_save_episode_video(start_ep, end_ep) - - # Finalize the dataset to properly close all writers - self.dataset.finalize() - - # Clean up episode images if recording was interrupted (only for non-streaming mode) - if exc_type is not None and streaming_encoder is None: - interrupted_episode_index = self.dataset.num_episodes - for key in self.dataset.meta.video_keys: - img_dir = self.dataset._get_image_file_path( - episode_index=interrupted_episode_index, image_key=key, frame_index=0 - ).parent - if img_dir.exists(): - logger.debug( - f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}" - ) - shutil.rmtree(img_dir) + # Clean up episode images if recording was interrupted (only for non-streaming mode) + if exc_type is not None and writer._streaming_encoder is None: + writer.cleanup_interrupted_episode(self.dataset.num_episodes) + else: + self.dataset.finalize() # Clean up any remaining images directory if it's empty img_dir = self.dataset.root / "images" diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index 81aa29c48..68954162d 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -563,7 +563,7 @@ class ReplayBuffer: ) # Start writing images if needed - lerobot_dataset.start_image_writer(num_processes=0, num_threads=3) + lerobot_dataset.writer.start_image_writer(num_processes=0, num_threads=3) # Convert transitions into episodes and frames @@ -603,10 +603,10 @@ class ReplayBuffer: lerobot_dataset.save_episode() # Save any remaining frames in the buffer - if lerobot_dataset.episode_buffer["size"] > 0: + if lerobot_dataset.has_pending_frames(): lerobot_dataset.save_episode() - lerobot_dataset.stop_image_writer() + lerobot_dataset.writer.stop_image_writer() lerobot_dataset.finalize() return lerobot_dataset diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index f5fcb7437..bd64d205f 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -752,8 +752,7 @@ def replay_trajectory( episodes=[cfg.dataset.replay_episode], download_videos=False, ) - episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.replay_episode) - actions = episode_frames.select_columns(ACTION) + actions = dataset.select_columns(ACTION) _, info = env.reset() diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 819634ba2..ac01c9319 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -468,7 +468,8 @@ def record(cfg: RecordConfig) -> LeRobotDataset: try: if cfg.resume: - dataset = LeRobotDataset( + num_cameras = len(robot.cameras) if hasattr(robot, "cameras") else 0 + dataset = LeRobotDataset.resume( cfg.dataset.repo_id, root=cfg.dataset.root, batch_encoding_size=cfg.dataset.video_encoding_batch_size, @@ -476,13 +477,11 @@ def record(cfg: RecordConfig) -> LeRobotDataset: streaming_encoding=cfg.dataset.streaming_encoding, encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize, encoder_threads=cfg.dataset.encoder_threads, + image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0, + image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras + if num_cameras > 0 + else 0, ) - - if hasattr(robot, "cameras") and len(robot.cameras) > 0: - dataset.start_image_writer( - num_processes=cfg.dataset.num_image_writer_processes, - num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), - ) sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features) else: # Create empty dataset or load existing saved episodes diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index 7c0b5b96b..09e7d4e8b 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -104,15 +104,13 @@ def replay(cfg: ReplayConfig): robot = make_robot_from_config(cfg.robot) dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode]) - # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 - episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.episode) - actions = episode_frames.select_columns(ACTION) + actions = dataset.select_columns(ACTION) robot.connect() try: log_say("Replaying episode", cfg.play_sounds, blocking=True) - for idx in range(len(episode_frames)): + for idx in range(dataset.num_frames): start_episode_t = time.perf_counter() action_array = actions[idx][ACTION] diff --git a/src/lerobot/scripts/lerobot_train_tokenizer.py b/src/lerobot/scripts/lerobot_train_tokenizer.py index 807d48333..70185fc51 100644 --- a/src/lerobot/scripts/lerobot_train_tokenizer.py +++ b/src/lerobot/scripts/lerobot_train_tokenizer.py @@ -204,15 +204,15 @@ def process_episode(args): for abs_idx in range(from_idx, to_idx): # map absolute index to relative index if needed - if dataset._absolute_to_relative_idx is not None: - if abs_idx not in dataset._absolute_to_relative_idx: + if dataset.reader._absolute_to_relative_idx is not None: + if abs_idx not in dataset.reader._absolute_to_relative_idx: # this episode's frames aren't in the filtered dataset return None - rel_idx = dataset._absolute_to_relative_idx[abs_idx] + rel_idx = dataset.reader._absolute_to_relative_idx[abs_idx] else: rel_idx = abs_idx - frame = dataset.hf_dataset[rel_idx] + frame = dataset.get_raw_item(rel_idx) # get state (could be from observation.state or other state key) if state_key in frame: diff --git a/tests/artifacts/policies/save_policy_to_safetensors.py b/tests/artifacts/policies/save_policy_to_safetensors.py index 64b125cc9..7359f6169 100644 --- a/tests/artifacts/policies/save_policy_to_safetensors.py +++ b/tests/artifacts/policies/save_policy_to_safetensors.py @@ -80,7 +80,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): # HACK: We reload a batch with no delta_indices as `select_action` won't expect a timestamps dimension # We simulate having an environment using a dataset by setting delta_indices to None and dropping tensors # indicating padding (those ending with "_is_pad") - dataset.delta_indices = None + dataset.reader.delta_indices = None batch = next(iter(dataloader)) obs = {} for k in batch: diff --git a/tests/datasets/test_dataset_metadata.py b/tests/datasets/test_dataset_metadata.py new file mode 100644 index 000000000..3f3971e15 --- /dev/null +++ b/tests/datasets/test_dataset_metadata.py @@ -0,0 +1,385 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contract tests for LeRobotDatasetMetadata.""" + +import json + +import numpy as np +import pytest + +from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata +from lerobot.datasets.utils import INFO_PATH +from tests.fixtures.constants import DEFAULT_FPS, DUMMY_ROBOT_TYPE + +# ── helpers ────────────────────────────────────────────────────────── + +SIMPLE_FEATURES = { + "state": {"dtype": "float32", "shape": (6,), "names": None}, + "action": {"dtype": "float32", "shape": (6,), "names": None}, +} + +VIDEO_FEATURES = { + **SIMPLE_FEATURES, + "observation.images.laptop": { + "dtype": "video", + "shape": (64, 96, 3), + "names": ["height", "width", "channels"], + "info": None, + }, +} + +IMAGE_FEATURES = { + **SIMPLE_FEATURES, + "observation.images.laptop": { + "dtype": "image", + "shape": (64, 96, 3), + "names": ["height", "width", "channels"], + "info": None, + }, +} + + +def _make_dummy_stats(features: dict) -> dict: + """Create minimal episode stats matching the given features.""" + stats = {} + for key, ft in features.items(): + if ft["dtype"] in ("image", "video"): + stats[key] = { + "max": np.ones((3, 1, 1), dtype=np.float32), + "mean": np.full((3, 1, 1), 0.5, dtype=np.float32), + "min": np.zeros((3, 1, 1), dtype=np.float32), + "std": np.full((3, 1, 1), 0.25, dtype=np.float32), + "count": np.array([5]), + } + elif ft["dtype"] in ("float32", "float64", "int64"): + stats[key] = { + "max": np.ones(ft["shape"], dtype=np.float32), + "mean": np.full(ft["shape"], 0.5, dtype=np.float32), + "min": np.zeros(ft["shape"], dtype=np.float32), + "std": np.full(ft["shape"], 0.25, dtype=np.float32), + "count": np.array([5]), + } + return stats + + +# ── Construction contracts ─────────────────────────────────────────── + + +def test_create_produces_valid_info_on_disk(tmp_path): + """create() writes info.json and the returned object reflects the provided settings.""" + root = tmp_path / "new_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/meta", + fps=DEFAULT_FPS, + features=SIMPLE_FEATURES, + robot_type=DUMMY_ROBOT_TYPE, + root=root, + use_videos=False, + ) + + # info.json was written to disk + assert (root / INFO_PATH).exists() + with open(root / INFO_PATH) as f: + info_on_disk = json.load(f) + + assert meta.fps == DEFAULT_FPS + assert meta.robot_type == DUMMY_ROBOT_TYPE + assert "state" in meta.features + assert "action" in meta.features + assert info_on_disk["fps"] == DEFAULT_FPS + + +def test_create_starts_with_zero_counts(tmp_path): + """A freshly created metadata has zero episode/frame/task counts.""" + root = tmp_path / "empty_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/empty", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + assert meta.total_episodes == 0 + assert meta.total_frames == 0 + assert meta.total_tasks == 0 + assert meta.tasks is None + assert meta.episodes is None + assert meta.stats is None + + +def test_create_with_videos_sets_video_path(tmp_path): + """When features include video-dtype keys, create() produces a non-None video_path.""" + root = tmp_path / "video_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/video", fps=DEFAULT_FPS, features=VIDEO_FEATURES, root=root, use_videos=True + ) + + assert meta.video_path is not None + assert len(meta.video_keys) == 1 + assert "observation.images.laptop" in meta.video_keys + + +def test_create_without_videos_has_no_video_path(tmp_path): + """When use_videos=False and no video features, video_path is None.""" + root = tmp_path / "no_video" + meta = LeRobotDatasetMetadata.create( + repo_id="test/novid", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + assert meta.video_path is None + assert meta.video_keys == [] + + +def test_create_raises_on_existing_directory(tmp_path): + """create() raises if root directory already exists.""" + root = tmp_path / "existing" + root.mkdir() + + with pytest.raises(FileExistsError): + LeRobotDatasetMetadata.create( + repo_id="test/exists", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + +def test_init_loads_existing_metadata(tmp_path, lerobot_dataset_metadata_factory, info_factory): + """When metadata files exist on disk, __init__ loads them correctly.""" + root = tmp_path / "load_test" + info = info_factory(total_episodes=3, total_frames=150, total_tasks=1, use_videos=False) + meta = lerobot_dataset_metadata_factory(root=root, info=info) + + assert meta.total_episodes == 3 + assert meta.total_frames == 150 + assert meta.fps == info["fps"] + + +# ── Property accessors ─────────────────────────────────────────────── + + +def test_property_accessors_reflect_info(tmp_path): + """Properties return values consistent with the info dict.""" + root = tmp_path / "props_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/props", + fps=DEFAULT_FPS, + features=IMAGE_FEATURES, + robot_type=DUMMY_ROBOT_TYPE, + root=root, + use_videos=False, + ) + + assert meta.fps == DEFAULT_FPS + assert meta.robot_type == DUMMY_ROBOT_TYPE + # shapes should be tuples + for _key, shape in meta.shapes.items(): + assert isinstance(shape, tuple) + # image_keys should contain the image feature + assert "observation.images.laptop" in meta.image_keys + # camera_keys is a superset of image_keys and video_keys + assert set(meta.image_keys + meta.video_keys) == set(meta.camera_keys) + + +def test_data_path_is_formattable(tmp_path): + """data_path contains format placeholders that can be .format()-ed.""" + root = tmp_path / "fmt_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/fmt", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + formatted = meta.data_path.format(chunk_index=0, file_index=0) + assert "chunk" in formatted.lower() or "0" in formatted + + +# ── Task management ────────────────────────────────────────────────── + + +def test_save_episode_tasks_creates_tasks_dataframe(tmp_path): + """On a fresh metadata, save_episode_tasks() creates the tasks DataFrame.""" + root = tmp_path / "task_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/task", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + assert meta.tasks is None + + meta.save_episode_tasks(["Pick up the cube"]) + + assert meta.tasks is not None + assert len(meta.tasks) == 1 + assert "Pick up the cube" in meta.tasks.index + + +def test_save_episode_tasks_is_additive(tmp_path): + """New tasks are added; existing tasks keep their original index.""" + root = tmp_path / "additive_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/add", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + meta.save_episode_tasks(["Task A"]) + idx_a = meta.get_task_index("Task A") + + meta.save_episode_tasks(["Task A", "Task B"]) + assert meta.get_task_index("Task A") == idx_a # unchanged + assert meta.get_task_index("Task B") is not None + assert len(meta.tasks) == 2 + + +def test_get_task_index_returns_none_for_unknown(tmp_path): + """get_task_index() returns None for an unknown task.""" + root = tmp_path / "unknown_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/unknown", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + meta.save_episode_tasks(["Known task"]) + + assert meta.get_task_index("Known task") == 0 + assert meta.get_task_index("Unknown task") is None + + +def test_save_episode_tasks_rejects_duplicates(tmp_path): + """save_episode_tasks() raises ValueError on duplicate task strings.""" + root = tmp_path / "dup_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/dup", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + with pytest.raises(ValueError): + meta.save_episode_tasks(["Same task", "Same task"]) + + +# ── Episode saving ─────────────────────────────────────────────────── + + +def test_save_episode_increments_counters(tmp_path): + """After save_episode(), total_episodes and total_frames increase.""" + root = tmp_path / "ep_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/ep", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + meta.save_episode_tasks(["Task 1"]) + stats = _make_dummy_stats(meta.features) + + meta.save_episode( + episode_index=0, + episode_length=10, + episode_tasks=["Task 1"], + episode_stats=stats, + episode_metadata={}, + ) + + assert meta.total_episodes == 1 + assert meta.total_frames == 10 + + +def test_save_episode_updates_stats(tmp_path): + """After save_episode(), .stats is non-None and has feature keys.""" + root = tmp_path / "stats_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/stats", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + meta.save_episode_tasks(["Task 1"]) + stats = _make_dummy_stats(meta.features) + + meta.save_episode( + episode_index=0, + episode_length=5, + episode_tasks=["Task 1"], + episode_stats=stats, + episode_metadata={}, + ) + + assert meta.stats is not None + # Stats should contain at least the user-defined feature keys + for key in SIMPLE_FEATURES: + assert key in meta.stats + + +# ── Chunk settings ─────────────────────────────────────────────────── + + +def test_update_chunk_settings_persists(tmp_path): + """update_chunk_settings() changes values and writes info.json.""" + root = tmp_path / "chunk_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/chunk", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + original = meta.get_chunk_settings() + + meta.update_chunk_settings(chunks_size=500) + assert meta.chunks_size == 500 + assert meta.chunks_size != original["chunks_size"] or original["chunks_size"] == 500 + + # Verify persisted + with open(root / INFO_PATH) as f: + info_on_disk = json.load(f) + assert info_on_disk["chunks_size"] == 500 + + +def test_update_chunk_settings_rejects_non_positive(tmp_path): + """update_chunk_settings() raises ValueError for <= 0 values.""" + root = tmp_path / "bad_chunk" + meta = LeRobotDatasetMetadata.create( + repo_id="test/bad", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + with pytest.raises(ValueError): + meta.update_chunk_settings(chunks_size=0) + with pytest.raises(ValueError): + meta.update_chunk_settings(data_files_size_in_mb=-1) + + +# ── Finalization ───────────────────────────────────────────────────── + + +def test_finalize_is_idempotent(tmp_path): + """Calling finalize() multiple times does not raise.""" + root = tmp_path / "fin_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/fin", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False + ) + + meta.finalize() + meta.finalize() # second call should not raise + + +def test_finalize_flushes_buffered_metadata(tmp_path): + """Episodes saved before finalize() are written to parquet.""" + root = tmp_path / "flush_ds" + meta = LeRobotDatasetMetadata.create( + repo_id="test/flush", + fps=DEFAULT_FPS, + features=SIMPLE_FEATURES, + root=root, + use_videos=False, + metadata_buffer_size=100, # large buffer so nothing auto-flushes + ) + meta.save_episode_tasks(["Task 1"]) + stats = _make_dummy_stats(meta.features) + + # Save a few episodes (won't auto-flush since buffer_size=100) + for i in range(3): + meta.save_episode( + episode_index=i, + episode_length=5, + episode_tasks=["Task 1"], + episode_stats=stats, + episode_metadata={}, + ) + + # Before finalize, the parquet might not exist yet + meta.finalize() + + # After finalize, episodes parquet should exist + episodes_dir = root / "meta" / "episodes" + assert episodes_dir.exists() + parquet_files = list(episodes_dir.rglob("*.parquet")) + assert len(parquet_files) > 0 diff --git a/tests/datasets/test_dataset_reader.py b/tests/datasets/test_dataset_reader.py new file mode 100644 index 000000000..4c8a8b23f --- /dev/null +++ b/tests/datasets/test_dataset_reader.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contract tests for DatasetReader.""" + +from lerobot.datasets.dataset_reader import DatasetReader +from lerobot.datasets.video_utils import get_safe_default_codec + +# ── Loading ────────────────────────────────────────────────────────── + + +def test_try_load_returns_true_when_data_exists(tmp_path, lerobot_dataset_factory): + """Given a fully written dataset, try_load() returns True.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=False + ) + reader = DatasetReader( + meta=dataset.meta, + root=dataset.root, + episodes=None, + tolerance_s=1e-4, + video_backend=get_safe_default_codec(), + delta_timestamps=None, + image_transforms=None, + ) + assert reader.try_load() is True + assert reader.hf_dataset is not None + + +def test_try_load_returns_false_when_no_data(tmp_path): + """When only metadata exists (no data/ parquets), try_load() returns False.""" + from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata + + root = tmp_path / "meta_only" + features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} + meta = LeRobotDatasetMetadata.create( + repo_id="test/meta_only", fps=30, features=features, root=root, use_videos=False + ) + + reader = DatasetReader( + meta=meta, + root=meta.root, + episodes=None, + tolerance_s=1e-4, + video_backend=get_safe_default_codec(), + delta_timestamps=None, + image_transforms=None, + ) + assert reader.try_load() is False + assert reader.hf_dataset is None + + +# ── Counts ─────────────────────────────────────────────────────────── + + +def test_num_frames_without_filter(tmp_path, lerobot_dataset_factory): + """With episodes=None, num_frames equals total_frames.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=3, total_frames=60, use_videos=False + ) + assert dataset.reader.num_frames == dataset.meta.total_frames + + +def test_num_episodes_without_filter(tmp_path, lerobot_dataset_factory): + """With episodes=None, num_episodes equals total_episodes.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=3, total_frames=60, use_videos=False + ) + assert dataset.reader.num_episodes == dataset.meta.total_episodes + + +def test_num_frames_with_episode_filter(tmp_path, lerobot_dataset_factory): + """When filtering to a subset, only those episodes' frames are counted.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=5, total_frames=100, episodes=[0, 2], use_videos=False + ) + # Filtered frames should be less than total + assert dataset.reader.num_frames <= dataset.meta.total_frames + assert dataset.reader.num_episodes == 2 + + +# ── get_item ───────────────────────────────────────────────────────── + + +def test_get_item_returns_expected_keys(tmp_path, lerobot_dataset_factory): + """get_item(0) returns a dict with expected keys.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False + ) + item = dataset.reader.get_item(0) + + # Standard keys that must always be present + for key in ["index", "episode_index", "frame_index", "timestamp", "task_index", "task"]: + assert key in item, f"Missing key: {key}" + + +def test_get_item_values_are_correct(tmp_path, lerobot_dataset_factory): + """get_item() returns correct index and episode_index.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=False + ) + item_0 = dataset.reader.get_item(0) + + assert item_0["index"].item() == 0 + assert item_0["episode_index"].item() == 0 + + +# ── Transforms ─────────────────────────────────────────────────────── + + +def test_image_transforms_are_applied(tmp_path, lerobot_dataset_factory): + """When image_transforms is provided, get_item() applies it to camera keys.""" + transform_called = {"count": 0} + + def sentinel_transform(img): + transform_called["count"] += 1 + return img + + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", + total_episodes=1, + total_frames=5, + use_videos=False, + image_transforms=sentinel_transform, + ) + item = dataset[0] # noqa: F841 + + # Should have been called once per camera key per frame + num_cameras = len(dataset.meta.camera_keys) + if num_cameras > 0: + assert transform_called["count"] >= 1 + + +# ── File paths ─────────────────────────────────────────────────────── + + +def test_get_episodes_file_paths_returns_data_paths(tmp_path, lerobot_dataset_factory): + """get_episodes_file_paths() returns paths including data/ paths.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=False + ) + paths = dataset.reader.get_episodes_file_paths() + + assert len(paths) > 0 + assert any("data/" in str(p) for p in paths) + + +def test_get_episodes_file_paths_includes_video_paths(tmp_path, lerobot_dataset_factory): + """When dataset has video keys, file paths include video/ paths.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=True + ) + + if len(dataset.meta.video_keys) > 0: + paths = dataset.reader.get_episodes_file_paths() + assert any("video" in str(p).lower() for p in paths) diff --git a/tests/datasets/test_dataset_writer.py b/tests/datasets/test_dataset_writer.py new file mode 100644 index 000000000..8c6ee68bd --- /dev/null +++ b/tests/datasets/test_dataset_writer.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contract tests for DatasetWriter.""" + +from pathlib import Path +from unittest.mock import patch + +import numpy as np +import pytest +import torch +from PIL import Image + +from lerobot.datasets.dataset_writer import _encode_video_worker +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.utils import DEFAULT_IMAGE_PATH +from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID + +SIMPLE_FEATURES = { + "state": {"dtype": "float32", "shape": (6,), "names": None}, + "action": {"dtype": "float32", "shape": (6,), "names": None}, +} + + +def _make_frame(features: dict, task: str = "Dummy task") -> dict: + """Build a valid frame dict for the given features.""" + frame = {"task": task} + for key, ft in features.items(): + if ft["dtype"] in ("image", "video"): + frame[key] = np.random.randint(0, 256, size=ft["shape"], dtype=np.uint8) + elif ft["dtype"] in ("float32", "float64"): + frame[key] = torch.randn(ft["shape"]) + elif ft["dtype"] == "int64": + frame[key] = torch.zeros(ft["shape"], dtype=torch.int64) + return frame + + +# ── Existing encode_video_worker tests ─────────────────────────────── + + +def test_encode_video_worker_forwards_vcodec(tmp_path): + """_encode_video_worker correctly forwards the vcodec parameter.""" + video_key = "observation.images.laptop" + fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0) + img_dir = tmp_path / Path(fpath).parent + img_dir.mkdir(parents=True, exist_ok=True) + Image.new("RGB", (64, 64), color="red").save(img_dir / "frame-000000.png") + + captured_kwargs = {} + + def mock_encode(imgs_dir, video_path, fps, **kwargs): + captured_kwargs.update(kwargs) + Path(video_path).parent.mkdir(parents=True, exist_ok=True) + Path(video_path).touch() + + with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode): + _encode_video_worker(video_key, 0, tmp_path, fps=30, vcodec="h264") + + assert captured_kwargs["vcodec"] == "h264" + + +def test_encode_video_worker_default_vcodec(tmp_path): + """_encode_video_worker uses libsvtav1 as the default codec.""" + video_key = "observation.images.laptop" + fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0) + img_dir = tmp_path / Path(fpath).parent + img_dir.mkdir(parents=True, exist_ok=True) + Image.new("RGB", (64, 64), color="red").save(img_dir / "frame-000000.png") + + captured_kwargs = {} + + def mock_encode(imgs_dir, video_path, fps, **kwargs): + captured_kwargs.update(kwargs) + Path(video_path).parent.mkdir(parents=True, exist_ok=True) + Path(video_path).touch() + + with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode): + _encode_video_worker(video_key, 0, tmp_path, fps=30) + + assert captured_kwargs["vcodec"] == "libsvtav1" + + +# ── add_frame contracts ────────────────────────────────────────────── + + +def test_add_frame_increments_buffer_size(tmp_path): + """Each add_frame() call increases episode_buffer['size'] by 1.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + assert dataset.writer.episode_buffer["size"] == 0 + + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + assert dataset.writer.episode_buffer["size"] == 1 + + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + assert dataset.writer.episode_buffer["size"] == 2 + + +def test_add_frame_rejects_missing_feature(tmp_path): + """add_frame() raises ValueError when a required feature is missing.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + with pytest.raises(ValueError, match="Missing features"): + dataset.add_frame({"task": "Dummy task", "state": torch.randn(6)}) + # missing 'action' + + +# ── save_episode contracts ─────────────────────────────────────────── + + +def test_save_episode_writes_parquet(tmp_path): + """After save_episode(), at least one .parquet file exists under data/.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + for _ in range(3): + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + dataset.save_episode() + + parquet_files = list((tmp_path / "ds" / "data").rglob("*.parquet")) + assert len(parquet_files) > 0 + + +def test_save_episode_updates_counters(tmp_path): + """After save_episode(), metadata counters are updated.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + for _ in range(5): + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + dataset.save_episode() + + assert dataset.meta.total_episodes == 1 + assert dataset.meta.total_frames == 5 + + +def test_save_episode_resets_buffer(tmp_path): + """After save_episode(), the episode buffer is reset.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + for _ in range(3): + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + dataset.save_episode() + + assert dataset.writer.episode_buffer["size"] == 0 + + +def test_save_multiple_episodes(tmp_path): + """Recording 3 episodes results in correct total counts.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + total_frames = 0 + for ep in range(3): + n_frames = ep + 2 # 2, 3, 4 + for _ in range(n_frames): + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + dataset.save_episode() + total_frames += n_frames + + assert dataset.meta.total_episodes == 3 + assert dataset.meta.total_frames == total_frames + + +# ── clear / lifecycle ──────────────────────────────────────────────── + + +def test_clear_resets_buffer(tmp_path): + """clear_episode_buffer() resets the buffer size to 0.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + assert dataset.writer.episode_buffer["size"] == 1 + + dataset.clear_episode_buffer() + assert dataset.writer.episode_buffer["size"] == 0 + + +def test_finalize_is_idempotent(tmp_path): + """Calling finalize() twice does not raise.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + for _ in range(3): + dataset.add_frame(_make_frame(SIMPLE_FEATURES)) + dataset.save_episode() + + dataset.finalize() + dataset.finalize() # second call should not raise + + +def test_finalize_then_read_roundtrip(tmp_path): + """Write data, finalize, re-open, and verify data matches.""" + root = tmp_path / "roundtrip" + features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} + dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=features, root=root) + + # Record known values + known_states = [] + for i in range(5): + state = torch.tensor([float(i), float(i * 10)]) + known_states.append(state) + dataset.add_frame({"task": "Test task", "state": state}) + dataset.save_episode() + dataset.finalize() + + # Read back + for i in range(5): + item = dataset[i] + assert torch.allclose(item["state"], known_states[i], atol=1e-5) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 67878d8f6..b2518149f 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -32,10 +32,7 @@ from lerobot.datasets.factory import make_dataset from lerobot.datasets.feature_utils import get_hf_features_from_features, hw_to_dataset_features from lerobot.datasets.image_writer import image_array_to_pil_image from lerobot.datasets.io_utils import hf_transform_to_torch -from lerobot.datasets.lerobot_dataset import ( - LeRobotDataset, - _encode_video_worker, -) +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.multi_dataset import MultiLeRobotDataset from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, @@ -72,7 +69,7 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory): def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): """ Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated - objects have the same sets of attributes defined. + objects have the same sets of facade-level attributes defined. """ # Instantiate both ways robot = make_robot_from_config(MockRobotConfig()) @@ -87,6 +84,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): root_init = tmp_path / "init" dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1) + # Facade-level attributes should match between __init__ and create() init_attr = set(vars(dataset_init).keys()) create_attr = set(vars(dataset_create).keys()) @@ -214,6 +212,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert len(dataset) == 1 assert dataset[0]["task"] == "Dummy task" @@ -226,6 +225,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["state"].shape == torch.Size([2]) @@ -235,6 +235,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["state"].shape == torch.Size([2, 4]) @@ -244,6 +245,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["state"].shape == torch.Size([2, 4, 3]) @@ -253,6 +255,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5]) @@ -262,6 +265,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1]) @@ -271,6 +275,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["state"].ndim == 0 @@ -280,6 +285,7 @@ def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory): dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["caption"] == "Dummy caption" @@ -315,6 +321,7 @@ def test_add_frame_image(image_dataset): dataset = image_dataset dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -323,6 +330,7 @@ def test_add_frame_image_h_w_c(image_dataset): dataset = image_dataset dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -332,6 +340,7 @@ def test_add_frame_image_uint8(image_dataset): image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) dataset.add_frame({"image": image, "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -341,6 +350,7 @@ def test_add_frame_image_pil(image_dataset): image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"}) dataset.save_episode() + dataset.finalize() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -361,7 +371,7 @@ def test_tmp_image_deletion(tmp_path, empty_lerobot_dataset_factory): ds_img = empty_lerobot_dataset_factory(root=tmp_path / "img", features=features_image) ds_img.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) ds_img.save_episode() - img_dir = ds_img._get_image_file_dir(0, image_key) + img_dir = ds_img.writer._get_image_file_dir(0, image_key) assert not img_dir.exists(), "Temporary image directory should be removed for image features" @@ -374,10 +384,10 @@ def test_tmp_video_deletion(tmp_path, empty_lerobot_dataset_factory): } ds_vid = empty_lerobot_dataset_factory(root=tmp_path / "vid", features=features_video) - ds_vid.batch_encoding_size = 1 + ds_vid.writer._batch_encoding_size = 1 ds_vid.add_frame({vid_key: np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) ds_vid.save_episode() - vid_img_dir = ds_vid._get_image_file_dir(0, vid_key) + vid_img_dir = ds_vid.writer._get_image_file_dir(0, vid_key) assert not vid_img_dir.exists(), ( "Temporary image directory should be removed when batch_encoding_size == 1" ) @@ -402,8 +412,8 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory): } ) ds_mixed.save_episode() - img_dir = ds_mixed._get_image_file_dir(0, image_key) - vid_img_dir = ds_mixed._get_image_file_dir(0, vid_key) + img_dir = ds_mixed.writer._get_image_file_dir(0, image_key) + vid_img_dir = ds_mixed.writer._get_image_file_dir(0, vid_key) assert not img_dir.exists(), "Temporary image directory should be removed for image features" assert vid_img_dir.exists(), ( "Temporary image directory should not be removed for video features when batch_encoding_size == 2" @@ -631,29 +641,29 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory): ) # Test hf_dataset is None - dataset.hf_dataset = None - assert dataset._check_cached_episodes_sufficient() is False + dataset.reader.hf_dataset = None + assert dataset.reader._check_cached_episodes_sufficient() is False # Test hf_dataset is empty import datasets empty_features = get_hf_features_from_features(dataset.features) - dataset.hf_dataset = datasets.Dataset.from_dict( + dataset.reader.hf_dataset = datasets.Dataset.from_dict( {key: [] for key in empty_features}, features=empty_features ) - dataset.hf_dataset.set_transform(hf_transform_to_torch) - assert dataset._check_cached_episodes_sufficient() is False + dataset.reader.hf_dataset.set_transform(hf_transform_to_torch) + assert dataset.reader._check_cached_episodes_sufficient() is False # Restore the original dataset for remaining tests - dataset.hf_dataset = dataset.load_hf_dataset() + dataset.reader.hf_dataset = dataset.reader._load_hf_dataset() # Test all episodes requested (self.episodes = None) and all are available - dataset.episodes = None - assert dataset._check_cached_episodes_sufficient() is True + dataset.reader.episodes = None + assert dataset.reader._check_cached_episodes_sufficient() is True # Test specific episodes requested that are all available - dataset.episodes = [0, 2, 4] - assert dataset._check_cached_episodes_sufficient() is True + dataset.reader.episodes = [0, 2, 4] + assert dataset.reader._check_cached_episodes_sufficient() is True # Test request episodes that don't exist in the cached dataset # Create a dataset with only episodes 0, 1, 2 @@ -665,8 +675,8 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory): ) # Request episodes that include non-existent ones - limited_dataset.episodes = [0, 1, 2, 3, 4] - assert limited_dataset._check_cached_episodes_sufficient() is False + limited_dataset.reader.episodes = [0, 1, 2, 3, 4] + assert limited_dataset.reader._check_cached_episodes_sufficient() is False # Test create a dataset with sparse episodes (e.g., only episodes 0, 2, 4) # First create the full dataset structure @@ -702,22 +712,22 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory): filtered_data[key] = filtered_values - sparse_dataset.hf_dataset = datasets.Dataset.from_dict( + sparse_dataset.reader.hf_dataset = datasets.Dataset.from_dict( filtered_data, features=get_hf_features_from_features(sparse_dataset.features) ) - sparse_dataset.hf_dataset.set_transform(hf_transform_to_torch) + sparse_dataset.reader.hf_dataset.set_transform(hf_transform_to_torch) # Test requesting all episodes when only some are cached - sparse_dataset.episodes = None - assert sparse_dataset._check_cached_episodes_sufficient() is False + sparse_dataset.reader.episodes = None + assert sparse_dataset.reader._check_cached_episodes_sufficient() is False # Test requesting only the available episodes - sparse_dataset.episodes = [0, 2, 4] - assert sparse_dataset._check_cached_episodes_sufficient() is True + sparse_dataset.reader.episodes = [0, 2, 4] + assert sparse_dataset.reader._check_cached_episodes_sufficient() is True # Test requesting a mix of available and unavailable episodes - sparse_dataset.episodes = [0, 1, 2] - assert sparse_dataset._check_cached_episodes_sufficient() is False + sparse_dataset.reader.episodes = [0, 1, 2] + assert sparse_dataset.reader._check_cached_episodes_sufficient() is False def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory): @@ -1189,13 +1199,13 @@ def test_dataset_resume_recording(tmp_path, empty_lerobot_dataset_factory): del dataset_verify # Phase 3: Resume recording - add more episodes - dataset_resumed = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0") + dataset_resumed = LeRobotDataset.resume(initial_repo_id, root=initial_root, revision="v3.0") assert dataset_resumed.meta.total_episodes == initial_episodes assert dataset_resumed.meta.total_frames == initial_episodes * frames_per_episode - assert dataset_resumed.latest_episode is None # Not recording yet - assert dataset_resumed.writer is None - assert dataset_resumed.meta.writer is None + assert dataset_resumed.writer._latest_episode is None # Not recording yet + assert dataset_resumed.writer._pq_writer is None + assert dataset_resumed.meta._pq_writer is None additional_episodes = 2 for ep_idx in range(initial_episodes, initial_episodes + additional_episodes): @@ -1271,7 +1281,7 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) dataset.meta.update_chunk_settings(data_files_size_in_mb=100) - assert dataset._current_file_start_frame is None + assert dataset.writer._current_file_start_frame is None frames_per_episode = 10 for _ in range(frames_per_episode): @@ -1284,7 +1294,7 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact ) dataset.save_episode() - assert dataset._current_file_start_frame == 0 + assert dataset.writer._current_file_start_frame == 0 assert dataset.meta.total_episodes == 1 assert dataset.meta.total_frames == frames_per_episode @@ -1298,12 +1308,12 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact ) dataset.save_episode() - assert dataset._current_file_start_frame == 0 + assert dataset.writer._current_file_start_frame == 0 assert dataset.meta.total_episodes == 2 assert dataset.meta.total_frames == 2 * frames_per_episode - ep1_chunk = dataset.latest_episode["data/chunk_index"] - ep1_file = dataset.latest_episode["data/file_index"] + ep1_chunk = dataset.writer._latest_episode["data/chunk_index"] + ep1_file = dataset.writer._latest_episode["data/file_index"] assert ep1_chunk == 0 assert ep1_file == 0 @@ -1317,12 +1327,12 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact ) dataset.save_episode() - assert dataset._current_file_start_frame == 0 + assert dataset.writer._current_file_start_frame == 0 assert dataset.meta.total_episodes == 3 assert dataset.meta.total_frames == 3 * frames_per_episode - ep2_chunk = dataset.latest_episode["data/chunk_index"] - ep2_file = dataset.latest_episode["data/file_index"] + ep2_chunk = dataset.writer._latest_episode["data/chunk_index"] + ep2_file = dataset.writer._latest_episode["data/file_index"] assert ep2_chunk == 0 assert ep2_file == 0 @@ -1354,82 +1364,6 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact assert frame["episode_index"].item() == expected_ep -def test_encode_video_worker_forwards_vcodec(tmp_path): - """Test that _encode_video_worker correctly forwards the vcodec parameter to encode_video_frames.""" - from unittest.mock import patch - - from lerobot.datasets.utils import DEFAULT_IMAGE_PATH - - # Create the expected directory structure - video_key = "observation.images.laptop" - episode_index = 0 - frame_index = 0 - - fpath = DEFAULT_IMAGE_PATH.format( - image_key=video_key, episode_index=episode_index, frame_index=frame_index - ) - img_dir = tmp_path / Path(fpath).parent - img_dir.mkdir(parents=True, exist_ok=True) - - # Create a dummy image file - dummy_img = Image.new("RGB", (64, 64), color="red") - dummy_img.save(img_dir / "frame-000000.png") - - # Track what vcodec was passed to encode_video_frames - captured_kwargs = {} - - def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs): - captured_kwargs.update(kwargs) - # Create a dummy output file so the worker doesn't fail - Path(video_path).parent.mkdir(parents=True, exist_ok=True) - Path(video_path).touch() - - with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames): - # Test with h264 codec - _encode_video_worker(video_key, episode_index, tmp_path, fps=30, vcodec="h264") - - assert "vcodec" in captured_kwargs - assert captured_kwargs["vcodec"] == "h264" - - -def test_encode_video_worker_default_vcodec(tmp_path): - """Test that _encode_video_worker uses libsvtav1 as the default codec.""" - from unittest.mock import patch - - from lerobot.datasets.utils import DEFAULT_IMAGE_PATH - - # Create the expected directory structure - video_key = "observation.images.laptop" - episode_index = 0 - frame_index = 0 - - fpath = DEFAULT_IMAGE_PATH.format( - image_key=video_key, episode_index=episode_index, frame_index=frame_index - ) - img_dir = tmp_path / Path(fpath).parent - img_dir.mkdir(parents=True, exist_ok=True) - - # Create a dummy image file - dummy_img = Image.new("RGB", (64, 64), color="red") - dummy_img.save(img_dir / "frame-000000.png") - - # Track what vcodec was passed to encode_video_frames - captured_kwargs = {} - - def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs): - captured_kwargs.update(kwargs) - # Create a dummy output file so the worker doesn't fail - Path(video_path).parent.mkdir(parents=True, exist_ok=True) - Path(video_path).touch() - - with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames): - # Test with default codec (no vcodec specified) - _encode_video_worker(video_key, episode_index, tmp_path, fps=30) - - assert "vcodec" in captured_kwargs - assert captured_kwargs["vcodec"] == "libsvtav1" - - def test_lerobot_dataset_vcodec_validation(): """Test that LeRobotDataset validates the vcodec parameter.""" # Test that invalid vcodec raises ValueError diff --git a/tests/datasets/test_image_writer.py b/tests/datasets/test_image_writer.py index e02755171..55419473f 100644 --- a/tests/datasets/test_image_writer.py +++ b/tests/datasets/test_image_writer.py @@ -352,10 +352,14 @@ def test_with_different_image_formats(tmp_path, img_array_factory): def test_safe_stop_image_writer_decorator(): - class MockDataset: + class MockWriter: def __init__(self): self.image_writer = MagicMock(spec=AsyncImageWriter) + class MockDataset: + def __init__(self): + self.writer = MockWriter() + @safe_stop_image_writer def function_that_raises_exception(dataset=None): raise Exception("Test exception") @@ -366,7 +370,7 @@ def test_safe_stop_image_writer_decorator(): function_that_raises_exception(dataset=dataset) assert str(exc_info.value) == "Test exception" - dataset.image_writer.stop.assert_called_once() + dataset.writer.image_writer.stop.assert_called_once() def test_main_process_time(tmp_path, img_tensor_factory): diff --git a/tests/datasets/test_lerobot_dataset.py b/tests/datasets/test_lerobot_dataset.py new file mode 100644 index 000000000..d7ce54a15 --- /dev/null +++ b/tests/datasets/test_lerobot_dataset.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contract tests for the LeRobotDataset facade. + +Tests focus on mode contracts (read-only, write-only, resume), guards, +property delegation, and the full create-record-finalize-read lifecycle. +""" + +import pytest +import torch + +from lerobot.datasets.dataset_reader import DatasetReader +from lerobot.datasets.dataset_writer import DatasetWriter +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID + +SIMPLE_FEATURES = { + "state": {"dtype": "float32", "shape": (2,), "names": None}, +} + + +def _make_frame(task: str = "Dummy task") -> dict: + return {"task": task, "state": torch.randn(2)} + + +# ── Read-only mode (via __init__) ──────────────────────────────────── + + +def test_init_creates_reader_no_writer(tmp_path, lerobot_dataset_factory): + """__init__() sets reader to a DatasetReader and writer to None.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False + ) + assert isinstance(dataset.reader, DatasetReader) + assert dataset.writer is None + + +def test_init_loads_data(tmp_path, lerobot_dataset_factory): + """After __init__(), the dataset has data and len > 0.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False + ) + assert len(dataset) > 0 + + +def test_getitem_works_in_read_mode(tmp_path, lerobot_dataset_factory): + """dataset[0] returns a dict with expected keys in read-only mode.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False + ) + item = dataset[0] + assert isinstance(item, dict) + assert "index" in item + assert "task" in item + + +def test_len_matches_num_frames(tmp_path, lerobot_dataset_factory): + """len(dataset) equals dataset.num_frames.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=2, total_frames=30, use_videos=False + ) + assert len(dataset) == dataset.num_frames + + +# ── Write-only mode (via create()) ────────────────────────────────── + + +def test_create_sets_writer_no_reader(tmp_path): + """create() sets writer to a DatasetWriter and reader to None.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + assert isinstance(dataset.writer, DatasetWriter) + assert dataset.reader is None + + +def test_create_initial_counts_zero(tmp_path): + """After create(), num_episodes == 0 and num_frames == 0.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + assert dataset.num_episodes == 0 + assert dataset.num_frames == 0 + + +def test_add_frame_works_in_write_mode(tmp_path): + """add_frame() succeeds on a dataset created via create().""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + dataset.add_frame(_make_frame()) # should not raise + + +# ── Resume mode ────────────────────────────────────────────────────── + + +def test_resume_creates_writer(tmp_path): + """After resume(), writer is a DatasetWriter.""" + root = tmp_path / "resume_ds" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root + ) + for _ in range(3): + dataset.add_frame(_make_frame()) + dataset.save_episode() + dataset.finalize() + + resumed = LeRobotDataset.resume(repo_id=DUMMY_REPO_ID, root=root) + assert isinstance(resumed.writer, DatasetWriter) + + +def test_resume_preserves_episode_count(tmp_path): + """After resume(), existing episodes are counted.""" + root = tmp_path / "resume_ds" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root + ) + for _ in range(3): + dataset.add_frame(_make_frame()) + dataset.save_episode() + dataset.finalize() + + resumed = LeRobotDataset.resume(repo_id=DUMMY_REPO_ID, root=root) + assert resumed.meta.total_episodes == 1 + + +def test_resume_can_add_more_episodes(tmp_path): + """After resume(), new episodes can be added.""" + root = tmp_path / "resume_ds" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root + ) + for _ in range(3): + dataset.add_frame(_make_frame()) + dataset.save_episode() + dataset.finalize() + + resumed = LeRobotDataset.resume(repo_id=DUMMY_REPO_ID, root=root) + for _ in range(2): + resumed.add_frame(_make_frame()) + resumed.save_episode() + + assert resumed.meta.total_episodes == 2 + + +# ── Writer guard ───────────────────────────────────────────────────── + + +def test_add_frame_raises_without_writer(tmp_path, lerobot_dataset_factory): + """add_frame() raises RuntimeError on a read-only dataset.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False + ) + with pytest.raises(RuntimeError, match="read-only"): + dataset.add_frame(_make_frame()) + + +def test_save_episode_raises_without_writer(tmp_path, lerobot_dataset_factory): + """save_episode() raises RuntimeError on a read-only dataset.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False + ) + with pytest.raises(RuntimeError, match="read-only"): + dataset.save_episode() + + +def test_clear_episode_buffer_raises_without_writer(tmp_path, lerobot_dataset_factory): + """clear_episode_buffer() raises RuntimeError on a read-only dataset.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False + ) + with pytest.raises(RuntimeError, match="read-only"): + dataset.clear_episode_buffer() + + +# ── Reader guard ───────────────────────────────────────────────────── + + +def test_getitem_raises_before_finalize(tmp_path): + """dataset[0] raises RuntimeError while recording (before finalize).""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + for _ in range(3): + dataset.add_frame(_make_frame()) + dataset.save_episode() + + with pytest.raises(RuntimeError, match="finalize"): + dataset[0] + + +def test_getitem_works_after_finalize(tmp_path): + """After finalize(), dataset[0] returns data.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + for _ in range(3): + dataset.add_frame(_make_frame()) + dataset.save_episode() + dataset.finalize() + + item = dataset[0] + assert "state" in item + assert "task" in item + + +# ── Property delegation ────────────────────────────────────────────── + + +def test_fps_delegates_to_meta(tmp_path, lerobot_dataset_factory): + """dataset.fps == dataset.meta.fps.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False + ) + assert dataset.fps == dataset.meta.fps + + +def test_features_delegates_to_meta(tmp_path, lerobot_dataset_factory): + """dataset.features is dataset.meta.features.""" + dataset = lerobot_dataset_factory( + root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False + ) + assert dataset.features is dataset.meta.features + + +def test_num_frames_uses_meta_in_write_mode(tmp_path): + """In write-only mode (reader=None), num_frames comes from metadata.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + assert dataset.reader is None + assert dataset.num_frames == dataset.meta.total_frames + + +# ── Lifecycle ──────────────────────────────────────────────────────── + + +def test_finalize_is_idempotent(tmp_path): + """Calling finalize() twice does not raise.""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + dataset.finalize() + dataset.finalize() + + +def test_has_pending_frames_lifecycle(tmp_path): + """has_pending_frames: False -> True (add_frame) -> False (save_episode).""" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds" + ) + assert dataset.has_pending_frames() is False + + dataset.add_frame(_make_frame()) + assert dataset.has_pending_frames() is True + + dataset.save_episode() + assert dataset.has_pending_frames() is False + + +def test_create_record_finalize_read_roundtrip(tmp_path): + """End-to-end: create, record 2 episodes, finalize, re-open, verify data.""" + root = tmp_path / "roundtrip" + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root + ) + + # Episode 0: 3 frames with known values + ep0_states = [] + for i in range(3): + state = torch.tensor([float(i), float(i * 2)]) + ep0_states.append(state) + dataset.add_frame({"task": "Task A", "state": state}) + dataset.save_episode() + + # Episode 1: 2 frames + ep1_states = [] + for i in range(2): + state = torch.tensor([float(i + 100), float(i + 200)]) + ep1_states.append(state) + dataset.add_frame({"task": "Task B", "state": state}) + dataset.save_episode() + + dataset.finalize() + + # Re-open as read-only + reopened = LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root) + assert len(reopened) == 5 + assert reopened.num_episodes == 2 + + # Verify episode 0 + for i in range(3): + item = reopened[i] + assert torch.allclose(item["state"], ep0_states[i], atol=1e-5) + assert item["episode_index"].item() == 0 + + # Verify episode 1 + for i in range(2): + item = reopened[3 + i] + assert torch.allclose(item["state"], ep1_states[i], atol=1e-5) + assert item["episode_index"].item() == 1 diff --git a/tests/datasets/test_streaming_video_encoder.py b/tests/datasets/test_streaming_video_encoder.py index a85db6a8d..f7e63b06f 100644 --- a/tests/datasets/test_streaming_video_encoder.py +++ b/tests/datasets/test_streaming_video_encoder.py @@ -534,7 +534,7 @@ class TestStreamingEncoderIntegration: streaming_encoding=True, ) - assert dataset._streaming_encoder is not None + assert dataset.writer._streaming_encoder is not None num_frames = 20 for _ in range(num_frames): @@ -580,7 +580,7 @@ class TestStreamingEncoderIntegration: streaming_encoding=False, ) - assert dataset._streaming_encoder is None + assert dataset.writer._streaming_encoder is None num_frames = 5 for _ in range(num_frames):