diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py index 4f0aa5704..6b568e893 100644 --- a/src/lerobot/datasets/pipeline_features.py +++ b/src/lerobot/datasets/pipeline_features.py @@ -27,14 +27,30 @@ def aggregate_pipeline_dataset_features( use_videos: bool = True, patterns: Sequence[str] | None = None, ) -> dict[str, dict]: - """ - Aggregates the pipeline's features and returns a features dict ready for the dataset, - filtered to only those keys matching any of the given patterns (for action/state only). + """Aggregates and filters dataset features based on a data processing pipeline. - - `initial_features`: raw camera specs, e.g. {"front": (h,w,c), ...} - - `use_videos`: whether to treat image features as video streams - - `patterns`: regexes to filter action & state features; images are included - whenever use_videos=True, regardless of patterns. + This function determines the final structure of dataset features after applying a series + of processing steps defined in a pipeline. It starts with an initial set of hardware + features (e.g., camera image shapes), transforms them using the pipeline, and then + filters the results. + + Image features are controlled by the `use_videos` flag, while action and state features + can be selectively included by matching their keys against the provided regex `patterns`. + The final output is formatted to be compatible with Hugging Face Datasets feature dictionaries. + + Args: + pipeline (DataProcessorPipeline): The data processing pipeline that defines all + feature transformations. + initial_features (dict[str, Any]): A dictionary of initial hardware features, where + keys are feature names and values are their shapes or types (e.g., camera resolutions). + use_videos (bool): If `True`, includes image/video features in the output. Defaults to `True`. + patterns (Sequence[str] | None): An optional sequence of regular expression patterns. + Only action and state keys that match at least one pattern will be included. If `None`, + all action and state keys are kept. Defaults to `None`. + + Returns: + dict[str, dict]: A dictionary representing the final dataset features, structured for + use with `datasets.Features`. """ import re diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 148726aba..4e9f852f9 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -75,13 +75,20 @@ DEFAULT_FEATURES = { def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: - """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator. + """Flatten a nested dictionary by joining keys with a separator. - For example: - ``` - >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}` - >>> print(flatten_dict(dct)) - {"a/b": 1, "a/c/d": 2, "e": 3} + Example: + >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3} + >>> print(flatten_dict(dct)) + {'a/b': 1, 'a/c/d': 2, 'e': 3} + + Args: + d (dict): The dictionary to flatten. + parent_key (str): The base key to prepend to the keys in this level. + sep (str): The separator to use between keys. + + Returns: + dict: A flattened dictionary. """ items = [] for k, v in d.items(): @@ -94,6 +101,20 @@ def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: def unflatten_dict(d: dict, sep: str = "/") -> dict: + """Unflatten a dictionary with delimited keys into a nested dictionary. + + Example: + >>> flat_dct = {"a/b": 1, "a/c/d": 2, "e": 3} + >>> print(unflatten_dict(flat_dct)) + {'a': {'b': 1, 'c': {'d': 2}}, 'e': 3} + + Args: + d (dict): A dictionary with flattened keys. + sep (str): The separator used in the keys. + + Returns: + dict: A nested dictionary. + """ outdict = {} for key, value in d.items(): parts = key.split(sep) @@ -107,6 +128,16 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict: def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any: + """Access an item in a nested dictionary using a flattened key. + + Args: + obj (DictLike): The nested dictionary-like object. + flattened_key (str): A key with parts separated by `sep`. + sep (str): The separator used in the flattened key. + + Returns: + Any: The value from the nested dictionary. + """ split_keys = flattened_key.split(sep) getter = obj[split_keys[0]] if len(split_keys) == 1: @@ -119,6 +150,19 @@ def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any: def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: + """Serialize a dictionary containing tensors or numpy arrays to be JSON-compatible. + + Converts torch.Tensor, np.ndarray, and np.generic types to lists or native Python types. + + Args: + stats (dict): A dictionary that may contain non-serializable numeric types. + + Returns: + dict: A dictionary with all values converted to JSON-serializable types. + + Raises: + NotImplementedError: If a value has an unsupported type. + """ serialized_dict = {} for key, value in flatten_dict(stats).items(): if isinstance(value, (torch.Tensor, np.ndarray)): @@ -133,6 +177,17 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: + """Embed image bytes into the dataset table before saving to Parquet. + + This function prepares a Hugging Face dataset for serialization by converting + image objects into an embedded format that can be stored in Arrow/Parquet. + + Args: + dataset (datasets.Dataset): The input dataset, possibly containing image features. + + Returns: + datasets.Dataset: The dataset with images embedded in the table storage. + """ # Embed image bytes into the table before saving to parquet format = dataset.format dataset = dataset.with_format("arrow") @@ -142,38 +197,94 @@ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: def load_json(fpath: Path) -> Any: + """Load data from a JSON file. + + Args: + fpath (Path): Path to the JSON file. + + Returns: + Any: The data loaded from the JSON file. + """ with open(fpath) as f: return json.load(f) def write_json(data: dict, fpath: Path) -> None: + """Write data to a JSON file. + + Creates parent directories if they don't exist. + + Args: + data (dict): The dictionary to write. + fpath (Path): The path to the output JSON file. + """ fpath.parent.mkdir(exist_ok=True, parents=True) with open(fpath, "w") as f: json.dump(data, f, indent=4, ensure_ascii=False) def load_jsonlines(fpath: Path) -> list[Any]: + """Load data from a JSON Lines file. + + Args: + fpath (Path): Path to the JSON Lines file. + + Returns: + list[Any]: A list of objects loaded from the file. + """ with jsonlines.open(fpath, "r") as reader: return list(reader) def write_jsonlines(data: dict, fpath: Path) -> None: + """Write a list of dictionaries to a JSON Lines file. + + Creates parent directories if they don't exist. + + Args: + data (dict): The list of dictionaries to write. + fpath (Path): The path to the output JSON Lines file. + """ fpath.parent.mkdir(exist_ok=True, parents=True) with jsonlines.open(fpath, "w") as writer: writer.write_all(data) def append_jsonlines(data: dict, fpath: Path) -> None: + """Append a dictionary to a JSON Lines file. + + Creates parent directories if they don't exist. + + Args: + data (dict): The dictionary to append. + fpath (Path): The path to the JSON Lines file. + """ fpath.parent.mkdir(exist_ok=True, parents=True) with jsonlines.open(fpath, "a") as writer: writer.write(data) def write_info(info: dict, local_dir: Path): + """Write dataset info metadata to its standard file path. + + Args: + info (dict): The dataset information dictionary. + local_dir (Path): The root directory of the dataset. + """ write_json(info, local_dir / INFO_PATH) def load_info(local_dir: Path) -> dict: + """Load dataset info metadata from its standard file path. + + Also converts shape lists to tuples for consistency. + + Args: + local_dir (Path): The root directory of the dataset. + + Returns: + dict: The dataset information dictionary. + """ info = load_json(local_dir / INFO_PATH) for ft in info["features"].values(): ft["shape"] = tuple(ft["shape"]) @@ -181,16 +292,40 @@ def load_info(local_dir: Path) -> dict: def write_stats(stats: dict, local_dir: Path): + """Serialize and write dataset statistics to their standard file path. + + Args: + stats (dict): The statistics dictionary (can contain tensors/numpy arrays). + local_dir (Path): The root directory of the dataset. + """ serialized_stats = serialize_dict(stats) write_json(serialized_stats, local_dir / STATS_PATH) def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]: + """Recursively cast numerical values in a stats dictionary to numpy arrays. + + Args: + stats (dict): The statistics dictionary. + + Returns: + dict: The statistics dictionary with values cast to numpy arrays. + """ stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} return unflatten_dict(stats) def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]: + """Load dataset statistics and cast numerical values to numpy arrays. + + Returns None if the stats file doesn't exist. + + Args: + local_dir (Path): The root directory of the dataset. + + Returns: + A dictionary of statistics or None if the file is not found. + """ if not (local_dir / STATS_PATH).exists(): return None stats = load_json(local_dir / STATS_PATH) @@ -198,6 +333,13 @@ def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]: def write_task(task_index: int, task: dict, local_dir: Path): + """Write a single task to the tasks metadata file. + + Args: + task_index (int): The index of the task. + task (dict): The task description dictionary. + local_dir (Path): The root directory of the dataset. + """ task_dict = { "task_index": task_index, "task": task, @@ -206,6 +348,16 @@ def write_task(task_index: int, task: dict, local_dir: Path): def load_tasks(local_dir: Path) -> tuple[dict, dict]: + """Load tasks from the tasks metadata file. + + Args: + local_dir (Path): The root directory of the dataset. + + Returns: + A tuple containing: + - A dictionary mapping task index to task description. + - A dictionary mapping task description to task index. + """ tasks = load_jsonlines(local_dir / TASKS_PATH) tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} task_to_task_index = {task: task_index for task_index, task in tasks.items()} @@ -213,15 +365,36 @@ def load_tasks(local_dir: Path) -> tuple[dict, dict]: def write_episode(episode: dict, local_dir: Path): + """Write a single episode's metadata to the episodes metadata file. + + Args: + episode (dict): The episode metadata dictionary. + local_dir (Path): The root directory of the dataset. + """ append_jsonlines(episode, local_dir / EPISODES_PATH) def load_episodes(local_dir: Path) -> dict: + """Load episode metadata from the episodes metadata file. + + Args: + local_dir (Path): The root directory of the dataset. + + Returns: + dict: A dictionary mapping episode index to episode metadata. + """ episodes = load_jsonlines(local_dir / EPISODES_PATH) return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])} def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path): + """Write statistics for a single episode to the episode stats file. + + Args: + episode_index (int): The index of the episode. + episode_stats (dict): The statistics for the episode. + local_dir (Path): The root directory of the dataset. + """ # We wrap episode_stats in a dictionary since `episode_stats["episode_index"]` # is a dictionary of stats and not an integer. episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)} @@ -229,6 +402,14 @@ def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path def load_episodes_stats(local_dir: Path) -> dict: + """Load per-episode statistics from the episode stats file. + + Args: + local_dir (Path): The root directory of the dataset. + + Returns: + dict: A dictionary mapping episode index to its statistics dictionary. + """ episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH) return { item["episode_index"]: cast_stats_to_numpy(item["stats"]) @@ -239,12 +420,35 @@ def load_episodes_stats(local_dir: Path) -> dict: def backward_compatible_episodes_stats( stats: dict[str, dict[str, np.ndarray]], episodes: list[int] ) -> dict[str, dict[str, np.ndarray]]: + """Create a per-episode stats dictionary from a global stats dictionary. + + This is used for backward compatibility with older datasets that only had global stats. + + Args: + stats (dict): The global dataset statistics. + episodes (list[int]): A list of episode indices. + + Returns: + dict: A dictionary mapping each episode index to the global stats. + """ return dict.fromkeys(episodes, stats) def load_image_as_numpy( fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True ) -> np.ndarray: + """Load an image from a file into a numpy array. + + Args: + fpath (str | Path): Path to the image file. + dtype (np.dtype): The desired data type of the output array. If floating, + pixels are scaled to [0, 1]. + channel_first (bool): If True, converts the image to (C, H, W) format. + Otherwise, it remains in (H, W, C) format. + + Returns: + np.ndarray: The image as a numpy array. + """ img = PILImage.open(fpath).convert("RGB") img_array = np.array(img, dtype=dtype) if channel_first: # (H, W, C) -> (C, H, W) @@ -255,10 +459,19 @@ def load_image_as_numpy( def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): - """Get a transform function that convert items from Hugging Face dataset (pyarrow) - to torch tensors. Importantly, images are converted from PIL, which corresponds to - a channel last representation (h w c) of uint8 type, to a torch image representation - with channel first (c h w) of float32 type in range [0,1]. + """Convert a batch from a Hugging Face dataset to torch tensors. + + This transform function converts items from Hugging Face dataset format (pyarrow) + to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8) + to a torch image representation (C, H, W, float32) in the range [0, 1]. Other + types are converted to torch.tensor. + + Args: + items_dict (dict): A dictionary representing a batch of data from a + Hugging Face dataset. + + Returns: + dict: The batch with items converted to torch tensors. """ for key in items_dict: first_item = items_dict[key][0] @@ -273,6 +486,14 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): def is_valid_version(version: str) -> bool: + """Check if a string is a valid PEP 440 version. + + Args: + version (str): The version string to check. + + Returns: + bool: True if the version string is valid, False otherwise. + """ try: packaging.version.parse(version) return True @@ -286,6 +507,18 @@ def check_version_compatibility( current_version: str | packaging.version.Version, enforce_breaking_major: bool = True, ) -> None: + """Check for version compatibility between a dataset and the current codebase. + + Args: + repo_id (str): The repository ID for logging purposes. + version_to_check (str | packaging.version.Version): The version of the dataset. + current_version (str | packaging.version.Version): The current version of the codebase. + enforce_breaking_major (bool): If True, raise an error on major version mismatch. + + Raises: + BackwardCompatibilityError: If the dataset version is from a newer, incompatible + major version of the codebase. + """ v_check = ( packaging.version.parse(version_to_check) if not isinstance(version_to_check, packaging.version.Version) @@ -303,7 +536,14 @@ def check_version_compatibility( def get_repo_versions(repo_id: str) -> list[packaging.version.Version]: - """Returns available valid versions (branches and tags) on given repo.""" + """Return available valid versions (branches and tags) on a given Hub repo. + + Args: + repo_id (str): The repository ID on the Hugging Face Hub. + + Returns: + list[packaging.version.Version]: A list of valid versions found. + """ api = HfApi() repo_refs = api.list_repo_refs(repo_id, repo_type="dataset") repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags] @@ -316,9 +556,22 @@ def get_repo_versions(repo_id: str) -> list[packaging.version.Version]: def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str: - """ - Returns the version if available on repo or the latest compatible one. - Otherwise, will throw a `CompatibilityError`. + """Return the specified version if available on repo, or the latest compatible one. + + If the exact version is not found, it looks for the latest version with the + same major version number that is less than or equal to the target minor version. + + Args: + repo_id (str): The repository ID on the Hugging Face Hub. + version (str | packaging.version.Version): The target version. + + Returns: + str: The safe version string (e.g., "v1.2.3") to use as a revision. + + Raises: + RevisionNotFoundError: If the repo has no version tags. + BackwardCompatibilityError: If only older major versions are available. + ForwardCompatibilityError: If only newer major versions are available. """ target_version = ( packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version @@ -360,6 +613,17 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> def get_hf_features_from_features(features: dict) -> datasets.Features: + """Convert a LeRobot features dictionary to a `datasets.Features` object. + + Args: + features (dict): A LeRobot-style feature dictionary. + + Returns: + datasets.Features: The corresponding Hugging Face `datasets.Features` object. + + Raises: + ValueError: If a feature has an unsupported shape. + """ hf_features = {} for key, ft in features.items(): if ft["dtype"] == "video": @@ -387,6 +651,14 @@ def get_hf_features_from_features(features: dict) -> datasets.Features: def _validate_feature_names(features: dict[str, dict]) -> None: + """Validate that feature names do not contain invalid characters. + + Args: + features (dict): The LeRobot features dictionary. + + Raises: + ValueError: If any feature name contains '/'. + """ invalid_features = {name: ft for name, ft in features.items() if "/" in name} if invalid_features: raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.") @@ -395,6 +667,22 @@ def _validate_feature_names(features: dict[str, dict]) -> None: def hw_to_dataset_features( hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True ) -> dict[str, dict]: + """Convert hardware-specific features to a LeRobot dataset feature dictionary. + + This function takes a dictionary describing hardware outputs (like joint states + or camera image shapes) and formats it into the standard LeRobot feature + specification. + + Args: + hw_features (dict): Dictionary mapping feature names to their type (float for + joints) or shape (tuple for images). + prefix (str): The prefix to add to the feature keys (e.g., "observation" + or "action"). + use_video (bool): If True, image features are marked as "video", otherwise "image". + + Returns: + dict: A LeRobot features dictionary. + """ features = {} joint_fts = {key: ftype for key, ftype in hw_features.items() if ftype is float} cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} @@ -427,6 +715,20 @@ def hw_to_dataset_features( def build_dataset_frame( ds_features: dict[str, dict], values: dict[str, Any], prefix: str ) -> dict[str, np.ndarray]: + """Construct a single data frame from raw values based on dataset features. + + A "frame" is a dictionary containing all the data for a single timestep, + formatted as numpy arrays according to the feature specification. + + Args: + ds_features (dict): The LeRobot dataset features dictionary. + values (dict): A dictionary of raw values from the hardware/environment. + prefix (str): The prefix to filter features by (e.g., "observation" + or "action"). + + Returns: + dict: A dictionary representing a single frame of data. + """ frame = {} for key, ft in ds_features.items(): if key in DEFAULT_FEATURES or not key.startswith(prefix): @@ -440,6 +742,21 @@ def build_dataset_frame( def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: + """Convert dataset features to policy features. + + This function transforms the dataset's feature specification into a format + that a policy can use, classifying features by type (e.g., visual, state, + action) and ensuring correct shapes (e.g., channel-first for images). + + Args: + features (dict): The LeRobot dataset features dictionary. + + Returns: + dict: A dictionary mapping feature keys to `PolicyFeature` objects. + + Raises: + ValueError: If an image feature does not have a 3D shape. + """ # TODO(aliberts): Implement "type" in dataset features and simplify this policy_features = {} for key, ft in features.items(): @@ -471,11 +788,19 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea def combine_feature_dicts(*dicts: dict) -> dict: - """ - Merge LeRobot grouped feature dicts. + """Merge LeRobot grouped feature dicts. - For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape. - - For others (observation.images.*), last one wins (if they are identical). + - For others (e.g. `observation.images.*`), the last one wins (if they are identical). + + Args: + *dicts: A variable number of LeRobot feature dictionaries to merge. + + Returns: + dict: A single merged feature dictionary. + + Raises: + ValueError: If there's a dtype mismatch for a feature being merged. """ out: dict = {} for d in dicts: @@ -521,6 +846,18 @@ def create_empty_dataset_info( use_videos: bool, robot_type: str | None = None, ) -> dict: + """Create a template dictionary for a new dataset's `info.json`. + + Args: + codebase_version (str): The version of the LeRobot codebase. + fps (int): The frames per second of the data. + features (dict): The LeRobot features dictionary for the dataset. + use_videos (bool): Whether the dataset will store videos. + robot_type (str | None): The type of robot used, if any. + + Returns: + dict: A dictionary with the initial dataset metadata. + """ return { "codebase_version": codebase_version, "robot_type": robot_type, @@ -541,6 +878,18 @@ def create_empty_dataset_info( def get_episode_data_index( episode_dicts: dict[dict], episodes: list[int] | None = None ) -> dict[str, torch.Tensor]: + """Calculate the start and end indices for each episode in a flattened dataset. + + Args: + episode_dicts (dict): A dictionary mapping episode index to episode metadata, + which must contain a "length" key. + episodes (list[int] | None): An optional list of episode indices to consider. + If None, all episodes are used. + + Returns: + dict: A dictionary with "from" and "to" keys, containing torch tensors + with the start and end indices for each episode. + """ episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()} if episodes is not None: episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes} @@ -560,16 +909,19 @@ def check_timestamps_sync( tolerance_s: float, raise_value_error: bool = True, ) -> bool: - """ - This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance - to account for possible numerical error. + """Check if timestamps are separated by (1/fps) +/- tolerance. + + This check ensures that consecutive timestamps within an episode are spaced + correctly, accounting for possible numerical errors. It ignores the boundaries + between episodes. Args: timestamps (np.ndarray): Array of timestamps in seconds. episode_indices (np.ndarray): Array indicating the episode index for each timestamp. - episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to', + episode_data_index (dict): A dictionary that includes 'to', which identifies indices for the end of each episode. - fps (int): Frames per second. Used to check the expected difference between consecutive timestamps. + fps (int): Frames per second. Used to check the expected difference between + consecutive timestamps. tolerance_s (float): Allowed deviation from the expected (1/fps) difference. raise_value_error (bool): Whether to raise a ValueError if the check fails. @@ -577,7 +929,8 @@ def check_timestamps_sync( bool: True if all checked timestamp differences lie within tolerance, False otherwise. Raises: - ValueError: If the check fails and `raise_value_error` is True. + ValueError: If `timestamps` and `episode_indices` shapes do not match, or if + the check fails and `raise_value_error` is True. """ if timestamps.shape != episode_indices.shape: raise ValueError( @@ -628,9 +981,23 @@ def check_timestamps_sync( def check_delta_timestamps( delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True ) -> bool: - """This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance. - This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be - actual timestamps from the dataset. + """Check if delta timestamps are multiples of 1/fps +/- tolerance. + + This ensures that adding these delta timestamps to any existing timestamp in + the dataset will result in a value that aligns with the dataset's frame rate. + + Args: + delta_timestamps (dict): A dictionary where values are lists of time + deltas in seconds. + fps (int): The frames per second of the dataset. + tolerance_s (float): The allowed tolerance in seconds. + raise_value_error (bool): If True, raises an error on failure. + + Returns: + bool: True if all deltas are valid, False otherwise. + + Raises: + ValueError: If any delta is outside the tolerance and `raise_value_error` is True. """ outside_tolerance = {} for key, delta_ts in delta_timestamps.items(): @@ -656,6 +1023,15 @@ def check_delta_timestamps( def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]: + """Convert delta timestamps in seconds to delta indices in frames. + + Args: + delta_timestamps (dict): A dictionary of time deltas in seconds. + fps (int): The frames per second of the dataset. + + Returns: + dict: A dictionary of frame delta indices. + """ delta_indices = {} for key, delta_ts in delta_timestamps.items(): delta_indices[key] = [round(d * fps) for d in delta_ts] @@ -664,9 +1040,17 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic def cycle(iterable): - """The equivalent of itertools.cycle, but safe for Pytorch dataloaders. + """Create a dataloader-safe cyclical iterator. - See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe. + This is an equivalent of `itertools.cycle` but is safe for use with + PyTorch DataLoaders with multiple workers. + See https://github.com/pytorch/pytorch/issues/23900 for details. + + Args: + iterable: The iterable to cycle over. + + Yields: + Items from the iterable, restarting from the beginning when exhausted. """ iterator = iter(iterable) while True: @@ -677,8 +1061,14 @@ def cycle(iterable): def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None: - """Create a branch on a existing Hugging Face repo. Delete the branch if it already - exists before creating it. + """Create a branch on an existing Hugging Face repo. + + Deletes the branch if it already exists before creating it. + + Args: + repo_id (str): The ID of the repository. + branch (str): The name of the branch to create. + repo_type (str | None): The type of the repository (e.g., "dataset"). """ api = HfApi() @@ -696,9 +1086,20 @@ def create_lerobot_dataset_card( dataset_info: dict | None = None, **kwargs, ) -> DatasetCard: - """ - Keyword arguments will be used to replace values in src/lerobot/datasets/card_template.md. - Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses. + """Create a `DatasetCard` for a LeRobot dataset. + + Keyword arguments are used to replace values in the card template. + Note: If specified, `license` must be a valid license identifier from + https://huggingface.co/docs/hub/repositories-licenses. + + Args: + tags (list | None): A list of tags to add to the dataset card. + dataset_info (dict | None): The dataset's info dictionary, which will + be displayed on the card. + **kwargs: Additional keyword arguments to populate the card template. + + Returns: + DatasetCard: The generated dataset card object. """ card_tags = ["LeRobot"] @@ -730,19 +1131,16 @@ def create_lerobot_dataset_card( class IterableNamespace(SimpleNamespace): - """ - A namespace object that supports both dictionary-like iteration and dot notation access. - Automatically converts nested dictionaries into IterableNamespaces. + """A namespace object that supports both dictionary-like iteration and dot notation. - This class extends SimpleNamespace to provide: - - Dictionary-style iteration over keys - - Access to items via both dot notation (obj.key) and brackets (obj["key"]) - - Dictionary-like methods: items(), keys(), values() - - Recursive conversion of nested dictionaries + This class extends `SimpleNamespace` to provide dictionary-style iteration, + access to items via brackets (`obj["key"]`), and dictionary-like methods + (`items()`, `keys()`, `values()`). Nested dictionaries are recursively + converted to `IterableNamespace` objects. Args: - dictionary: Optional dictionary to initialize the namespace - **kwargs: Additional keyword arguments passed to SimpleNamespace + dictionary (dict, optional): A dictionary to initialize the namespace with. + **kwargs: Additional keyword arguments to initialize the namespace. Examples: >>> data = {"name": "Alice", "details": {"age": 25}} @@ -756,10 +1154,16 @@ class IterableNamespace(SimpleNamespace): >>> for key, value in ns.items(): ... print(f"{key}: {value}") name: Alice - details: IterableNamespace(age=25) + details: <__main__.IterableNamespace object at ...> """ def __init__(self, dictionary: dict[str, Any] = None, **kwargs): + """Initialize the IterableNamespace. + + Args: + dictionary (dict, optional): Dictionary to populate the namespace. + **kwargs: Keyword arguments to populate the namespace. + """ super().__init__(**kwargs) if dictionary is not None: for key, value in dictionary.items(): @@ -769,22 +1173,46 @@ class IterableNamespace(SimpleNamespace): setattr(self, key, value) def __iter__(self) -> Iterator[str]: + """Return an iterator over the keys of the namespace.""" return iter(vars(self)) def __getitem__(self, key: str) -> Any: + """Allow bracket-style access to attributes. + + Args: + key (str): The name of the attribute. + + Returns: + Any: The value of the attribute. + """ return vars(self)[key] def items(self): + """Return a view of the namespace's (key, value) pairs.""" return vars(self).items() def values(self): + """Return a view of the namespace's values.""" return vars(self).values() def keys(self): + """Return a view of the namespace's keys.""" return vars(self).keys() def validate_frame(frame: dict, features: dict): + """Validate a single data frame against the dataset's feature specification. + + Checks for missing/extra features, and validates the dtype and shape of each + provided feature. + + Args: + frame (dict): The data frame to validate. + features (dict): The LeRobot features dictionary for the dataset. + + Raises: + ValueError: If the frame does not match the feature specification. + """ expected_features = set(features) - set(DEFAULT_FEATURES) actual_features = set(frame) @@ -799,6 +1227,15 @@ def validate_frame(frame: dict, features: dict): def validate_features_presence(actual_features: set[str], expected_features: set[str]): + """Check for missing or extra features in a frame. + + Args: + actual_features (set[str]): The set of feature names present in the frame. + expected_features (set[str]): The set of feature names expected in the frame. + + Returns: + str: An error message string if there's a mismatch, otherwise an empty string. + """ error_message = "" missing_features = expected_features - actual_features extra_features = actual_features - expected_features @@ -814,6 +1251,19 @@ def validate_features_presence(actual_features: set[str], expected_features: set def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str): + """Validate the dtype and shape of a single feature's value. + + Args: + name (str): The name of the feature. + feature (dict): The feature specification from the LeRobot features dictionary. + value: The value of the feature to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + + Raises: + NotImplementedError: If the feature dtype is not supported for validation. + """ expected_dtype = feature["dtype"] expected_shape = feature["shape"] if is_valid_numpy_dtype_string(expected_dtype): @@ -829,6 +1279,17 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray def validate_feature_numpy_array( name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray ): + """Validate a feature that is expected to be a numpy array. + + Args: + name (str): The name of the feature. + expected_dtype (str): The expected numpy dtype as a string. + expected_shape (list[int]): The expected shape. + value (np.ndarray): The numpy array to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + """ error_message = "" if isinstance(value, np.ndarray): actual_dtype = value.dtype @@ -846,6 +1307,18 @@ def validate_feature_numpy_array( def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image): + """Validate a feature that is expected to be an image or video frame. + + Accepts `np.ndarray` (channel-first or channel-last) or `PIL.Image.Image`. + + Args: + name (str): The name of the feature. + expected_shape (list[str]): The expected shape (C, H, W). + value: The image data to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + """ # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. error_message = "" if isinstance(value, np.ndarray): @@ -862,12 +1335,35 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value: def validate_feature_string(name: str, value: str): + """Validate a feature that is expected to be a string. + + Args: + name (str): The name of the feature. + value (str): The value to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + """ if not isinstance(value, str): return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" return "" def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict): + """Validate the episode buffer before it's written to disk. + + Ensures the buffer has the required keys, contains at least one frame, and + has features consistent with the dataset's specification. + + Args: + episode_buffer (dict): The buffer containing data for a single episode. + total_episodes (int): The current total number of episodes in the dataset. + features (dict): The LeRobot features dictionary for the dataset. + + Raises: + ValueError: If the buffer is invalid. + NotImplementedError: If the episode index is manually set and doesn't match. + """ if "size" not in episode_buffer: raise ValueError("size key not found in episode_buffer") diff --git a/src/lerobot/policies/act/processor_act.py b/src/lerobot/policies/act/processor_act.py index aec922839..2e17c9a89 100644 --- a/src/lerobot/policies/act/processor_act.py +++ b/src/lerobot/policies/act/processor_act.py @@ -34,6 +34,24 @@ def make_act_pre_post_processors( preprocessor_kwargs: ProcessorKwargs | None = None, postprocessor_kwargs: ProcessorKwargs | None = None, ) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: + """Creates the pre- and post-processing pipelines for the ACT policy. + + The pre-processing pipeline handles normalization, batching, and device placement for the model inputs. + The post-processing pipeline handles unnormalization and moves the model outputs back to the CPU. + + Args: + config (ACTConfig): The ACT policy configuration object. + dataset_stats (dict[str, dict[str, torch.Tensor]] | None): A dictionary containing dataset + statistics (e.g., mean and std) used for normalization. Defaults to None. + preprocessor_kwargs (ProcessorKwargs | None): Extra keyword arguments to pass to the + preprocessor pipeline's constructor. Defaults to None. + postprocessor_kwargs (ProcessorKwargs | None): Extra keyword arguments to pass to the + postprocessor pipeline's constructor. Defaults to None. + + Returns: + tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: A tuple containing the + pre-processor pipeline and the post-processor pipeline. + """ if preprocessor_kwargs is None: preprocessor_kwargs = {} if postprocessor_kwargs is None: diff --git a/src/lerobot/policies/diffusion/processor_diffusion.py b/src/lerobot/policies/diffusion/processor_diffusion.py index 2d7868b25..3c73e7bc1 100644 --- a/src/lerobot/policies/diffusion/processor_diffusion.py +++ b/src/lerobot/policies/diffusion/processor_diffusion.py @@ -35,6 +35,32 @@ def make_diffusion_pre_post_processors( preprocessor_kwargs: ProcessorKwargs | None = None, postprocessor_kwargs: ProcessorKwargs | None = None, ) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: + """ + Constructs pre-processor and post-processor pipelines for a diffusion policy. + + The pre-processing pipeline prepares the input data for the model by: + 1. Renaming features (if a `rename_map` is provided in `preprocessor_kwargs`). + 2. Normalizing the input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Moving the data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving the data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the diffusion policy, + containing feature definitions, normalization mappings, and device information. + dataset_stats: A dictionary of statistics used for normalization. + Defaults to None. + preprocessor_kwargs: Additional keyword arguments + for the pre-processor pipeline. Defaults to an empty dictionary. + postprocessor_kwargs: Additional keyword arguments + for the post-processor pipeline. Defaults to an empty dictionary. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ if preprocessor_kwargs is None: preprocessor_kwargs = {} if postprocessor_kwargs is None: diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index c251210b3..520f2342d 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -43,7 +43,22 @@ from lerobot.processor import PolicyProcessorPipeline, ProcessorKwargs def get_policy_class(name: str) -> type[PreTrainedPolicy]: - """Get the policy's class and config class given a name (matching the policy class' `name` attribute).""" + """ + Retrieves a policy class by its registered name. + + This function uses dynamic imports to avoid loading all policy classes into memory + at once, improving startup time and reducing dependencies. + + Args: + name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", + "vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla". + + Returns: + The policy class corresponding to the given name. + + Raises: + NotImplementedError: If the policy name is not recognized. + """ if name == "tdmpc": from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy @@ -85,6 +100,24 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: + """ + Instantiates a policy configuration object based on the policy type. + + This factory function simplifies the creation of policy configuration objects by + mapping a string identifier to the corresponding config class. + + Args: + policy_type: The type of the policy. Supported types include "tdmpc", + "diffusion", "act", "vqbet", "pi0", "pi0fast", "sac", "smolvla", + "reward_classifier". + **kwargs: Keyword arguments to be passed to the configuration class constructor. + + Returns: + An instance of a `PreTrainedConfig` subclass. + + Raises: + ValueError: If the `policy_type` is not recognized. + """ if policy_type == "tdmpc": return TDMPCConfig(**kwargs) elif policy_type == "diffusion": @@ -108,7 +141,21 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: class ProcessorConfigKwargs(TypedDict, total=False): - """Keyword arguments for the processor config.""" + """ + A TypedDict defining the keyword arguments for processor configuration. + + This provides type hints for the optional arguments passed to `make_pre_post_processors`, + improving code clarity and enabling static analysis. + + Attributes: + preprocessor_config_filename: The filename for the preprocessor configuration. + postprocessor_config_filename: The filename for the postprocessor configuration. + preprocessor_overrides: A dictionary of overrides for the preprocessor configuration. + postprocessor_overrides: A dictionary of overrides for the postprocessor configuration. + dataset_stats: Dataset statistics for normalization. + preprocessor_kwargs: Additional arguments for the `PolicyProcessorPipeline`. + postprocessor_kwargs: Additional arguments for the `PolicyProcessorPipeline`. + """ preprocessor_config_filename: str | None postprocessor_config_filename: str | None @@ -124,22 +171,27 @@ def make_pre_post_processors( pretrained_path: str | None = None, **kwargs: Unpack[ProcessorConfigKwargs], ) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: - """Make a processor instance for a given policy type. + """ + Create or load pre- and post-processor pipelines for a given policy. - This function creates the appropriate processor configuration based on the policy type. - Each policy type has its own processor with specific preprocessing steps. + This function acts as a factory. It can either load existing processor pipelines + from a pretrained path or create new ones from scratch based on the policy + configuration. Each policy type has a dedicated factory function for its + processors (e.g., `make_tdmpc_pre_post_processors`). Args: - policy_cfg: The config of the policy to create a processor for (e.g., "act", "diffusion", etc.) - pretrained_path: Optional path to load a pretrained processor from. If provided, loads - the processor from this path instead of creating a new one. - **kwargs: Additional keyword arguments passed to the processor creation. + policy_cfg: The configuration of the policy for which to create processors. + pretrained_path: An optional path to load pretrained processor pipelines from. + If provided, pipelines are loaded from this path. + **kwargs: Keyword arguments for processor configuration, as defined in + `ProcessorConfigKwargs`. Returns: - Tuple of (input_processor, output_processor) for the policy. + A tuple containing the input (pre-processor) and output (post-processor) pipelines. Raises: - NotImplementedError: If the policy type doesn't have a processor implemented. + NotImplementedError: If a processor factory is not implemented for the given + policy configuration type. """ if pretrained_path: # Extract preprocessor and postprocessor kwargs @@ -269,25 +321,29 @@ def make_policy( ds_meta: LeRobotDatasetMetadata | None = None, env_cfg: EnvConfig | None = None, ) -> PreTrainedPolicy: - """Make an instance of a policy class. + """ + Instantiate a policy model. - This function exists because (for now) we need to parse features from either a dataset or an environment - in order to properly dimension and instantiate a policy for that dataset or environment. + This factory function handles the logic of creating a policy, which requires + determining the input and output feature shapes. These shapes can be derived + either from a `LeRobotDatasetMetadata` object or an `EnvConfig` object. The function + can either initialize a new policy from scratch or load a pretrained one. Args: - cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will - be loaded with the weights from that path. - ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and - statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None. - env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be - provided if ds_meta is not. Defaults to None. - - Raises: - ValueError: Either ds_meta or env and env_cfg must be provided. - NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility) + cfg: The configuration for the policy to be created. If `cfg.pretrained_path` is + set, the policy will be loaded with weights from that path. + ds_meta: Dataset metadata used to infer feature shapes and types. Also provides + statistics for normalization layers. + env_cfg: Environment configuration used to infer feature shapes and types. + One of `ds_meta` or `env_cfg` must be provided. Returns: - PreTrainedPolicy: _description_ + An instantiated and device-placed policy model. + + Raises: + ValueError: If both or neither of `ds_meta` and `env_cfg` are provided. + NotImplementedError: If attempting to use an unsupported policy-backend + combination (e.g., VQBeT with 'mps'). """ if bool(ds_meta) == bool(env_cfg): raise ValueError("Either one of a dataset metadata or a sim env must be provided.") diff --git a/src/lerobot/policies/pi0/processor_pi0.py b/src/lerobot/policies/pi0/processor_pi0.py index acdc0dca9..766b7d7f9 100644 --- a/src/lerobot/policies/pi0/processor_pi0.py +++ b/src/lerobot/policies/pi0/processor_pi0.py @@ -37,11 +37,25 @@ from lerobot.processor import ( @ProcessorStepRegistry.register(name="pi0_new_line_processor") class Pi0NewLineProcessor(ComplementaryDataProcessorStep): - """Add a new line to the end of the task if it doesn't have one. - This is required for the PaliGemma tokenizer. + """ + Ensures that the task description string ends with a newline character. + + This processing step is required for compatibility with the PaliGemma tokenizer, + which expects a newline at the end of the text prompt. It handles both single + strings and lists of strings for the 'task' key in complementary data. """ def complementary_data(self, complementary_data): + """ + Adds a newline to the 'task' field if it doesn't already have one. + + Args: + complementary_data: A dictionary that may contain a 'task' key with a + string or list of strings. + + Returns: + A new dictionary with the modified 'task' field. + """ if "task" not in complementary_data: return complementary_data @@ -64,6 +78,15 @@ class Pi0NewLineProcessor(ComplementaryDataProcessorStep): return new_complementary_data def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """ + This step does not alter the feature definitions. + + Args: + features: The input feature dictionary. + + Returns: + The unchanged feature dictionary. + """ return features @@ -73,6 +96,30 @@ def make_pi0_pre_post_processors( preprocessor_kwargs: ProcessorKwargs | None = None, postprocessor_kwargs: ProcessorKwargs | None = None, ) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: + """ + Constructs pre-processor and post-processor pipelines for the PI0 policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Appending a newline character to the task description for tokenizer compatibility. + 5. Tokenizing the text prompt using the PaliGemma tokenizer. + 6. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the PI0 policy. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ if preprocessor_kwargs is None: preprocessor_kwargs = {} if postprocessor_kwargs is None: diff --git a/src/lerobot/policies/pi0fast/processor_pi0fast.py b/src/lerobot/policies/pi0fast/processor_pi0fast.py index 38882c21f..62d255686 100644 --- a/src/lerobot/policies/pi0fast/processor_pi0fast.py +++ b/src/lerobot/policies/pi0fast/processor_pi0fast.py @@ -17,7 +17,7 @@ import torch from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME -from lerobot.policies.pi0.configuration_pi0 import PI0Config +from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, DeviceProcessorStep, @@ -30,11 +30,33 @@ from lerobot.processor import ( def make_pi0fast_pre_post_processors( - config: PI0Config, + config: PI0FASTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, preprocessor_kwargs: ProcessorKwargs | None = None, postprocessor_kwargs: ProcessorKwargs | None = None, ) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: + """ + Constructs pre-processor and post-processor pipelines for the PI0Fast policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the PI0Fast policy. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ if preprocessor_kwargs is None: preprocessor_kwargs = {} if postprocessor_kwargs is None: diff --git a/src/lerobot/policies/sac/processor_sac.py b/src/lerobot/policies/sac/processor_sac.py index 9130a196d..0098f1999 100644 --- a/src/lerobot/policies/sac/processor_sac.py +++ b/src/lerobot/policies/sac/processor_sac.py @@ -36,6 +36,28 @@ def make_sac_pre_post_processors( preprocessor_kwargs: ProcessorKwargs | None = None, postprocessor_kwargs: ProcessorKwargs | None = None, ) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: + """ + Constructs pre-processor and post-processor pipelines for the SAC policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the SAC policy. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ if preprocessor_kwargs is None: preprocessor_kwargs = {} if postprocessor_kwargs is None: diff --git a/src/lerobot/policies/sac/reward_model/processor_classifier.py b/src/lerobot/policies/sac/reward_model/processor_classifier.py index 70195d69d..571ccdfd9 100644 --- a/src/lerobot/policies/sac/reward_model/processor_classifier.py +++ b/src/lerobot/policies/sac/reward_model/processor_classifier.py @@ -31,6 +31,26 @@ def make_classifier_processor( preprocessor_kwargs: ProcessorKwargs | None = None, postprocessor_kwargs: ProcessorKwargs | None = None, ) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: + """ + Constructs pre-processor and post-processor pipelines for the reward classifier. + + The pre-processing pipeline prepares input data for the classifier by: + 1. Normalizing both input and output features based on dataset statistics. + 2. Moving the data to the specified device. + + The post-processing pipeline handles the classifier's output by: + 1. Moving the data to the CPU. + 2. Applying an identity step, as no unnormalization is needed for the output logits. + + Args: + config: The configuration object for the RewardClassifier. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ if preprocessor_kwargs is None: preprocessor_kwargs = {} if postprocessor_kwargs is None: diff --git a/src/lerobot/policies/smolvla/processor_smolvla.py b/src/lerobot/policies/smolvla/processor_smolvla.py index 1f2abcede..00b479f42 100644 --- a/src/lerobot/policies/smolvla/processor_smolvla.py +++ b/src/lerobot/policies/smolvla/processor_smolvla.py @@ -39,6 +39,30 @@ def make_smolvla_pre_post_processors( preprocessor_kwargs: ProcessorKwargs | None = None, postprocessor_kwargs: ProcessorKwargs | None = None, ) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: + """ + Constructs pre-processor and post-processor pipelines for the SmolVLA policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Ensuring the language task description ends with a newline character. + 5. Tokenizing the language task description. + 6. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output actions to their original scale. + + Args: + config: The configuration object for the SmolVLA policy. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ if preprocessor_kwargs is None: preprocessor_kwargs = {} if postprocessor_kwargs is None: @@ -83,7 +107,13 @@ def make_smolvla_pre_post_processors( @ProcessorStepRegistry.register(name="smolvla_new_line_processor") class SmolVLANewLineProcessor(ComplementaryDataProcessorStep): - """Add a new line to the end of the task if it doesn't have one.""" + """ + A processor step that ensures the 'task' description ends with a newline character. + + This step is necessary for certain tokenizers (e.g., PaliGemma) that expect a + newline at the end of the prompt. It handles both single string tasks and lists + of string tasks. + """ def complementary_data(self, complementary_data): if "task" not in complementary_data: diff --git a/src/lerobot/policies/tdmpc/processor_tdmpc.py b/src/lerobot/policies/tdmpc/processor_tdmpc.py index d131972cb..77497bd23 100644 --- a/src/lerobot/policies/tdmpc/processor_tdmpc.py +++ b/src/lerobot/policies/tdmpc/processor_tdmpc.py @@ -35,6 +35,28 @@ def make_tdmpc_pre_post_processors( preprocessor_kwargs: ProcessorKwargs | None = None, postprocessor_kwargs: ProcessorKwargs | None = None, ) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: + """ + Constructs pre-processor and post-processor pipelines for the TDMPC policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the TDMPC policy. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ if preprocessor_kwargs is None: preprocessor_kwargs = {} if postprocessor_kwargs is None: diff --git a/src/lerobot/policies/vqbet/processor_vqbet.py b/src/lerobot/policies/vqbet/processor_vqbet.py index ad78d1b6a..08d1de334 100644 --- a/src/lerobot/policies/vqbet/processor_vqbet.py +++ b/src/lerobot/policies/vqbet/processor_vqbet.py @@ -36,6 +36,28 @@ def make_vqbet_pre_post_processors( preprocessor_kwargs: ProcessorKwargs | None = None, postprocessor_kwargs: ProcessorKwargs | None = None, ) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: + """ + Constructs pre-processor and post-processor pipelines for the VQ-BeT policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features, allowing customization to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the VQ-BeT policy. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ if preprocessor_kwargs is None: preprocessor_kwargs = {} if postprocessor_kwargs is None: diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py index 02d1dd7e7..d0956d5f3 100644 --- a/src/lerobot/processor/batch_processor.py +++ b/src/lerobot/processor/batch_processor.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python + # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,6 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +This script defines processor steps for adding a batch dimension to various components of an environment transition. + +These steps are designed to process actions, observations, and complementary data, making them suitable for batch processing by adding a leading dimension. This is a common requirement before feeding data into a neural network model. +""" + from dataclasses import dataclass, field from torch import Tensor @@ -31,24 +40,63 @@ from .pipeline import ( @dataclass @ProcessorStepRegistry.register(name="to_batch_processor_action") class AddBatchDimensionActionStep(ActionProcessorStep): - """Process action component in-place, adding batch dimension if needed.""" + """ + Processor step to add a batch dimension to a 1D tensor action. - def action(self, action): + This is useful for creating a batch of size 1 from a single action sample. + """ + + def action(self, action: Tensor) -> Tensor: + """ + Adds a batch dimension to the action if it's a 1D tensor. + + Args: + action: The action tensor. + + Returns: + The action tensor with an added batch dimension. + """ if not isinstance(action, Tensor) or action.dim() != 1: return action - return action.unsqueeze(0) def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """ + Returns the input features unchanged. + + Adding a batch dimension does not alter the feature definition. + + Args: + features: A dictionary of policy features. + + Returns: + The original dictionary of policy features. + """ return features @dataclass @ProcessorStepRegistry.register(name="to_batch_processor_observation") class AddBatchDimensionObservationStep(ObservationProcessorStep): - """Process observation component in-place, adding batch dimensions where needed.""" + """ + Processor step to add a batch dimension to observations. - def observation(self, observation): + It handles different types of observations: + - State vectors (1D tensors). + - Single images (3D tensors). + - Dictionaries of multiple images (3D tensors). + """ + + def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]: + """ + Adds a batch dimension to tensor-based observations in the observation dictionary. + + Args: + observation: The observation dictionary. + + Returns: + The observation dictionary with batch dimensions added to tensors. + """ # Process state observations - add batch dim if 1D for state_key in [OBS_STATE, OBS_ENV_STATE]: if state_key in observation: @@ -69,15 +117,41 @@ class AddBatchDimensionObservationStep(ObservationProcessorStep): return observation def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """ + Returns the input features unchanged. + + Adding a batch dimension does not alter the feature definition. + + Args: + features: A dictionary of policy features. + + Returns: + The original dictionary of policy features. + """ return features @dataclass @ProcessorStepRegistry.register(name="to_batch_processor_complementary_data") class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep): - """Process complementary data in-place, handling task field batching.""" + """ + Processor step to add a batch dimension to complementary data fields. - def complementary_data(self, complementary_data): + Handles specific keys like 'task', 'index', and 'task_index' to make them batched. + - 'task' (str) is wrapped in a list. + - 'index' and 'task_index' (0D tensors) get a batch dimension. + """ + + def complementary_data(self, complementary_data: dict) -> dict: + """ + Adds a batch dimension to specific fields in the complementary data dictionary. + + Args: + complementary_data: The complementary data dictionary. + + Returns: + The complementary data dictionary with batch dimensions added. + """ # Process task field - wrap string in list to add batch dimension if "task" in complementary_data: task_value = complementary_data["task"] @@ -98,44 +172,33 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep): return complementary_data def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """ + Returns the input features unchanged. + + Adding a batch dimension does not alter the feature definition. + + Args: + features: A dictionary of policy features. + + Returns: + The original dictionary of policy features. + """ return features @dataclass @ProcessorStepRegistry.register(name="to_batch_processor") class AddBatchDimensionProcessorStep(ProcessorStep): - """Processor that adds batch dimensions to observations and actions when needed. + """ + A composite processor step that adds a batch dimension to the entire environment transition. - This processor ensures that observations and actions have proper batch dimensions for model processing: + This step combines individual processors for actions, observations, and complementary data + to create a batched transition (batch size 1) from a single-instance transition. - - For state observations (observation.state, observation.environment_state): - Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional - - - For image observations (observation.image, observation.images.*): - Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C) - - - For actions: - Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional - - - For task field in complementary data: - Wraps string task in a list to add batch dimension - (task must be a string or list of strings) - - This is useful when processing single transitions that need to be batched for - model inference or when converting from unbatched environment outputs to - batched model inputs. - - The processor only modifies tensors that need batching and leaves already - batched tensors unchanged. - - Example: - ```python - # State: (7,) -> (1, 7) - # Image: (224, 224, 3) -> (1, 224, 224, 3) - # Action: (4,) -> (1, 4) - # Task: "pick_cube" -> ["pick_cube"] - # Already batched: (1, 7) -> (1, 7) [unchanged] - ``` + Attributes: + to_batch_action_processor: Processor for the action component. + to_batch_observation_processor: Processor for the observation component. + to_batch_complementary_data_processor: Processor for the complementary data component. """ to_batch_action_processor: AddBatchDimensionActionStep = field( @@ -149,11 +212,31 @@ class AddBatchDimensionProcessorStep(ProcessorStep): ) def __call__(self, transition: EnvTransition) -> EnvTransition: + """ + Applies the batching process to all relevant parts of an environment transition. + + Args: + transition: The environment transition to process. + + Returns: + The environment transition with a batch dimension added. + """ transition = self.to_batch_action_processor(transition) transition = self.to_batch_observation_processor(transition) transition = self.to_batch_complementary_data_processor(transition) return transition def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """ + Returns the input features unchanged. + + Adding a batch dimension does not alter the feature definition. + + Args: + features: A dictionary of policy features. + + Returns: + The original dictionary of policy features. + """ # NOTE: We ignore the batch dimension when transforming features return features diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index ceec211aa..2551b2b6b 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -44,12 +44,12 @@ def to_tensor( different input types appropriately. Args: - value: Input value to convert (tensor, array, scalar, sequence, etc.) + value: Input value to convert (tensor, array, scalar, sequence, etc.). dtype: Target tensor dtype. If None, preserves original dtype. device: Target device for the tensor. Returns: - PyTorch tensor. + A PyTorch tensor. Raises: TypeError: If the input type is not supported. @@ -59,7 +59,7 @@ def to_tensor( @to_tensor.register(torch.Tensor) def _(value: torch.Tensor, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor: - """Handle existing PyTorch tensors.""" + """Handle conversion for existing PyTorch tensors.""" if dtype is not None: value = value.to(dtype=dtype) if device is not None: @@ -75,17 +75,17 @@ def _( device=None, **kwargs, ) -> torch.Tensor: - """Handle numpy arrays.""" - # Check for numpy scalars (0-dimensional arrays) and treat them as scalars + """Handle conversion for numpy arrays.""" + # Check for numpy scalars (0-dimensional arrays) and treat them as scalars. if value.ndim == 0: - # Numpy scalars should be converted to 0-dimensional tensors + # Numpy scalars should be converted to 0-dimensional tensors. scalar_value = value.item() return torch.tensor(scalar_value, dtype=dtype, device=device) - # Create tensor from numpy array (torch.from_numpy handles contiguity automatically) + # Create tensor from numpy array. tensor = torch.from_numpy(value) - # Apply dtype conversion if specified + # Apply dtype and device conversion if specified. if dtype is not None: tensor = tensor.to(dtype=dtype) if device is not None: @@ -99,20 +99,20 @@ def _( @to_tensor.register(np.integer) @to_tensor.register(np.floating) def _(value, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor: - """Handle scalar values including numpy scalars.""" + """Handle conversion for scalar values including numpy scalars.""" return torch.tensor(value, dtype=dtype, device=device) @to_tensor.register(list) @to_tensor.register(tuple) def _(value: Sequence, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor: - """Handle sequences (lists, tuples).""" + """Handle conversion for sequences (lists, tuples).""" return torch.tensor(value, dtype=dtype, device=device) @to_tensor.register(dict) def _(value: dict, *, device=None, **kwargs) -> dict: - """Handle dictionaries by recursively converting values to tensors.""" + """Handle conversion for dictionaries by recursively converting their values to tensors.""" if not value: return {} @@ -122,7 +122,7 @@ def _(value: dict, *, device=None, **kwargs) -> dict: continue if isinstance(sub_value, dict): - # Recursively process nested dictionaries + # Recursively process nested dictionaries. result[key] = to_tensor( sub_value, device=device, @@ -130,7 +130,7 @@ def _(value: dict, *, device=None, **kwargs) -> dict: ) continue - # Convert individual values to tensors + # Convert individual values to tensors. result[key] = to_tensor( sub_value, device=device, @@ -140,17 +140,45 @@ def _(value: dict, *, device=None, **kwargs) -> dict: def _from_tensor(x: torch.Tensor | Any) -> np.ndarray | float | int | Any: - """Convert tensor to numpy/scalar if needed.""" + """ + Convert a PyTorch tensor to a numpy array or scalar if applicable. + + If the input is not a tensor, it is returned unchanged. + + Args: + x: The input, which can be a tensor or any other type. + + Returns: + A numpy array, a scalar, or the original input. + """ if isinstance(x, torch.Tensor): return x.item() if x.numel() == 1 else x.detach().cpu().numpy() return x def _is_image(arr: Any) -> bool: + """ + Check if a given array is likely an image (uint8, 3D). + + Args: + arr: The array to check. + + Returns: + True if the array matches the image criteria, False otherwise. + """ return isinstance(arr, np.ndarray) and arr.dtype == np.uint8 and arr.ndim == 3 def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: + """ + Separate an observation dictionary into state and image components. + + Args: + obs: The observation dictionary. + + Returns: + A tuple containing two dictionaries: one for state and one for images. + """ state, images = {}, {} for k, v in obs.items(): if "image" in k.lower() or _is_image(v): @@ -160,13 +188,21 @@ def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], return state, images -# ============================================================================ # Private Helper Functions (Common Logic) -# ============================================================================ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]: - """Extract complementary data (pad flags, task, index, task_index).""" + """ + Extract complementary data from a batch dictionary. + + This includes padding flags, task description, and indices. + + Args: + batch: The batch dictionary. + + Returns: + A dictionary with the extracted complementary data. + """ pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} task_key = {"task": batch["task"]} if "task" in batch else {} index_key = {"index": batch["index"]} if "index" in batch else {} @@ -176,7 +212,16 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]: def _merge_transitions(base: EnvTransition, other: EnvTransition) -> EnvTransition: - """Merge two transitions, with other taking precedence.""" + """ + Merge two transitions, with the second one taking precedence in case of conflicts. + + Args: + base: The base transition. + other: The transition to merge, which will overwrite base values. + + Returns: + The merged transition dictionary. + """ out = deepcopy(base) for key in ( @@ -194,9 +239,7 @@ def _merge_transitions(base: EnvTransition, other: EnvTransition) -> EnvTransiti return out -# ============================================================================ # Core Conversion Functions -# ============================================================================ def create_transition( @@ -208,7 +251,8 @@ def create_transition( info: dict[str, Any] | None = None, complementary_data: dict[str, Any] | None = None, ) -> EnvTransition: - """Create an EnvTransition with sensible defaults. + """ + Create an `EnvTransition` dictionary with sensible defaults. Args: observation: Observation dictionary. @@ -220,7 +264,7 @@ def create_transition( complementary_data: Complementary data dictionary. Returns: - Complete EnvTransition dictionary. + A complete `EnvTransition` dictionary. """ return { TransitionKey.OBSERVATION: observation, @@ -233,9 +277,19 @@ def create_transition( } -def action_to_transition(action: dict[str, Any]) -> EnvTransition: # action_to_transition +def action_to_transition(action: dict[str, Any]) -> EnvTransition: """ - Convert a raw teleop action dict into an EnvTransition under the ACTION TransitionKey. + Convert a raw action dictionary into a standardized `EnvTransition`. + + The keys in the action dictionary are prefixed with "action." and stored under + the `ACTION` key in the transition. Values are converted to tensors, except for + special types like `Rotation`. + + Args: + action: The raw action dictionary from a teleoperation device or controller. + + Returns: + An `EnvTransition` containing the formatted action. """ act_dict: dict[str, Any] = {} for k, v in action.items(): @@ -250,10 +304,19 @@ def action_to_transition(action: dict[str, Any]) -> EnvTransition: # action_to_ return create_transition(observation={}, action=act_dict) -# TODO(Adil, Pepijn): Overtime we can maybe add these converters to pipeline.py itself def observation_to_transition(observation: dict[str, Any]) -> EnvTransition: """ - Convert a raw robot observation dict into an EnvTransition under the OBSERVATION TransitionKey. + Convert a raw robot observation dictionary into a standardized `EnvTransition`. + + The observation is split into state and image components. State keys are prefixed + with "observation.state." and image keys with "observation.images.". The result is + stored under the `OBSERVATION` key in the transition. + + Args: + observation: The raw observation dictionary from the environment. + + Returns: + An `EnvTransition` containing the formatted observation. """ state, images = _split_obs_to_state_and_images(observation) @@ -270,7 +333,16 @@ def observation_to_transition(observation: dict[str, Any]) -> EnvTransition: def transition_to_robot_action(transition: EnvTransition) -> dict[str, Any]: """ - Converts a EnvTransition under the ACTION TransitionKey to a dict with keys ending in '.pos' for raw robot actions. + Extract a raw action dictionary for a robot from an `EnvTransition`. + + This function searches for keys in the format "action.*.pos" or "action.*.vel" + and converts them into a flat dictionary suitable for sending to a robot controller. + + Args: + transition: The `EnvTransition` containing the action. + + Returns: + A dictionary representing the raw robot action. """ out: dict[str, Any] = {} action_dict = transition.get(TransitionKey.ACTION) or {} @@ -287,13 +359,21 @@ def transition_to_robot_action(transition: EnvTransition) -> dict[str, Any]: def merge_transitions(transitions: Sequence[EnvTransition] | EnvTransition) -> EnvTransition: - """Merge multiple transitions or return single transition. + """ + Merge a sequence of transitions into a single one. + + If a single transition is provided, it is returned as is. For a sequence, + transitions are merged sequentially, with later transitions in the sequence + overwriting earlier ones. Args: - transitions: Either a single transition or iterable of transitions. + transitions: A single transition or a sequence of them. Returns: - Merged EnvTransition. + A single merged `EnvTransition`. + + Raises: + ValueError: If an empty sequence of transitions is provided. """ if not isinstance(transitions, Sequence): # Single transition @@ -312,26 +392,18 @@ def merge_transitions(transitions: Sequence[EnvTransition] | EnvTransition) -> E def transition_to_dataset_frame( transitions_or_transition: EnvTransition | Sequence[EnvTransition], features: dict[str, dict] ) -> dict[str, Any]: - """Convert a single EnvTransition or an iterable of them into a flat, dataset-friendly dictionary for training or evaluation. + """ + Convert one or more transitions into a flat dictionary suitable for a dataset frame. - Processes transitions according to the provided feature specification and returns - data in the format expected by machine learning models and datasets. + This function processes `EnvTransition` objects according to a feature + specification, producing a format ready for training or evaluation. Args: - transitions_or_transition: Either a single EnvTransition dict or an iterable of them - (which will be merged using merge_transitions). - features: Feature specification dictionary with the following structure: - - 'action': dict with 'names': list of action feature names - - 'observation.state': dict with 'names': list of state feature names - - keys starting with 'observation.images.' are passed through as-is + transitions_or_transition: A single `EnvTransition` or a sequence to be merged. + features: A feature specification dictionary. Returns: - Flat dictionary containing: - - numpy arrays for "observation.state" and "action" (vectorized from feature names) - - any image tensors defined in features (passed through unchanged) - - next.{reward,done,truncated} scalar values - - info dict - - *_is_pad flags and task from complementary_data + A flat dictionary representing a single frame of data for a dataset. """ action_names = features.get(ACTION, {}).get("names", []) obs_state_names = features.get(OBS_STATE, {}).get("names", []) @@ -342,25 +414,25 @@ def transition_to_dataset_frame( act = tr.get(TransitionKey.ACTION, {}) or {} batch: dict[str, Any] = {} - # Images passthrough + # Passthrough for images. for k in image_keys: if k in obs: batch[k] = obs[k] - # Observation.state vector + # Create observation.state vector. if obs_state_names: vals = [_from_tensor(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names] batch[OBS_STATE] = np.asarray(vals, dtype=np.float32) - # Action vector + # Create action vector. if action_names: vals = [_from_tensor(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names] batch[ACTION] = np.asarray(vals, dtype=np.float32) - # Add transition metadata + # Add transition metadata. if tr.get(TransitionKey.REWARD) is not None: reward_val = _from_tensor(tr[TransitionKey.REWARD]) - # Check if features expect array format, otherwise keep as scalar + # Check if features expect array format, otherwise keep as scalar. if REWARD in features and features[REWARD].get("shape") == (1,): batch[REWARD] = np.array([reward_val], dtype=np.float32) else: @@ -380,14 +452,14 @@ def transition_to_dataset_frame( else: batch[TRUNCATED] = truncated_val - # Complementary data flags and task + # Add complementary data flags and task. comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {} if comp: - # pad flags + # Padding flags. for k, v in comp.items(): if k.endswith("_is_pad"): batch[k] = v - # task label + # Task label. if comp.get("task") is not None: batch["task"] = comp["task"] @@ -395,36 +467,27 @@ def transition_to_dataset_frame( def batch_to_transition(batch: dict[str, Any]) -> EnvTransition: - """Convert a batch dict coming from LeRobot replay/dataset code into an EnvTransition dictionary. + """ + Convert a batch dictionary from a dataset/dataloader into an `EnvTransition`. - The function maps well known keys to the EnvTransition structure. Missing keys are - filled with sane defaults (None or 0.0/False). - - Keys recognised (case-sensitive): - * "observation.*" (keys starting with "observation." are grouped into observation dict) - * "action" - * "next.reward" - * "next.done" - * "next.truncated" - * "info" - * "_is_pad" patterns (padding flags) - * "task", "index", "task_index" (complementary data) - - Additional keys are ignored so that existing dataloaders can carry extra - metadata without breaking the processor. + This function maps recognized keys from a batch to the `EnvTransition` structure, + filling in missing keys with sensible defaults. Args: - batch: Batch dictionary from datasets or dataloaders containing the above keys. + batch: A batch dictionary. Returns: - EnvTransition dictionary with properly structured transition data. + An `EnvTransition` dictionary. + + Raises: + ValueError: If the input is not a dictionary. """ - # Validate input type + # Validate input type. if not isinstance(batch, dict): raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}") - # Extract observation keys + # Extract observation and complementary data keys. observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} complementary_data = _extract_complementary_data(batch) @@ -440,25 +503,16 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition: def transition_to_batch(transition: EnvTransition) -> dict[str, Any]: - """Inverse of batch_to_transition. Returns a dict with canonical field names used throughout LeRobot. + """ + Convert an `EnvTransition` back to the canonical batch format used in LeRobot. - Converts an EnvTransition back to the batch format expected by datasets, dataloaders, - and other LeRobot components. - - Output format: - * "action": Action data from transition - * "next.reward": Reward value (defaults to 0.0) - * "next.done": Done flag (defaults to False) - * "next.truncated": Truncated flag (defaults to False) - * "info": Info dictionary (defaults to {}) - * Flattened observation keys (e.g., "observation.state", "observation.images.cam1") - * Complementary data fields ("task", "index", "task_index", padding flags) + This is the inverse of `batch_to_transition`. Args: - transition: EnvTransition dictionary to convert. + transition: The `EnvTransition` to convert. Returns: - Batch dictionary with canonical LeRobot field names suitable for dataloaders. + A batch dictionary with canonical LeRobot field names. """ batch = { "action": transition.get(TransitionKey.ACTION), @@ -468,12 +522,12 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]: "info": transition.get(TransitionKey.INFO, {}), } - # Add complementary data + # Add complementary data. comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) if comp_data: batch.update(comp_data) - # Flatten observation dict + # Flatten observation dictionary. observation = transition.get(TransitionKey.OBSERVATION) if isinstance(observation, dict): batch.update(observation) @@ -482,4 +536,15 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]: def identity_transition(tr: EnvTransition) -> EnvTransition: + """ + An identity function for transitions, returning the input unchanged. + + Useful as a default or placeholder in processing pipelines. + + Args: + tr: An `EnvTransition`. + + Returns: + The same `EnvTransition`. + """ return tr diff --git a/src/lerobot/processor/delta_action_processor.py b/src/lerobot/processor/delta_action_processor.py index 1aeadba2d..0af982f82 100644 --- a/src/lerobot/processor/delta_action_processor.py +++ b/src/lerobot/processor/delta_action_processor.py @@ -1,4 +1,4 @@ -# !/usr/bin/env python +#!/usr/bin/env python # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # @@ -28,7 +28,15 @@ from .pipeline import ActionProcessorStep, ProcessorStepRegistry @dataclass class MapTensorToDeltaActionDictStep(ActionProcessorStep): """ - Map a tensor to a delta action dictionary. + Maps a flat action tensor from a policy to a structured delta action dictionary. + + This step is typically used after a policy outputs a continuous action vector. + It decomposes the vector into named components for delta movements of the + end-effector (x, y, z) and optionally the gripper. + + Attributes: + use_gripper: If True, assumes the 4th element of the tensor is the + gripper action. """ use_gripper: bool = True @@ -60,28 +68,17 @@ class MapTensorToDeltaActionDictStep(ActionProcessorStep): @dataclass class MapDeltaActionToRobotActionStep(ActionProcessorStep): """ - Map delta actions from teleoperators (gamepad, keyboard) to robot target actions - for use with inverse kinematics processors. + Maps delta actions from teleoperators to robot target actions for inverse kinematics. - Expected input ACTION keys: - { - "action.delta_x": float, - "action.delta_y": float, - "action.delta_z": float, - "action.gripper": float (optional), - } + This step converts a dictionary of delta movements (e.g., from a gamepad) + into a target action format that includes an "enabled" flag and target + end-effector positions. It also handles scaling and noise filtering. - Output ACTION keys: - { - "action.enabled": bool, - "action.target_x": float, - "action.target_y": float, - "action.target_z": float, - "action.target_wx": float, - "action.target_wy": float, - "action.target_wz": float, - "action.gripper": float, - } + Attributes: + position_scale: A factor to scale the delta position inputs. + rotation_scale: A factor to scale the delta rotation inputs (currently unused). + noise_threshold: The magnitude below which delta inputs are considered noise + and do not trigger an "enabled" state. """ # Scale factors for delta movements diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index ec7156632..5f1a190b7 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -13,6 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +This script defines a processor step for moving environment transition data to a specific torch device and casting +its floating-point precision. +""" + from dataclasses import dataclass from typing import Any @@ -28,12 +34,16 @@ from .pipeline import ProcessorStep, ProcessorStepRegistry @ProcessorStepRegistry.register("device_processor") @dataclass class DeviceProcessorStep(ProcessorStep): - """Processes transitions by moving tensors to the specified device and optionally converting float dtypes. + """ + Processor step to move all tensors within an `EnvTransition` to a specified device and optionally cast their + floating-point data type. - This processor ensures that all tensors in the transition are moved to the - specified device (CPU or GPU) before they are returned. It can also convert - floating-point tensors to a specified dtype while preserving non-float types - (int, long, bool, etc.). + This is crucial for preparing data for model training or inference on hardware like GPUs. + + Attributes: + device: The target device for tensors (e.g., "cpu", "cuda", "cuda:0"). + float_dtype: The target floating-point dtype as a string (e.g., "float32", "float16", "bfloat16"). + If None, the dtype is not changed. """ device: str = "cpu" @@ -50,8 +60,15 @@ class DeviceProcessorStep(ProcessorStep): } def __post_init__(self): + """ + Initializes the processor by converting string configurations to torch objects. + + This method sets up the `torch.device`, determines if transfers can be non-blocking, and validates the + `float_dtype` string, converting it to a `torch.dtype` object. + """ self.tensor_device: torch.device = get_safe_torch_device(self.device) - self.device = self.tensor_device.type # cuda might have changed to cuda:1 + # Update device string in case a specific GPU was selected (e.g., "cuda" -> "cuda:0") + self.device = self.tensor_device.type self.non_blocking = "cuda" in str(self.device) # Validate and convert float_dtype string to torch dtype @@ -60,27 +77,32 @@ class DeviceProcessorStep(ProcessorStep): raise ValueError( f"Invalid float_dtype '{self.float_dtype}'. Available options: {list(self.DTYPE_MAPPING.keys())}" ) - self._target_float_dtype = self.DTYPE_MAPPING[self.float_dtype] else: self._target_float_dtype = None def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor: - """Process a tensor by moving to device and optionally converting float dtype. + """ + Moves a single tensor to the target device and casts its dtype. - If the tensor is already on a GPU and we're configured for a GPU, it preserves - that GPU placement (useful for multi-GPU training with Accelerate). - Otherwise, it moves to the configured device. + Handles multi-GPU scenarios by not moving a tensor if it's already on a different CUDA device than + the target, which is useful when using frameworks like Accelerate. + + Args: + tensor: The input torch.Tensor. + + Returns: + The processed tensor on the correct device and with the correct dtype. """ # Determine target device if tensor.is_cuda and self.tensor_device.type == "cuda": - # Both tensor and target are on GPU - preserve tensor's GPU placement + # Both tensor and target are on GPU - preserve tensor's GPU placement. # This handles multi-GPU scenarios where Accelerate has already placed - # tensors on the correct GPU for each process + # tensors on the correct GPU for each process. target_device = tensor.device else: - # Either tensor is on CPU, or we're configured for CPU - # In both cases, use the configured device + # Either tensor is on CPU, or we're configured for CPU. + # In both cases, use the configured device. target_device = self.tensor_device # Only move if necessary @@ -94,6 +116,18 @@ class DeviceProcessorStep(ProcessorStep): return tensor def __call__(self, transition: EnvTransition) -> EnvTransition: + """ + Applies device and dtype conversion to all tensors in an environment transition. + + It iterates through the transition, finds all `torch.Tensor` objects (including those nested in + dictionaries like `observation`), and processes them. + + Args: + transition: The input `EnvTransition` object. + + Returns: + A new `EnvTransition` object with all tensors moved to the target device and dtype. + """ new_transition = transition.copy() simple_tensor_keys = [ @@ -108,13 +142,13 @@ class DeviceProcessorStep(ProcessorStep): TransitionKey.COMPLEMENTARY_DATA, ] - # Process simple tensors + # Process simple, top-level tensors for key in simple_tensor_keys: value = transition.get(key) if isinstance(value, torch.Tensor): new_transition[key] = self._process_tensor(value) - # Process dictionary-like tensors + # Process tensors nested within dictionaries for key in dict_tensor_keys: data_dict = transition.get(key) if data_dict is not None: @@ -127,8 +161,24 @@ class DeviceProcessorStep(ProcessorStep): return new_transition def get_config(self) -> dict[str, Any]: - """Return configuration for serialization.""" + """ + Returns the serializable configuration of the processor. + + Returns: + A dictionary containing the device and float_dtype settings. + """ return {"device": self.device, "float_dtype": self.float_dtype} def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """ + Returns the input features unchanged. + + Device and dtype transformations do not alter the fundamental definition of the features (e.g., shape). + + Args: + features: A dictionary of policy features. + + Returns: + The original dictionary of policy features. + """ return features diff --git a/src/lerobot/processor/gym_action_processor.py b/src/lerobot/processor/gym_action_processor.py index a5e7b1777..e4f57c4d9 100644 --- a/src/lerobot/processor/gym_action_processor.py +++ b/src/lerobot/processor/gym_action_processor.py @@ -1,4 +1,4 @@ -#! /usr/bin/env python +#!/usr/bin/env python # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # @@ -10,6 +10,9 @@ # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from dataclasses import dataclass @@ -25,7 +28,17 @@ from .pipeline import ActionProcessorStep, ProcessorStepRegistry @ProcessorStepRegistry.register("torch2numpy_action_processor") @dataclass class Torch2NumpyActionProcessorStep(ActionProcessorStep): - """Convert PyTorch tensor actions to NumPy arrays.""" + """ + Converts a PyTorch tensor action to a NumPy array. + + This step is useful when the output of a policy (typically a torch.Tensor) + needs to be passed to an environment or component that expects a NumPy array. + + Attributes: + squeeze_batch_dim: If True, removes the first dimension of the array + if it is of size 1. This is useful for converting a + batched action of size (1, D) to a single action of size (D,). + """ squeeze_batch_dim: bool = True @@ -38,8 +51,8 @@ class Torch2NumpyActionProcessorStep(ActionProcessorStep): numpy_action = action.detach().cpu().numpy() - # Remove batch dimensions but preserve action dimensions - # Only squeeze if there's a batch dimension (first dim == 1) + # Remove batch dimensions but preserve action dimensions. + # Only squeeze if there's a batch dimension (first dim == 1). if ( self.squeeze_batch_dim and numpy_action.shape @@ -57,7 +70,13 @@ class Torch2NumpyActionProcessorStep(ActionProcessorStep): @ProcessorStepRegistry.register("numpy2torch_action_processor") @dataclass class Numpy2TorchActionProcessorStep(ActionProcessorStep): - """Convert NumPy array action to PyTorch tensor.""" + """ + Converts a NumPy array action to a PyTorch tensor. + + This step is useful for converting actions from environments or hardware, + which are often NumPy arrays, into PyTorch tensors that can be processed + by a policy or model. + """ def action(self, action: np.ndarray) -> torch.Tensor: if not isinstance(action, np.ndarray): diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index 70dc416bf..fe12b27e2 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -1,3 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import time from dataclasses import dataclass @@ -29,21 +46,25 @@ TELEOP_ACTION_KEY = "teleop_action" @runtime_checkable class HasTeleopEvents(Protocol): - """Minimal protocol for objects that provide teleoperation events. + """ + Minimal protocol for objects that provide teleoperation events. - This protocol only defines the additional get_teleop_events() method, - avoiding duplication of the entire Teleoperator interface. + This protocol defines the `get_teleop_events()` method, allowing processor + steps to interact with teleoperators that support event-based controls + (like episode termination or success flagging) without needing to know the + teleoperator's specific class. """ def get_teleop_events(self) -> dict[str, Any]: - """Get extra control events from the teleoperator. + """ + Get extra control events from the teleoperator. Returns: - Dictionary containing control events such as: - - is_intervention: bool - Whether human is currently intervening - - terminate_episode: bool - Whether to terminate the current episode - - success: bool - Whether the episode was successful - - rerecord_episode: bool - Whether to rerecord the episode + A dictionary containing control events such as: + - `is_intervention`: bool - Whether the human is currently intervening. + - `terminate_episode`: bool - Whether to terminate the current episode. + - `success`: bool - Whether the episode was successful. + - `rerecord_episode`: bool - Whether to rerecord the episode. """ ... @@ -53,7 +74,15 @@ TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator) def _check_teleop_with_events(teleop: Teleoperator) -> None: - """Runtime check that a teleoperator implements get_teleop_events.""" + """ + Runtime check that a teleoperator implements the `HasTeleopEvents` protocol. + + Args: + teleop: The teleoperator instance to check. + + Raises: + TypeError: If the teleoperator does not have a `get_teleop_events` method. + """ if not isinstance(teleop, HasTeleopEvents): raise TypeError( f"Teleoperator {type(teleop).__name__} must implement get_teleop_events() method. " @@ -64,11 +93,30 @@ def _check_teleop_with_events(teleop: Teleoperator) -> None: @ProcessorStepRegistry.register("add_teleop_action_as_complementary_data") @dataclass class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep): - """Add teleoperator action to transition complementary data.""" + """ + Adds the raw action from a teleoperator to the transition's complementary data. + + This is useful for human-in-the-loop scenarios where the human's input needs to + be available to downstream processors, for example, to override a policy's action + during an intervention. + + Attributes: + teleop_device: The teleoperator instance to get the action from. + """ teleop_device: Teleoperator def complementary_data(self, complementary_data: dict) -> dict: + """ + Retrieves the teleoperator's action and adds it to the complementary data. + + Args: + complementary_data: The incoming complementary data dictionary. + + Returns: + A new dictionary with the teleoperator action added under the + `teleop_action` key. + """ new_complementary_data = dict(complementary_data) new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action() return new_complementary_data @@ -80,26 +128,33 @@ class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep): @ProcessorStepRegistry.register("add_teleop_action_as_info") @dataclass class AddTeleopEventsAsInfoStep(InfoProcessorStep): - """Add teleoperator control events to transition info. + """ + Adds teleoperator control events (e.g., terminate, success) to the transition's info. - This processor step extracts control events from teleoperators that support - event-based interaction (intervention detection, episode termination, etc.). + This step extracts control events from teleoperators that support event-based + interaction, making these signals available to other parts of the system. - Works with any teleoperator that inherits from Teleoperator and implements the - get_teleop_events() method, including custom user-defined teleoperators. - - Built-in compatible teleoperators: - - GamepadTeleop: Uses gamepad buttons for control events - - KeyboardEndEffectorTeleop: Uses keyboard keys for control events + Attributes: + teleop_device: An instance of a teleoperator that implements the + `HasTeleopEvents` protocol. """ teleop_device: TeleopWithEvents def __post_init__(self): - """Validate that the teleoperator supports events.""" + """Validates that the provided teleoperator supports events after initialization.""" _check_teleop_with_events(self.teleop_device) def info(self, info: dict) -> dict: + """ + Retrieves teleoperator events and updates the info dictionary. + + Args: + info: The incoming info dictionary. + + Returns: + A new dictionary including the teleoperator events. + """ new_info = dict(info) teleop_events = self.teleop_device.get_teleop_events() @@ -113,12 +168,32 @@ class AddTeleopEventsAsInfoStep(InfoProcessorStep): @ProcessorStepRegistry.register("image_crop_resize_processor") @dataclass class ImageCropResizeProcessorStep(ObservationProcessorStep): - """Crop and resize image observations.""" + """ + Crops and/or resizes image observations. + + This step iterates through all image keys in an observation dictionary and applies + the specified transformations. It handles device placement, moving tensors to the + CPU if necessary for operations not supported on certain accelerators like MPS. + + Attributes: + crop_params_dict: A dictionary mapping image keys to cropping parameters + (top, left, height, width). + resize_size: A tuple (height, width) to resize all images to. + """ crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None resize_size: tuple[int, int] | None = None def observation(self, observation: dict) -> dict: + """ + Applies cropping and resizing to all images in the observation dictionary. + + Args: + observation: The observation dictionary, potentially containing image tensors. + + Returns: + A new observation dictionary with transformed images. + """ if self.resize_size is None and not self.crop_params_dict: return observation @@ -146,12 +221,27 @@ class ImageCropResizeProcessorStep(ObservationProcessorStep): return new_observation def get_config(self) -> dict[str, Any]: + """ + Returns the configuration of the step for serialization. + + Returns: + A dictionary with the crop parameters and resize dimensions. + """ return { "crop_params_dict": self.crop_params_dict, "resize_size": self.resize_size, } def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """ + Updates the image feature shapes in the policy features dictionary if resizing is applied. + + Args: + features: The policy features dictionary. + + Returns: + The updated policy features dictionary with new image shapes. + """ if self.resize_size is None: return features for key in features: @@ -163,12 +253,27 @@ class ImageCropResizeProcessorStep(ObservationProcessorStep): @dataclass @ProcessorStepRegistry.register("time_limit_processor") class TimeLimitProcessorStep(TruncatedProcessorStep): - """Track episode steps and enforce time limits.""" + """ + Tracks episode steps and enforces a time limit by truncating the episode. + + Attributes: + max_episode_steps: The maximum number of steps allowed per episode. + current_step: The current step count for the active episode. + """ max_episode_steps: int current_step: int = 0 - def truncated(self, truncated): + def truncated(self, truncated: bool) -> bool: + """ + Increments the step counter and sets the truncated flag if the time limit is reached. + + Args: + truncated: The incoming truncated flag. + + Returns: + True if the episode step limit is reached, otherwise the incoming value. + """ self.current_step += 1 if self.current_step >= self.max_episode_steps: truncated = True @@ -176,11 +281,18 @@ class TimeLimitProcessorStep(TruncatedProcessorStep): return truncated def get_config(self) -> dict[str, Any]: + """ + Returns the configuration of the step for serialization. + + Returns: + A dictionary containing the `max_episode_steps`. + """ return { "max_episode_steps": self.max_episode_steps, } def reset(self) -> None: + """Resets the step counter, typically called at the start of a new episode.""" self.current_step = 0 def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: @@ -190,13 +302,31 @@ class TimeLimitProcessorStep(TruncatedProcessorStep): @dataclass @ProcessorStepRegistry.register("gripper_penalty_processor") class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep): - """Apply penalty for inappropriate gripper usage.""" + """ + Applies a penalty for inefficient gripper usage. + + This step penalizes actions that attempt to close an already closed gripper or + open an already open one, based on position thresholds. + + Attributes: + penalty: The negative reward value to apply. + max_gripper_pos: The maximum position value for the gripper, used for normalization. + """ penalty: float = -0.01 max_gripper_pos: float = 30.0 - def complementary_data(self, complementary_data): - """Calculate gripper penalty and add to complementary data.""" + def complementary_data(self, complementary_data: dict) -> dict: + """ + Calculates the gripper penalty and adds it to the complementary data. + + Args: + complementary_data: The incoming complementary data, which should contain + raw joint positions. + + Returns: + A new complementary data dictionary with the `discrete_penalty` key added. + """ action = self.transition.get(TransitionKey.ACTION) current_gripper_pos = complementary_data.get("raw_joint_positions", None).get(GRIPPER_KEY, None) @@ -223,14 +353,20 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep): return new_complementary_data def get_config(self) -> dict[str, Any]: + """ + Returns the configuration of the step for serialization. + + Returns: + A dictionary containing the penalty value and max gripper position. + """ return { "penalty": self.penalty, "max_gripper_pos": self.max_gripper_pos, } def reset(self) -> None: - """Reset the processor state.""" - self.last_gripper_state = None + """Resets the processor's internal state.""" + pass def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: return features @@ -239,12 +375,33 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep): @dataclass @ProcessorStepRegistry.register("intervention_action_processor") class InterventionActionProcessorStep(ProcessorStep): - """Handle human intervention actions and episode termination.""" + """ + Handles human intervention, overriding policy actions and managing episode termination. + + When an intervention is detected (via teleoperator events in the `info` dict), + this step replaces the policy's action with the human's teleoperated action. + It also processes signals to terminate the episode or flag success. + + Attributes: + use_gripper: Whether to include the gripper in the teleoperated action. + terminate_on_success: If True, automatically sets the `done` flag when a + `success` event is received. + """ use_gripper: bool = False terminate_on_success: bool = True def __call__(self, transition: EnvTransition) -> EnvTransition: + """ + Processes the transition to handle interventions. + + Args: + transition: The incoming environment transition. + + Returns: + The modified transition, potentially with an overridden action, updated + reward, and termination status. + """ action = transition.get(TransitionKey.ACTION) if action is None: return transition @@ -300,6 +457,12 @@ class InterventionActionProcessorStep(ProcessorStep): return new_transition def get_config(self) -> dict[str, Any]: + """ + Returns the configuration of the step for serialization. + + Returns: + A dictionary containing the step's configuration attributes. + """ return { "use_gripper": self.use_gripper, "terminate_on_success": self.terminate_on_success, @@ -312,7 +475,20 @@ class InterventionActionProcessorStep(ProcessorStep): @dataclass @ProcessorStepRegistry.register("reward_classifier_processor") class RewardClassifierProcessorStep(ProcessorStep): - """Apply reward classification to image observations.""" + """ + Applies a pretrained reward classifier to image observations to predict success. + + This step uses a model to determine if the current state is successful, updating + the reward and potentially terminating the episode. + + Attributes: + pretrained_path: Path to the pretrained reward classifier model. + device: The device to run the classifier on. + success_threshold: The probability threshold to consider a prediction as successful. + success_reward: The reward value to assign on success. + terminate_on_success: If True, terminates the episode upon successful classification. + reward_classifier: The loaded classifier model instance. + """ pretrained_path: str | None = None device: str = "cpu" @@ -323,7 +499,7 @@ class RewardClassifierProcessorStep(ProcessorStep): reward_classifier: Any = None def __post_init__(self): - """Initialize the reward classifier after dataclass initialization.""" + """Initializes the reward classifier model after the dataclass is created.""" if self.pretrained_path is not None: from lerobot.policies.sac.reward_model.modeling_classifier import Classifier @@ -332,6 +508,16 @@ class RewardClassifierProcessorStep(ProcessorStep): self.reward_classifier.eval() def __call__(self, transition: EnvTransition) -> EnvTransition: + """ + Processes a transition, applying the reward classifier to its image observations. + + Args: + transition: The incoming environment transition. + + Returns: + The modified transition with an updated reward and done flag based on the + classifier's prediction. + """ new_transition = transition.copy() observation = new_transition.get(TransitionKey.OBSERVATION) if observation is None or self.reward_classifier is None: @@ -371,6 +557,12 @@ class RewardClassifierProcessorStep(ProcessorStep): return new_transition def get_config(self) -> dict[str, Any]: + """ + Returns the configuration of the step for serialization. + + Returns: + A dictionary containing the step's configuration attributes. + """ return { "device": self.device, "success_threshold": self.success_threshold, diff --git a/src/lerobot/processor/joint_observations_processor.py b/src/lerobot/processor/joint_observations_processor.py index dc91d725e..81ba66b53 100644 --- a/src/lerobot/processor/joint_observations_processor.py +++ b/src/lerobot/processor/joint_observations_processor.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass from typing import Any @@ -15,17 +31,42 @@ from lerobot.robots import Robot @dataclass @ProcessorStepRegistry.register("joint_velocity_processor") class JointVelocityProcessorStep(ObservationProcessorStep): - """Add joint velocity information to observations.""" + """ + Calculates and appends joint velocity information to the observation state. + + This step computes the velocity of each joint by calculating the finite + difference between the current and the last observed joint positions. The + resulting velocity vector is then concatenated to the original state vector. + + Attributes: + dt: The time step (delta time) in seconds between observations, used for + calculating velocity. + last_joint_positions: Stores the joint positions from the previous step + to enable velocity calculation. + """ dt: float = 0.1 last_joint_positions: torch.Tensor | None = None def observation(self, observation: dict) -> dict: + """ + Computes joint velocities and adds them to the observation state. + + Args: + observation: The input observation dictionary, expected to contain + an `observation.state` key with joint positions. + + Returns: + A new observation dictionary with the `observation.state` tensor + extended to include joint velocities. + + Raises: + ValueError: If `observation.state` is not found in the observation. + """ # Get current joint positions (assuming they're in observation.state) current_positions = observation.get(OBS_STATE) if current_positions is None: - # TODO(steven): if we get here, then the transform_features method will not hold raise ValueError(f"{OBS_STATE} is not in observation") # Initialize last joint positions if not already set @@ -48,14 +89,33 @@ class JointVelocityProcessorStep(ObservationProcessorStep): return new_observation def get_config(self) -> dict[str, Any]: + """ + Returns the configuration of the step for serialization. + + Returns: + A dictionary containing the time step `dt`. + """ return { "dt": self.dt, } def reset(self) -> None: + """Resets the internal state, clearing the last known joint positions.""" self.last_joint_positions = None def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """ + Updates the `observation.state` feature to reflect the added velocities. + + This method doubles the size of the first dimension of the `observation.state` + shape to account for the concatenation of position and velocity vectors. + + Args: + features: The policy features dictionary. + + Returns: + The updated policy features dictionary. + """ if OBS_STATE in features: original_feature = features[OBS_STATE] # Double the shape to account for positions + velocities @@ -68,11 +128,33 @@ class JointVelocityProcessorStep(ObservationProcessorStep): @dataclass @ProcessorStepRegistry.register("current_processor") class MotorCurrentProcessorStep(ObservationProcessorStep): - """Add motor current information to observations.""" + """ + Reads motor currents from a robot and appends them to the observation state. + + This step queries the robot's hardware interface to get the present current + for each motor and concatenates this information to the existing state vector. + + Attributes: + robot: An instance of a `lerobot` Robot class that provides access to + the hardware bus. + """ robot: Robot | None = None def observation(self, observation: dict) -> dict: + """ + Fetches motor currents and adds them to the observation state. + + Args: + observation: The input observation dictionary. + + Returns: + A new observation dictionary with the `observation.state` tensor + extended to include motor currents. + + Raises: + ValueError: If the `robot` attribute has not been set. + """ # Get current values from robot state if self.robot is None: raise ValueError("Robot is not set") @@ -96,6 +178,18 @@ class MotorCurrentProcessorStep(ObservationProcessorStep): return new_observation def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """ + Updates the `observation.state` feature to reflect the added motor currents. + + This method increases the size of the first dimension of the `observation.state` + shape by the number of motors in the robot. + + Args: + features: The policy features dictionary. + + Returns: + The updated policy features dictionary. + """ if OBS_STATE in features and self.robot is not None: original_feature = features[OBS_STATE] # Add motor current dimensions to the original state shape diff --git a/src/lerobot/processor/migrate_policy_normalization.py b/src/lerobot/processor/migrate_policy_normalization.py index 9690f6b98..659a43856 100644 --- a/src/lerobot/processor/migrate_policy_normalization.py +++ b/src/lerobot/processor/migrate_policy_normalization.py @@ -15,16 +15,22 @@ # limitations under the License. """ -Generic script to migrate any policy model with normalization layers to the new pipeline-based system. +A generic script to migrate LeRobot policies with built-in normalization layers to the new +pipeline-based processor system. -This script: -1. Loads an existing pretrained policy model -2. Extracts normalization statistics from the model -3. Creates both preprocessor and postprocessor: - - Preprocessor: normalizes both inputs (observations) and outputs (actions) for training - - Postprocessor: unnormalizes outputs (actions) for inference -4. Removes normalization layers from the model state_dict -5. Saves the new model and both processors +This script performs the following steps: +1. Loads a pretrained policy model and its configuration from a local path or the + Hugging Face Hub. +2. Scans the model's state dictionary to extract normalization statistics (e.g., mean, + std, min, max) for all features. +3. Creates two new processor pipelines: + - A preprocessor that normalizes inputs (observations) and outputs (actions). + - A postprocessor that unnormalizes outputs (actions) for inference. +4. Removes the original normalization layers from the model's state dictionary, + creating a "clean" model. +5. Saves the new clean model, the preprocessor, the postprocessor, and a generated + model card to a new directory. +6. Optionally pushes all the new artifacts to the Hugging Face Hub. Usage: python src/lerobot/processor/migrate_policy_normalization.py \ @@ -68,7 +74,21 @@ POLICY_CLASSES = { def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]: - """Extract normalization statistics from model state_dict.""" + """ + Scans a model's state_dict to find and extract normalization statistics. + + This function identifies keys corresponding to normalization layers (e.g., those + for mean, std, min, max) based on a set of predefined patterns and organizes + them into a nested dictionary. + + Args: + state_dict: The state dictionary of a pretrained policy model. + + Returns: + A nested dictionary where outer keys are feature names (e.g., + 'observation.state') and inner keys are statistic types ('mean', 'std'), + mapping to their corresponding tensor values. + """ stats = {} # Define patterns to match and their prefixes to remove @@ -112,7 +132,25 @@ def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str def detect_features_and_norm_modes( config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]] ) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]: - """Detect features and normalization modes from config and stats.""" + """ + Infers policy features and normalization modes from the model config and stats. + + This function first attempts to find feature definitions and normalization + mappings directly from the policy's configuration file. If this information is + not present, it infers it from the extracted normalization statistics, using + tensor shapes to determine feature shapes and the presence of specific stat + keys (e.g., 'mean'/'std' vs 'min'/'max') to determine the normalization mode. + It applies sensible defaults if inference is not possible. + + Args: + config: The policy's configuration dictionary from `config.json`. + stats: The normalization statistics extracted from the model's state_dict. + + Returns: + A tuple containing: + - A dictionary mapping feature names to `PolicyFeature` objects. + - A dictionary mapping `FeatureType` enums to `NormalizationMode` enums. + """ features = {} norm_modes = {} @@ -204,7 +242,19 @@ def detect_features_and_norm_modes( def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - """Remove normalization layers from state_dict.""" + """ + Creates a new state_dict with all normalization-related layers removed. + + This function filters the original state dictionary, excluding any keys that + match a set of predefined patterns associated with normalization modules. + + Args: + state_dict: The original model state dictionary. + + Returns: + A new state dictionary containing only the core model weights, without + any normalization parameters. + """ new_state_dict = {} # Patterns to remove @@ -228,7 +278,16 @@ def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]: - """Convert features from old format to PolicyFeature objects.""" + """ + Converts a feature dictionary from the old config format to the new `PolicyFeature` format. + + Args: + features_dict: The feature dictionary in the old format, where values are + simple dictionaries (e.g., `{"shape": [7]}`). + + Returns: + A dictionary mapping feature names to `PolicyFeature` dataclass objects. + """ converted_features = {} for key, feature_dict in features_dict.items(): @@ -254,8 +313,18 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[ def load_model_from_hub( repo_id: str, revision: str = None ) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: - """Load model state_dict and config from hub.""" - # Download files + """ + Downloads and loads a model's state_dict and configs from the Hugging Face Hub. + + Args: + repo_id: The repository ID on the Hub (e.g., 'lerobot/aloha'). + revision: The specific git revision (branch, tag, or commit hash) to use. + + Returns: + A tuple containing the model's state dictionary, the policy configuration, + and the training configuration. + """ + # Download files. safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision) config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision) diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index ab67b1708..d61b84660 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -1,3 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations from copy import deepcopy @@ -20,9 +37,26 @@ class _NormalizationMixin: """ A mixin class providing core functionality for normalization and unnormalization. - This class manages normalization statistics, their conversion to tensors, device placement, - and the application of normalization transformations. It is designed to be inherited by - concrete ProcessorStep implementations. + This class manages normalization statistics (`stats`), converts them to tensors for + efficient computation, handles device placement, and implements the logic for + applying normalization transformations (mean/std and min/max). It is designed to + be inherited by concrete `ProcessorStep` implementations and should not be used + directly. + + Attributes: + features: A dictionary mapping feature names to `PolicyFeature` objects, defining + the data structure to be processed. + norm_map: A dictionary mapping `FeatureType` to `NormalizationMode`, specifying + which normalization method to use for each type of feature. + stats: A dictionary containing the normalization statistics (e.g., mean, std, + min, max) for each feature. + device: The PyTorch device on which to store and perform tensor operations. + eps: A small epsilon value to prevent division by zero in normalization + calculations. + normalize_observation_keys: An optional set of keys to selectively apply + normalization to specific observation features. + _tensor_stats: An internal dictionary holding the normalization statistics as + PyTorch tensors. """ features: dict[str, PolicyFeature] @@ -36,7 +70,15 @@ class _NormalizationMixin: _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) def __post_init__(self): - # Robust JSON deserialization handling (guard empty maps) + """ + Initializes the mixin after dataclass construction. + + This method handles the robust deserialization of `features` and `norm_map` + from JSON-compatible formats (where enums become strings and tuples become + lists) and converts the provided `stats` dictionary into a dictionary of + tensors (`_tensor_stats`) on the specified device. + """ + # Robust JSON deserialization handling (guard empty maps). if self.features: first_val = next(iter(self.features.values())) if isinstance(first_val, dict): @@ -65,7 +107,15 @@ class _NormalizationMixin: def to( self, device: torch.device | str | None = None, dtype: torch.dtype | None = None ) -> _NormalizationMixin: - """Moves the processor's normalization stats to the specified device and returns self.""" + """ + Moves the processor's normalization stats to the specified device. + + Args: + device: The target PyTorch device. + + Returns: + The instance of the class, allowing for method chaining. + """ if device is not None: self.device = device if dtype is not None: @@ -74,6 +124,16 @@ class _NormalizationMixin: return self def state_dict(self) -> dict[str, Tensor]: + """ + Returns the normalization statistics as a flat state dictionary. + + All tensors are moved to the CPU before being returned, which is standard practice + for saving state dictionaries. + + Returns: + A flat dictionary mapping from `'feature_name.stat_name'` to the + corresponding statistics tensor on the CPU. + """ flat: dict[str, Tensor] = {} for key, sub in self._tensor_stats.items(): for stat_name, tensor in sub.items(): @@ -81,6 +141,15 @@ class _NormalizationMixin: return flat def load_state_dict(self, state: dict[str, Tensor]) -> None: + """ + Loads normalization statistics from a state dictionary. + + The loaded tensors are moved to the processor's configured device. + + Args: + state: A flat state dictionary with keys in the format + `'feature_name.stat_name'`. + """ self._tensor_stats.clear() for flat_key, tensor in state.items(): key, stat_name = flat_key.rsplit(".", 1) @@ -90,6 +159,15 @@ class _NormalizationMixin: ) def get_config(self) -> dict[str, Any]: + """ + Returns a serializable dictionary of the processor's configuration. + + This method is used when saving the processor to disk, ensuring that its + configuration can be reconstructed later. + + Returns: + A JSON-serializable dictionary containing the configuration. + """ config = { "eps": self.eps, "features": { @@ -102,6 +180,16 @@ class _NormalizationMixin: return config def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]: + """ + Applies (un)normalization to all relevant features in an observation dictionary. + + Args: + observation: The observation dictionary to process. + inverse: If `True`, applies unnormalization; otherwise, applies normalization. + + Returns: + A new observation dictionary with the transformed tensor values. + """ new_observation = dict(observation) for key, feature in self.features.items(): if self.normalize_observation_keys is not None and key not in self.normalize_observation_keys: @@ -114,6 +202,16 @@ class _NormalizationMixin: def _normalize_action(self, action: Any, inverse: bool) -> Tensor: # Convert to tensor but preserve original dtype for adaptation logic + """ + Applies (un)normalization to an action tensor. + + Args: + action: The action tensor to process. + inverse: If `True`, applies unnormalization; otherwise, applies normalization. + + Returns: + The transformed action tensor. + """ tensor = torch.as_tensor(action) processed_action = self._apply_transform(tensor, "action", FeatureType.ACTION, inverse=inverse) return processed_action @@ -121,7 +219,24 @@ class _NormalizationMixin: def _apply_transform( self, tensor: Tensor, key: str, feature_type: FeatureType, *, inverse: bool = False ) -> Tensor: - """Core logic to apply normalization or unnormalization.""" + """ + Core logic to apply a normalization or unnormalization transformation to a tensor. + + This method selects the appropriate normalization mode (e.g., mean/std, min/max) + based on the feature type and applies the corresponding mathematical operation. + + Args: + tensor: The input tensor to transform. + key: The feature key corresponding to the tensor. + feature_type: The `FeatureType` of the tensor. + inverse: If `True`, applies the inverse transformation (unnormalization). + + Returns: + The transformed tensor. + + Raises: + ValueError: If an unsupported normalization mode is encountered. + """ norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY) if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats: return tensor @@ -168,11 +283,11 @@ class _NormalizationMixin: @ProcessorStepRegistry.register(name="normalizer_processor") class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep): """ - A processor that applies normalization to observations and actions in a transition. + A processor step that applies normalization to observations and actions in a transition. - This class directly implements the normalization logic for both observation and action - components of an `EnvTransition`, using statistics (mean/std or min/max) provided at - initialization. + This class uses the logic from `_NormalizationMixin` to perform forward normalization + (e.g., scaling data to have zero mean and unit variance, or to the range [-1, 1]). + It is typically used in the pre-processing pipeline before feeding data to a policy. """ @classmethod @@ -186,6 +301,20 @@ class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep): eps: float = 1e-8, device: torch.device | str | None = None, ) -> NormalizerProcessorStep: + """ + Creates a `NormalizerProcessorStep` instance using statistics from a `LeRobotDataset`. + + Args: + dataset: The dataset from which to extract normalization statistics. + features: The feature definition for the processor. + norm_map: The mapping from feature types to normalization modes. + normalize_observation_keys: An optional set of observation keys to normalize. + eps: A small epsilon value for numerical stability. + device: The target device for the processor. + + Returns: + A new instance of `NormalizerProcessorStep`. + """ return cls( features=features, norm_map=norm_map, @@ -220,11 +349,12 @@ class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep): @ProcessorStepRegistry.register(name="unnormalizer_processor") class UnnormalizerProcessorStep(_NormalizationMixin, ProcessorStep): """ - A processor that applies unnormalization (the inverse of normalization) to - observations and actions in a transition. + A processor step that applies unnormalization to observations and actions. - This is typically used to transform actions from a normalized policy output back into - the original scale for execution in an environment. + This class inverts the normalization process, scaling data back to its original + range. It is typically used in the post-processing pipeline to convert a policy's + normalized action output into a format that can be executed by a robot or + environment. """ @classmethod @@ -236,6 +366,18 @@ class UnnormalizerProcessorStep(_NormalizationMixin, ProcessorStep): *, device: torch.device | str | None = None, ) -> UnnormalizerProcessorStep: + """ + Creates an `UnnormalizerProcessorStep` using statistics from a `LeRobotDataset`. + + Args: + dataset: The dataset from which to extract normalization statistics. + features: The feature definition for the processor. + norm_map: The mapping from feature types to normalization modes. + device: The target device for the processor. + + Returns: + A new instance of `UnnormalizerProcessorStep`. + """ return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, device=device) def __call__(self, transition: EnvTransition) -> EnvTransition: @@ -261,12 +403,20 @@ def hotswap_stats( policy_processor: PolicyProcessorPipeline, stats: dict[str, dict[str, Any]] ) -> PolicyProcessorPipeline: """ - Replaces normalization statistics in a PolicyProcessor pipeline. + Replaces normalization statistics in an existing `PolicyProcessorPipeline` instance. - This function creates a deep copy of the provided `PolicyProcessorPipeline` and updates the - statistics of any `NormalizerProcessorStep` or `UnnormalizerProcessorStep` steps within it. - It's useful for adapting a trained policy to a new environment or dataset with - different data distributions. + This function creates a deep copy of the provided pipeline and updates the + statistics of any `NormalizerProcessorStep` or `UnnormalizerProcessorStep` it + contains. This is useful for adapting a trained policy to a new environment or + dataset with different data distributions without having to reconstruct the entire + pipeline. + + Args: + policy_processor: The policy processor pipeline to modify. + stats: The new dictionary of normalization statistics to apply. + + Returns: + A new `PolicyProcessorPipeline` instance with the updated statistics. """ rp = deepcopy(policy_processor) for step in rp.steps: diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index b5502f769..bcb351669 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -30,23 +30,44 @@ from .pipeline import ObservationProcessorStep, ProcessorStepRegistry @ProcessorStepRegistry.register(name="observation_processor") class VanillaObservationProcessorStep(ObservationProcessorStep): """ - Processes environment observations into the LeRobot format by handling both images and states. + Processes standard Gymnasium observations into the LeRobot format. - Image processing: - - Converts channel-last (H, W, C) images to channel-first (C, H, W) - - Normalizes uint8 images ([0, 255]) to float32 ([0, 1]) - - Adds a batch dimension if missing - - Supports single images and image dictionaries + This step handles both image and state data from a typical observation dictionary, + preparing it for use in a LeRobot policy. - State processing: - - Maps 'environment_state' to observation.environment_state - - Maps 'agent_pos' to observation.state - - Converts numpy arrays to tensors - - Adds a batch dimension if missing + **Image Processing:** + - Converts channel-last (H, W, C), `uint8` images to channel-first (C, H, W), + `float32` tensors. + - Normalizes pixel values from the [0, 255] range to [0, 1]. + - Adds a batch dimension if one is not already present. + - Recognizes a single image under the key `"pixels"` and maps it to + `"observation.image"`. + - Recognizes a dictionary of images under the key `"pixels"` and maps them + to `"observation.images.{camera_name}"`. + + **State Processing:** + - Maps the `"environment_state"` key to `"observation.environment_state"`. + - Maps the `"agent_pos"` key to `"observation.state"`. + - Converts NumPy arrays to PyTorch tensors. + - Adds a batch dimension if one is not already present. """ def _process_single_image(self, img: np.ndarray) -> Tensor: - """Process a single image array.""" + """ + Processes a single NumPy image array into a channel-first, normalized tensor. + + Args: + img: A NumPy array representing the image, expected to be in channel-last + (H, W, C) format with a `uint8` dtype. + + Returns: + A `float32` PyTorch tensor in channel-first (B, C, H, W) format, with + pixel values normalized to the [0, 1] range. + + Raises: + ValueError: If the input image does not appear to be in channel-last + format or is not of `uint8` dtype. + """ # Convert to tensor img_tensor = torch.from_numpy(img) @@ -108,16 +129,24 @@ class VanillaObservationProcessorStep(ObservationProcessorStep): return self._process_observation(observation) def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - """Transforms feature keys to a standardized contract. - This method handles several renaming patterns: - - Exact matches (e.g., 'pixels' -> 'OBS_IMAGE'). - - Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE'). - - Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1'). - - Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1'). - - environment_state -> OBS_ENV_STATE, - - agent_pos -> OBS_STATE, - - observation.environment_state -> OBS_ENV_STATE, - - observation.agent_pos -> OBS_STATE + """ + Transforms feature keys from the Gym standard to the LeRobot standard. + + This method standardizes the feature dictionary by renaming keys according + to LeRobot's conventions, ensuring that policies can be constructed correctly. + It handles various raw key formats, including those with an "observation." prefix. + + **Renaming Rules:** + - `pixels` or `observation.pixels` -> `observation.image` + - `pixels.{cam}` or `observation.pixels.{cam}` -> `observation.images.{cam}` + - `environment_state` or `observation.environment_state` -> `observation.environment_state` + - `agent_pos` or `observation.agent_pos` -> `observation.state` + + Args: + features: The policy features dictionary with Gym-style keys. + + Returns: + The policy features dictionary with standardized LeRobot keys. """ exact_pairs = { "pixels": OBS_IMAGE, diff --git a/src/lerobot/processor/rename_processor.py b/src/lerobot/processor/rename_processor.py index 215365803..f233e1881 100644 --- a/src/lerobot/processor/rename_processor.py +++ b/src/lerobot/processor/rename_processor.py @@ -25,7 +25,18 @@ from .pipeline import ObservationProcessorStep, ProcessorStepRegistry @dataclass @ProcessorStepRegistry.register(name="rename_processor") class RenameProcessorStep(ObservationProcessorStep): - """Rename processor that renames keys in the observation.""" + """ + A processor step that renames keys in an observation dictionary. + + This step is useful for creating a standardized data interface by mapping keys + from an environment's format to the format expected by a LeRobot policy or + other downstream components. + + Attributes: + rename_map: A dictionary mapping from old key names to new key names. + Keys present in an observation that are not in this map will + be kept with their original names. + """ rename_map: dict[str, str] = field(default_factory=dict) @@ -51,7 +62,22 @@ class RenameProcessorStep(ObservationProcessorStep): def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]: - """Rename keys in the stats dictionary according to rename_map (defensive copy).""" + """ + Renames the top-level keys in a statistics dictionary using a provided mapping. + + This is a helper function typically used to keep normalization statistics + consistent with renamed observation or action features. It performs a defensive + deep copy to avoid modifying the original `stats` dictionary. + + Args: + stats: A nested dictionary of statistics, where top-level keys are + feature names (e.g., `{"observation.state": {"mean": 0.5}}`). + rename_map: A dictionary mapping old feature names to new feature names. + + Returns: + A new statistics dictionary with its top-level keys renamed. Returns an + empty dictionary if the input `stats` is empty. + """ if not stats: return {} renamed: dict[str, dict[str, Any]] = {} diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index 54ca17098..3ab21ecdd 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -1,5 +1,24 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ -Tokenizer processor for handling text tokenization in robot transitions. +This script defines a processor for tokenizing natural language instructions from an environment transition. + +It uses a tokenizer from the Hugging Face `transformers` library to convert task descriptions (text) into +token IDs and attention masks, which are then added to the observation dictionary. """ from __future__ import annotations @@ -16,6 +35,7 @@ from lerobot.utils.import_utils import _transformers_available from .core import EnvTransition, TransitionKey from .pipeline import ObservationProcessorStep, ProcessorStepRegistry +# Conditional import for type checking and lazy loading if TYPE_CHECKING or _transformers_available: from transformers import AutoTokenizer else: @@ -25,54 +45,48 @@ else: @dataclass @ProcessorStepRegistry.register(name="tokenizer_processor") class TokenizerProcessorStep(ObservationProcessorStep): - """Tokenizes text tasks in complementary data using a huggingface tokenizer. + """ + Processor step to tokenize a natural language task description. - This processor handles tokenization of task strings found in the complementary_data - using a specified pretrained tokenizer from Hugging Face. It adds tokenized versions - to the observation data for model processing while preserving the original task string. + This step extracts a task string from the `complementary_data` of an `EnvTransition`, + tokenizes it using a Hugging Face `transformers` tokenizer, and adds the resulting + token IDs and attention mask to the `observation` dictionary. - The processor supports both single strings and lists of strings as task inputs. + Requires the `transformers` library to be installed. - Args: - tokenizer_name: Name of the pretrained tokenizer to load from Hugging Face Hub - (e.g., "bert-base-uncased", "microsoft/DialoGPT-medium"). This will be used - with AutoTokenizer.from_pretrained(). If tokenizer is provided, this is ignored. - tokenizer: A tokenizer object (e.g., from transformers library) that implements - the __call__ method. If provided, tokenizer_name is ignored. This parameter - is not serialized and must be provided via overrides when loading. - max_length: Maximum sequence length for tokenization. Defaults to 512. - task_key: Key in complementary_data containing the task text. Defaults to "task". - padding: Padding strategy for tokenization. Defaults to "max_length". - truncation: Whether to truncate sequences longer than max_length. Defaults to True. - - Examples: - Using tokenizer name (auto-loaded): - ```python - processor = TokenizerProcessorStep(tokenizer_name="bert-base-uncased", max_length=128) - ``` - - Using custom tokenizer object: - ```python - from transformers import AutoTokenizer - - custom_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") - processor = TokenizerProcessorStep(tokenizer=custom_tokenizer, max_length=128) - ``` + Attributes: + tokenizer_name: The name of a pretrained tokenizer from the Hugging Face Hub (e.g., "bert-base-uncased"). + tokenizer: A pre-initialized tokenizer object. If provided, `tokenizer_name` is ignored. + max_length: The maximum length to pad or truncate sequences to. + task_key: The key in `complementary_data` where the task string is stored. + padding_side: The side to pad on ('left' or 'right'). + padding: The padding strategy ('max_length', 'longest', etc.). + truncation: Whether to truncate sequences longer than `max_length`. + input_tokenizer: The internal tokenizer instance, loaded during initialization. """ tokenizer_name: str | None = None - tokenizer: Any | None = None # Otherwise transformers is not available in the core dependencies + tokenizer: Any | None = None # Use `Any` for compatibility without a hard dependency max_length: int = 512 task_key: str = "task" padding_side: str = "right" padding: str = "max_length" truncation: bool = True - # Internal tokenizer instance (not serialized) + # Internal tokenizer instance (not part of the config) input_tokenizer: Any = field(default=None, init=False, repr=False) def __post_init__(self): - """Initialize the tokenizer from the provided tokenizer or tokenizer name.""" + """ + Initializes the tokenizer after the dataclass is created. + + It checks for the availability of the `transformers` library and loads the tokenizer + either from a provided object or by name from the Hugging Face Hub. + + Raises: + ImportError: If the `transformers` library is not installed. + ValueError: If neither `tokenizer` nor `tokenizer_name` is provided. + """ if not _transformers_available: raise ImportError( "The 'transformers' library is not installed. " @@ -93,13 +107,14 @@ class TokenizerProcessorStep(ObservationProcessorStep): ) def get_task(self, transition: EnvTransition) -> list[str] | None: - """Extract and normalize task from complementary data. + """ + Extracts the task description(s) from the transition's complementary data. Args: - transition: Input transition containing complementary_data. + transition: The environment transition. Returns: - List of task strings if task is present, None otherwise. + A list of task strings, or None if the task key is not found or the value is None. """ complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) if complementary_data is None: @@ -112,7 +127,7 @@ class TokenizerProcessorStep(ObservationProcessorStep): if task is None: return None - # Convert to list of strings + # Standardize to a list of strings for the tokenizer if isinstance(task, str): return [task] elif isinstance(task, list) and all(isinstance(t, str) for t in task): @@ -120,78 +135,80 @@ class TokenizerProcessorStep(ObservationProcessorStep): return None - def observation(self, observation): - """Process the transition by tokenizing the task text. + def observation(self, observation: dict[str, Any]) -> dict[str, Any]: + """ + Tokenizes the task description and adds it to the observation dictionary. + + This method retrieves the task, tokenizes it, moves the resulting tensors to the + same device as other data in the transition, and updates the observation. Args: - transition: Input transition containing complementary_data with task text. + observation: The original observation dictionary. Returns: - Modified transition with tokenized task added to observation. - - Raises: - ValueError: If tokenizer initialization failed. + The updated observation dictionary including token IDs and an attention mask. """ task = self.get_task(self.transition) if task is None: return observation - # Tokenize the task (creates CPU tensors) + # Tokenize the task (this will create CPU tensors) tokenized_prompt = self._tokenize_text(task) - # Detect device from existing tensors in the transition + # Detect the device from existing tensors in the transition to ensure consistency target_device = self._detect_device(self.transition) - # Move tokenized tensors to match the device of other data + # Move new tokenized tensors to the detected device if target_device is not None: tokenized_prompt = { k: v.to(target_device) if isinstance(v, torch.Tensor) else v for k, v in tokenized_prompt.items() } - # Get or create observation dict + # Create a new observation dict to avoid modifying the original in place new_observation = dict(observation) - # Add tokenized data to observation + # Add tokenized data to the observation new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"] new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool) return new_observation def _detect_device(self, transition: EnvTransition) -> torch.device | None: - """Detect device from existing tensors in the transition. + """ + Detects the torch.device from existing tensors in the transition. - This allows the tokenized tensors to match the device of other data, - which is especially important for multi-GPU training with Accelerate. + It checks tensors in the observation dictionary first, then the action tensor. Args: - transition: The transition to search for existing tensors. + transition: The environment transition. Returns: - The device of the first tensor found, or None if no tensors exist. + The detected `torch.device`, or None if no tensors are found. """ - # Check observation tensors first (most likely to exist) + # Check observation tensors first (most likely place to find tensors) observation = transition.get(TransitionKey.OBSERVATION) if observation: for value in observation.values(): if isinstance(value, torch.Tensor): return value.device - # Check action tensor + # Fallback to checking the action tensor action = transition.get(TransitionKey.ACTION) if isinstance(action, torch.Tensor): return action.device - return None # No tensors found, keep on CPU + return None # No tensors found, default will be CPU def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]: - """Tokenize text using the configured tokenizer. + """ + A wrapper around the tokenizer call. Args: - text: Text string or list of strings to tokenize. + text: A string or list of strings to tokenize. Returns: - Dictionary containing tokenized output with keys like 'input_ids', 'attention_mask'. + A dictionary containing tokenized 'input_ids' and 'attention_mask' as PyTorch tensors. """ return self.input_tokenizer( text, @@ -203,10 +220,14 @@ class TokenizerProcessorStep(ObservationProcessorStep): ) def get_config(self) -> dict[str, Any]: - """Return configuration for serialization. + """ + Returns the serializable configuration of the processor. - Note: Only tokenizer_name is saved, not the tokenizer object itself. - When loading, provide the tokenizer via overrides if needed. + Note: The tokenizer object itself is not serialized. If the processor was initialized + with a tokenizer name, that name will be included in the config. + + Returns: + A dictionary with the processor's configuration parameters. """ config = { "max_length": self.max_length, @@ -216,28 +237,30 @@ class TokenizerProcessorStep(ObservationProcessorStep): "truncation": self.truncation, } - # Only include tokenizer_name if it was used (not when tokenizer object was provided) - # TODO(steven): Consider saving the name of the _tokenizer if it was loaded + # Only save tokenizer_name if it was used to create the tokenizer if self.tokenizer_name is not None and self.tokenizer is None: config["tokenizer_name"] = self.tokenizer_name return config def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - """Add tokenized task features to the feature contract. + """ + Adds feature definitions for the language tokens and attention mask. + + This updates the policy features dictionary to include the new data added to the + observation, ensuring downstream components are aware of their shape and type. Args: - features: Input feature dictionary. + features: The dictionary of existing policy features. Returns: - Updated feature dictionary with tokenized task features added. + The updated dictionary of policy features. """ - # Add features for tokenized output if they don't exist - # Standard tokenizer output includes tokens and attention_mask - + # Add a feature for the token IDs if it doesn't already exist if OBS_LANGUAGE_TOKENS not in features: features[OBS_LANGUAGE_TOKENS] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,)) + # Add a feature for the attention mask if it doesn't already exist if OBS_LANGUAGE_ATTENTION_MASK not in features: features[OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature( type=FeatureType.LANGUAGE, shape=(self.max_length,) diff --git a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py index b6ac6a6c0..0eb5ad5cd 100644 --- a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py +++ b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py @@ -1,4 +1,4 @@ -# !/usr/bin/env python +#!/usr/bin/env python # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # @@ -38,18 +38,27 @@ from lerobot.utils.rotation import Rotation @dataclass class EEReferenceAndDelta(ActionProcessorStep): """ - Compute the desired end-effector pose from the target pose and the current pose. + Computes a target end-effector pose from a relative delta command. - Input ACTION keys: - { - "action.ee.{x,y,z,wx,wy,wz}" : float - "complementary_data.raw_joint_positions": dict, - } + This step takes a desired change in position and orientation (`target_*`) and applies it to a + reference end-effector pose to calculate an absolute target pose. The reference pose is derived + from the current robot joint positions using forward kinematics. - Output ACTION keys: - { - "action.ee.{x,y,z,wx,wy,wz}" : float - } + The processor can operate in two modes: + 1. `use_latched_reference=True`: The reference pose is "latched" or saved at the moment the action + is first enabled. Subsequent commands are relative to this fixed reference. + 2. `use_latched_reference=False`: The reference pose is updated to the robot's current pose at + every step. + + Attributes: + kinematics: The robot's kinematic model for forward kinematics. + end_effector_step_sizes: A dictionary scaling the input delta commands. + motor_names: A list of motor names required for forward kinematics. + use_latched_reference: If True, latch the reference pose on enable; otherwise, always use the + current pose as the reference. + reference_ee_pose: Internal state storing the latched reference pose. + _prev_enabled: Internal state to detect the rising edge of the enable signal. + _command_when_disabled: Internal state to hold the last command while disabled. """ kinematics: RobotKinematics @@ -135,6 +144,7 @@ class EEReferenceAndDelta(ActionProcessorStep): return new_action def reset(self): + """Resets the internal state of the processor.""" self._prev_enabled = False self.reference_ee_pose = None self._command_when_disabled = None @@ -161,17 +171,17 @@ class EEReferenceAndDelta(ActionProcessorStep): @dataclass class EEBoundsAndSafety(ActionProcessorStep): """ - Clip the end-effector pose to the bounds and check for jumps. + Clips the end-effector pose to predefined bounds and checks for unsafe jumps. - Input ACTION keys: - { - "action.ee.{x,y,z,wx,wy,wz}" : float - } + This step ensures that the target end-effector pose remains within a safe operational workspace. + It also moderates the command to prevent large, sudden movements between consecutive steps. - Output ACTION keys: - { - "action.ee.{x,y,z,wx,wy,wz}" : float - } + Attributes: + end_effector_bounds: A dictionary with "min" and "max" keys for position clipping. + max_ee_step_m: The maximum allowed change in position (in meters) between steps. + max_ee_twist_step_rad: The maximum allowed change in orientation (in radians) between steps. + _last_pos: Internal state storing the last commanded position. + _last_twist: Internal state storing the last commanded orientation. """ end_effector_bounds: dict @@ -219,6 +229,7 @@ class EEBoundsAndSafety(ActionProcessorStep): return act def reset(self): + """Resets the last known position and orientation.""" self._last_pos = None self._last_twist = None @@ -232,21 +243,17 @@ class EEBoundsAndSafety(ActionProcessorStep): @dataclass class InverseKinematicsEEToJoints(ProcessorStep): """ - Compute the desired joint positions from the desired end-effector pose. + Computes desired joint positions from a target end-effector pose using inverse kinematics (IK). - Input ACTION keys: - { - "action.ee.{x,y,z,wx,wy,wz}" : float - "complementary_data.raw_joint_positions": dict, - } + This step translates a Cartesian command (position and orientation of the end-effector) into + the corresponding joint-space commands for each motor. - Output ACTION keys: - { - "action.joint_name_1.pos": float, - "action.joint_name_2.pos": float, - ... - "action.joint_name_n.pos": float, - } + Attributes: + kinematics: The robot's kinematic model for inverse kinematics. + motor_names: A list of motor names for which to compute joint positions. + q_curr: Internal state storing the last joint positions, used as an initial guess for the IK solver. + initial_guess_current_joints: If True, use the robot's current joint state as the IK guess. + If False, use the solution from the previous step. """ kinematics: RobotKinematics @@ -312,6 +319,7 @@ class InverseKinematicsEEToJoints(ProcessorStep): return features def reset(self): + """Resets the initial guess for the IK solver.""" self.q_curr = None @@ -319,17 +327,18 @@ class InverseKinematicsEEToJoints(ProcessorStep): @dataclass class GripperVelocityToJoint(ProcessorStep): """ - Convert the gripper velocity to a joint velocity. + Converts a gripper velocity command into a target gripper joint position. - Input ACTION keys: - { - "action.gripper": float, - } + This step integrates a normalized velocity command over time to produce a position command, + taking the current gripper position as a starting point. It also supports a discrete mode + where integer actions map to open, close, or no-op. - Output ACTION keys: - { - "action.gripper.pos": float, - } + Attributes: + motor_names: A list of motor names, which must include 'gripper'. + speed_factor: A scaling factor to convert the normalized velocity command to a position change. + clip_min: The minimum allowed gripper joint position. + clip_max: The maximum allowed gripper joint position. + discrete_gripper: If True, treat the input action as discrete (0: open, 1: close, 2: stay). """ motor_names: list[str] @@ -365,7 +374,7 @@ class GripperVelocityToJoint(ProcessorStep): raw = comp.get("raw_joint_positions") or {} curr_pos = float(raw.get("gripper")) - # Compute desired gripper velocity + # Compute desired gripper position u = float(act.get(f"{ACTION}.gripper", 0.0)) delta = u * float(self.speed_factor) gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max)) @@ -391,17 +400,14 @@ class GripperVelocityToJoint(ProcessorStep): @dataclass class ForwardKinematicsJointsToEE(ObservationProcessorStep): """ - Compute the end-effector pose from the joint positions. + Computes the end-effector pose from joint positions using forward kinematics (FK). - Input OBSERVATION keys: - { - "observation.state.{joint_name_1,joint_name_2,...,joint_name_n}.pos": float, - } + This step is typically used to add the robot's Cartesian pose to the observation space, + which can be useful for visualization or as an input to a policy. - Output OBSERVATION keys: - { - "observation.state.ee.{x,y,z,wx,wy,wz}" : float - } + Attributes: + kinematics: The robot's kinematic model. + motor_names: A list of motor names whose joint positions are used for FK. """ kinematics: RobotKinematics @@ -435,10 +441,14 @@ class ForwardKinematicsJointsToEE(ObservationProcessorStep): @dataclass class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessorStep): """ - Read the robot's current observation and insert it into the transition as complementary data. + Reads the robot's current observation and adds it to the transition's complementary data. - - Joint positions are added under complementary_data["raw_joint_positions"] as a dict: - { "": , ... } + This step acts as a bridge to the physical robot, injecting its real-time sensor readings + (like raw joint positions) into the data processing pipeline. This data is then available + for other processing steps. + + Attributes: + robot: An instance of a `Robot` class used to get observations from hardware. """ robot: Robot diff --git a/src/lerobot/scripts/rl/actor.py b/src/lerobot/scripts/rl/actor.py index ee2f6c6b1..baa284c4a 100644 --- a/src/lerobot/scripts/rl/actor.py +++ b/src/lerobot/scripts/rl/actor.py @@ -98,9 +98,7 @@ from lerobot.utils.utils import ( ACTOR_SHUTDOWN_TIMEOUT = 30 -################################################# -# Main entry point # -################################################# +# Main entry point @parser.wrap() @@ -207,9 +205,7 @@ def actor_cli(cfg: TrainRLServerPipelineConfig): logging.info("[ACTOR] queues closed") -################################################# -# Core algorithm functions # -################################################# +# Core algorithm functions def act_with_policy( @@ -406,9 +402,7 @@ def act_with_policy( busy_wait(1 / cfg.env.fps - dt_time) -################################################# -# Communication Functions - Group all gRPC/messaging functions # -################################################# +# Communication Functions - Group all gRPC/messaging functions def establish_learner_connection( @@ -653,9 +647,7 @@ def interactions_stream( return services_pb2.Empty() -################################################# -# Policy functions # -################################################# +# Policy functions def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device): @@ -687,9 +679,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device) logging.info("[ACTOR] Loaded discrete critic parameters from Learner.") -################################################# -# Utilities functions # -################################################# +# Utilities functions def push_transitions_to_transport_queue(transitions: list, transitions_queue): diff --git a/src/lerobot/scripts/rl/learner.py b/src/lerobot/scripts/rl/learner.py index 840f0b96e..5d9953827 100644 --- a/src/lerobot/scripts/rl/learner.py +++ b/src/lerobot/scripts/rl/learner.py @@ -103,11 +103,6 @@ from lerobot.utils.wandb_utils import WandBLogger LOG_PREFIX = "[LEARNER]" -################################################# -# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS # -################################################# - - @parser.wrap() def train_cli(cfg: TrainRLServerPipelineConfig): if not use_threads(cfg): @@ -250,9 +245,7 @@ def start_learner_threads( logging.info("[LEARNER] queues closed") -################################################# -# Core algorithm functions # -################################################# +# Core algorithm functions def add_actor_information_and_train( @@ -820,9 +813,7 @@ def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.M return optimizers, lr_scheduler -################################################# -# Training setup functions # -################################################# +# Training setup functions def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipelineConfig: @@ -1023,9 +1014,7 @@ def initialize_offline_replay_buffer( return offline_replay_buffer -################################################# -# Utilities/Helpers functions # -################################################# +# Utilities/Helpers functions def get_observation_features( diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index d0240b427..a2b92bc31 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -65,6 +65,28 @@ def update_policy( use_amp: bool = False, lock=None, ) -> tuple[MetricsTracker, dict]: + """ + Performs a single training step to update the policy's weights. + + This function executes the forward and backward passes, clips gradients, and steps the optimizer and + learning rate scheduler. It also handles mixed-precision training via a GradScaler. + + Args: + train_metrics: A MetricsTracker instance to record training statistics. + policy: The policy model to be trained. + batch: A batch of training data. + optimizer: The optimizer used to update the policy's parameters. + grad_clip_norm: The maximum norm for gradient clipping. + grad_scaler: The GradScaler for automatic mixed-precision training. + lr_scheduler: An optional learning rate scheduler. + use_amp: A boolean indicating whether to use automatic mixed precision. + lock: An optional lock for thread-safe optimizer updates. + + Returns: + A tuple containing: + - The updated MetricsTracker with new statistics for this step. + - A dictionary of outputs from the policy's forward pass, for logging purposes. + """ start_time = time.perf_counter() device = get_device_from_parameters(policy) policy.train() @@ -108,6 +130,20 @@ def update_policy( @parser.wrap() def train(cfg: TrainPipelineConfig): + """ + Main function to train a policy. + + This function orchestrates the entire training pipeline, including: + - Setting up logging, seeding, and device configuration. + - Creating the dataset, evaluation environment (if applicable), policy, and optimizer. + - Handling resumption from a checkpoint. + - Running the main training loop, which involves fetching data batches and calling `update_policy`. + - Periodically logging metrics, saving model checkpoints, and evaluating the policy. + - Pushing the final trained model to the Hugging Face Hub if configured. + + Args: + cfg: A `TrainPipelineConfig` object containing all training configurations. + """ cfg.validate() logging.info(pformat(cfg.to_dict())) diff --git a/src/lerobot/teleoperate.py b/src/lerobot/teleoperate.py index 8f5236d96..ff57511dd 100644 --- a/src/lerobot/teleoperate.py +++ b/src/lerobot/teleoperate.py @@ -121,6 +121,21 @@ def teleop_loop( robot_action_processor: RobotProcessorPipeline[dict[str, Any]] | None = None, robot_observation_processor: RobotProcessorPipeline[EnvTransition] | None = None, ): + """ + This function continuously reads actions from a teleoperation device, processes them through optional + pipelines, sends them to a robot, and optionally displays the robot's state. The loop runs at a + specified frequency until a set duration is reached or it is manually interrupted. + + Args: + teleop: The teleoperator device instance providing control actions. + robot: The robot instance being controlled. + fps: The target frequency for the control loop in frames per second. + display_data: If True, fetches robot observations and displays them in the console and Rerun. + duration: The maximum duration of the teleoperation loop in seconds. If None, the loop runs indefinitely. + teleop_action_processor: An optional pipeline to process raw actions from the teleoperator. + robot_action_processor: An optional pipeline to process actions before they are sent to the robot. + robot_observation_processor: An optional pipeline to process raw observations from the robot. + """ # Initialize processors with defaults if not provided teleop_action_processor: RobotProcessorPipeline[EnvTransition] = ( teleop_action_processor diff --git a/src/lerobot/teleoperators/phone/phone_processor.py b/src/lerobot/teleoperators/phone/phone_processor.py index f719e8dbc..8b3a3d3a7 100644 --- a/src/lerobot/teleoperators/phone/phone_processor.py +++ b/src/lerobot/teleoperators/phone/phone_processor.py @@ -26,28 +26,35 @@ from lerobot.teleoperators.phone.config_phone import PhoneOS @dataclass class MapPhoneActionToRobotAction(ActionProcessorStep): """ - Map calibrated phone pose (actions) to the inputs for robot actions + Maps calibrated phone pose actions to standardized robot action inputs. - Expected input ACTION keys: - { - "action.phone.enabled": bool, - "action.phone.pos": np.ndarray, - "action.phone.rot": Rotation, - "action.phone.raw_inputs": dict, - } + This processor step acts as a bridge between the phone teleoperator's output + and the robot's expected action format. It remaps the phone's 6-DoF pose + (position and rotation) to the robot's target end-effector pose, applying + necessary axis inversions and swaps. It also interprets platform-specific + button presses to generate a gripper command. - Output ACTION keys: - { - "action.enabled": bool, - "action.ee.{x,y,z,wx,wy,wz}" : float - "action.gripper": float, - } + Attributes: + platform: The operating system of the phone (iOS or Android), used + to determine the correct button mappings for the gripper. """ platform: PhoneOS _enabled_prev: bool = field(default=False, init=False, repr=False) def action(self, act: dict) -> dict: + """ + Processes the phone action dictionary to create a robot action dictionary. + + Args: + act: The input action dictionary from the phone teleoperator. + + Returns: + A new action dictionary formatted for the robot controller. + + Raises: + ValueError: If 'pos' or 'rot' keys are missing from the input action. + """ # Pop them from the action enabled = bool(act.pop(f"{ACTION}.phone.enabled", 0)) pos = act.pop(f"{ACTION}.phone.pos", None) diff --git a/src/lerobot/teleoperators/phone/teleop_phone.py b/src/lerobot/teleoperators/phone/teleop_phone.py index f526e6398..c90729efa 100644 --- a/src/lerobot/teleoperators/phone/teleop_phone.py +++ b/src/lerobot/teleoperators/phone/teleop_phone.py @@ -108,7 +108,17 @@ class IOSPhone(BasePhone, Teleoperator): print("Calibration done\n") def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]: - """Wait trigger for calibration: iOS: B1. Android: 'move'.""" + """ + Blocks execution until the calibration trigger is detected from the iOS device. + + This method enters a loop, continuously reading the phone's state. It waits for the user to press + and hold the 'B1' button in the HEBI Mobile I/O app. Once B1 is pressed, the loop breaks and + returns the phone's pose at that exact moment. + + Returns: + A tuple containing the position (np.ndarray) and rotation (Rotation) of the phone at the + moment the trigger was activated. + """ while True: has_pose, position, rotation, fb_pose = self._read_current_pose() if not has_pose: @@ -126,6 +136,21 @@ class IOSPhone(BasePhone, Teleoperator): time.sleep(0.01) def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]: + """ + Reads the instantaneous 6-DoF pose from the connected iOS device via the HEBI SDK. + + This method fetches the latest feedback packet from the HEBI group, extracts the ARKit + position and orientation, and converts them into a standard format. It also applies a + configured camera offset to adjust the pose from the camera's frame to the phone's + physical frame. + + Returns: + A tuple containing: + - A boolean indicating if a valid pose was successfully read. + - The 3D position as a NumPy array, or None if not available. + - The orientation as a `Rotation` object, or None if not available. + - The raw HEBI feedback object for accessing other data like button presses. + """ fbk = self._group.get_next_feedback() pose = fbk[0] ar_pos = getattr(pose, "ar_position", None) @@ -228,7 +253,18 @@ class AndroidPhone(BasePhone, Teleoperator): print("Calibration done\n") def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]: - """Wait trigger for calibration: iOS: B1. Android: 'move'.""" + """ + Blocks execution until the calibration trigger is detected from the Android device. + + This method enters a loop, continuously checking the latest message received from the WebXR + session. It waits for the user to touch and move their finger on the screen, which generates + a `move` event. Once this event is detected, the loop breaks and returns the phone's current + pose. + + Returns: + A tuple containing the position (np.ndarray) and rotation (Rotation) of the phone at the + moment the trigger was activated. + """ while True: with self._android_lock: msg = self._latest_message or {} @@ -241,6 +277,20 @@ class AndroidPhone(BasePhone, Teleoperator): time.sleep(0.01) def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]: + """ + Reads the latest 6-DoF pose received from the Android device's WebXR session. + + This method accesses the most recent pose data stored by the `_android_callback`. It uses a + thread lock to safely read the shared `_latest_pose` variable. The pose, a 4x4 matrix, is + then decomposed into position and rotation, and the configured camera offset is applied. + + Returns: + A tuple containing: + - A boolean indicating if a valid pose was available. + - The 3D position as a NumPy array, or None if no pose has been received yet. + - The orientation as a `Rotation` object, or None if no pose has been received. + - The raw 4x4 pose matrix as received from the teleop stream. + """ with self._android_lock: if self._latest_pose is None: return False, None, None, None @@ -251,6 +301,19 @@ class AndroidPhone(BasePhone, Teleoperator): return True, pos, rot, pose def _android_callback(self, pose: np.ndarray, message: dict) -> None: + """ + Callback function to handle incoming data from the Android teleop stream. + + This method is executed by the `teleop` package's subscriber thread whenever a new + pose and message are received from the WebXR session on the Android phone. It updates + the internal state (`_latest_pose` and `_latest_message`) with the new data. + A thread lock is used to ensure that these shared variables are updated atomically, + preventing race conditions with the main thread that reads them. + + Args: + pose: A 4x4 NumPy array representing the phone's transformation matrix. + message: A dictionary containing additional data, such as button presses or touch events. + """ with self._android_lock: self._latest_pose = pose self._latest_message = message diff --git a/src/lerobot/utils/control_utils.py b/src/lerobot/utils/control_utils.py index fb883d2c5..087f35732 100644 --- a/src/lerobot/utils/control_utils.py +++ b/src/lerobot/utils/control_utils.py @@ -36,6 +36,20 @@ from lerobot.robots import Robot def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): + """ + Logs performance metrics for a single step of the robot control loop. + + This function formats and prints a single line of log information, including episode/frame counters, + total loop time (dt), and detailed timings for various robot and camera operations. It can also + highlight performance drops in yellow if the actual FPS is lower than the target FPS. + + Args: + robot: The `Robot` instance, used to access its internal logs for detailed timings. + dt_s: The total duration of the control loop step in seconds. + episode_index: The index of the current episode. + frame_index: The index of the current frame within the episode. + fps: The target frames per second, used to check for performance degradation. + """ log_items = [] if episode_index is not None: log_items.append(f"ep:{episode_index}") @@ -81,7 +95,16 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f @cache def is_headless(): - """Detects if python is running without a monitor.""" + """ + Detects if the Python script is running in a headless environment (e.g., without a display). + + This function attempts to import `pynput`, a library that requires a graphical environment. + If the import fails, it assumes the environment is headless. The result is cached to avoid + re-running the check. + + Returns: + True if the environment is determined to be headless, False otherwise. + """ try: import pynput # noqa @@ -108,6 +131,29 @@ def predict_action( task: str | None = None, robot_type: str | None = None, ): + """ + Performs a single-step inference to predict a robot action from an observation. + + This function encapsulates the full inference pipeline: + 1. Prepares the observation by converting it to PyTorch tensors and adding a batch dimension. + 2. Runs the preprocessor pipeline on the observation. + 3. Feeds the processed observation to the policy to get a raw action. + 4. Runs the postprocessor pipeline on the raw action. + 5. Formats the final action by removing the batch dimension and moving it to the CPU. + + Args: + observation: A dictionary of NumPy arrays representing the robot's current observation. + policy: The `PreTrainedPolicy` model to use for action prediction. + device: The `torch.device` (e.g., 'cuda' or 'cpu') to run inference on. + preprocessor: The `PolicyProcessorPipeline` for preprocessing observations. + postprocessor: The `PolicyProcessorPipeline` for postprocessing actions. + use_amp: A boolean to enable/disable Automatic Mixed Precision for CUDA inference. + task: An optional string identifier for the task. + robot_type: An optional string identifier for the robot type. + + Returns: + A `torch.Tensor` containing the predicted action, ready for the robot. + """ observation = copy(observation) with ( torch.inference_mode(), @@ -143,6 +189,18 @@ def predict_action( def init_keyboard_listener(): + """ + Initializes a non-blocking keyboard listener for real-time user interaction. + + This function sets up a listener for specific keys (right arrow, left arrow, escape) to control + the program flow during execution, such as stopping recording or exiting loops. It gracefully + handles headless environments where keyboard listening is not possible. + + Returns: + A tuple containing: + - The `pynput.keyboard.Listener` instance, or `None` if in a headless environment. + - A dictionary of event flags (e.g., `exit_early`) that are set by key presses. + """ # Allow to exit early while recording an episode or resetting the environment, # by tapping the right arrow key '->'. This might require a sudo permission # to allow your terminal to monitor keyboard events. @@ -184,6 +242,19 @@ def init_keyboard_listener(): def sanity_check_dataset_name(repo_id, policy_cfg): + """ + Validates the dataset repository name against the presence of a policy configuration. + + This function enforces a naming convention: a dataset repository ID should start with "eval_" + if and only if a policy configuration is provided for evaluation purposes. + + Args: + repo_id: The Hugging Face Hub repository ID of the dataset. + policy_cfg: The configuration object for the policy, or `None`. + + Raises: + ValueError: If the naming convention is violated. + """ _, dataset_name = repo_id.split("/") # either repo_id doesnt start with "eval_" and there is no policy # or repo_id starts with "eval_" and there is a policy @@ -204,6 +275,21 @@ def sanity_check_dataset_name(repo_id, policy_cfg): def sanity_check_dataset_robot_compatibility( dataset: LeRobotDataset, robot: Robot, fps: int, features: dict ) -> None: + """ + Checks if a dataset's metadata is compatible with the current robot and recording setup. + + This function compares key metadata fields (`robot_type`, `fps`, and `features`) from the + dataset against the current configuration to ensure that appended data will be consistent. + + Args: + dataset: The `LeRobotDataset` instance to check. + robot: The `Robot` instance representing the current hardware setup. + fps: The current recording frequency (frames per second). + features: The dictionary of features for the current recording session. + + Raises: + ValueError: If any of the checked metadata fields do not match. + """ fields = [ ("robot_type", dataset.meta.robot_type, robot.robot_type), ("fps", dataset.fps, fps), diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index c4f273506..e6acc87de 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -42,7 +42,23 @@ def log_rerun_data( observation: dict[str, Any] | None = None, action: dict[str, Any] | None = None, ) -> None: - """Log observation and action data to Rerun for visualization.""" + """ + Logs observation and action data to Rerun for real-time visualization. + + This function iterates through the provided observation and action dictionaries and sends their contents + to the Rerun viewer. It handles different data types appropriately: + - Scalar values (floats, ints) are logged as `rr.Scalar`. + - 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed + from CHW to HWC format and logged as `rr.Image`. + - 1D NumPy arrays are logged as a series of individual scalars, with each element indexed. + - Other multi-dimensional arrays are flattened and logged as individual scalars. + + Keys are automatically namespaced with "observation." or "action." if not already present. + + Args: + observation: An optional dictionary containing observation data to log. + action: An optional dictionary containing action data to log. + """ if observation: for k, v in observation.items(): if v is None: