diff --git a/src/lerobot/configs/datasets.py b/src/lerobot/configs/datasets.py index 15630991f..0014e096f 100644 --- a/src/lerobot/configs/datasets.py +++ b/src/lerobot/configs/datasets.py @@ -1,4 +1,4 @@ -from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3, OBS_STATE, OBS_IMAGE_4, TASK, ROBOT +from lerobot.constants import ACTION, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3, OBS_STATE, OBS_IMAGE_4, TASK, ROBOT_TYPE IMAGES_ORDER = { OBS_IMAGE: 0, @@ -16,9 +16,9 @@ ROBOT_TYPE_KEYS_MAPPING = { "lerobot/taco_play": "static_single_arm_7statedim", } TRAINING_FEATURES = { - 0: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE], - 1: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2], - 2: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3], + 0: [ACTION, OBS_STATE, TASK, ROBOT_TYPE, OBS_IMAGE], + 1: [ACTION, OBS_STATE, TASK, ROBOT_TYPE, OBS_IMAGE, OBS_IMAGE_2], + 2: [ACTION, OBS_STATE, TASK, ROBOT_TYPE, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3], } # Map to "observation.state", "action", "observation.image", etc. FEATURE_KEYS_MAPPING = { diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index e69de29bb..fb0ebf16c 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -0,0 +1,1333 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import logging +import shutil +import os +from pathlib import Path +from typing import Callable + +import datasets +import numpy as np +import packaging.version +import PIL.Image +import torch +import torch.utils +from datasets import concatenate_datasets, load_dataset +from huggingface_hub import HfApi, snapshot_download +from huggingface_hub.constants import REPOCARD_NAME +from huggingface_hub.errors import RevisionNotFoundError + +from lerobot.constants import HF_LEROBOT_HOME +from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats #aggregate_stats_per_robot_type, +from lerobot.datasets.image_writer import AsyncImageWriter, write_image +from lerobot.datasets.utils import ( + DEFAULT_FEATURES, + DEFAULT_IMAGE_PATH, + INFO_PATH, + TASKS_PATH, + _validate_feature_names, + append_jsonlines, + backward_compatible_episodes_stats, + check_delta_timestamps, + check_timestamps_sync, + check_version_compatibility, + create_empty_dataset_info, + create_lerobot_dataset_card, + embed_images, + get_delta_indices, + get_episode_data_index, + get_features_from_robot, + get_hf_features_from_features, + get_safe_version, + hf_transform_to_torch, + is_valid_version, + load_episodes, + load_episodes_stats, + load_info, + load_stats, + load_tasks, + map_dict_keys, + validate_episode_buffer, + validate_frame, + write_episode, + write_episode_stats, + write_info, + write_json, + #keep_datasets_with_the_same_features_per_robot_type, + #map_dict_pad_keys, + #keep_datasets_with_valid_fps, + #find_start_of_motion, +) +from lerobot.datasets.video_utils import ( + VideoFrame, + decode_video_frames, + encode_video_frames, + get_safe_default_codec, + get_video_info, +) + +# from lerobot.robots.utils import Robot +from lerobot.configs.datasets import ROBOT_TYPE_KEYS_MAPPING, TASKS_KEYS_MAPPING +#FIXME: remove this import +from lerobot.datasets.collators import pad_tensor + +CODEBASE_VERSION = "v2.1" +LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser() + + +def find_start_of_motion(velocities, window_size, threshold, motion_buffer): + for t in range(len(velocities) - window_size): + window_mean = velocities[t:t+window_size].mean() + if window_mean > threshold: + return max(0, t - motion_buffer) # include slight context before motion + return 0 + +class LeRobotDatasetMetadata: + def __init__( + self, + repo_id: str, + root: str | Path | None = None, + local_files_only: bool = False, + feature_keys_mapping: dict[str, str] | None = None, + revision: str | None = None, + force_cache_sync: bool = False, + ): + self.repo_id = repo_id + self.local_files_only = local_files_only + self.revision = revision if revision else CODEBASE_VERSION + self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + + try: + if force_cache_sync: + raise FileNotFoundError + self.load_metadata() + except (FileNotFoundError, NotADirectoryError): + if is_valid_version(self.revision): + self.revision = get_safe_version(self.repo_id, self.revision) + + (self.root / "meta").mkdir(exist_ok=True, parents=True) + self.pull_from_repo(allow_patterns="meta/") + self.load_metadata() + self.feature_keys_mapping = feature_keys_mapping.get(repo_id, None) if feature_keys_mapping else None + self.inverse_feature_keys_mapping = ( + {v: k for k, v in self.feature_keys_mapping.items() if v} if self.feature_keys_mapping else {} + ) + self.info["features"] = map_dict_keys( + self.info["features"], feature_keys_mapping=self.feature_keys_mapping + ) + def load_metadata(self): + self.info = load_info(self.root) + check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) + self.tasks, self.task_to_task_index = load_tasks(self.root) + self.episodes = load_episodes(self.root) + if self._version < packaging.version.parse("v2.1"): + self.stats = load_stats(self.root) + self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes) + else: + self.episodes_stats = load_episodes_stats(self.root) + self.stats = aggregate_stats(list(self.episodes_stats.values())) + + def pull_from_repo( + self, + allow_patterns: list[str] | str | None = None, + ignore_patterns: list[str] | str | None = None, + ) -> None: + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + local_dir=self.root, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + + @property + def _version(self) -> packaging.version.Version: + """Codebase version used to create this dataset.""" + return packaging.version.parse(self.info["codebase_version"]) + + def get_data_file_path(self, ep_index: int) -> Path: + ep_chunk = self.get_episode_chunk(ep_index) + fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index) + return Path(fpath) + + def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: + ep_chunk = self.get_episode_chunk(ep_index) + fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) + return Path(fpath) + + def get_episode_chunk(self, ep_index: int) -> int: + return ep_index // self.chunks_size + + @property + def data_path(self) -> str: + """Formattable string for the parquet files.""" + return self.info["data_path"] + + @property + def video_path(self) -> str | None: + """Formattable string for the video files.""" + return self.info["video_path"] + + @property + def robot_type(self) -> str | None: + """Robot type used in recording this dataset.""" + return self.info["robot_type"] + + @property + def fps(self) -> int: + """Frames per second used during data collection.""" + return self.info["fps"] + + @property + def features(self) -> dict[str, dict]: + """All features contained in the dataset.""" + return self.info["features"] + + @property + def image_keys(self) -> list[str]: + """Keys to access visual modalities stored as images.""" + return [key for key, ft in self.features.items() if ft["dtype"] == "image"] + + @property + def video_keys(self) -> list[str]: + """Keys to access visual modalities stored as videos.""" + return [key for key, ft in self.features.items() if ft["dtype"] == "video"] + + @property + def camera_keys(self) -> list[str]: + """Keys to access visual modalities (regardless of their storage method).""" + return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] + + @property + def names(self) -> dict[str, list | dict]: + """Names of the various dimensions of vector modalities.""" + return {key: ft["names"] for key, ft in self.features.items()} + + @property + def shapes(self) -> dict: + """Shapes for the different features.""" + return {key: tuple(ft["shape"]) for key, ft in self.features.items()} + + @property + def total_episodes(self) -> int: + """Total number of episodes available.""" + return self.info["total_episodes"] + + @property + def total_frames(self) -> int: + """Total number of frames saved in this dataset.""" + return self.info["total_frames"] + + @property + def total_tasks(self) -> int: + """Total number of different tasks performed in this dataset.""" + return self.info["total_tasks"] + + @property + def total_chunks(self) -> int: + """Total number of chunks (groups of episodes).""" + return self.info["total_chunks"] + + @property + def chunks_size(self) -> int: + """Max number of episodes per chunk.""" + return self.info["chunks_size"] + + def get_task_index(self, task: str) -> int | None: + """ + Given a task in natural language, returns its task_index if the task already exists in the dataset, + otherwise return None. + """ + return self.task_to_task_index.get(task, None) + + def add_task(self, task: str): + """ + Given a task in natural language, add it to the dictionary of tasks. + """ + if task in self.task_to_task_index: + raise ValueError(f"The task '{task}' already exists and can't be added twice.") + + task_index = self.info["total_tasks"] + self.task_to_task_index[task] = task_index + self.tasks[task_index] = task + self.info["total_tasks"] += 1 + + task_dict = { + "task_index": task_index, + "task": task, + } + append_jsonlines(task_dict, self.root / TASKS_PATH) + + def save_episode( + self, + episode_index: int, + episode_length: int, + episode_tasks: list[str], + episode_stats: dict[str, dict], + ) -> None: + self.info["total_episodes"] += 1 + self.info["total_frames"] += episode_length + + chunk = self.get_episode_chunk(episode_index) + if chunk >= self.total_chunks: + self.info["total_chunks"] += 1 + + self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} + self.info["total_videos"] += len(self.video_keys) + if len(self.video_keys) > 0: + self.update_video_info() + + write_info(self.info, self.root) + + episode_dict = { + "episode_index": episode_index, + "tasks": episode_tasks, + "length": episode_length, + } + self.episodes[episode_index] = episode_dict + write_episode(episode_dict, self.root) + + self.episodes_stats[episode_index] = episode_stats + self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats + write_episode_stats(episode_index, episode_stats, self.root) + + def update_video_info(self) -> None: + """ + Warning: this function writes info from first episode videos, implicitly assuming that all videos have + been encoded the same way. Also, this means it assumes the first episode exists. + """ + for key in self.video_keys: + if not self.features[key].get("info", None): + video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) + self.info["features"][key]["info"] = get_video_info(video_path) + + def __repr__(self): + feature_keys = list(self.features) + return ( + f"{self.__class__.__name__}({{\n" + f" Repository ID: '{self.repo_id}',\n" + f" Total episodes: '{self.total_episodes}',\n" + f" Total frames: '{self.total_frames}',\n" + f" Features: '{feature_keys}',\n" + "})',\n" + ) + + @classmethod + def create( + cls, + repo_id: str, + fps: int, + features: dict, + robot_type: str | None = None, + root: str | Path | None = None, + use_videos: bool = True, + ) -> "LeRobotDatasetMetadata": + """Creates metadata for a LeRobotDataset.""" + obj = cls.__new__(cls) + obj.repo_id = repo_id + obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + + obj.root.mkdir(parents=True, exist_ok=False) + + # TODO(aliberts, rcadene): implement sanity check for features + features = {**features, **DEFAULT_FEATURES} + _validate_feature_names(features) + + obj.tasks, obj.task_to_task_index = {}, {} + obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {} + obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type) + if len(obj.video_keys) > 0 and not use_videos: + raise ValueError() + write_json(obj.info, obj.root / INFO_PATH) + obj.revision = None + return obj + + +class LeRobotDataset(torch.utils.data.Dataset): + def __init__( + self, + repo_id: str, + root: str | Path | None = None, + episodes: list[int] | None = None, + image_transforms: Callable | None = None, + delta_timestamps: dict[list[float]] | None = None, + tolerance_s: float = 1e-4, + revision: str | None = None, + force_cache_sync: bool = False, + download_videos: bool = True, + video_backend: str | None = None, + ): + """ + 2 modes are available for instantiating this class, depending on 2 different use cases: + + 1. Your dataset already exists: + - On your local disk in the 'root' folder. This is typically the case when you recorded your + dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class + with 'root' will load your dataset directly from disk. This can happen while you're offline (no + internet connection). + + - On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on + your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download + the dataset from that address and load it, pending your dataset is compliant with + codebase_version v2.0. If your dataset has been created before this new format, you will be + prompted to convert it using our conversion script from v1.6 to v2.0, which you can find at + lerobot/datasets/v2/convert_dataset_v1_to_v2.py. + + + 2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty + LeRobotDataset with the 'create' classmethod. This can be used for recording a dataset or port an + existing dataset to the LeRobotDataset format. + + + In terms of files, LeRobotDataset encapsulates 3 main things: + - metadata: + - info contains various information about the dataset like shapes, keys, fps etc. + - stats stores the dataset statistics of the different modalities for normalization + - tasks contains the prompts for each task of the dataset, which can be used for + task-conditioned training. + - hf_dataset (from datasets.Dataset), which will read any values from parquet files. + - videos (optional) from which frames are loaded to be synchronous with data from parquet files. + + A typical LeRobotDataset looks like this from its root path: + . + ├── data + │ ├── chunk-000 + │ │ ├── episode_000000.parquet + │ │ ├── episode_000001.parquet + │ │ ├── episode_000002.parquet + │ │ └── ... + │ ├── chunk-001 + │ │ ├── episode_001000.parquet + │ │ ├── episode_001001.parquet + │ │ ├── episode_001002.parquet + │ │ └── ... + │ └── ... + ├── meta + │ ├── episodes.jsonl + │ ├── info.json + │ ├── stats.json + │ └── tasks.jsonl + └── videos + ├── chunk-000 + │ ├── observation.images.laptop + │ │ ├── episode_000000.mp4 + │ │ ├── episode_000001.mp4 + │ │ ├── episode_000002.mp4 + │ │ └── ... + │ ├── observation.images.phone + │ │ ├── episode_000000.mp4 + │ │ ├── episode_000001.mp4 + │ │ ├── episode_000002.mp4 + │ │ └── ... + ├── chunk-001 + └── ... + + Note that this file-based structure is designed to be as versatile as possible. The files are split by + episodes which allows a more granular control over which episodes one wants to use and download. The + structure of the dataset is entirely described in the info.json file, which can be easily downloaded + or viewed directly on the hub before downloading any actual data. The type of files used are very + simple and do not need complex tools to be read, it only uses .parquet, .json and .mp4 files (and .md + for the README). + + Args: + repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset + will be stored under root/repo_id. + root (Path | None, optional): Local directory to use for downloading/writing files. You can also + set the LEROBOT_HOME environment variable to point to a different location. Defaults to + '~/.cache/huggingface/lerobot'. + episodes (list[int] | None, optional): If specified, this will only load episodes specified by + their episode_index in this list. Defaults to None. + image_transforms (Callable | None, optional): You can pass standard v2 image transforms from + torchvision.transforms.v2 here which will be applied to visual modalities (whether they come + from videos or images). Defaults to None. + delta_timestamps (dict[list[float]] | None, optional): _description_. Defaults to None. + tolerance_s (float, optional): Tolerance in seconds used to ensure data timestamps are actually in + sync with the fps value. It is used at the init of the dataset to make sure that each + timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames + decoded from video files. It is also used to check that `delta_timestamps` (when provided) are + multiples of 1/fps. Defaults to 1e-4. + revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a + commit hash. Defaults to current codebase version tag. + sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files + are already present in the local cache, this will be faster. However, files loaded might not + be in sync with the version on the hub, especially if you specified 'revision'. Defaults to + False. + download_videos (bool, optional): Flag to download the videos. Note that when set to True but the + video files are already present on local disk, they won't be downloaded again. Defaults to + True. + video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'. + You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. + """ + super().__init__() + self.repo_id = repo_id + self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id + self.image_transforms = image_transforms + self.delta_timestamps = delta_timestamps + self.episodes = episodes + self.tolerance_s = tolerance_s + self.revision = revision if revision else CODEBASE_VERSION + self.video_backend = video_backend if video_backend else get_safe_default_codec() + self.delta_indices = None + + # Unused attributes + self.image_writer = None + self.episode_buffer = None + + self.root.mkdir(exist_ok=True, parents=True) + + # Load metadata + self.meta = LeRobotDatasetMetadata( + self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync + ) + if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"): + episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes] + self.stats = aggregate_stats(episodes_stats) + + # Load actual data + try: + if force_cache_sync: + raise FileNotFoundError + assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths()) + self.hf_dataset = self.load_hf_dataset() + except (AssertionError, FileNotFoundError, NotADirectoryError): + self.revision = get_safe_version(self.repo_id, self.revision) + self.download_episodes(download_videos) + self.hf_dataset = self.load_hf_dataset() + + self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) + + # Check timestamps + timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy() + episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy() + ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()} + check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s) + + # Setup delta_indices + if self.delta_timestamps is not None: + check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) + self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) + + def push_to_hub( + self, + branch: str | None = None, + tags: list | None = None, + license: str | None = "apache-2.0", + tag_version: bool = True, + push_videos: bool = True, + private: bool = False, + allow_patterns: list[str] | str | None = None, + upload_large_folder: bool = False, + **card_kwargs, + ) -> None: + ignore_patterns = ["images/"] + if not push_videos: + ignore_patterns.append("videos/") + + hub_api = HfApi() + hub_api.create_repo( + repo_id=self.repo_id, + private=private, + repo_type="dataset", + exist_ok=True, + ) + if branch: + hub_api.create_branch( + repo_id=self.repo_id, + branch=branch, + revision=self.revision, + repo_type="dataset", + exist_ok=True, + ) + + upload_kwargs = { + "repo_id": self.repo_id, + "folder_path": self.root, + "repo_type": "dataset", + "revision": branch, + "allow_patterns": allow_patterns, + "ignore_patterns": ignore_patterns, + } + if upload_large_folder: + hub_api.upload_large_folder(**upload_kwargs) + else: + hub_api.upload_folder(**upload_kwargs) + + if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch): + card = create_lerobot_dataset_card( + tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs + ) + card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch) + + if tag_version: + with contextlib.suppress(RevisionNotFoundError): + hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset") + hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") + + def pull_from_repo( + self, + allow_patterns: list[str] | str | None = None, + ignore_patterns: list[str] | str | None = None, + ) -> None: + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + local_dir=self.root, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + + def download_episodes(self, download_videos: bool = True) -> None: + """Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this + will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole + dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present + in 'local_dir', they won't be downloaded again. + """ + # TODO(rcadene, aliberts): implement faster transfer + # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads + files = None + ignore_patterns = None if download_videos else "videos/" + if self.episodes is not None: + files = self.get_episodes_file_paths() + + self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns) + + def get_episodes_file_paths(self) -> list[Path]: + episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes)) + fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes] + if len(self.meta.video_keys) > 0: + video_files = [ + str(self.meta.get_video_file_path(ep_idx, vid_key)) + for vid_key in self.meta.video_keys + for ep_idx in episodes + ] + fpaths += video_files + + return fpaths + + def load_hf_dataset(self) -> datasets.Dataset: + """hf_dataset contains all the observations, states, actions, rewards, etc.""" + if self.episodes is None: + path = str(self.root / "data") + hf_dataset = load_dataset("parquet", data_dir=path, split="train") + else: + files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes] + hf_dataset = load_dataset("parquet", data_files=files, split="train") + + # TODO(aliberts): hf_dataset.set_format("torch") + hf_dataset.set_transform(hf_transform_to_torch) + return hf_dataset + + def create_hf_dataset(self) -> datasets.Dataset: + features = get_hf_features_from_features(self.features) + ft_dict = {col: [] for col in features} + hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train") + + # TODO(aliberts): hf_dataset.set_format("torch") + hf_dataset.set_transform(hf_transform_to_torch) + return hf_dataset + + @property + def fps(self) -> int: + """Frames per second used during data collection.""" + return self.meta.fps + + @property + def num_frames(self) -> int: + """Number of frames in selected episodes.""" + return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames + + @property + def num_episodes(self) -> int: + """Number of episodes selected.""" + return len(self.episodes) if self.episodes is not None else self.meta.total_episodes + + @property + def features(self) -> dict[str, dict]: + return self.meta.features + + @property + def hf_features(self) -> datasets.Features: + """Features of the hf_dataset.""" + if self.hf_dataset is not None: + return self.hf_dataset.features + else: + return get_hf_features_from_features(self.features) + + def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]: + ep_start = self.episode_data_index["from"][ep_idx] + ep_end = self.episode_data_index["to"][ep_idx] + query_indices = { + key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx] + for key, delta_idx in self.delta_indices.items() + } + padding = { # Pad values outside of current episode range + f"{key}_is_pad": torch.BoolTensor( + [(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx] + ) + for key, delta_idx in self.delta_indices.items() + } + return query_indices, padding + + def _get_query_timestamps( + self, + current_ts: float, + query_indices: dict[str, list[int]] | None = None, + ) -> dict[str, list[float]]: + query_timestamps = {} + for key in self.meta.video_keys: + if query_indices is not None and key in query_indices: + timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] + query_timestamps[key] = torch.stack(timestamps).tolist() + else: + query_timestamps[key] = [current_ts] + + return query_timestamps + + def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: + return { + key: torch.stack(self.hf_dataset.select(q_idx)[key]) + for key, q_idx in query_indices.items() + if key not in self.meta.video_keys + } + + def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: + """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function + in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a + Segmentation Fault. This probably happens because a memory reference to the video loader is created in + the main process and a subprocess fails to access it. + """ + item = {} + for vid_key, query_ts in query_timestamps.items(): + video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) + frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend) + item[vid_key] = frames.squeeze(0) + + return item + + def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict: + for key, val in padding.items(): + item[key] = torch.BoolTensor(val) + return item + + def __len__(self): + return self.num_frames + + def __getitem__(self, idx) -> dict: + item = self.hf_dataset[idx] + ep_idx = item["episode_index"].item() + + query_indices = None + if self.delta_indices is not None: + query_indices, padding = self._get_query_indices(idx, ep_idx) + query_result = self._query_hf_dataset(query_indices) + item = {**item, **padding} + for key, val in query_result.items(): + item[key] = val + + if len(self.meta.video_keys) > 0: + current_ts = item["timestamp"].item() + query_timestamps = self._get_query_timestamps(current_ts, query_indices) + video_frames = self._query_videos(query_timestamps, ep_idx) + item = {**video_frames, **item} + + if self.image_transforms is not None: + image_keys = self.meta.camera_keys + for cam in image_keys: + item[cam] = self.image_transforms(item[cam]) + + # Add task as a string + task_idx = item["task_index"].item() + item["task"] = self.meta.tasks[task_idx] + + return item + + def __repr__(self): + feature_keys = list(self.features) + return ( + f"{self.__class__.__name__}({{\n" + f" Repository ID: '{self.repo_id}',\n" + f" Number of selected episodes: '{self.num_episodes}',\n" + f" Number of selected samples: '{self.num_frames}',\n" + f" Features: '{feature_keys}',\n" + "})',\n" + ) + + def create_episode_buffer(self, episode_index: int | None = None) -> dict: + current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index + ep_buffer = {} + # size and task are special cases that are not in self.features + ep_buffer["size"] = 0 + ep_buffer["task"] = [] + for key in self.features: + ep_buffer[key] = current_ep_idx if key == "episode_index" else [] + return ep_buffer + + def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: + fpath = DEFAULT_IMAGE_PATH.format( + image_key=image_key, episode_index=episode_index, frame_index=frame_index + ) + return self.root / fpath + + def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None: + if self.image_writer is None: + if isinstance(image, torch.Tensor): + image = image.cpu().numpy() + write_image(image, fpath) + else: + self.image_writer.save_image(image=image, fpath=fpath) + + def add_frame(self, frame: dict, task: str, timestamp: float | None = None) -> None: + """ + This function only adds the frame to the episode_buffer. Apart from images — which are written in a + temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method + then needs to be called. + """ + # Convert torch to numpy if needed + for name in frame: + if isinstance(frame[name], torch.Tensor): + frame[name] = frame[name].numpy() + + validate_frame(frame, self.features) + + if self.episode_buffer is None: + self.episode_buffer = self.create_episode_buffer() + + # Automatically add frame_index and timestamp to episode buffer + frame_index = self.episode_buffer["size"] + if timestamp is None: + timestamp = frame_index / self.fps + self.episode_buffer["frame_index"].append(frame_index) + self.episode_buffer["timestamp"].append(timestamp) + self.episode_buffer["task"].append(task) + + # Add frame features to episode_buffer + for key in frame: + + if key not in self.features: + raise ValueError( + f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'." + ) + + if self.features[key]["dtype"] in ["image", "video"]: + img_path = self._get_image_file_path( + episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index + ) + if frame_index == 0: + img_path.parent.mkdir(parents=True, exist_ok=True) + self._save_image(frame[key], img_path) + self.episode_buffer[key].append(str(img_path)) + else: + self.episode_buffer[key].append(frame[key]) + + self.episode_buffer["size"] += 1 + + def save_episode(self, episode_data: dict | None = None) -> None: + """ + This will save to disk the current episode in self.episode_buffer. + + Args: + episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will + save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to + None. + """ + if not episode_data: + episode_buffer = self.episode_buffer + + validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features) + + # size and task are special cases that won't be added to hf_dataset + episode_length = episode_buffer.pop("size") + tasks = episode_buffer.pop("task") + episode_tasks = list(set(tasks)) + episode_index = episode_buffer["episode_index"] + + episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length) + episode_buffer["episode_index"] = np.full((episode_length,), episode_index) + + # Add new tasks to the tasks dictionary + for task in episode_tasks: + task_index = self.meta.get_task_index(task) + if task_index is None: + self.meta.add_task(task) + + # Given tasks in natural language, find their corresponding task indices + episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks]) + + for key, ft in self.features.items(): + # index, episode_index, task_index are already processed above, and image and video + # are processed separately by storing image path and frame info as meta data + if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: + continue + episode_buffer[key] = np.stack(episode_buffer[key]) + + self._wait_image_writer() + self._save_episode_table(episode_buffer, episode_index) + ep_stats = compute_episode_stats(episode_buffer, self.features) + + if len(self.meta.video_keys) > 0: + video_paths = self.encode_episode_videos(episode_index) + for key in self.meta.video_keys: + episode_buffer[key] = video_paths[key] + + # `meta.save_episode` be executed after encoding the videos + self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) + + ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index]) + ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()} + check_timestamps_sync( + episode_buffer["timestamp"], + episode_buffer["episode_index"], + ep_data_index_np, + self.fps, + self.tolerance_s, + ) + + video_files = list(self.root.rglob("*.mp4")) + assert len(video_files) == self.num_episodes * len(self.meta.video_keys) + + parquet_files = list(self.root.rglob("*.parquet")) + assert len(parquet_files) == self.num_episodes + + # delete images + img_dir = self.root / "images" + if img_dir.is_dir(): + shutil.rmtree(self.root / "images") + + if not episode_data: # Reset the buffer + self.episode_buffer = self.create_episode_buffer() + + def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None: + episode_dict = {key: episode_buffer[key] for key in self.hf_features} + ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train") + ep_dataset = embed_images(ep_dataset) + self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset]) + self.hf_dataset.set_transform(hf_transform_to_torch) + ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index) + ep_data_path.parent.mkdir(parents=True, exist_ok=True) + ep_dataset.to_parquet(ep_data_path) + + def clear_episode_buffer(self) -> None: + episode_index = self.episode_buffer["episode_index"] + if self.image_writer is not None: + for cam_key in self.meta.camera_keys: + img_dir = self._get_image_file_path( + episode_index=episode_index, image_key=cam_key, frame_index=0 + ).parent + if img_dir.is_dir(): + shutil.rmtree(img_dir) + + # Reset the buffer + self.episode_buffer = self.create_episode_buffer() + + def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None: + if isinstance(self.image_writer, AsyncImageWriter): + logging.warning( + "You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset." + ) + + self.image_writer = AsyncImageWriter( + num_processes=num_processes, + num_threads=num_threads, + ) + + def stop_image_writer(self) -> None: + """ + Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to + remove the image_writer in order for the LeRobotDataset object to be picklable and parallelized. + """ + if self.image_writer is not None: + self.image_writer.stop() + self.image_writer = None + + def _wait_image_writer(self) -> None: + """Wait for asynchronous image writer to finish.""" + if self.image_writer is not None: + self.image_writer.wait_until_done() + + def encode_videos(self) -> None: + """ + Use ffmpeg to convert frames stored as png into mp4 videos. + Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, + since video encoding with ffmpeg is already using multithreading. + """ + for ep_idx in range(self.meta.total_episodes): + self.encode_episode_videos(ep_idx) + + def encode_episode_videos(self, episode_index: int) -> dict: + """ + Use ffmpeg to convert frames stored as png into mp4 videos. + Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, + since video encoding with ffmpeg is already using multithreading. + """ + video_paths = {} + for key in self.meta.video_keys: + video_path = self.root / self.meta.get_video_file_path(episode_index, key) + video_paths[key] = str(video_path) + if video_path.is_file(): + # Skip if video is already encoded. Could be the case when resuming data recording. + continue + img_dir = self._get_image_file_path( + episode_index=episode_index, image_key=key, frame_index=0 + ).parent + encode_video_frames(img_dir, video_path, self.fps, overwrite=True) + + return video_paths + + @classmethod + def create( + cls, + repo_id: str, + fps: int, + features: dict, + root: str | Path | None = None, + robot_type: str | None = None, + use_videos: bool = True, + tolerance_s: float = 1e-4, + image_writer_processes: int = 0, + image_writer_threads: int = 0, + video_backend: str | None = None, + ) -> "LeRobotDataset": + """Create a LeRobot Dataset from scratch in order to record data.""" + obj = cls.__new__(cls) + obj.meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=fps, + robot_type=robot_type, + features=features, + root=root, + use_videos=use_videos, + ) + obj.repo_id = obj.meta.repo_id + obj.root = obj.meta.root + obj.local_files_only = obj.meta.local_files_only + obj.revision = None + obj.tolerance_s = tolerance_s + obj.image_writer = None + + if image_writer_processes or image_writer_threads: + obj.start_image_writer(image_writer_processes, image_writer_threads) + + # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer + obj.episode_buffer = obj.create_episode_buffer() + + obj.episodes = None + obj.hf_dataset = obj.create_hf_dataset() + obj.image_transforms = None + obj.delta_timestamps = None + obj.delta_indices = None + obj.episode_data_index = None + obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() + return obj + +def reshape_features_to_max_dim(features: dict, reshape_dim: int = -1, keys_to_max_dim: dict = {}) -> dict: + """Reshape features to have a maximum dimension of `max_dim`.""" + reshaped_features = {} + for key in features: + if key in keys_to_max_dim and keys_to_max_dim[key] is not None: + reshaped_features[key] = features[key] + shape = list(features[key]["shape"]) + if any([k in key for k in [OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3]]): # Assume square images + shape[-3] = keys_to_max_dim[key] + shape[-2] = keys_to_max_dim[key] + else: + shape[reshape_dim] = keys_to_max_dim[key] + reshaped_features[key]["shape"] = tuple(shape) + else: + reshaped_features[key] = features[key] + return reshaped_features + +def str_to_torch_dtype(dtype_str): + """Convert a dtype string to a torch dtype.""" + mapping = { + "float32": torch.float32, + "int64": torch.int64, + "int16": torch.int16, + "bool": torch.bool, + "video": torch.float32, # Assuming video is stored as uint8 images + } + return mapping.get(dtype_str, torch.float32) +def create_padded_features(item: dict, features: dict = {}): + for key, ft in features.items(): + if any([k in key for k in ["cam", "effort", "absolute"]]): # FIXME(mshukor): temporary hack + continue + shape = ft["shape"] + if len(shape) == 3: # images to torch format (C, H, W) + shape = (shape[2], shape[0], shape[1]) + if len(shape) == 1 and shape[0] == 1: # ft with shape are actually tensor(ele) + shape = [] + if key not in item: + dtype = str_to_torch_dtype(ft["dtype"]) + item[key] = torch.zeros(shape, dtype=dtype) + item[f"{key}_padding_mask"] = torch.tensor(0, dtype=torch.int64) + if "image" in key: # FIXME(mshukor): support other observations + item[f"{key}_is_pad"] = torch.BoolTensor([False]) + else: + item[f"{key}_padding_mask"] = torch.tensor(1, dtype=torch.int64) + return item + +class MultiLeRobotDataset(torch.utils.data.Dataset): + """A dataset consisting of multiple underlying `LeRobotDataset`s. + + The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API + structure of `LeRobotDataset`. + """ + + def __init__( + self, + repo_ids: list[str], + root: str | Path | None = None, + episodes: dict | None = None, + image_transforms: Callable | None = None, + delta_timestamps: dict[list[float]] | None = None, + tolerances_s: dict | None = None, + download_videos: bool = True, + local_files_only: bool = False, + video_backend: str | None = None, + sampling_weights: list[float] | None = None, + feature_keys_mapping: dict[str, dict[str, str]] | None = None, + max_action_dim: int = None, + max_state_dim: int = None, + max_num_images: int = None, + max_image_dim: int = None, + train_on_all_features: bool = False, + training_features: list | None = None, + discard_first_n_frames: int = 0, + min_fps: int = 1, + max_fps: int = 100, + discard_first_idle_frames: bool = False, + motion_threshold: float = 0.05, + motion_window_size: int = 10, + motion_buffer: int = 3, + ): + super().__init__() + self.repo_ids = repo_ids + self.root = Path(root) if root else HF_LEROBOT_HOME + self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001) + # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which + # are handled by this class. + self._datasets = [ + LeRobotDataset( + repo_id, + root=self.root / repo_id, + episodes=episodes[repo_id] if episodes else None, + image_transforms=image_transforms, + delta_timestamps=delta_timestamps, + tolerance_s=self.tolerances_s[repo_id], + download_videos=download_videos, + video_backend=video_backend, + ) + for repo_id in repo_ids + ] + + # Disable any data keys that are not common across all of the datasets. Note: we may relax this + # restriction in future iterations of this class. For now, this is necessary at least for being able + # to use PyTorch's default DataLoader collate function. + self.disabled_features = set() + intersection_features = set(self._datasets[0].features) + for ds in self._datasets: + intersection_features.intersection_update(ds.features) + if len(intersection_features) == 0: + raise RuntimeError( + "Multiple datasets were provided but they had no keys common to all of them. " + "The multi-dataset functionality currently only keeps common keys." + ) + for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True): + extra_keys = set(ds.features).difference(intersection_features) + logging.warning( + f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " + "other datasets." + ) + self.disabled_features.update(extra_keys) + + self.image_transforms = image_transforms + self.delta_timestamps = self.delta_timestamps = delta_timestamps.get( + repo_id, None + ) + # TODO(rcadene, aliberts): We should not perform this aggregation for datasets + # with multiple robots of different ranges. Instead we should have one normalization + # per robot. + for ds in _datasets: + ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"]) + ds.robot_type = ds.meta.info["robot_type"] + #self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets]) + _datasets = keep_datasets_with_valid_fps(_datasets, min_fps=min_fps, max_fps=max_fps) + self._datasets, datasets_maks = keep_datasets_with_the_same_features_per_robot_type(_datasets) + self.sampling_weights = [self.sampling_weights[i] for i in range(len(_datasets)) if datasets_maks[i]] + self.repo_ids = [repo_ids[i] for i in range(len(_datasets)) if datasets_maks[i]] + self.datasets_repo_ids = [datasets_repo_ids[i] for i in range(len(_datasets)) if datasets_maks[i]] + # Compute cumulative sizes for fast indexing + self.cumulative_sizes = np.array( + [0] + list(torch.cumsum(torch.tensor([len(d) for d in self._datasets]), dim=0)) + ) + self.sampling_weights = np.array(self.sampling_weights, dtype=np.float32) + self.stats = aggregate_stats_per_robot_type(self._datasets) + self.meta = copy.deepcopy(self._datasets[0].meta) # FIXME(mshukor): aggregate meta from all datasets + self.meta.info = { + repo_id: ds.meta.info for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False) + } + # self.meta.info["features"] = self._datasets[0].meta.info["features"] # Assume all datasets have the same features + # FIXME(mshukor): pad based on types in case we have more than one state? + self.keys_to_max_dim = { + ACTION: max_action_dim, + OBS_ENV: max_state_dim, + OBS_ROBOT: max_state_dim, + OBS_IMAGE: max_image_dim, + OBS_IMAGE_2: max_image_dim, + OBS_IMAGE_3: max_image_dim, + } + # self.meta.info["features"] = reshape_features_to_max_dim(self._datasets[0].meta.info["features"], reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim) + self.meta.info["features"] = reshape_features_to_max_dim( + union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim + ) + # reshape stats + for robot_type_, stats_ in self.stats.items(): + for feat_key, feat_stats in stats_.items(): + if feat_key in [ACTION, OBS_ENV, OBS_ROBOT]: + for k, v in feat_stats.items(): + if k in ["min", "mean"]: + pad_value = 0 + elif k in ["max", "std"]: + pad_value = 1 + else: + continue + self.stats[robot_type_][feat_key][k] = pad_tensor(v, max_size=self.keys_to_max_dim.get(feat_key, -1), pad_dim=-1, pad_value=pad_value) + + self.meta.stats = self.stats + # self.meta.info["features"] = aggregate_features(self._datasets) + self.meta.tasks = { + repo_id: ds.meta.tasks for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False) + } + self.meta.episodes = { + repo_id: ds.meta.episodes for repo_id, ds in zip(self.repo_ids, self._datasets, strict=False) + } + self.robot_types = [ds.meta.info["robot_type"] for ds in self._datasets] + + @property + def repo_id_to_index(self): + """Return a mapping from dataset repo_id to a dataset index automatically created by this class. + + This index is incorporated as a data key in the dictionary returned by `__getitem__`. + """ + return {repo_id: i for i, repo_id in enumerate(self.repo_ids)} + + @property + def repo_index_to_id(self): + """Return the inverse mapping if repo_id_to_index.""" + return {v: k for k, v in self.repo_id_to_index} + + @property + def fps(self) -> int: + """Frames per second used during data collection. + + NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. + """ + return self._datasets[0].meta.info["fps"] + + @property + def video(self) -> bool: + """Returns True if this dataset loads video frames from mp4 files. + + Returns False if it only loads images from png files. + + NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. + """ + return self._datasets[0].meta.info.get("video", False) + + @property + def features(self) -> datasets.Features: + features = {} + for dataset in self._datasets: + features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features}) + return features + + @property + def camera_keys(self) -> list[str]: + """Keys to access image and video stream from cameras.""" + keys = [] + for key, feats in self.features.items(): + if isinstance(feats, (datasets.Image, VideoFrame)): + keys.append(key) + return keys + + @property + def video_frame_keys(self) -> list[str]: + """Keys to access video frames that requires to be decoded into images. + + Note: It is empty if the dataset contains images only, + or equal to `self.cameras` if the dataset contains videos only, + or can even be a subset of `self.cameras` in a case of a mixed image/video dataset. + """ + video_frame_keys = [] + for key, feats in self.features.items(): + if isinstance(feats, VideoFrame): + video_frame_keys.append(key) + return video_frame_keys + + @property + def num_frames(self) -> int: + """Number of samples/frames.""" + return sum(d.num_frames for d in self._datasets) + + @property + def num_episodes(self) -> int: + """Number of episodes.""" + return sum(d.num_episodes for d in self._datasets) + + @property + def tolerance_s(self) -> float: + """Tolerance in seconds used to discard loaded frames when their timestamps + are not close enough from the requested frames. It is only used when `delta_timestamps` + is provided or when loading video frames from mp4 files. + """ + # 1e-4 to account for possible numerical error + return 1 / self.fps - 1e-4 + + def __len__(self): + return self.num_frames + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + if idx >= len(self): + raise IndexError(f"Index {idx} out of bounds.") + # Determine which dataset to get an item from based on the index. + dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right").item() - 1 + local_idx = (idx - self.cumulative_sizes[dataset_idx]).item() + + + item = self._datasets[dataset_idx][local_idx] + item["dataset_index"] = torch.tensor(dataset_idx) + item = create_padded_features(item, self.meta.info["features"]) + for data_key in self.disabled_features: # FIXME(mshukor): not in getitem? + if data_key in item: + del item[data_key] + + return item + + def __repr__(self): + return ( + f"{self.__class__.__name__}(\n" + f" Repository IDs: '{self.repo_ids}',\n" + f" Number of Samples: {self.num_frames},\n" + f" Number of Episodes: {self.num_episodes},\n" + f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" + f" Recorded Frames per Second: {self.fps},\n" + f" Camera Keys: {self.camera_keys},\n" + f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" + f" Transformations: {self.image_transforms},\n" + f")" + ) diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index ef56bdb61..c3a8aa80f 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -32,6 +32,7 @@ from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig +from lerobot.policies.smolvla2.configuration_smolvla2 import SmolVLA2Config from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig @@ -74,6 +75,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy: from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy return SmolVLAPolicy + elif name == "smolvla2": + from lerobot.policies.smolvla2.modeling_smolvla2 import SmolVLA2Policy + + return SmolVLA2Policy else: raise NotImplementedError(f"Policy with name {name} is not implemented.") @@ -95,6 +100,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return SACConfig(**kwargs) elif policy_type == "smolvla": return SmolVLAConfig(**kwargs) + elif policy_type == "smolvla2": + return SmolVLA2Config(**kwargs) elif policy_type == "reward_classifier": return RewardClassifierConfig(**kwargs) else: diff --git a/lerobot/common/policies/smolvla2/configuration_smolvla2.py b/src/lerobot/policies/smolvla2/configuration_smolvla2.py similarity index 98% rename from lerobot/common/policies/smolvla2/configuration_smolvla2.py rename to src/lerobot/policies/smolvla2/configuration_smolvla2.py index dd5885841..79f179c9a 100644 --- a/lerobot/common/policies/smolvla2/configuration_smolvla2.py +++ b/src/lerobot/policies/smolvla2/configuration_smolvla2.py @@ -14,8 +14,8 @@ from dataclasses import dataclass, field -from lerobot.common.optim.optimizers import AdamWConfig -from lerobot.common.optim.schedulers import ( +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, ) from lerobot.configs.policies import PreTrainedConfig diff --git a/lerobot/common/policies/smolvla2/modeling_smolvla2.py b/src/lerobot/policies/smolvla2/modeling_smolvla2.py similarity index 99% rename from lerobot/common/policies/smolvla2/modeling_smolvla2.py rename to src/lerobot/policies/smolvla2/modeling_smolvla2.py index 4d2b89307..bf3653025 100644 --- a/lerobot/common/policies/smolvla2/modeling_smolvla2.py +++ b/src/lerobot/policies/smolvla2/modeling_smolvla2.py @@ -64,18 +64,18 @@ import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn from transformers import AutoProcessor -from lerobot.common.constants import ACTION, OBS_STATE -from lerobot.common.policies.normalize import ( +from lerobot.constants import ACTION, OBS_STATE +from lerobot.policies.normalize import ( Normalize, Unnormalize, ) -from lerobot.common.policies.pretrained import PreTrainedPolicy -from lerobot.common.policies.smolvla2.configuration_smolvla2 import SmolVLA2Config -from lerobot.common.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel -from lerobot.common.policies.utils import ( +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.smolvla2.configuration_smolvla2 import SmolVLA2Config +from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel +from lerobot.policies.utils import ( populate_queues, ) -from lerobot.common.utils.utils import get_safe_dtype +from lerobot.utils.utils import get_safe_dtype from lerobot.datasets import IMAGES_ORDER # Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker diff --git a/lerobot/common/policies/smolvla2/smolvlm_with_expert2.py b/src/lerobot/policies/smolvla2/smolvlm_with_expert2.py similarity index 100% rename from lerobot/common/policies/smolvla2/smolvlm_with_expert2.py rename to src/lerobot/policies/smolvla2/smolvlm_with_expert2.py