chore(docs): update doctrines pipeline files (#1872)

* docs(processor): update docstrings batch_processor

* docs(processor): update docstrings device_processor

* docs(processor): update docstrings tokenizer_processor

* update docstrings processor_act

* update docstrings for pipeline_features

* update docstrings for utils

* update docstring for processor_diffusion

* update docstrings factory

* add docstrings to pi0 processor

* add docstring to pi0fast processor

* add docstring classifier processor

* add docstring to sac processor

* add docstring smolvla processor

* add docstring to tdmpc processor

* add docstring to vqbet processor

* add docstrings to converters

* add docstrings for delta_action_processor

* add docstring to gym action processor

* update hil processor

* add docstring to joint obs processor

* add docstring to migrate_normalize_processor

* update docstrings normalize processor

* update docstring normalize processor

* update docstrings observation processor

* update docstrings rename_processor

* add docstrings robot_kinematic_processor

* cleanup rl comments

* add docstring to train.py

* add docstring to teleoperate.py

* add docstrings to phone_processor.py

* add docstrings to teleop_phone.py

* add docstrings to control_utils.py

* add docstrings to visualization_utils.py

---------

Co-authored-by: Pepijn <pepijn@huggingface.co>
This commit is contained in:
Steven Palma
2025-09-08 18:44:15 +02:00
committed by GitHub
parent d32006440c
commit af9ddcf9a2
33 changed files with 2325 additions and 519 deletions

View File

@@ -27,14 +27,30 @@ def aggregate_pipeline_dataset_features(
use_videos: bool = True,
patterns: Sequence[str] | None = None,
) -> dict[str, dict]:
"""
Aggregates the pipeline's features and returns a features dict ready for the dataset,
filtered to only those keys matching any of the given patterns (for action/state only).
"""Aggregates and filters dataset features based on a data processing pipeline.
- `initial_features`: raw camera specs, e.g. {"front": (h,w,c), ...}
- `use_videos`: whether to treat image features as video streams
- `patterns`: regexes to filter action & state features; images are included
whenever use_videos=True, regardless of patterns.
This function determines the final structure of dataset features after applying a series
of processing steps defined in a pipeline. It starts with an initial set of hardware
features (e.g., camera image shapes), transforms them using the pipeline, and then
filters the results.
Image features are controlled by the `use_videos` flag, while action and state features
can be selectively included by matching their keys against the provided regex `patterns`.
The final output is formatted to be compatible with Hugging Face Datasets feature dictionaries.
Args:
pipeline (DataProcessorPipeline): The data processing pipeline that defines all
feature transformations.
initial_features (dict[str, Any]): A dictionary of initial hardware features, where
keys are feature names and values are their shapes or types (e.g., camera resolutions).
use_videos (bool): If `True`, includes image/video features in the output. Defaults to `True`.
patterns (Sequence[str] | None): An optional sequence of regular expression patterns.
Only action and state keys that match at least one pattern will be included. If `None`,
all action and state keys are kept. Defaults to `None`.
Returns:
dict[str, dict]: A dictionary representing the final dataset features, structured for
use with `datasets.Features`.
"""
import re

View File

@@ -75,13 +75,20 @@ DEFAULT_FEATURES = {
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
"""Flatten a nested dictionary by joining keys with a separator.
For example:
```
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
>>> print(flatten_dict(dct))
{"a/b": 1, "a/c/d": 2, "e": 3}
Example:
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}
>>> print(flatten_dict(dct))
{'a/b': 1, 'a/c/d': 2, 'e': 3}
Args:
d (dict): The dictionary to flatten.
parent_key (str): The base key to prepend to the keys in this level.
sep (str): The separator to use between keys.
Returns:
dict: A flattened dictionary.
"""
items = []
for k, v in d.items():
@@ -94,6 +101,20 @@ def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
def unflatten_dict(d: dict, sep: str = "/") -> dict:
"""Unflatten a dictionary with delimited keys into a nested dictionary.
Example:
>>> flat_dct = {"a/b": 1, "a/c/d": 2, "e": 3}
>>> print(unflatten_dict(flat_dct))
{'a': {'b': 1, 'c': {'d': 2}}, 'e': 3}
Args:
d (dict): A dictionary with flattened keys.
sep (str): The separator used in the keys.
Returns:
dict: A nested dictionary.
"""
outdict = {}
for key, value in d.items():
parts = key.split(sep)
@@ -107,6 +128,16 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
"""Access an item in a nested dictionary using a flattened key.
Args:
obj (DictLike): The nested dictionary-like object.
flattened_key (str): A key with parts separated by `sep`.
sep (str): The separator used in the flattened key.
Returns:
Any: The value from the nested dictionary.
"""
split_keys = flattened_key.split(sep)
getter = obj[split_keys[0]]
if len(split_keys) == 1:
@@ -119,6 +150,19 @@ def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
"""Serialize a dictionary containing tensors or numpy arrays to be JSON-compatible.
Converts torch.Tensor, np.ndarray, and np.generic types to lists or native Python types.
Args:
stats (dict): A dictionary that may contain non-serializable numeric types.
Returns:
dict: A dictionary with all values converted to JSON-serializable types.
Raises:
NotImplementedError: If a value has an unsupported type.
"""
serialized_dict = {}
for key, value in flatten_dict(stats).items():
if isinstance(value, (torch.Tensor, np.ndarray)):
@@ -133,6 +177,17 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
"""Embed image bytes into the dataset table before saving to Parquet.
This function prepares a Hugging Face dataset for serialization by converting
image objects into an embedded format that can be stored in Arrow/Parquet.
Args:
dataset (datasets.Dataset): The input dataset, possibly containing image features.
Returns:
datasets.Dataset: The dataset with images embedded in the table storage.
"""
# Embed image bytes into the table before saving to parquet
format = dataset.format
dataset = dataset.with_format("arrow")
@@ -142,38 +197,94 @@ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
def load_json(fpath: Path) -> Any:
"""Load data from a JSON file.
Args:
fpath (Path): Path to the JSON file.
Returns:
Any: The data loaded from the JSON file.
"""
with open(fpath) as f:
return json.load(f)
def write_json(data: dict, fpath: Path) -> None:
"""Write data to a JSON file.
Creates parent directories if they don't exist.
Args:
data (dict): The dictionary to write.
fpath (Path): The path to the output JSON file.
"""
fpath.parent.mkdir(exist_ok=True, parents=True)
with open(fpath, "w") as f:
json.dump(data, f, indent=4, ensure_ascii=False)
def load_jsonlines(fpath: Path) -> list[Any]:
"""Load data from a JSON Lines file.
Args:
fpath (Path): Path to the JSON Lines file.
Returns:
list[Any]: A list of objects loaded from the file.
"""
with jsonlines.open(fpath, "r") as reader:
return list(reader)
def write_jsonlines(data: dict, fpath: Path) -> None:
"""Write a list of dictionaries to a JSON Lines file.
Creates parent directories if they don't exist.
Args:
data (dict): The list of dictionaries to write.
fpath (Path): The path to the output JSON Lines file.
"""
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(data)
def append_jsonlines(data: dict, fpath: Path) -> None:
"""Append a dictionary to a JSON Lines file.
Creates parent directories if they don't exist.
Args:
data (dict): The dictionary to append.
fpath (Path): The path to the JSON Lines file.
"""
fpath.parent.mkdir(exist_ok=True, parents=True)
with jsonlines.open(fpath, "a") as writer:
writer.write(data)
def write_info(info: dict, local_dir: Path):
"""Write dataset info metadata to its standard file path.
Args:
info (dict): The dataset information dictionary.
local_dir (Path): The root directory of the dataset.
"""
write_json(info, local_dir / INFO_PATH)
def load_info(local_dir: Path) -> dict:
"""Load dataset info metadata from its standard file path.
Also converts shape lists to tuples for consistency.
Args:
local_dir (Path): The root directory of the dataset.
Returns:
dict: The dataset information dictionary.
"""
info = load_json(local_dir / INFO_PATH)
for ft in info["features"].values():
ft["shape"] = tuple(ft["shape"])
@@ -181,16 +292,40 @@ def load_info(local_dir: Path) -> dict:
def write_stats(stats: dict, local_dir: Path):
"""Serialize and write dataset statistics to their standard file path.
Args:
stats (dict): The statistics dictionary (can contain tensors/numpy arrays).
local_dir (Path): The root directory of the dataset.
"""
serialized_stats = serialize_dict(stats)
write_json(serialized_stats, local_dir / STATS_PATH)
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
"""Recursively cast numerical values in a stats dictionary to numpy arrays.
Args:
stats (dict): The statistics dictionary.
Returns:
dict: The statistics dictionary with values cast to numpy arrays.
"""
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
"""Load dataset statistics and cast numerical values to numpy arrays.
Returns None if the stats file doesn't exist.
Args:
local_dir (Path): The root directory of the dataset.
Returns:
A dictionary of statistics or None if the file is not found.
"""
if not (local_dir / STATS_PATH).exists():
return None
stats = load_json(local_dir / STATS_PATH)
@@ -198,6 +333,13 @@ def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
def write_task(task_index: int, task: dict, local_dir: Path):
"""Write a single task to the tasks metadata file.
Args:
task_index (int): The index of the task.
task (dict): The task description dictionary.
local_dir (Path): The root directory of the dataset.
"""
task_dict = {
"task_index": task_index,
"task": task,
@@ -206,6 +348,16 @@ def write_task(task_index: int, task: dict, local_dir: Path):
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
"""Load tasks from the tasks metadata file.
Args:
local_dir (Path): The root directory of the dataset.
Returns:
A tuple containing:
- A dictionary mapping task index to task description.
- A dictionary mapping task description to task index.
"""
tasks = load_jsonlines(local_dir / TASKS_PATH)
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
@@ -213,15 +365,36 @@ def load_tasks(local_dir: Path) -> tuple[dict, dict]:
def write_episode(episode: dict, local_dir: Path):
"""Write a single episode's metadata to the episodes metadata file.
Args:
episode (dict): The episode metadata dictionary.
local_dir (Path): The root directory of the dataset.
"""
append_jsonlines(episode, local_dir / EPISODES_PATH)
def load_episodes(local_dir: Path) -> dict:
"""Load episode metadata from the episodes metadata file.
Args:
local_dir (Path): The root directory of the dataset.
Returns:
dict: A dictionary mapping episode index to episode metadata.
"""
episodes = load_jsonlines(local_dir / EPISODES_PATH)
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
"""Write statistics for a single episode to the episode stats file.
Args:
episode_index (int): The index of the episode.
episode_stats (dict): The statistics for the episode.
local_dir (Path): The root directory of the dataset.
"""
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
# is a dictionary of stats and not an integer.
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
@@ -229,6 +402,14 @@ def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path
def load_episodes_stats(local_dir: Path) -> dict:
"""Load per-episode statistics from the episode stats file.
Args:
local_dir (Path): The root directory of the dataset.
Returns:
dict: A dictionary mapping episode index to its statistics dictionary.
"""
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
return {
item["episode_index"]: cast_stats_to_numpy(item["stats"])
@@ -239,12 +420,35 @@ def load_episodes_stats(local_dir: Path) -> dict:
def backward_compatible_episodes_stats(
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
) -> dict[str, dict[str, np.ndarray]]:
"""Create a per-episode stats dictionary from a global stats dictionary.
This is used for backward compatibility with older datasets that only had global stats.
Args:
stats (dict): The global dataset statistics.
episodes (list[int]): A list of episode indices.
Returns:
dict: A dictionary mapping each episode index to the global stats.
"""
return dict.fromkeys(episodes, stats)
def load_image_as_numpy(
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
) -> np.ndarray:
"""Load an image from a file into a numpy array.
Args:
fpath (str | Path): Path to the image file.
dtype (np.dtype): The desired data type of the output array. If floating,
pixels are scaled to [0, 1].
channel_first (bool): If True, converts the image to (C, H, W) format.
Otherwise, it remains in (H, W, C) format.
Returns:
np.ndarray: The image as a numpy array.
"""
img = PILImage.open(fpath).convert("RGB")
img_array = np.array(img, dtype=dtype)
if channel_first: # (H, W, C) -> (C, H, W)
@@ -255,10 +459,19 @@ def load_image_as_numpy(
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to
a channel last representation (h w c) of uint8 type, to a torch image representation
with channel first (c h w) of float32 type in range [0,1].
"""Convert a batch from a Hugging Face dataset to torch tensors.
This transform function converts items from Hugging Face dataset format (pyarrow)
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
types are converted to torch.tensor.
Args:
items_dict (dict): A dictionary representing a batch of data from a
Hugging Face dataset.
Returns:
dict: The batch with items converted to torch tensors.
"""
for key in items_dict:
first_item = items_dict[key][0]
@@ -273,6 +486,14 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
def is_valid_version(version: str) -> bool:
"""Check if a string is a valid PEP 440 version.
Args:
version (str): The version string to check.
Returns:
bool: True if the version string is valid, False otherwise.
"""
try:
packaging.version.parse(version)
return True
@@ -286,6 +507,18 @@ def check_version_compatibility(
current_version: str | packaging.version.Version,
enforce_breaking_major: bool = True,
) -> None:
"""Check for version compatibility between a dataset and the current codebase.
Args:
repo_id (str): The repository ID for logging purposes.
version_to_check (str | packaging.version.Version): The version of the dataset.
current_version (str | packaging.version.Version): The current version of the codebase.
enforce_breaking_major (bool): If True, raise an error on major version mismatch.
Raises:
BackwardCompatibilityError: If the dataset version is from a newer, incompatible
major version of the codebase.
"""
v_check = (
packaging.version.parse(version_to_check)
if not isinstance(version_to_check, packaging.version.Version)
@@ -303,7 +536,14 @@ def check_version_compatibility(
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
"""Returns available valid versions (branches and tags) on given repo."""
"""Return available valid versions (branches and tags) on a given Hub repo.
Args:
repo_id (str): The repository ID on the Hugging Face Hub.
Returns:
list[packaging.version.Version]: A list of valid versions found.
"""
api = HfApi()
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
@@ -316,9 +556,22 @@ def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str:
"""
Returns the version if available on repo or the latest compatible one.
Otherwise, will throw a `CompatibilityError`.
"""Return the specified version if available on repo, or the latest compatible one.
If the exact version is not found, it looks for the latest version with the
same major version number that is less than or equal to the target minor version.
Args:
repo_id (str): The repository ID on the Hugging Face Hub.
version (str | packaging.version.Version): The target version.
Returns:
str: The safe version string (e.g., "v1.2.3") to use as a revision.
Raises:
RevisionNotFoundError: If the repo has no version tags.
BackwardCompatibilityError: If only older major versions are available.
ForwardCompatibilityError: If only newer major versions are available.
"""
target_version = (
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
@@ -360,6 +613,17 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
def get_hf_features_from_features(features: dict) -> datasets.Features:
"""Convert a LeRobot features dictionary to a `datasets.Features` object.
Args:
features (dict): A LeRobot-style feature dictionary.
Returns:
datasets.Features: The corresponding Hugging Face `datasets.Features` object.
Raises:
ValueError: If a feature has an unsupported shape.
"""
hf_features = {}
for key, ft in features.items():
if ft["dtype"] == "video":
@@ -387,6 +651,14 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
def _validate_feature_names(features: dict[str, dict]) -> None:
"""Validate that feature names do not contain invalid characters.
Args:
features (dict): The LeRobot features dictionary.
Raises:
ValueError: If any feature name contains '/'.
"""
invalid_features = {name: ft for name, ft in features.items() if "/" in name}
if invalid_features:
raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
@@ -395,6 +667,22 @@ def _validate_feature_names(features: dict[str, dict]) -> None:
def hw_to_dataset_features(
hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
) -> dict[str, dict]:
"""Convert hardware-specific features to a LeRobot dataset feature dictionary.
This function takes a dictionary describing hardware outputs (like joint states
or camera image shapes) and formats it into the standard LeRobot feature
specification.
Args:
hw_features (dict): Dictionary mapping feature names to their type (float for
joints) or shape (tuple for images).
prefix (str): The prefix to add to the feature keys (e.g., "observation"
or "action").
use_video (bool): If True, image features are marked as "video", otherwise "image".
Returns:
dict: A LeRobot features dictionary.
"""
features = {}
joint_fts = {key: ftype for key, ftype in hw_features.items() if ftype is float}
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
@@ -427,6 +715,20 @@ def hw_to_dataset_features(
def build_dataset_frame(
ds_features: dict[str, dict], values: dict[str, Any], prefix: str
) -> dict[str, np.ndarray]:
"""Construct a single data frame from raw values based on dataset features.
A "frame" is a dictionary containing all the data for a single timestep,
formatted as numpy arrays according to the feature specification.
Args:
ds_features (dict): The LeRobot dataset features dictionary.
values (dict): A dictionary of raw values from the hardware/environment.
prefix (str): The prefix to filter features by (e.g., "observation"
or "action").
Returns:
dict: A dictionary representing a single frame of data.
"""
frame = {}
for key, ft in ds_features.items():
if key in DEFAULT_FEATURES or not key.startswith(prefix):
@@ -440,6 +742,21 @@ def build_dataset_frame(
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
"""Convert dataset features to policy features.
This function transforms the dataset's feature specification into a format
that a policy can use, classifying features by type (e.g., visual, state,
action) and ensuring correct shapes (e.g., channel-first for images).
Args:
features (dict): The LeRobot dataset features dictionary.
Returns:
dict: A dictionary mapping feature keys to `PolicyFeature` objects.
Raises:
ValueError: If an image feature does not have a 3D shape.
"""
# TODO(aliberts): Implement "type" in dataset features and simplify this
policy_features = {}
for key, ft in features.items():
@@ -471,11 +788,19 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
def combine_feature_dicts(*dicts: dict) -> dict:
"""
Merge LeRobot grouped feature dicts.
"""Merge LeRobot grouped feature dicts.
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
- For others (observation.images.*), last one wins (if they are identical).
- For others (e.g. `observation.images.*`), the last one wins (if they are identical).
Args:
*dicts: A variable number of LeRobot feature dictionaries to merge.
Returns:
dict: A single merged feature dictionary.
Raises:
ValueError: If there's a dtype mismatch for a feature being merged.
"""
out: dict = {}
for d in dicts:
@@ -521,6 +846,18 @@ def create_empty_dataset_info(
use_videos: bool,
robot_type: str | None = None,
) -> dict:
"""Create a template dictionary for a new dataset's `info.json`.
Args:
codebase_version (str): The version of the LeRobot codebase.
fps (int): The frames per second of the data.
features (dict): The LeRobot features dictionary for the dataset.
use_videos (bool): Whether the dataset will store videos.
robot_type (str | None): The type of robot used, if any.
Returns:
dict: A dictionary with the initial dataset metadata.
"""
return {
"codebase_version": codebase_version,
"robot_type": robot_type,
@@ -541,6 +878,18 @@ def create_empty_dataset_info(
def get_episode_data_index(
episode_dicts: dict[dict], episodes: list[int] | None = None
) -> dict[str, torch.Tensor]:
"""Calculate the start and end indices for each episode in a flattened dataset.
Args:
episode_dicts (dict): A dictionary mapping episode index to episode metadata,
which must contain a "length" key.
episodes (list[int] | None): An optional list of episode indices to consider.
If None, all episodes are used.
Returns:
dict: A dictionary with "from" and "to" keys, containing torch tensors
with the start and end indices for each episode.
"""
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
@@ -560,16 +909,19 @@ def check_timestamps_sync(
tolerance_s: float,
raise_value_error: bool = True,
) -> bool:
"""
This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
to account for possible numerical error.
"""Check if timestamps are separated by (1/fps) +/- tolerance.
This check ensures that consecutive timestamps within an episode are spaced
correctly, accounting for possible numerical errors. It ignores the boundaries
between episodes.
Args:
timestamps (np.ndarray): Array of timestamps in seconds.
episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
episode_data_index (dict): A dictionary that includes 'to',
which identifies indices for the end of each episode.
fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
fps (int): Frames per second. Used to check the expected difference between
consecutive timestamps.
tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
raise_value_error (bool): Whether to raise a ValueError if the check fails.
@@ -577,7 +929,8 @@ def check_timestamps_sync(
bool: True if all checked timestamp differences lie within tolerance, False otherwise.
Raises:
ValueError: If the check fails and `raise_value_error` is True.
ValueError: If `timestamps` and `episode_indices` shapes do not match, or if
the check fails and `raise_value_error` is True.
"""
if timestamps.shape != episode_indices.shape:
raise ValueError(
@@ -628,9 +981,23 @@ def check_timestamps_sync(
def check_delta_timestamps(
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
) -> bool:
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
actual timestamps from the dataset.
"""Check if delta timestamps are multiples of 1/fps +/- tolerance.
This ensures that adding these delta timestamps to any existing timestamp in
the dataset will result in a value that aligns with the dataset's frame rate.
Args:
delta_timestamps (dict): A dictionary where values are lists of time
deltas in seconds.
fps (int): The frames per second of the dataset.
tolerance_s (float): The allowed tolerance in seconds.
raise_value_error (bool): If True, raises an error on failure.
Returns:
bool: True if all deltas are valid, False otherwise.
Raises:
ValueError: If any delta is outside the tolerance and `raise_value_error` is True.
"""
outside_tolerance = {}
for key, delta_ts in delta_timestamps.items():
@@ -656,6 +1023,15 @@ def check_delta_timestamps(
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
"""Convert delta timestamps in seconds to delta indices in frames.
Args:
delta_timestamps (dict): A dictionary of time deltas in seconds.
fps (int): The frames per second of the dataset.
Returns:
dict: A dictionary of frame delta indices.
"""
delta_indices = {}
for key, delta_ts in delta_timestamps.items():
delta_indices[key] = [round(d * fps) for d in delta_ts]
@@ -664,9 +1040,17 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic
def cycle(iterable):
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
"""Create a dataloader-safe cyclical iterator.
See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe.
This is an equivalent of `itertools.cycle` but is safe for use with
PyTorch DataLoaders with multiple workers.
See https://github.com/pytorch/pytorch/issues/23900 for details.
Args:
iterable: The iterable to cycle over.
Yields:
Items from the iterable, restarting from the beginning when exhausted.
"""
iterator = iter(iterable)
while True:
@@ -677,8 +1061,14 @@ def cycle(iterable):
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
exists before creating it.
"""Create a branch on an existing Hugging Face repo.
Deletes the branch if it already exists before creating it.
Args:
repo_id (str): The ID of the repository.
branch (str): The name of the branch to create.
repo_type (str | None): The type of the repository (e.g., "dataset").
"""
api = HfApi()
@@ -696,9 +1086,20 @@ def create_lerobot_dataset_card(
dataset_info: dict | None = None,
**kwargs,
) -> DatasetCard:
"""
Keyword arguments will be used to replace values in src/lerobot/datasets/card_template.md.
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
"""Create a `DatasetCard` for a LeRobot dataset.
Keyword arguments are used to replace values in the card template.
Note: If specified, `license` must be a valid license identifier from
https://huggingface.co/docs/hub/repositories-licenses.
Args:
tags (list | None): A list of tags to add to the dataset card.
dataset_info (dict | None): The dataset's info dictionary, which will
be displayed on the card.
**kwargs: Additional keyword arguments to populate the card template.
Returns:
DatasetCard: The generated dataset card object.
"""
card_tags = ["LeRobot"]
@@ -730,19 +1131,16 @@ def create_lerobot_dataset_card(
class IterableNamespace(SimpleNamespace):
"""
A namespace object that supports both dictionary-like iteration and dot notation access.
Automatically converts nested dictionaries into IterableNamespaces.
"""A namespace object that supports both dictionary-like iteration and dot notation.
This class extends SimpleNamespace to provide:
- Dictionary-style iteration over keys
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
- Dictionary-like methods: items(), keys(), values()
- Recursive conversion of nested dictionaries
This class extends `SimpleNamespace` to provide dictionary-style iteration,
access to items via brackets (`obj["key"]`), and dictionary-like methods
(`items()`, `keys()`, `values()`). Nested dictionaries are recursively
converted to `IterableNamespace` objects.
Args:
dictionary: Optional dictionary to initialize the namespace
**kwargs: Additional keyword arguments passed to SimpleNamespace
dictionary (dict, optional): A dictionary to initialize the namespace with.
**kwargs: Additional keyword arguments to initialize the namespace.
Examples:
>>> data = {"name": "Alice", "details": {"age": 25}}
@@ -756,10 +1154,16 @@ class IterableNamespace(SimpleNamespace):
>>> for key, value in ns.items():
... print(f"{key}: {value}")
name: Alice
details: IterableNamespace(age=25)
details: <__main__.IterableNamespace object at ...>
"""
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
"""Initialize the IterableNamespace.
Args:
dictionary (dict, optional): Dictionary to populate the namespace.
**kwargs: Keyword arguments to populate the namespace.
"""
super().__init__(**kwargs)
if dictionary is not None:
for key, value in dictionary.items():
@@ -769,22 +1173,46 @@ class IterableNamespace(SimpleNamespace):
setattr(self, key, value)
def __iter__(self) -> Iterator[str]:
"""Return an iterator over the keys of the namespace."""
return iter(vars(self))
def __getitem__(self, key: str) -> Any:
"""Allow bracket-style access to attributes.
Args:
key (str): The name of the attribute.
Returns:
Any: The value of the attribute.
"""
return vars(self)[key]
def items(self):
"""Return a view of the namespace's (key, value) pairs."""
return vars(self).items()
def values(self):
"""Return a view of the namespace's values."""
return vars(self).values()
def keys(self):
"""Return a view of the namespace's keys."""
return vars(self).keys()
def validate_frame(frame: dict, features: dict):
"""Validate a single data frame against the dataset's feature specification.
Checks for missing/extra features, and validates the dtype and shape of each
provided feature.
Args:
frame (dict): The data frame to validate.
features (dict): The LeRobot features dictionary for the dataset.
Raises:
ValueError: If the frame does not match the feature specification.
"""
expected_features = set(features) - set(DEFAULT_FEATURES)
actual_features = set(frame)
@@ -799,6 +1227,15 @@ def validate_frame(frame: dict, features: dict):
def validate_features_presence(actual_features: set[str], expected_features: set[str]):
"""Check for missing or extra features in a frame.
Args:
actual_features (set[str]): The set of feature names present in the frame.
expected_features (set[str]): The set of feature names expected in the frame.
Returns:
str: An error message string if there's a mismatch, otherwise an empty string.
"""
error_message = ""
missing_features = expected_features - actual_features
extra_features = actual_features - expected_features
@@ -814,6 +1251,19 @@ def validate_features_presence(actual_features: set[str], expected_features: set
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
"""Validate the dtype and shape of a single feature's value.
Args:
name (str): The name of the feature.
feature (dict): The feature specification from the LeRobot features dictionary.
value: The value of the feature to validate.
Returns:
str: An error message if validation fails, otherwise an empty string.
Raises:
NotImplementedError: If the feature dtype is not supported for validation.
"""
expected_dtype = feature["dtype"]
expected_shape = feature["shape"]
if is_valid_numpy_dtype_string(expected_dtype):
@@ -829,6 +1279,17 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray
def validate_feature_numpy_array(
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
):
"""Validate a feature that is expected to be a numpy array.
Args:
name (str): The name of the feature.
expected_dtype (str): The expected numpy dtype as a string.
expected_shape (list[int]): The expected shape.
value (np.ndarray): The numpy array to validate.
Returns:
str: An error message if validation fails, otherwise an empty string.
"""
error_message = ""
if isinstance(value, np.ndarray):
actual_dtype = value.dtype
@@ -846,6 +1307,18 @@ def validate_feature_numpy_array(
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
"""Validate a feature that is expected to be an image or video frame.
Accepts `np.ndarray` (channel-first or channel-last) or `PIL.Image.Image`.
Args:
name (str): The name of the feature.
expected_shape (list[str]): The expected shape (C, H, W).
value: The image data to validate.
Returns:
str: An error message if validation fails, otherwise an empty string.
"""
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
@@ -862,12 +1335,35 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value:
def validate_feature_string(name: str, value: str):
"""Validate a feature that is expected to be a string.
Args:
name (str): The name of the feature.
value (str): The value to validate.
Returns:
str: An error message if validation fails, otherwise an empty string.
"""
if not isinstance(value, str):
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
return ""
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
"""Validate the episode buffer before it's written to disk.
Ensures the buffer has the required keys, contains at least one frame, and
has features consistent with the dataset's specification.
Args:
episode_buffer (dict): The buffer containing data for a single episode.
total_episodes (int): The current total number of episodes in the dataset.
features (dict): The LeRobot features dictionary for the dataset.
Raises:
ValueError: If the buffer is invalid.
NotImplementedError: If the episode index is manually set and doesn't match.
"""
if "size" not in episode_buffer:
raise ValueError("size key not found in episode_buffer")

View File

@@ -34,6 +34,24 @@ def make_act_pre_post_processors(
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""Creates the pre- and post-processing pipelines for the ACT policy.
The pre-processing pipeline handles normalization, batching, and device placement for the model inputs.
The post-processing pipeline handles unnormalization and moves the model outputs back to the CPU.
Args:
config (ACTConfig): The ACT policy configuration object.
dataset_stats (dict[str, dict[str, torch.Tensor]] | None): A dictionary containing dataset
statistics (e.g., mean and std) used for normalization. Defaults to None.
preprocessor_kwargs (ProcessorKwargs | None): Extra keyword arguments to pass to the
preprocessor pipeline's constructor. Defaults to None.
postprocessor_kwargs (ProcessorKwargs | None): Extra keyword arguments to pass to the
postprocessor pipeline's constructor. Defaults to None.
Returns:
tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: A tuple containing the
pre-processor pipeline and the post-processor pipeline.
"""
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:

View File

@@ -35,6 +35,32 @@ def make_diffusion_pre_post_processors(
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""
Constructs pre-processor and post-processor pipelines for a diffusion policy.
The pre-processing pipeline prepares the input data for the model by:
1. Renaming features (if a `rename_map` is provided in `preprocessor_kwargs`).
2. Normalizing the input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Moving the data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving the data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the diffusion policy,
containing feature definitions, normalization mappings, and device information.
dataset_stats: A dictionary of statistics used for normalization.
Defaults to None.
preprocessor_kwargs: Additional keyword arguments
for the pre-processor pipeline. Defaults to an empty dictionary.
postprocessor_kwargs: Additional keyword arguments
for the post-processor pipeline. Defaults to an empty dictionary.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:

View File

@@ -43,7 +43,22 @@ from lerobot.processor import PolicyProcessorPipeline, ProcessorKwargs
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
"""
Retrieves a policy class by its registered name.
This function uses dynamic imports to avoid loading all policy classes into memory
at once, improving startup time and reducing dependencies.
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla".
Returns:
The policy class corresponding to the given name.
Raises:
NotImplementedError: If the policy name is not recognized.
"""
if name == "tdmpc":
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
@@ -85,6 +100,24 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
"""
Instantiates a policy configuration object based on the policy type.
This factory function simplifies the creation of policy configuration objects by
mapping a string identifier to the corresponding config class.
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"diffusion", "act", "vqbet", "pi0", "pi0fast", "sac", "smolvla",
"reward_classifier".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
An instance of a `PreTrainedConfig` subclass.
Raises:
ValueError: If the `policy_type` is not recognized.
"""
if policy_type == "tdmpc":
return TDMPCConfig(**kwargs)
elif policy_type == "diffusion":
@@ -108,7 +141,21 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
class ProcessorConfigKwargs(TypedDict, total=False):
"""Keyword arguments for the processor config."""
"""
A TypedDict defining the keyword arguments for processor configuration.
This provides type hints for the optional arguments passed to `make_pre_post_processors`,
improving code clarity and enabling static analysis.
Attributes:
preprocessor_config_filename: The filename for the preprocessor configuration.
postprocessor_config_filename: The filename for the postprocessor configuration.
preprocessor_overrides: A dictionary of overrides for the preprocessor configuration.
postprocessor_overrides: A dictionary of overrides for the postprocessor configuration.
dataset_stats: Dataset statistics for normalization.
preprocessor_kwargs: Additional arguments for the `PolicyProcessorPipeline`.
postprocessor_kwargs: Additional arguments for the `PolicyProcessorPipeline`.
"""
preprocessor_config_filename: str | None
postprocessor_config_filename: str | None
@@ -124,22 +171,27 @@ def make_pre_post_processors(
pretrained_path: str | None = None,
**kwargs: Unpack[ProcessorConfigKwargs],
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""Make a processor instance for a given policy type.
"""
Create or load pre- and post-processor pipelines for a given policy.
This function creates the appropriate processor configuration based on the policy type.
Each policy type has its own processor with specific preprocessing steps.
This function acts as a factory. It can either load existing processor pipelines
from a pretrained path or create new ones from scratch based on the policy
configuration. Each policy type has a dedicated factory function for its
processors (e.g., `make_tdmpc_pre_post_processors`).
Args:
policy_cfg: The config of the policy to create a processor for (e.g., "act", "diffusion", etc.)
pretrained_path: Optional path to load a pretrained processor from. If provided, loads
the processor from this path instead of creating a new one.
**kwargs: Additional keyword arguments passed to the processor creation.
policy_cfg: The configuration of the policy for which to create processors.
pretrained_path: An optional path to load pretrained processor pipelines from.
If provided, pipelines are loaded from this path.
**kwargs: Keyword arguments for processor configuration, as defined in
`ProcessorConfigKwargs`.
Returns:
Tuple of (input_processor, output_processor) for the policy.
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
Raises:
NotImplementedError: If the policy type doesn't have a processor implemented.
NotImplementedError: If a processor factory is not implemented for the given
policy configuration type.
"""
if pretrained_path:
# Extract preprocessor and postprocessor kwargs
@@ -269,25 +321,29 @@ def make_policy(
ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None,
) -> PreTrainedPolicy:
"""Make an instance of a policy class.
"""
Instantiate a policy model.
This function exists because (for now) we need to parse features from either a dataset or an environment
in order to properly dimension and instantiate a policy for that dataset or environment.
This factory function handles the logic of creating a policy, which requires
determining the input and output feature shapes. These shapes can be derived
either from a `LeRobotDatasetMetadata` object or an `EnvConfig` object. The function
can either initialize a new policy from scratch or load a pretrained one.
Args:
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
be loaded with the weights from that path.
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
provided if ds_meta is not. Defaults to None.
Raises:
ValueError: Either ds_meta or env and env_cfg must be provided.
NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
cfg: The configuration for the policy to be created. If `cfg.pretrained_path` is
set, the policy will be loaded with weights from that path.
ds_meta: Dataset metadata used to infer feature shapes and types. Also provides
statistics for normalization layers.
env_cfg: Environment configuration used to infer feature shapes and types.
One of `ds_meta` or `env_cfg` must be provided.
Returns:
PreTrainedPolicy: _description_
An instantiated and device-placed policy model.
Raises:
ValueError: If both or neither of `ds_meta` and `env_cfg` are provided.
NotImplementedError: If attempting to use an unsupported policy-backend
combination (e.g., VQBeT with 'mps').
"""
if bool(ds_meta) == bool(env_cfg):
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")

View File

@@ -37,11 +37,25 @@ from lerobot.processor import (
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
class Pi0NewLineProcessor(ComplementaryDataProcessorStep):
"""Add a new line to the end of the task if it doesn't have one.
This is required for the PaliGemma tokenizer.
"""
Ensures that the task description string ends with a newline character.
This processing step is required for compatibility with the PaliGemma tokenizer,
which expects a newline at the end of the text prompt. It handles both single
strings and lists of strings for the 'task' key in complementary data.
"""
def complementary_data(self, complementary_data):
"""
Adds a newline to the 'task' field if it doesn't already have one.
Args:
complementary_data: A dictionary that may contain a 'task' key with a
string or list of strings.
Returns:
A new dictionary with the modified 'task' field.
"""
if "task" not in complementary_data:
return complementary_data
@@ -64,6 +78,15 @@ class Pi0NewLineProcessor(ComplementaryDataProcessorStep):
return new_complementary_data
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""
This step does not alter the feature definitions.
Args:
features: The input feature dictionary.
Returns:
The unchanged feature dictionary.
"""
return features
@@ -73,6 +96,30 @@ def make_pi0_pre_post_processors(
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""
Constructs pre-processor and post-processor pipelines for the PI0 policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Appending a newline character to the task description for tokenizer compatibility.
5. Tokenizing the text prompt using the PaliGemma tokenizer.
6. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the PI0 policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:

View File

@@ -17,7 +17,7 @@
import torch
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
@@ -30,11 +30,33 @@ from lerobot.processor import (
def make_pi0fast_pre_post_processors(
config: PI0Config,
config: PI0FASTConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""
Constructs pre-processor and post-processor pipelines for the PI0Fast policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the PI0Fast policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:

View File

@@ -36,6 +36,28 @@ def make_sac_pre_post_processors(
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""
Constructs pre-processor and post-processor pipelines for the SAC policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the SAC policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:

View File

@@ -31,6 +31,26 @@ def make_classifier_processor(
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""
Constructs pre-processor and post-processor pipelines for the reward classifier.
The pre-processing pipeline prepares input data for the classifier by:
1. Normalizing both input and output features based on dataset statistics.
2. Moving the data to the specified device.
The post-processing pipeline handles the classifier's output by:
1. Moving the data to the CPU.
2. Applying an identity step, as no unnormalization is needed for the output logits.
Args:
config: The configuration object for the RewardClassifier.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:

View File

@@ -39,6 +39,30 @@ def make_smolvla_pre_post_processors(
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""
Constructs pre-processor and post-processor pipelines for the SmolVLA policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Ensuring the language task description ends with a newline character.
5. Tokenizing the language task description.
6. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output actions to their original scale.
Args:
config: The configuration object for the SmolVLA policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:
@@ -83,7 +107,13 @@ def make_smolvla_pre_post_processors(
@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
class SmolVLANewLineProcessor(ComplementaryDataProcessorStep):
"""Add a new line to the end of the task if it doesn't have one."""
"""
A processor step that ensures the 'task' description ends with a newline character.
This step is necessary for certain tokenizers (e.g., PaliGemma) that expect a
newline at the end of the prompt. It handles both single string tasks and lists
of string tasks.
"""
def complementary_data(self, complementary_data):
if "task" not in complementary_data:

View File

@@ -35,6 +35,28 @@ def make_tdmpc_pre_post_processors(
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""
Constructs pre-processor and post-processor pipelines for the TDMPC policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the TDMPC policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:

View File

@@ -36,6 +36,28 @@ def make_vqbet_pre_post_processors(
preprocessor_kwargs: ProcessorKwargs | None = None,
postprocessor_kwargs: ProcessorKwargs | None = None,
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
"""
Constructs pre-processor and post-processor pipelines for the VQ-BeT policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features, allowing customization to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the VQ-BeT policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
if preprocessor_kwargs is None:
preprocessor_kwargs = {}
if postprocessor_kwargs is None:

View File

@@ -1,3 +1,5 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,6 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script defines processor steps for adding a batch dimension to various components of an environment transition.
These steps are designed to process actions, observations, and complementary data, making them suitable for batch processing by adding a leading dimension. This is a common requirement before feeding data into a neural network model.
"""
from dataclasses import dataclass, field
from torch import Tensor
@@ -31,24 +40,63 @@ from .pipeline import (
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor_action")
class AddBatchDimensionActionStep(ActionProcessorStep):
"""Process action component in-place, adding batch dimension if needed."""
"""
Processor step to add a batch dimension to a 1D tensor action.
def action(self, action):
This is useful for creating a batch of size 1 from a single action sample.
"""
def action(self, action: Tensor) -> Tensor:
"""
Adds a batch dimension to the action if it's a 1D tensor.
Args:
action: The action tensor.
Returns:
The action tensor with an added batch dimension.
"""
if not isinstance(action, Tensor) or action.dim() != 1:
return action
return action.unsqueeze(0)
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""
Returns the input features unchanged.
Adding a batch dimension does not alter the feature definition.
Args:
features: A dictionary of policy features.
Returns:
The original dictionary of policy features.
"""
return features
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor_observation")
class AddBatchDimensionObservationStep(ObservationProcessorStep):
"""Process observation component in-place, adding batch dimensions where needed."""
"""
Processor step to add a batch dimension to observations.
def observation(self, observation):
It handles different types of observations:
- State vectors (1D tensors).
- Single images (3D tensors).
- Dictionaries of multiple images (3D tensors).
"""
def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
"""
Adds a batch dimension to tensor-based observations in the observation dictionary.
Args:
observation: The observation dictionary.
Returns:
The observation dictionary with batch dimensions added to tensors.
"""
# Process state observations - add batch dim if 1D
for state_key in [OBS_STATE, OBS_ENV_STATE]:
if state_key in observation:
@@ -69,15 +117,41 @@ class AddBatchDimensionObservationStep(ObservationProcessorStep):
return observation
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""
Returns the input features unchanged.
Adding a batch dimension does not alter the feature definition.
Args:
features: A dictionary of policy features.
Returns:
The original dictionary of policy features.
"""
return features
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data")
class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
"""Process complementary data in-place, handling task field batching."""
"""
Processor step to add a batch dimension to complementary data fields.
def complementary_data(self, complementary_data):
Handles specific keys like 'task', 'index', and 'task_index' to make them batched.
- 'task' (str) is wrapped in a list.
- 'index' and 'task_index' (0D tensors) get a batch dimension.
"""
def complementary_data(self, complementary_data: dict) -> dict:
"""
Adds a batch dimension to specific fields in the complementary data dictionary.
Args:
complementary_data: The complementary data dictionary.
Returns:
The complementary data dictionary with batch dimensions added.
"""
# Process task field - wrap string in list to add batch dimension
if "task" in complementary_data:
task_value = complementary_data["task"]
@@ -98,44 +172,33 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
return complementary_data
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""
Returns the input features unchanged.
Adding a batch dimension does not alter the feature definition.
Args:
features: A dictionary of policy features.
Returns:
The original dictionary of policy features.
"""
return features
@dataclass
@ProcessorStepRegistry.register(name="to_batch_processor")
class AddBatchDimensionProcessorStep(ProcessorStep):
"""Processor that adds batch dimensions to observations and actions when needed.
"""
A composite processor step that adds a batch dimension to the entire environment transition.
This processor ensures that observations and actions have proper batch dimensions for model processing:
This step combines individual processors for actions, observations, and complementary data
to create a batched transition (batch size 1) from a single-instance transition.
- For state observations (observation.state, observation.environment_state):
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
- For image observations (observation.image, observation.images.*):
Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C)
- For actions:
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
- For task field in complementary data:
Wraps string task in a list to add batch dimension
(task must be a string or list of strings)
This is useful when processing single transitions that need to be batched for
model inference or when converting from unbatched environment outputs to
batched model inputs.
The processor only modifies tensors that need batching and leaves already
batched tensors unchanged.
Example:
```python
# State: (7,) -> (1, 7)
# Image: (224, 224, 3) -> (1, 224, 224, 3)
# Action: (4,) -> (1, 4)
# Task: "pick_cube" -> ["pick_cube"]
# Already batched: (1, 7) -> (1, 7) [unchanged]
```
Attributes:
to_batch_action_processor: Processor for the action component.
to_batch_observation_processor: Processor for the observation component.
to_batch_complementary_data_processor: Processor for the complementary data component.
"""
to_batch_action_processor: AddBatchDimensionActionStep = field(
@@ -149,11 +212,31 @@ class AddBatchDimensionProcessorStep(ProcessorStep):
)
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
Applies the batching process to all relevant parts of an environment transition.
Args:
transition: The environment transition to process.
Returns:
The environment transition with a batch dimension added.
"""
transition = self.to_batch_action_processor(transition)
transition = self.to_batch_observation_processor(transition)
transition = self.to_batch_complementary_data_processor(transition)
return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""
Returns the input features unchanged.
Adding a batch dimension does not alter the feature definition.
Args:
features: A dictionary of policy features.
Returns:
The original dictionary of policy features.
"""
# NOTE: We ignore the batch dimension when transforming features
return features

View File

@@ -44,12 +44,12 @@ def to_tensor(
different input types appropriately.
Args:
value: Input value to convert (tensor, array, scalar, sequence, etc.)
value: Input value to convert (tensor, array, scalar, sequence, etc.).
dtype: Target tensor dtype. If None, preserves original dtype.
device: Target device for the tensor.
Returns:
PyTorch tensor.
A PyTorch tensor.
Raises:
TypeError: If the input type is not supported.
@@ -59,7 +59,7 @@ def to_tensor(
@to_tensor.register(torch.Tensor)
def _(value: torch.Tensor, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
"""Handle existing PyTorch tensors."""
"""Handle conversion for existing PyTorch tensors."""
if dtype is not None:
value = value.to(dtype=dtype)
if device is not None:
@@ -75,17 +75,17 @@ def _(
device=None,
**kwargs,
) -> torch.Tensor:
"""Handle numpy arrays."""
# Check for numpy scalars (0-dimensional arrays) and treat them as scalars
"""Handle conversion for numpy arrays."""
# Check for numpy scalars (0-dimensional arrays) and treat them as scalars.
if value.ndim == 0:
# Numpy scalars should be converted to 0-dimensional tensors
# Numpy scalars should be converted to 0-dimensional tensors.
scalar_value = value.item()
return torch.tensor(scalar_value, dtype=dtype, device=device)
# Create tensor from numpy array (torch.from_numpy handles contiguity automatically)
# Create tensor from numpy array.
tensor = torch.from_numpy(value)
# Apply dtype conversion if specified
# Apply dtype and device conversion if specified.
if dtype is not None:
tensor = tensor.to(dtype=dtype)
if device is not None:
@@ -99,20 +99,20 @@ def _(
@to_tensor.register(np.integer)
@to_tensor.register(np.floating)
def _(value, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
"""Handle scalar values including numpy scalars."""
"""Handle conversion for scalar values including numpy scalars."""
return torch.tensor(value, dtype=dtype, device=device)
@to_tensor.register(list)
@to_tensor.register(tuple)
def _(value: Sequence, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
"""Handle sequences (lists, tuples)."""
"""Handle conversion for sequences (lists, tuples)."""
return torch.tensor(value, dtype=dtype, device=device)
@to_tensor.register(dict)
def _(value: dict, *, device=None, **kwargs) -> dict:
"""Handle dictionaries by recursively converting values to tensors."""
"""Handle conversion for dictionaries by recursively converting their values to tensors."""
if not value:
return {}
@@ -122,7 +122,7 @@ def _(value: dict, *, device=None, **kwargs) -> dict:
continue
if isinstance(sub_value, dict):
# Recursively process nested dictionaries
# Recursively process nested dictionaries.
result[key] = to_tensor(
sub_value,
device=device,
@@ -130,7 +130,7 @@ def _(value: dict, *, device=None, **kwargs) -> dict:
)
continue
# Convert individual values to tensors
# Convert individual values to tensors.
result[key] = to_tensor(
sub_value,
device=device,
@@ -140,17 +140,45 @@ def _(value: dict, *, device=None, **kwargs) -> dict:
def _from_tensor(x: torch.Tensor | Any) -> np.ndarray | float | int | Any:
"""Convert tensor to numpy/scalar if needed."""
"""
Convert a PyTorch tensor to a numpy array or scalar if applicable.
If the input is not a tensor, it is returned unchanged.
Args:
x: The input, which can be a tensor or any other type.
Returns:
A numpy array, a scalar, or the original input.
"""
if isinstance(x, torch.Tensor):
return x.item() if x.numel() == 1 else x.detach().cpu().numpy()
return x
def _is_image(arr: Any) -> bool:
"""
Check if a given array is likely an image (uint8, 3D).
Args:
arr: The array to check.
Returns:
True if the array matches the image criteria, False otherwise.
"""
return isinstance(arr, np.ndarray) and arr.dtype == np.uint8 and arr.ndim == 3
def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
"""
Separate an observation dictionary into state and image components.
Args:
obs: The observation dictionary.
Returns:
A tuple containing two dictionaries: one for state and one for images.
"""
state, images = {}, {}
for k, v in obs.items():
if "image" in k.lower() or _is_image(v):
@@ -160,13 +188,21 @@ def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any],
return state, images
# ============================================================================
# Private Helper Functions (Common Logic)
# ============================================================================
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
"""Extract complementary data (pad flags, task, index, task_index)."""
"""
Extract complementary data from a batch dictionary.
This includes padding flags, task description, and indices.
Args:
batch: The batch dictionary.
Returns:
A dictionary with the extracted complementary data.
"""
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
@@ -176,7 +212,16 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
def _merge_transitions(base: EnvTransition, other: EnvTransition) -> EnvTransition:
"""Merge two transitions, with other taking precedence."""
"""
Merge two transitions, with the second one taking precedence in case of conflicts.
Args:
base: The base transition.
other: The transition to merge, which will overwrite base values.
Returns:
The merged transition dictionary.
"""
out = deepcopy(base)
for key in (
@@ -194,9 +239,7 @@ def _merge_transitions(base: EnvTransition, other: EnvTransition) -> EnvTransiti
return out
# ============================================================================
# Core Conversion Functions
# ============================================================================
def create_transition(
@@ -208,7 +251,8 @@ def create_transition(
info: dict[str, Any] | None = None,
complementary_data: dict[str, Any] | None = None,
) -> EnvTransition:
"""Create an EnvTransition with sensible defaults.
"""
Create an `EnvTransition` dictionary with sensible defaults.
Args:
observation: Observation dictionary.
@@ -220,7 +264,7 @@ def create_transition(
complementary_data: Complementary data dictionary.
Returns:
Complete EnvTransition dictionary.
A complete `EnvTransition` dictionary.
"""
return {
TransitionKey.OBSERVATION: observation,
@@ -233,9 +277,19 @@ def create_transition(
}
def action_to_transition(action: dict[str, Any]) -> EnvTransition: # action_to_transition
def action_to_transition(action: dict[str, Any]) -> EnvTransition:
"""
Convert a raw teleop action dict into an EnvTransition under the ACTION TransitionKey.
Convert a raw action dictionary into a standardized `EnvTransition`.
The keys in the action dictionary are prefixed with "action." and stored under
the `ACTION` key in the transition. Values are converted to tensors, except for
special types like `Rotation`.
Args:
action: The raw action dictionary from a teleoperation device or controller.
Returns:
An `EnvTransition` containing the formatted action.
"""
act_dict: dict[str, Any] = {}
for k, v in action.items():
@@ -250,10 +304,19 @@ def action_to_transition(action: dict[str, Any]) -> EnvTransition: # action_to_
return create_transition(observation={}, action=act_dict)
# TODO(Adil, Pepijn): Overtime we can maybe add these converters to pipeline.py itself
def observation_to_transition(observation: dict[str, Any]) -> EnvTransition:
"""
Convert a raw robot observation dict into an EnvTransition under the OBSERVATION TransitionKey.
Convert a raw robot observation dictionary into a standardized `EnvTransition`.
The observation is split into state and image components. State keys are prefixed
with "observation.state." and image keys with "observation.images.". The result is
stored under the `OBSERVATION` key in the transition.
Args:
observation: The raw observation dictionary from the environment.
Returns:
An `EnvTransition` containing the formatted observation.
"""
state, images = _split_obs_to_state_and_images(observation)
@@ -270,7 +333,16 @@ def observation_to_transition(observation: dict[str, Any]) -> EnvTransition:
def transition_to_robot_action(transition: EnvTransition) -> dict[str, Any]:
"""
Converts a EnvTransition under the ACTION TransitionKey to a dict with keys ending in '.pos' for raw robot actions.
Extract a raw action dictionary for a robot from an `EnvTransition`.
This function searches for keys in the format "action.*.pos" or "action.*.vel"
and converts them into a flat dictionary suitable for sending to a robot controller.
Args:
transition: The `EnvTransition` containing the action.
Returns:
A dictionary representing the raw robot action.
"""
out: dict[str, Any] = {}
action_dict = transition.get(TransitionKey.ACTION) or {}
@@ -287,13 +359,21 @@ def transition_to_robot_action(transition: EnvTransition) -> dict[str, Any]:
def merge_transitions(transitions: Sequence[EnvTransition] | EnvTransition) -> EnvTransition:
"""Merge multiple transitions or return single transition.
"""
Merge a sequence of transitions into a single one.
If a single transition is provided, it is returned as is. For a sequence,
transitions are merged sequentially, with later transitions in the sequence
overwriting earlier ones.
Args:
transitions: Either a single transition or iterable of transitions.
transitions: A single transition or a sequence of them.
Returns:
Merged EnvTransition.
A single merged `EnvTransition`.
Raises:
ValueError: If an empty sequence of transitions is provided.
"""
if not isinstance(transitions, Sequence): # Single transition
@@ -312,26 +392,18 @@ def merge_transitions(transitions: Sequence[EnvTransition] | EnvTransition) -> E
def transition_to_dataset_frame(
transitions_or_transition: EnvTransition | Sequence[EnvTransition], features: dict[str, dict]
) -> dict[str, Any]:
"""Convert a single EnvTransition or an iterable of them into a flat, dataset-friendly dictionary for training or evaluation.
"""
Convert one or more transitions into a flat dictionary suitable for a dataset frame.
Processes transitions according to the provided feature specification and returns
data in the format expected by machine learning models and datasets.
This function processes `EnvTransition` objects according to a feature
specification, producing a format ready for training or evaluation.
Args:
transitions_or_transition: Either a single EnvTransition dict or an iterable of them
(which will be merged using merge_transitions).
features: Feature specification dictionary with the following structure:
- 'action': dict with 'names': list of action feature names
- 'observation.state': dict with 'names': list of state feature names
- keys starting with 'observation.images.' are passed through as-is
transitions_or_transition: A single `EnvTransition` or a sequence to be merged.
features: A feature specification dictionary.
Returns:
Flat dictionary containing:
- numpy arrays for "observation.state" and "action" (vectorized from feature names)
- any image tensors defined in features (passed through unchanged)
- next.{reward,done,truncated} scalar values
- info dict
- *_is_pad flags and task from complementary_data
A flat dictionary representing a single frame of data for a dataset.
"""
action_names = features.get(ACTION, {}).get("names", [])
obs_state_names = features.get(OBS_STATE, {}).get("names", [])
@@ -342,25 +414,25 @@ def transition_to_dataset_frame(
act = tr.get(TransitionKey.ACTION, {}) or {}
batch: dict[str, Any] = {}
# Images passthrough
# Passthrough for images.
for k in image_keys:
if k in obs:
batch[k] = obs[k]
# Observation.state vector
# Create observation.state vector.
if obs_state_names:
vals = [_from_tensor(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names]
batch[OBS_STATE] = np.asarray(vals, dtype=np.float32)
# Action vector
# Create action vector.
if action_names:
vals = [_from_tensor(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names]
batch[ACTION] = np.asarray(vals, dtype=np.float32)
# Add transition metadata
# Add transition metadata.
if tr.get(TransitionKey.REWARD) is not None:
reward_val = _from_tensor(tr[TransitionKey.REWARD])
# Check if features expect array format, otherwise keep as scalar
# Check if features expect array format, otherwise keep as scalar.
if REWARD in features and features[REWARD].get("shape") == (1,):
batch[REWARD] = np.array([reward_val], dtype=np.float32)
else:
@@ -380,14 +452,14 @@ def transition_to_dataset_frame(
else:
batch[TRUNCATED] = truncated_val
# Complementary data flags and task
# Add complementary data flags and task.
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if comp:
# pad flags
# Padding flags.
for k, v in comp.items():
if k.endswith("_is_pad"):
batch[k] = v
# task label
# Task label.
if comp.get("task") is not None:
batch["task"] = comp["task"]
@@ -395,36 +467,27 @@ def transition_to_dataset_frame(
def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
"""Convert a batch dict coming from LeRobot replay/dataset code into an EnvTransition dictionary.
"""
Convert a batch dictionary from a dataset/dataloader into an `EnvTransition`.
The function maps well known keys to the EnvTransition structure. Missing keys are
filled with sane defaults (None or 0.0/False).
Keys recognised (case-sensitive):
* "observation.*" (keys starting with "observation." are grouped into observation dict)
* "action"
* "next.reward"
* "next.done"
* "next.truncated"
* "info"
* "_is_pad" patterns (padding flags)
* "task", "index", "task_index" (complementary data)
Additional keys are ignored so that existing dataloaders can carry extra
metadata without breaking the processor.
This function maps recognized keys from a batch to the `EnvTransition` structure,
filling in missing keys with sensible defaults.
Args:
batch: Batch dictionary from datasets or dataloaders containing the above keys.
batch: A batch dictionary.
Returns:
EnvTransition dictionary with properly structured transition data.
An `EnvTransition` dictionary.
Raises:
ValueError: If the input is not a dictionary.
"""
# Validate input type
# Validate input type.
if not isinstance(batch, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
# Extract observation keys
# Extract observation and complementary data keys.
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
complementary_data = _extract_complementary_data(batch)
@@ -440,25 +503,16 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
"""Inverse of batch_to_transition. Returns a dict with canonical field names used throughout LeRobot.
"""
Convert an `EnvTransition` back to the canonical batch format used in LeRobot.
Converts an EnvTransition back to the batch format expected by datasets, dataloaders,
and other LeRobot components.
Output format:
* "action": Action data from transition
* "next.reward": Reward value (defaults to 0.0)
* "next.done": Done flag (defaults to False)
* "next.truncated": Truncated flag (defaults to False)
* "info": Info dictionary (defaults to {})
* Flattened observation keys (e.g., "observation.state", "observation.images.cam1")
* Complementary data fields ("task", "index", "task_index", padding flags)
This is the inverse of `batch_to_transition`.
Args:
transition: EnvTransition dictionary to convert.
transition: The `EnvTransition` to convert.
Returns:
Batch dictionary with canonical LeRobot field names suitable for dataloaders.
A batch dictionary with canonical LeRobot field names.
"""
batch = {
"action": transition.get(TransitionKey.ACTION),
@@ -468,12 +522,12 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
"info": transition.get(TransitionKey.INFO, {}),
}
# Add complementary data
# Add complementary data.
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if comp_data:
batch.update(comp_data)
# Flatten observation dict
# Flatten observation dictionary.
observation = transition.get(TransitionKey.OBSERVATION)
if isinstance(observation, dict):
batch.update(observation)
@@ -482,4 +536,15 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
def identity_transition(tr: EnvTransition) -> EnvTransition:
"""
An identity function for transitions, returning the input unchanged.
Useful as a default or placeholder in processing pipelines.
Args:
tr: An `EnvTransition`.
Returns:
The same `EnvTransition`.
"""
return tr

View File

@@ -1,4 +1,4 @@
# !/usr/bin/env python
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
@@ -28,7 +28,15 @@ from .pipeline import ActionProcessorStep, ProcessorStepRegistry
@dataclass
class MapTensorToDeltaActionDictStep(ActionProcessorStep):
"""
Map a tensor to a delta action dictionary.
Maps a flat action tensor from a policy to a structured delta action dictionary.
This step is typically used after a policy outputs a continuous action vector.
It decomposes the vector into named components for delta movements of the
end-effector (x, y, z) and optionally the gripper.
Attributes:
use_gripper: If True, assumes the 4th element of the tensor is the
gripper action.
"""
use_gripper: bool = True
@@ -60,28 +68,17 @@ class MapTensorToDeltaActionDictStep(ActionProcessorStep):
@dataclass
class MapDeltaActionToRobotActionStep(ActionProcessorStep):
"""
Map delta actions from teleoperators (gamepad, keyboard) to robot target actions
for use with inverse kinematics processors.
Maps delta actions from teleoperators to robot target actions for inverse kinematics.
Expected input ACTION keys:
{
"action.delta_x": float,
"action.delta_y": float,
"action.delta_z": float,
"action.gripper": float (optional),
}
This step converts a dictionary of delta movements (e.g., from a gamepad)
into a target action format that includes an "enabled" flag and target
end-effector positions. It also handles scaling and noise filtering.
Output ACTION keys:
{
"action.enabled": bool,
"action.target_x": float,
"action.target_y": float,
"action.target_z": float,
"action.target_wx": float,
"action.target_wy": float,
"action.target_wz": float,
"action.gripper": float,
}
Attributes:
position_scale: A factor to scale the delta position inputs.
rotation_scale: A factor to scale the delta rotation inputs (currently unused).
noise_threshold: The magnitude below which delta inputs are considered noise
and do not trigger an "enabled" state.
"""
# Scale factors for delta movements

View File

@@ -13,6 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script defines a processor step for moving environment transition data to a specific torch device and casting
its floating-point precision.
"""
from dataclasses import dataclass
from typing import Any
@@ -28,12 +34,16 @@ from .pipeline import ProcessorStep, ProcessorStepRegistry
@ProcessorStepRegistry.register("device_processor")
@dataclass
class DeviceProcessorStep(ProcessorStep):
"""Processes transitions by moving tensors to the specified device and optionally converting float dtypes.
"""
Processor step to move all tensors within an `EnvTransition` to a specified device and optionally cast their
floating-point data type.
This processor ensures that all tensors in the transition are moved to the
specified device (CPU or GPU) before they are returned. It can also convert
floating-point tensors to a specified dtype while preserving non-float types
(int, long, bool, etc.).
This is crucial for preparing data for model training or inference on hardware like GPUs.
Attributes:
device: The target device for tensors (e.g., "cpu", "cuda", "cuda:0").
float_dtype: The target floating-point dtype as a string (e.g., "float32", "float16", "bfloat16").
If None, the dtype is not changed.
"""
device: str = "cpu"
@@ -50,8 +60,15 @@ class DeviceProcessorStep(ProcessorStep):
}
def __post_init__(self):
"""
Initializes the processor by converting string configurations to torch objects.
This method sets up the `torch.device`, determines if transfers can be non-blocking, and validates the
`float_dtype` string, converting it to a `torch.dtype` object.
"""
self.tensor_device: torch.device = get_safe_torch_device(self.device)
self.device = self.tensor_device.type # cuda might have changed to cuda:1
# Update device string in case a specific GPU was selected (e.g., "cuda" -> "cuda:0")
self.device = self.tensor_device.type
self.non_blocking = "cuda" in str(self.device)
# Validate and convert float_dtype string to torch dtype
@@ -60,27 +77,32 @@ class DeviceProcessorStep(ProcessorStep):
raise ValueError(
f"Invalid float_dtype '{self.float_dtype}'. Available options: {list(self.DTYPE_MAPPING.keys())}"
)
self._target_float_dtype = self.DTYPE_MAPPING[self.float_dtype]
else:
self._target_float_dtype = None
def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
"""Process a tensor by moving to device and optionally converting float dtype.
"""
Moves a single tensor to the target device and casts its dtype.
If the tensor is already on a GPU and we're configured for a GPU, it preserves
that GPU placement (useful for multi-GPU training with Accelerate).
Otherwise, it moves to the configured device.
Handles multi-GPU scenarios by not moving a tensor if it's already on a different CUDA device than
the target, which is useful when using frameworks like Accelerate.
Args:
tensor: The input torch.Tensor.
Returns:
The processed tensor on the correct device and with the correct dtype.
"""
# Determine target device
if tensor.is_cuda and self.tensor_device.type == "cuda":
# Both tensor and target are on GPU - preserve tensor's GPU placement
# Both tensor and target are on GPU - preserve tensor's GPU placement.
# This handles multi-GPU scenarios where Accelerate has already placed
# tensors on the correct GPU for each process
# tensors on the correct GPU for each process.
target_device = tensor.device
else:
# Either tensor is on CPU, or we're configured for CPU
# In both cases, use the configured device
# Either tensor is on CPU, or we're configured for CPU.
# In both cases, use the configured device.
target_device = self.tensor_device
# Only move if necessary
@@ -94,6 +116,18 @@ class DeviceProcessorStep(ProcessorStep):
return tensor
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
Applies device and dtype conversion to all tensors in an environment transition.
It iterates through the transition, finds all `torch.Tensor` objects (including those nested in
dictionaries like `observation`), and processes them.
Args:
transition: The input `EnvTransition` object.
Returns:
A new `EnvTransition` object with all tensors moved to the target device and dtype.
"""
new_transition = transition.copy()
simple_tensor_keys = [
@@ -108,13 +142,13 @@ class DeviceProcessorStep(ProcessorStep):
TransitionKey.COMPLEMENTARY_DATA,
]
# Process simple tensors
# Process simple, top-level tensors
for key in simple_tensor_keys:
value = transition.get(key)
if isinstance(value, torch.Tensor):
new_transition[key] = self._process_tensor(value)
# Process dictionary-like tensors
# Process tensors nested within dictionaries
for key in dict_tensor_keys:
data_dict = transition.get(key)
if data_dict is not None:
@@ -127,8 +161,24 @@ class DeviceProcessorStep(ProcessorStep):
return new_transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
"""
Returns the serializable configuration of the processor.
Returns:
A dictionary containing the device and float_dtype settings.
"""
return {"device": self.device, "float_dtype": self.float_dtype}
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""
Returns the input features unchanged.
Device and dtype transformations do not alter the fundamental definition of the features (e.g., shape).
Args:
features: A dictionary of policy features.
Returns:
The original dictionary of policy features.
"""
return features

View File

@@ -1,4 +1,4 @@
#! /usr/bin/env python
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
@@ -10,6 +10,9 @@
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
@@ -25,7 +28,17 @@ from .pipeline import ActionProcessorStep, ProcessorStepRegistry
@ProcessorStepRegistry.register("torch2numpy_action_processor")
@dataclass
class Torch2NumpyActionProcessorStep(ActionProcessorStep):
"""Convert PyTorch tensor actions to NumPy arrays."""
"""
Converts a PyTorch tensor action to a NumPy array.
This step is useful when the output of a policy (typically a torch.Tensor)
needs to be passed to an environment or component that expects a NumPy array.
Attributes:
squeeze_batch_dim: If True, removes the first dimension of the array
if it is of size 1. This is useful for converting a
batched action of size (1, D) to a single action of size (D,).
"""
squeeze_batch_dim: bool = True
@@ -38,8 +51,8 @@ class Torch2NumpyActionProcessorStep(ActionProcessorStep):
numpy_action = action.detach().cpu().numpy()
# Remove batch dimensions but preserve action dimensions
# Only squeeze if there's a batch dimension (first dim == 1)
# Remove batch dimensions but preserve action dimensions.
# Only squeeze if there's a batch dimension (first dim == 1).
if (
self.squeeze_batch_dim
and numpy_action.shape
@@ -57,7 +70,13 @@ class Torch2NumpyActionProcessorStep(ActionProcessorStep):
@ProcessorStepRegistry.register("numpy2torch_action_processor")
@dataclass
class Numpy2TorchActionProcessorStep(ActionProcessorStep):
"""Convert NumPy array action to PyTorch tensor."""
"""
Converts a NumPy array action to a PyTorch tensor.
This step is useful for converting actions from environments or hardware,
which are often NumPy arrays, into PyTorch tensors that can be processed
by a policy or model.
"""
def action(self, action: np.ndarray) -> torch.Tensor:
if not isinstance(action, np.ndarray):

View File

@@ -1,3 +1,20 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import time
from dataclasses import dataclass
@@ -29,21 +46,25 @@ TELEOP_ACTION_KEY = "teleop_action"
@runtime_checkable
class HasTeleopEvents(Protocol):
"""Minimal protocol for objects that provide teleoperation events.
"""
Minimal protocol for objects that provide teleoperation events.
This protocol only defines the additional get_teleop_events() method,
avoiding duplication of the entire Teleoperator interface.
This protocol defines the `get_teleop_events()` method, allowing processor
steps to interact with teleoperators that support event-based controls
(like episode termination or success flagging) without needing to know the
teleoperator's specific class.
"""
def get_teleop_events(self) -> dict[str, Any]:
"""Get extra control events from the teleoperator.
"""
Get extra control events from the teleoperator.
Returns:
Dictionary containing control events such as:
- is_intervention: bool - Whether human is currently intervening
- terminate_episode: bool - Whether to terminate the current episode
- success: bool - Whether the episode was successful
- rerecord_episode: bool - Whether to rerecord the episode
A dictionary containing control events such as:
- `is_intervention`: bool - Whether the human is currently intervening.
- `terminate_episode`: bool - Whether to terminate the current episode.
- `success`: bool - Whether the episode was successful.
- `rerecord_episode`: bool - Whether to rerecord the episode.
"""
...
@@ -53,7 +74,15 @@ TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
def _check_teleop_with_events(teleop: Teleoperator) -> None:
"""Runtime check that a teleoperator implements get_teleop_events."""
"""
Runtime check that a teleoperator implements the `HasTeleopEvents` protocol.
Args:
teleop: The teleoperator instance to check.
Raises:
TypeError: If the teleoperator does not have a `get_teleop_events` method.
"""
if not isinstance(teleop, HasTeleopEvents):
raise TypeError(
f"Teleoperator {type(teleop).__name__} must implement get_teleop_events() method. "
@@ -64,11 +93,30 @@ def _check_teleop_with_events(teleop: Teleoperator) -> None:
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
@dataclass
class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
"""Add teleoperator action to transition complementary data."""
"""
Adds the raw action from a teleoperator to the transition's complementary data.
This is useful for human-in-the-loop scenarios where the human's input needs to
be available to downstream processors, for example, to override a policy's action
during an intervention.
Attributes:
teleop_device: The teleoperator instance to get the action from.
"""
teleop_device: Teleoperator
def complementary_data(self, complementary_data: dict) -> dict:
"""
Retrieves the teleoperator's action and adds it to the complementary data.
Args:
complementary_data: The incoming complementary data dictionary.
Returns:
A new dictionary with the teleoperator action added under the
`teleop_action` key.
"""
new_complementary_data = dict(complementary_data)
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
return new_complementary_data
@@ -80,26 +128,33 @@ class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
@ProcessorStepRegistry.register("add_teleop_action_as_info")
@dataclass
class AddTeleopEventsAsInfoStep(InfoProcessorStep):
"""Add teleoperator control events to transition info.
"""
Adds teleoperator control events (e.g., terminate, success) to the transition's info.
This processor step extracts control events from teleoperators that support
event-based interaction (intervention detection, episode termination, etc.).
This step extracts control events from teleoperators that support event-based
interaction, making these signals available to other parts of the system.
Works with any teleoperator that inherits from Teleoperator and implements the
get_teleop_events() method, including custom user-defined teleoperators.
Built-in compatible teleoperators:
- GamepadTeleop: Uses gamepad buttons for control events
- KeyboardEndEffectorTeleop: Uses keyboard keys for control events
Attributes:
teleop_device: An instance of a teleoperator that implements the
`HasTeleopEvents` protocol.
"""
teleop_device: TeleopWithEvents
def __post_init__(self):
"""Validate that the teleoperator supports events."""
"""Validates that the provided teleoperator supports events after initialization."""
_check_teleop_with_events(self.teleop_device)
def info(self, info: dict) -> dict:
"""
Retrieves teleoperator events and updates the info dictionary.
Args:
info: The incoming info dictionary.
Returns:
A new dictionary including the teleoperator events.
"""
new_info = dict(info)
teleop_events = self.teleop_device.get_teleop_events()
@@ -113,12 +168,32 @@ class AddTeleopEventsAsInfoStep(InfoProcessorStep):
@ProcessorStepRegistry.register("image_crop_resize_processor")
@dataclass
class ImageCropResizeProcessorStep(ObservationProcessorStep):
"""Crop and resize image observations."""
"""
Crops and/or resizes image observations.
This step iterates through all image keys in an observation dictionary and applies
the specified transformations. It handles device placement, moving tensors to the
CPU if necessary for operations not supported on certain accelerators like MPS.
Attributes:
crop_params_dict: A dictionary mapping image keys to cropping parameters
(top, left, height, width).
resize_size: A tuple (height, width) to resize all images to.
"""
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
resize_size: tuple[int, int] | None = None
def observation(self, observation: dict) -> dict:
"""
Applies cropping and resizing to all images in the observation dictionary.
Args:
observation: The observation dictionary, potentially containing image tensors.
Returns:
A new observation dictionary with transformed images.
"""
if self.resize_size is None and not self.crop_params_dict:
return observation
@@ -146,12 +221,27 @@ class ImageCropResizeProcessorStep(ObservationProcessorStep):
return new_observation
def get_config(self) -> dict[str, Any]:
"""
Returns the configuration of the step for serialization.
Returns:
A dictionary with the crop parameters and resize dimensions.
"""
return {
"crop_params_dict": self.crop_params_dict,
"resize_size": self.resize_size,
}
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""
Updates the image feature shapes in the policy features dictionary if resizing is applied.
Args:
features: The policy features dictionary.
Returns:
The updated policy features dictionary with new image shapes.
"""
if self.resize_size is None:
return features
for key in features:
@@ -163,12 +253,27 @@ class ImageCropResizeProcessorStep(ObservationProcessorStep):
@dataclass
@ProcessorStepRegistry.register("time_limit_processor")
class TimeLimitProcessorStep(TruncatedProcessorStep):
"""Track episode steps and enforce time limits."""
"""
Tracks episode steps and enforces a time limit by truncating the episode.
Attributes:
max_episode_steps: The maximum number of steps allowed per episode.
current_step: The current step count for the active episode.
"""
max_episode_steps: int
current_step: int = 0
def truncated(self, truncated):
def truncated(self, truncated: bool) -> bool:
"""
Increments the step counter and sets the truncated flag if the time limit is reached.
Args:
truncated: The incoming truncated flag.
Returns:
True if the episode step limit is reached, otherwise the incoming value.
"""
self.current_step += 1
if self.current_step >= self.max_episode_steps:
truncated = True
@@ -176,11 +281,18 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
return truncated
def get_config(self) -> dict[str, Any]:
"""
Returns the configuration of the step for serialization.
Returns:
A dictionary containing the `max_episode_steps`.
"""
return {
"max_episode_steps": self.max_episode_steps,
}
def reset(self) -> None:
"""Resets the step counter, typically called at the start of a new episode."""
self.current_step = 0
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
@@ -190,13 +302,31 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
@dataclass
@ProcessorStepRegistry.register("gripper_penalty_processor")
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
"""Apply penalty for inappropriate gripper usage."""
"""
Applies a penalty for inefficient gripper usage.
This step penalizes actions that attempt to close an already closed gripper or
open an already open one, based on position thresholds.
Attributes:
penalty: The negative reward value to apply.
max_gripper_pos: The maximum position value for the gripper, used for normalization.
"""
penalty: float = -0.01
max_gripper_pos: float = 30.0
def complementary_data(self, complementary_data):
"""Calculate gripper penalty and add to complementary data."""
def complementary_data(self, complementary_data: dict) -> dict:
"""
Calculates the gripper penalty and adds it to the complementary data.
Args:
complementary_data: The incoming complementary data, which should contain
raw joint positions.
Returns:
A new complementary data dictionary with the `discrete_penalty` key added.
"""
action = self.transition.get(TransitionKey.ACTION)
current_gripper_pos = complementary_data.get("raw_joint_positions", None).get(GRIPPER_KEY, None)
@@ -223,14 +353,20 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
return new_complementary_data
def get_config(self) -> dict[str, Any]:
"""
Returns the configuration of the step for serialization.
Returns:
A dictionary containing the penalty value and max gripper position.
"""
return {
"penalty": self.penalty,
"max_gripper_pos": self.max_gripper_pos,
}
def reset(self) -> None:
"""Reset the processor state."""
self.last_gripper_state = None
"""Resets the processor's internal state."""
pass
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
@@ -239,12 +375,33 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
@dataclass
@ProcessorStepRegistry.register("intervention_action_processor")
class InterventionActionProcessorStep(ProcessorStep):
"""Handle human intervention actions and episode termination."""
"""
Handles human intervention, overriding policy actions and managing episode termination.
When an intervention is detected (via teleoperator events in the `info` dict),
this step replaces the policy's action with the human's teleoperated action.
It also processes signals to terminate the episode or flag success.
Attributes:
use_gripper: Whether to include the gripper in the teleoperated action.
terminate_on_success: If True, automatically sets the `done` flag when a
`success` event is received.
"""
use_gripper: bool = False
terminate_on_success: bool = True
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
Processes the transition to handle interventions.
Args:
transition: The incoming environment transition.
Returns:
The modified transition, potentially with an overridden action, updated
reward, and termination status.
"""
action = transition.get(TransitionKey.ACTION)
if action is None:
return transition
@@ -300,6 +457,12 @@ class InterventionActionProcessorStep(ProcessorStep):
return new_transition
def get_config(self) -> dict[str, Any]:
"""
Returns the configuration of the step for serialization.
Returns:
A dictionary containing the step's configuration attributes.
"""
return {
"use_gripper": self.use_gripper,
"terminate_on_success": self.terminate_on_success,
@@ -312,7 +475,20 @@ class InterventionActionProcessorStep(ProcessorStep):
@dataclass
@ProcessorStepRegistry.register("reward_classifier_processor")
class RewardClassifierProcessorStep(ProcessorStep):
"""Apply reward classification to image observations."""
"""
Applies a pretrained reward classifier to image observations to predict success.
This step uses a model to determine if the current state is successful, updating
the reward and potentially terminating the episode.
Attributes:
pretrained_path: Path to the pretrained reward classifier model.
device: The device to run the classifier on.
success_threshold: The probability threshold to consider a prediction as successful.
success_reward: The reward value to assign on success.
terminate_on_success: If True, terminates the episode upon successful classification.
reward_classifier: The loaded classifier model instance.
"""
pretrained_path: str | None = None
device: str = "cpu"
@@ -323,7 +499,7 @@ class RewardClassifierProcessorStep(ProcessorStep):
reward_classifier: Any = None
def __post_init__(self):
"""Initialize the reward classifier after dataclass initialization."""
"""Initializes the reward classifier model after the dataclass is created."""
if self.pretrained_path is not None:
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
@@ -332,6 +508,16 @@ class RewardClassifierProcessorStep(ProcessorStep):
self.reward_classifier.eval()
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
Processes a transition, applying the reward classifier to its image observations.
Args:
transition: The incoming environment transition.
Returns:
The modified transition with an updated reward and done flag based on the
classifier's prediction.
"""
new_transition = transition.copy()
observation = new_transition.get(TransitionKey.OBSERVATION)
if observation is None or self.reward_classifier is None:
@@ -371,6 +557,12 @@ class RewardClassifierProcessorStep(ProcessorStep):
return new_transition
def get_config(self) -> dict[str, Any]:
"""
Returns the configuration of the step for serialization.
Returns:
A dictionary containing the step's configuration attributes.
"""
return {
"device": self.device,
"success_threshold": self.success_threshold,

View File

@@ -1,3 +1,19 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any
@@ -15,17 +31,42 @@ from lerobot.robots import Robot
@dataclass
@ProcessorStepRegistry.register("joint_velocity_processor")
class JointVelocityProcessorStep(ObservationProcessorStep):
"""Add joint velocity information to observations."""
"""
Calculates and appends joint velocity information to the observation state.
This step computes the velocity of each joint by calculating the finite
difference between the current and the last observed joint positions. The
resulting velocity vector is then concatenated to the original state vector.
Attributes:
dt: The time step (delta time) in seconds between observations, used for
calculating velocity.
last_joint_positions: Stores the joint positions from the previous step
to enable velocity calculation.
"""
dt: float = 0.1
last_joint_positions: torch.Tensor | None = None
def observation(self, observation: dict) -> dict:
"""
Computes joint velocities and adds them to the observation state.
Args:
observation: The input observation dictionary, expected to contain
an `observation.state` key with joint positions.
Returns:
A new observation dictionary with the `observation.state` tensor
extended to include joint velocities.
Raises:
ValueError: If `observation.state` is not found in the observation.
"""
# Get current joint positions (assuming they're in observation.state)
current_positions = observation.get(OBS_STATE)
if current_positions is None:
# TODO(steven): if we get here, then the transform_features method will not hold
raise ValueError(f"{OBS_STATE} is not in observation")
# Initialize last joint positions if not already set
@@ -48,14 +89,33 @@ class JointVelocityProcessorStep(ObservationProcessorStep):
return new_observation
def get_config(self) -> dict[str, Any]:
"""
Returns the configuration of the step for serialization.
Returns:
A dictionary containing the time step `dt`.
"""
return {
"dt": self.dt,
}
def reset(self) -> None:
"""Resets the internal state, clearing the last known joint positions."""
self.last_joint_positions = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""
Updates the `observation.state` feature to reflect the added velocities.
This method doubles the size of the first dimension of the `observation.state`
shape to account for the concatenation of position and velocity vectors.
Args:
features: The policy features dictionary.
Returns:
The updated policy features dictionary.
"""
if OBS_STATE in features:
original_feature = features[OBS_STATE]
# Double the shape to account for positions + velocities
@@ -68,11 +128,33 @@ class JointVelocityProcessorStep(ObservationProcessorStep):
@dataclass
@ProcessorStepRegistry.register("current_processor")
class MotorCurrentProcessorStep(ObservationProcessorStep):
"""Add motor current information to observations."""
"""
Reads motor currents from a robot and appends them to the observation state.
This step queries the robot's hardware interface to get the present current
for each motor and concatenates this information to the existing state vector.
Attributes:
robot: An instance of a `lerobot` Robot class that provides access to
the hardware bus.
"""
robot: Robot | None = None
def observation(self, observation: dict) -> dict:
"""
Fetches motor currents and adds them to the observation state.
Args:
observation: The input observation dictionary.
Returns:
A new observation dictionary with the `observation.state` tensor
extended to include motor currents.
Raises:
ValueError: If the `robot` attribute has not been set.
"""
# Get current values from robot state
if self.robot is None:
raise ValueError("Robot is not set")
@@ -96,6 +178,18 @@ class MotorCurrentProcessorStep(ObservationProcessorStep):
return new_observation
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""
Updates the `observation.state` feature to reflect the added motor currents.
This method increases the size of the first dimension of the `observation.state`
shape by the number of motors in the robot.
Args:
features: The policy features dictionary.
Returns:
The updated policy features dictionary.
"""
if OBS_STATE in features and self.robot is not None:
original_feature = features[OBS_STATE]
# Add motor current dimensions to the original state shape

View File

@@ -15,16 +15,22 @@
# limitations under the License.
"""
Generic script to migrate any policy model with normalization layers to the new pipeline-based system.
A generic script to migrate LeRobot policies with built-in normalization layers to the new
pipeline-based processor system.
This script:
1. Loads an existing pretrained policy model
2. Extracts normalization statistics from the model
3. Creates both preprocessor and postprocessor:
- Preprocessor: normalizes both inputs (observations) and outputs (actions) for training
- Postprocessor: unnormalizes outputs (actions) for inference
4. Removes normalization layers from the model state_dict
5. Saves the new model and both processors
This script performs the following steps:
1. Loads a pretrained policy model and its configuration from a local path or the
Hugging Face Hub.
2. Scans the model's state dictionary to extract normalization statistics (e.g., mean,
std, min, max) for all features.
3. Creates two new processor pipelines:
- A preprocessor that normalizes inputs (observations) and outputs (actions).
- A postprocessor that unnormalizes outputs (actions) for inference.
4. Removes the original normalization layers from the model's state dictionary,
creating a "clean" model.
5. Saves the new clean model, the preprocessor, the postprocessor, and a generated
model card to a new directory.
6. Optionally pushes all the new artifacts to the Hugging Face Hub.
Usage:
python src/lerobot/processor/migrate_policy_normalization.py \
@@ -68,7 +74,21 @@ POLICY_CLASSES = {
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
"""Extract normalization statistics from model state_dict."""
"""
Scans a model's state_dict to find and extract normalization statistics.
This function identifies keys corresponding to normalization layers (e.g., those
for mean, std, min, max) based on a set of predefined patterns and organizes
them into a nested dictionary.
Args:
state_dict: The state dictionary of a pretrained policy model.
Returns:
A nested dictionary where outer keys are feature names (e.g.,
'observation.state') and inner keys are statistic types ('mean', 'std'),
mapping to their corresponding tensor values.
"""
stats = {}
# Define patterns to match and their prefixes to remove
@@ -112,7 +132,25 @@ def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str
def detect_features_and_norm_modes(
config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]]
) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]:
"""Detect features and normalization modes from config and stats."""
"""
Infers policy features and normalization modes from the model config and stats.
This function first attempts to find feature definitions and normalization
mappings directly from the policy's configuration file. If this information is
not present, it infers it from the extracted normalization statistics, using
tensor shapes to determine feature shapes and the presence of specific stat
keys (e.g., 'mean'/'std' vs 'min'/'max') to determine the normalization mode.
It applies sensible defaults if inference is not possible.
Args:
config: The policy's configuration dictionary from `config.json`.
stats: The normalization statistics extracted from the model's state_dict.
Returns:
A tuple containing:
- A dictionary mapping feature names to `PolicyFeature` objects.
- A dictionary mapping `FeatureType` enums to `NormalizationMode` enums.
"""
features = {}
norm_modes = {}
@@ -204,7 +242,19 @@ def detect_features_and_norm_modes(
def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Remove normalization layers from state_dict."""
"""
Creates a new state_dict with all normalization-related layers removed.
This function filters the original state dictionary, excluding any keys that
match a set of predefined patterns associated with normalization modules.
Args:
state_dict: The original model state dictionary.
Returns:
A new state dictionary containing only the core model weights, without
any normalization parameters.
"""
new_state_dict = {}
# Patterns to remove
@@ -228,7 +278,16 @@ def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
"""Convert features from old format to PolicyFeature objects."""
"""
Converts a feature dictionary from the old config format to the new `PolicyFeature` format.
Args:
features_dict: The feature dictionary in the old format, where values are
simple dictionaries (e.g., `{"shape": [7]}`).
Returns:
A dictionary mapping feature names to `PolicyFeature` dataclass objects.
"""
converted_features = {}
for key, feature_dict in features_dict.items():
@@ -254,8 +313,18 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
def load_model_from_hub(
repo_id: str, revision: str = None
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
"""Load model state_dict and config from hub."""
# Download files
"""
Downloads and loads a model's state_dict and configs from the Hugging Face Hub.
Args:
repo_id: The repository ID on the Hub (e.g., 'lerobot/aloha').
revision: The specific git revision (branch, tag, or commit hash) to use.
Returns:
A tuple containing the model's state dictionary, the policy configuration,
and the training configuration.
"""
# Download files.
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)

View File

@@ -1,3 +1,20 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from copy import deepcopy
@@ -20,9 +37,26 @@ class _NormalizationMixin:
"""
A mixin class providing core functionality for normalization and unnormalization.
This class manages normalization statistics, their conversion to tensors, device placement,
and the application of normalization transformations. It is designed to be inherited by
concrete ProcessorStep implementations.
This class manages normalization statistics (`stats`), converts them to tensors for
efficient computation, handles device placement, and implements the logic for
applying normalization transformations (mean/std and min/max). It is designed to
be inherited by concrete `ProcessorStep` implementations and should not be used
directly.
Attributes:
features: A dictionary mapping feature names to `PolicyFeature` objects, defining
the data structure to be processed.
norm_map: A dictionary mapping `FeatureType` to `NormalizationMode`, specifying
which normalization method to use for each type of feature.
stats: A dictionary containing the normalization statistics (e.g., mean, std,
min, max) for each feature.
device: The PyTorch device on which to store and perform tensor operations.
eps: A small epsilon value to prevent division by zero in normalization
calculations.
normalize_observation_keys: An optional set of keys to selectively apply
normalization to specific observation features.
_tensor_stats: An internal dictionary holding the normalization statistics as
PyTorch tensors.
"""
features: dict[str, PolicyFeature]
@@ -36,7 +70,15 @@ class _NormalizationMixin:
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
def __post_init__(self):
# Robust JSON deserialization handling (guard empty maps)
"""
Initializes the mixin after dataclass construction.
This method handles the robust deserialization of `features` and `norm_map`
from JSON-compatible formats (where enums become strings and tuples become
lists) and converts the provided `stats` dictionary into a dictionary of
tensors (`_tensor_stats`) on the specified device.
"""
# Robust JSON deserialization handling (guard empty maps).
if self.features:
first_val = next(iter(self.features.values()))
if isinstance(first_val, dict):
@@ -65,7 +107,15 @@ class _NormalizationMixin:
def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
) -> _NormalizationMixin:
"""Moves the processor's normalization stats to the specified device and returns self."""
"""
Moves the processor's normalization stats to the specified device.
Args:
device: The target PyTorch device.
Returns:
The instance of the class, allowing for method chaining.
"""
if device is not None:
self.device = device
if dtype is not None:
@@ -74,6 +124,16 @@ class _NormalizationMixin:
return self
def state_dict(self) -> dict[str, Tensor]:
"""
Returns the normalization statistics as a flat state dictionary.
All tensors are moved to the CPU before being returned, which is standard practice
for saving state dictionaries.
Returns:
A flat dictionary mapping from `'feature_name.stat_name'` to the
corresponding statistics tensor on the CPU.
"""
flat: dict[str, Tensor] = {}
for key, sub in self._tensor_stats.items():
for stat_name, tensor in sub.items():
@@ -81,6 +141,15 @@ class _NormalizationMixin:
return flat
def load_state_dict(self, state: dict[str, Tensor]) -> None:
"""
Loads normalization statistics from a state dictionary.
The loaded tensors are moved to the processor's configured device.
Args:
state: A flat state dictionary with keys in the format
`'feature_name.stat_name'`.
"""
self._tensor_stats.clear()
for flat_key, tensor in state.items():
key, stat_name = flat_key.rsplit(".", 1)
@@ -90,6 +159,15 @@ class _NormalizationMixin:
)
def get_config(self) -> dict[str, Any]:
"""
Returns a serializable dictionary of the processor's configuration.
This method is used when saving the processor to disk, ensuring that its
configuration can be reconstructed later.
Returns:
A JSON-serializable dictionary containing the configuration.
"""
config = {
"eps": self.eps,
"features": {
@@ -102,6 +180,16 @@ class _NormalizationMixin:
return config
def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]:
"""
Applies (un)normalization to all relevant features in an observation dictionary.
Args:
observation: The observation dictionary to process.
inverse: If `True`, applies unnormalization; otherwise, applies normalization.
Returns:
A new observation dictionary with the transformed tensor values.
"""
new_observation = dict(observation)
for key, feature in self.features.items():
if self.normalize_observation_keys is not None and key not in self.normalize_observation_keys:
@@ -114,6 +202,16 @@ class _NormalizationMixin:
def _normalize_action(self, action: Any, inverse: bool) -> Tensor:
# Convert to tensor but preserve original dtype for adaptation logic
"""
Applies (un)normalization to an action tensor.
Args:
action: The action tensor to process.
inverse: If `True`, applies unnormalization; otherwise, applies normalization.
Returns:
The transformed action tensor.
"""
tensor = torch.as_tensor(action)
processed_action = self._apply_transform(tensor, "action", FeatureType.ACTION, inverse=inverse)
return processed_action
@@ -121,7 +219,24 @@ class _NormalizationMixin:
def _apply_transform(
self, tensor: Tensor, key: str, feature_type: FeatureType, *, inverse: bool = False
) -> Tensor:
"""Core logic to apply normalization or unnormalization."""
"""
Core logic to apply a normalization or unnormalization transformation to a tensor.
This method selects the appropriate normalization mode (e.g., mean/std, min/max)
based on the feature type and applies the corresponding mathematical operation.
Args:
tensor: The input tensor to transform.
key: The feature key corresponding to the tensor.
feature_type: The `FeatureType` of the tensor.
inverse: If `True`, applies the inverse transformation (unnormalization).
Returns:
The transformed tensor.
Raises:
ValueError: If an unsupported normalization mode is encountered.
"""
norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY)
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats:
return tensor
@@ -168,11 +283,11 @@ class _NormalizationMixin:
@ProcessorStepRegistry.register(name="normalizer_processor")
class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
"""
A processor that applies normalization to observations and actions in a transition.
A processor step that applies normalization to observations and actions in a transition.
This class directly implements the normalization logic for both observation and action
components of an `EnvTransition`, using statistics (mean/std or min/max) provided at
initialization.
This class uses the logic from `_NormalizationMixin` to perform forward normalization
(e.g., scaling data to have zero mean and unit variance, or to the range [-1, 1]).
It is typically used in the pre-processing pipeline before feeding data to a policy.
"""
@classmethod
@@ -186,6 +301,20 @@ class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
eps: float = 1e-8,
device: torch.device | str | None = None,
) -> NormalizerProcessorStep:
"""
Creates a `NormalizerProcessorStep` instance using statistics from a `LeRobotDataset`.
Args:
dataset: The dataset from which to extract normalization statistics.
features: The feature definition for the processor.
norm_map: The mapping from feature types to normalization modes.
normalize_observation_keys: An optional set of observation keys to normalize.
eps: A small epsilon value for numerical stability.
device: The target device for the processor.
Returns:
A new instance of `NormalizerProcessorStep`.
"""
return cls(
features=features,
norm_map=norm_map,
@@ -220,11 +349,12 @@ class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
@ProcessorStepRegistry.register(name="unnormalizer_processor")
class UnnormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
"""
A processor that applies unnormalization (the inverse of normalization) to
observations and actions in a transition.
A processor step that applies unnormalization to observations and actions.
This is typically used to transform actions from a normalized policy output back into
the original scale for execution in an environment.
This class inverts the normalization process, scaling data back to its original
range. It is typically used in the post-processing pipeline to convert a policy's
normalized action output into a format that can be executed by a robot or
environment.
"""
@classmethod
@@ -236,6 +366,18 @@ class UnnormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
*,
device: torch.device | str | None = None,
) -> UnnormalizerProcessorStep:
"""
Creates an `UnnormalizerProcessorStep` using statistics from a `LeRobotDataset`.
Args:
dataset: The dataset from which to extract normalization statistics.
features: The feature definition for the processor.
norm_map: The mapping from feature types to normalization modes.
device: The target device for the processor.
Returns:
A new instance of `UnnormalizerProcessorStep`.
"""
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, device=device)
def __call__(self, transition: EnvTransition) -> EnvTransition:
@@ -261,12 +403,20 @@ def hotswap_stats(
policy_processor: PolicyProcessorPipeline, stats: dict[str, dict[str, Any]]
) -> PolicyProcessorPipeline:
"""
Replaces normalization statistics in a PolicyProcessor pipeline.
Replaces normalization statistics in an existing `PolicyProcessorPipeline` instance.
This function creates a deep copy of the provided `PolicyProcessorPipeline` and updates the
statistics of any `NormalizerProcessorStep` or `UnnormalizerProcessorStep` steps within it.
It's useful for adapting a trained policy to a new environment or dataset with
different data distributions.
This function creates a deep copy of the provided pipeline and updates the
statistics of any `NormalizerProcessorStep` or `UnnormalizerProcessorStep` it
contains. This is useful for adapting a trained policy to a new environment or
dataset with different data distributions without having to reconstruct the entire
pipeline.
Args:
policy_processor: The policy processor pipeline to modify.
stats: The new dictionary of normalization statistics to apply.
Returns:
A new `PolicyProcessorPipeline` instance with the updated statistics.
"""
rp = deepcopy(policy_processor)
for step in rp.steps:

View File

@@ -30,23 +30,44 @@ from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@ProcessorStepRegistry.register(name="observation_processor")
class VanillaObservationProcessorStep(ObservationProcessorStep):
"""
Processes environment observations into the LeRobot format by handling both images and states.
Processes standard Gymnasium observations into the LeRobot format.
Image processing:
- Converts channel-last (H, W, C) images to channel-first (C, H, W)
- Normalizes uint8 images ([0, 255]) to float32 ([0, 1])
- Adds a batch dimension if missing
- Supports single images and image dictionaries
This step handles both image and state data from a typical observation dictionary,
preparing it for use in a LeRobot policy.
State processing:
- Maps 'environment_state' to observation.environment_state
- Maps 'agent_pos' to observation.state
- Converts numpy arrays to tensors
- Adds a batch dimension if missing
**Image Processing:**
- Converts channel-last (H, W, C), `uint8` images to channel-first (C, H, W),
`float32` tensors.
- Normalizes pixel values from the [0, 255] range to [0, 1].
- Adds a batch dimension if one is not already present.
- Recognizes a single image under the key `"pixels"` and maps it to
`"observation.image"`.
- Recognizes a dictionary of images under the key `"pixels"` and maps them
to `"observation.images.{camera_name}"`.
**State Processing:**
- Maps the `"environment_state"` key to `"observation.environment_state"`.
- Maps the `"agent_pos"` key to `"observation.state"`.
- Converts NumPy arrays to PyTorch tensors.
- Adds a batch dimension if one is not already present.
"""
def _process_single_image(self, img: np.ndarray) -> Tensor:
"""Process a single image array."""
"""
Processes a single NumPy image array into a channel-first, normalized tensor.
Args:
img: A NumPy array representing the image, expected to be in channel-last
(H, W, C) format with a `uint8` dtype.
Returns:
A `float32` PyTorch tensor in channel-first (B, C, H, W) format, with
pixel values normalized to the [0, 1] range.
Raises:
ValueError: If the input image does not appear to be in channel-last
format or is not of `uint8` dtype.
"""
# Convert to tensor
img_tensor = torch.from_numpy(img)
@@ -108,16 +129,24 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
return self._process_observation(observation)
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms feature keys to a standardized contract.
This method handles several renaming patterns:
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').
- Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1').
- Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1').
- environment_state -> OBS_ENV_STATE,
- agent_pos -> OBS_STATE,
- observation.environment_state -> OBS_ENV_STATE,
- observation.agent_pos -> OBS_STATE
"""
Transforms feature keys from the Gym standard to the LeRobot standard.
This method standardizes the feature dictionary by renaming keys according
to LeRobot's conventions, ensuring that policies can be constructed correctly.
It handles various raw key formats, including those with an "observation." prefix.
**Renaming Rules:**
- `pixels` or `observation.pixels` -> `observation.image`
- `pixels.{cam}` or `observation.pixels.{cam}` -> `observation.images.{cam}`
- `environment_state` or `observation.environment_state` -> `observation.environment_state`
- `agent_pos` or `observation.agent_pos` -> `observation.state`
Args:
features: The policy features dictionary with Gym-style keys.
Returns:
The policy features dictionary with standardized LeRobot keys.
"""
exact_pairs = {
"pixels": OBS_IMAGE,

View File

@@ -25,7 +25,18 @@ from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="rename_processor")
class RenameProcessorStep(ObservationProcessorStep):
"""Rename processor that renames keys in the observation."""
"""
A processor step that renames keys in an observation dictionary.
This step is useful for creating a standardized data interface by mapping keys
from an environment's format to the format expected by a LeRobot policy or
other downstream components.
Attributes:
rename_map: A dictionary mapping from old key names to new key names.
Keys present in an observation that are not in this map will
be kept with their original names.
"""
rename_map: dict[str, str] = field(default_factory=dict)
@@ -51,7 +62,22 @@ class RenameProcessorStep(ObservationProcessorStep):
def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]:
"""Rename keys in the stats dictionary according to rename_map (defensive copy)."""
"""
Renames the top-level keys in a statistics dictionary using a provided mapping.
This is a helper function typically used to keep normalization statistics
consistent with renamed observation or action features. It performs a defensive
deep copy to avoid modifying the original `stats` dictionary.
Args:
stats: A nested dictionary of statistics, where top-level keys are
feature names (e.g., `{"observation.state": {"mean": 0.5}}`).
rename_map: A dictionary mapping old feature names to new feature names.
Returns:
A new statistics dictionary with its top-level keys renamed. Returns an
empty dictionary if the input `stats` is empty.
"""
if not stats:
return {}
renamed: dict[str, dict[str, Any]] = {}

View File

@@ -1,5 +1,24 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tokenizer processor for handling text tokenization in robot transitions.
This script defines a processor for tokenizing natural language instructions from an environment transition.
It uses a tokenizer from the Hugging Face `transformers` library to convert task descriptions (text) into
token IDs and attention masks, which are then added to the observation dictionary.
"""
from __future__ import annotations
@@ -16,6 +35,7 @@ from lerobot.utils.import_utils import _transformers_available
from .core import EnvTransition, TransitionKey
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers import AutoTokenizer
else:
@@ -25,54 +45,48 @@ else:
@dataclass
@ProcessorStepRegistry.register(name="tokenizer_processor")
class TokenizerProcessorStep(ObservationProcessorStep):
"""Tokenizes text tasks in complementary data using a huggingface tokenizer.
"""
Processor step to tokenize a natural language task description.
This processor handles tokenization of task strings found in the complementary_data
using a specified pretrained tokenizer from Hugging Face. It adds tokenized versions
to the observation data for model processing while preserving the original task string.
This step extracts a task string from the `complementary_data` of an `EnvTransition`,
tokenizes it using a Hugging Face `transformers` tokenizer, and adds the resulting
token IDs and attention mask to the `observation` dictionary.
The processor supports both single strings and lists of strings as task inputs.
Requires the `transformers` library to be installed.
Args:
tokenizer_name: Name of the pretrained tokenizer to load from Hugging Face Hub
(e.g., "bert-base-uncased", "microsoft/DialoGPT-medium"). This will be used
with AutoTokenizer.from_pretrained(). If tokenizer is provided, this is ignored.
tokenizer: A tokenizer object (e.g., from transformers library) that implements
the __call__ method. If provided, tokenizer_name is ignored. This parameter
is not serialized and must be provided via overrides when loading.
max_length: Maximum sequence length for tokenization. Defaults to 512.
task_key: Key in complementary_data containing the task text. Defaults to "task".
padding: Padding strategy for tokenization. Defaults to "max_length".
truncation: Whether to truncate sequences longer than max_length. Defaults to True.
Examples:
Using tokenizer name (auto-loaded):
```python
processor = TokenizerProcessorStep(tokenizer_name="bert-base-uncased", max_length=128)
```
Using custom tokenizer object:
```python
from transformers import AutoTokenizer
custom_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
processor = TokenizerProcessorStep(tokenizer=custom_tokenizer, max_length=128)
```
Attributes:
tokenizer_name: The name of a pretrained tokenizer from the Hugging Face Hub (e.g., "bert-base-uncased").
tokenizer: A pre-initialized tokenizer object. If provided, `tokenizer_name` is ignored.
max_length: The maximum length to pad or truncate sequences to.
task_key: The key in `complementary_data` where the task string is stored.
padding_side: The side to pad on ('left' or 'right').
padding: The padding strategy ('max_length', 'longest', etc.).
truncation: Whether to truncate sequences longer than `max_length`.
input_tokenizer: The internal tokenizer instance, loaded during initialization.
"""
tokenizer_name: str | None = None
tokenizer: Any | None = None # Otherwise transformers is not available in the core dependencies
tokenizer: Any | None = None # Use `Any` for compatibility without a hard dependency
max_length: int = 512
task_key: str = "task"
padding_side: str = "right"
padding: str = "max_length"
truncation: bool = True
# Internal tokenizer instance (not serialized)
# Internal tokenizer instance (not part of the config)
input_tokenizer: Any = field(default=None, init=False, repr=False)
def __post_init__(self):
"""Initialize the tokenizer from the provided tokenizer or tokenizer name."""
"""
Initializes the tokenizer after the dataclass is created.
It checks for the availability of the `transformers` library and loads the tokenizer
either from a provided object or by name from the Hugging Face Hub.
Raises:
ImportError: If the `transformers` library is not installed.
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
"""
if not _transformers_available:
raise ImportError(
"The 'transformers' library is not installed. "
@@ -93,13 +107,14 @@ class TokenizerProcessorStep(ObservationProcessorStep):
)
def get_task(self, transition: EnvTransition) -> list[str] | None:
"""Extract and normalize task from complementary data.
"""
Extracts the task description(s) from the transition's complementary data.
Args:
transition: Input transition containing complementary_data.
transition: The environment transition.
Returns:
List of task strings if task is present, None otherwise.
A list of task strings, or None if the task key is not found or the value is None.
"""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
@@ -112,7 +127,7 @@ class TokenizerProcessorStep(ObservationProcessorStep):
if task is None:
return None
# Convert to list of strings
# Standardize to a list of strings for the tokenizer
if isinstance(task, str):
return [task]
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
@@ -120,78 +135,80 @@ class TokenizerProcessorStep(ObservationProcessorStep):
return None
def observation(self, observation):
"""Process the transition by tokenizing the task text.
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
"""
Tokenizes the task description and adds it to the observation dictionary.
This method retrieves the task, tokenizes it, moves the resulting tensors to the
same device as other data in the transition, and updates the observation.
Args:
transition: Input transition containing complementary_data with task text.
observation: The original observation dictionary.
Returns:
Modified transition with tokenized task added to observation.
Raises:
ValueError: If tokenizer initialization failed.
The updated observation dictionary including token IDs and an attention mask.
"""
task = self.get_task(self.transition)
if task is None:
return observation
# Tokenize the task (creates CPU tensors)
# Tokenize the task (this will create CPU tensors)
tokenized_prompt = self._tokenize_text(task)
# Detect device from existing tensors in the transition
# Detect the device from existing tensors in the transition to ensure consistency
target_device = self._detect_device(self.transition)
# Move tokenized tensors to match the device of other data
# Move new tokenized tensors to the detected device
if target_device is not None:
tokenized_prompt = {
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
for k, v in tokenized_prompt.items()
}
# Get or create observation dict
# Create a new observation dict to avoid modifying the original in place
new_observation = dict(observation)
# Add tokenized data to observation
# Add tokenized data to the observation
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
return new_observation
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
"""Detect device from existing tensors in the transition.
"""
Detects the torch.device from existing tensors in the transition.
This allows the tokenized tensors to match the device of other data,
which is especially important for multi-GPU training with Accelerate.
It checks tensors in the observation dictionary first, then the action tensor.
Args:
transition: The transition to search for existing tensors.
transition: The environment transition.
Returns:
The device of the first tensor found, or None if no tensors exist.
The detected `torch.device`, or None if no tensors are found.
"""
# Check observation tensors first (most likely to exist)
# Check observation tensors first (most likely place to find tensors)
observation = transition.get(TransitionKey.OBSERVATION)
if observation:
for value in observation.values():
if isinstance(value, torch.Tensor):
return value.device
# Check action tensor
# Fallback to checking the action tensor
action = transition.get(TransitionKey.ACTION)
if isinstance(action, torch.Tensor):
return action.device
return None # No tensors found, keep on CPU
return None # No tensors found, default will be CPU
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
"""Tokenize text using the configured tokenizer.
"""
A wrapper around the tokenizer call.
Args:
text: Text string or list of strings to tokenize.
text: A string or list of strings to tokenize.
Returns:
Dictionary containing tokenized output with keys like 'input_ids', 'attention_mask'.
A dictionary containing tokenized 'input_ids' and 'attention_mask' as PyTorch tensors.
"""
return self.input_tokenizer(
text,
@@ -203,10 +220,14 @@ class TokenizerProcessorStep(ObservationProcessorStep):
)
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization.
"""
Returns the serializable configuration of the processor.
Note: Only tokenizer_name is saved, not the tokenizer object itself.
When loading, provide the tokenizer via overrides if needed.
Note: The tokenizer object itself is not serialized. If the processor was initialized
with a tokenizer name, that name will be included in the config.
Returns:
A dictionary with the processor's configuration parameters.
"""
config = {
"max_length": self.max_length,
@@ -216,28 +237,30 @@ class TokenizerProcessorStep(ObservationProcessorStep):
"truncation": self.truncation,
}
# Only include tokenizer_name if it was used (not when tokenizer object was provided)
# TODO(steven): Consider saving the name of the _tokenizer if it was loaded
# Only save tokenizer_name if it was used to create the tokenizer
if self.tokenizer_name is not None and self.tokenizer is None:
config["tokenizer_name"] = self.tokenizer_name
return config
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Add tokenized task features to the feature contract.
"""
Adds feature definitions for the language tokens and attention mask.
This updates the policy features dictionary to include the new data added to the
observation, ensuring downstream components are aware of their shape and type.
Args:
features: Input feature dictionary.
features: The dictionary of existing policy features.
Returns:
Updated feature dictionary with tokenized task features added.
The updated dictionary of policy features.
"""
# Add features for tokenized output if they don't exist
# Standard tokenizer output includes tokens and attention_mask
# Add a feature for the token IDs if it doesn't already exist
if OBS_LANGUAGE_TOKENS not in features:
features[OBS_LANGUAGE_TOKENS] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
# Add a feature for the attention mask if it doesn't already exist
if OBS_LANGUAGE_ATTENTION_MASK not in features:
features[OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
type=FeatureType.LANGUAGE, shape=(self.max_length,)

View File

@@ -1,4 +1,4 @@
# !/usr/bin/env python
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
@@ -38,18 +38,27 @@ from lerobot.utils.rotation import Rotation
@dataclass
class EEReferenceAndDelta(ActionProcessorStep):
"""
Compute the desired end-effector pose from the target pose and the current pose.
Computes a target end-effector pose from a relative delta command.
Input ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
"complementary_data.raw_joint_positions": dict,
}
This step takes a desired change in position and orientation (`target_*`) and applies it to a
reference end-effector pose to calculate an absolute target pose. The reference pose is derived
from the current robot joint positions using forward kinematics.
Output ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
}
The processor can operate in two modes:
1. `use_latched_reference=True`: The reference pose is "latched" or saved at the moment the action
is first enabled. Subsequent commands are relative to this fixed reference.
2. `use_latched_reference=False`: The reference pose is updated to the robot's current pose at
every step.
Attributes:
kinematics: The robot's kinematic model for forward kinematics.
end_effector_step_sizes: A dictionary scaling the input delta commands.
motor_names: A list of motor names required for forward kinematics.
use_latched_reference: If True, latch the reference pose on enable; otherwise, always use the
current pose as the reference.
reference_ee_pose: Internal state storing the latched reference pose.
_prev_enabled: Internal state to detect the rising edge of the enable signal.
_command_when_disabled: Internal state to hold the last command while disabled.
"""
kinematics: RobotKinematics
@@ -135,6 +144,7 @@ class EEReferenceAndDelta(ActionProcessorStep):
return new_action
def reset(self):
"""Resets the internal state of the processor."""
self._prev_enabled = False
self.reference_ee_pose = None
self._command_when_disabled = None
@@ -161,17 +171,17 @@ class EEReferenceAndDelta(ActionProcessorStep):
@dataclass
class EEBoundsAndSafety(ActionProcessorStep):
"""
Clip the end-effector pose to the bounds and check for jumps.
Clips the end-effector pose to predefined bounds and checks for unsafe jumps.
Input ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
}
This step ensures that the target end-effector pose remains within a safe operational workspace.
It also moderates the command to prevent large, sudden movements between consecutive steps.
Output ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
}
Attributes:
end_effector_bounds: A dictionary with "min" and "max" keys for position clipping.
max_ee_step_m: The maximum allowed change in position (in meters) between steps.
max_ee_twist_step_rad: The maximum allowed change in orientation (in radians) between steps.
_last_pos: Internal state storing the last commanded position.
_last_twist: Internal state storing the last commanded orientation.
"""
end_effector_bounds: dict
@@ -219,6 +229,7 @@ class EEBoundsAndSafety(ActionProcessorStep):
return act
def reset(self):
"""Resets the last known position and orientation."""
self._last_pos = None
self._last_twist = None
@@ -232,21 +243,17 @@ class EEBoundsAndSafety(ActionProcessorStep):
@dataclass
class InverseKinematicsEEToJoints(ProcessorStep):
"""
Compute the desired joint positions from the desired end-effector pose.
Computes desired joint positions from a target end-effector pose using inverse kinematics (IK).
Input ACTION keys:
{
"action.ee.{x,y,z,wx,wy,wz}" : float
"complementary_data.raw_joint_positions": dict,
}
This step translates a Cartesian command (position and orientation of the end-effector) into
the corresponding joint-space commands for each motor.
Output ACTION keys:
{
"action.joint_name_1.pos": float,
"action.joint_name_2.pos": float,
...
"action.joint_name_n.pos": float,
}
Attributes:
kinematics: The robot's kinematic model for inverse kinematics.
motor_names: A list of motor names for which to compute joint positions.
q_curr: Internal state storing the last joint positions, used as an initial guess for the IK solver.
initial_guess_current_joints: If True, use the robot's current joint state as the IK guess.
If False, use the solution from the previous step.
"""
kinematics: RobotKinematics
@@ -312,6 +319,7 @@ class InverseKinematicsEEToJoints(ProcessorStep):
return features
def reset(self):
"""Resets the initial guess for the IK solver."""
self.q_curr = None
@@ -319,17 +327,18 @@ class InverseKinematicsEEToJoints(ProcessorStep):
@dataclass
class GripperVelocityToJoint(ProcessorStep):
"""
Convert the gripper velocity to a joint velocity.
Converts a gripper velocity command into a target gripper joint position.
Input ACTION keys:
{
"action.gripper": float,
}
This step integrates a normalized velocity command over time to produce a position command,
taking the current gripper position as a starting point. It also supports a discrete mode
where integer actions map to open, close, or no-op.
Output ACTION keys:
{
"action.gripper.pos": float,
}
Attributes:
motor_names: A list of motor names, which must include 'gripper'.
speed_factor: A scaling factor to convert the normalized velocity command to a position change.
clip_min: The minimum allowed gripper joint position.
clip_max: The maximum allowed gripper joint position.
discrete_gripper: If True, treat the input action as discrete (0: open, 1: close, 2: stay).
"""
motor_names: list[str]
@@ -365,7 +374,7 @@ class GripperVelocityToJoint(ProcessorStep):
raw = comp.get("raw_joint_positions") or {}
curr_pos = float(raw.get("gripper"))
# Compute desired gripper velocity
# Compute desired gripper position
u = float(act.get(f"{ACTION}.gripper", 0.0))
delta = u * float(self.speed_factor)
gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max))
@@ -391,17 +400,14 @@ class GripperVelocityToJoint(ProcessorStep):
@dataclass
class ForwardKinematicsJointsToEE(ObservationProcessorStep):
"""
Compute the end-effector pose from the joint positions.
Computes the end-effector pose from joint positions using forward kinematics (FK).
Input OBSERVATION keys:
{
"observation.state.{joint_name_1,joint_name_2,...,joint_name_n}.pos": float,
}
This step is typically used to add the robot's Cartesian pose to the observation space,
which can be useful for visualization or as an input to a policy.
Output OBSERVATION keys:
{
"observation.state.ee.{x,y,z,wx,wy,wz}" : float
}
Attributes:
kinematics: The robot's kinematic model.
motor_names: A list of motor names whose joint positions are used for FK.
"""
kinematics: RobotKinematics
@@ -435,10 +441,14 @@ class ForwardKinematicsJointsToEE(ObservationProcessorStep):
@dataclass
class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessorStep):
"""
Read the robot's current observation and insert it into the transition as complementary data.
Reads the robot's current observation and adds it to the transition's complementary data.
- Joint positions are added under complementary_data["raw_joint_positions"] as a dict:
{ "<motor_name>": <float position>, ... }
This step acts as a bridge to the physical robot, injecting its real-time sensor readings
(like raw joint positions) into the data processing pipeline. This data is then available
for other processing steps.
Attributes:
robot: An instance of a `Robot` class used to get observations from hardware.
"""
robot: Robot

View File

@@ -98,9 +98,7 @@ from lerobot.utils.utils import (
ACTOR_SHUTDOWN_TIMEOUT = 30
#################################################
# Main entry point #
#################################################
# Main entry point
@parser.wrap()
@@ -207,9 +205,7 @@ def actor_cli(cfg: TrainRLServerPipelineConfig):
logging.info("[ACTOR] queues closed")
#################################################
# Core algorithm functions #
#################################################
# Core algorithm functions
def act_with_policy(
@@ -406,9 +402,7 @@ def act_with_policy(
busy_wait(1 / cfg.env.fps - dt_time)
#################################################
# Communication Functions - Group all gRPC/messaging functions #
#################################################
# Communication Functions - Group all gRPC/messaging functions
def establish_learner_connection(
@@ -653,9 +647,7 @@ def interactions_stream(
return services_pb2.Empty()
#################################################
# Policy functions #
#################################################
# Policy functions
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
@@ -687,9 +679,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
#################################################
# Utilities functions #
#################################################
# Utilities functions
def push_transitions_to_transport_queue(transitions: list, transitions_queue):

View File

@@ -103,11 +103,6 @@ from lerobot.utils.wandb_utils import WandBLogger
LOG_PREFIX = "[LEARNER]"
#################################################
# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS #
#################################################
@parser.wrap()
def train_cli(cfg: TrainRLServerPipelineConfig):
if not use_threads(cfg):
@@ -250,9 +245,7 @@ def start_learner_threads(
logging.info("[LEARNER] queues closed")
#################################################
# Core algorithm functions #
#################################################
# Core algorithm functions
def add_actor_information_and_train(
@@ -820,9 +813,7 @@ def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.M
return optimizers, lr_scheduler
#################################################
# Training setup functions #
#################################################
# Training setup functions
def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipelineConfig:
@@ -1023,9 +1014,7 @@ def initialize_offline_replay_buffer(
return offline_replay_buffer
#################################################
# Utilities/Helpers functions #
#################################################
# Utilities/Helpers functions
def get_observation_features(

View File

@@ -65,6 +65,28 @@ def update_policy(
use_amp: bool = False,
lock=None,
) -> tuple[MetricsTracker, dict]:
"""
Performs a single training step to update the policy's weights.
This function executes the forward and backward passes, clips gradients, and steps the optimizer and
learning rate scheduler. It also handles mixed-precision training via a GradScaler.
Args:
train_metrics: A MetricsTracker instance to record training statistics.
policy: The policy model to be trained.
batch: A batch of training data.
optimizer: The optimizer used to update the policy's parameters.
grad_clip_norm: The maximum norm for gradient clipping.
grad_scaler: The GradScaler for automatic mixed-precision training.
lr_scheduler: An optional learning rate scheduler.
use_amp: A boolean indicating whether to use automatic mixed precision.
lock: An optional lock for thread-safe optimizer updates.
Returns:
A tuple containing:
- The updated MetricsTracker with new statistics for this step.
- A dictionary of outputs from the policy's forward pass, for logging purposes.
"""
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
@@ -108,6 +130,20 @@ def update_policy(
@parser.wrap()
def train(cfg: TrainPipelineConfig):
"""
Main function to train a policy.
This function orchestrates the entire training pipeline, including:
- Setting up logging, seeding, and device configuration.
- Creating the dataset, evaluation environment (if applicable), policy, and optimizer.
- Handling resumption from a checkpoint.
- Running the main training loop, which involves fetching data batches and calling `update_policy`.
- Periodically logging metrics, saving model checkpoints, and evaluating the policy.
- Pushing the final trained model to the Hugging Face Hub if configured.
Args:
cfg: A `TrainPipelineConfig` object containing all training configurations.
"""
cfg.validate()
logging.info(pformat(cfg.to_dict()))

View File

@@ -121,6 +121,21 @@ def teleop_loop(
robot_action_processor: RobotProcessorPipeline[dict[str, Any]] | None = None,
robot_observation_processor: RobotProcessorPipeline[EnvTransition] | None = None,
):
"""
This function continuously reads actions from a teleoperation device, processes them through optional
pipelines, sends them to a robot, and optionally displays the robot's state. The loop runs at a
specified frequency until a set duration is reached or it is manually interrupted.
Args:
teleop: The teleoperator device instance providing control actions.
robot: The robot instance being controlled.
fps: The target frequency for the control loop in frames per second.
display_data: If True, fetches robot observations and displays them in the console and Rerun.
duration: The maximum duration of the teleoperation loop in seconds. If None, the loop runs indefinitely.
teleop_action_processor: An optional pipeline to process raw actions from the teleoperator.
robot_action_processor: An optional pipeline to process actions before they are sent to the robot.
robot_observation_processor: An optional pipeline to process raw observations from the robot.
"""
# Initialize processors with defaults if not provided
teleop_action_processor: RobotProcessorPipeline[EnvTransition] = (
teleop_action_processor

View File

@@ -26,28 +26,35 @@ from lerobot.teleoperators.phone.config_phone import PhoneOS
@dataclass
class MapPhoneActionToRobotAction(ActionProcessorStep):
"""
Map calibrated phone pose (actions) to the inputs for robot actions
Maps calibrated phone pose actions to standardized robot action inputs.
Expected input ACTION keys:
{
"action.phone.enabled": bool,
"action.phone.pos": np.ndarray,
"action.phone.rot": Rotation,
"action.phone.raw_inputs": dict,
}
This processor step acts as a bridge between the phone teleoperator's output
and the robot's expected action format. It remaps the phone's 6-DoF pose
(position and rotation) to the robot's target end-effector pose, applying
necessary axis inversions and swaps. It also interprets platform-specific
button presses to generate a gripper command.
Output ACTION keys:
{
"action.enabled": bool,
"action.ee.{x,y,z,wx,wy,wz}" : float
"action.gripper": float,
}
Attributes:
platform: The operating system of the phone (iOS or Android), used
to determine the correct button mappings for the gripper.
"""
platform: PhoneOS
_enabled_prev: bool = field(default=False, init=False, repr=False)
def action(self, act: dict) -> dict:
"""
Processes the phone action dictionary to create a robot action dictionary.
Args:
act: The input action dictionary from the phone teleoperator.
Returns:
A new action dictionary formatted for the robot controller.
Raises:
ValueError: If 'pos' or 'rot' keys are missing from the input action.
"""
# Pop them from the action
enabled = bool(act.pop(f"{ACTION}.phone.enabled", 0))
pos = act.pop(f"{ACTION}.phone.pos", None)

View File

@@ -108,7 +108,17 @@ class IOSPhone(BasePhone, Teleoperator):
print("Calibration done\n")
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
"""
Blocks execution until the calibration trigger is detected from the iOS device.
This method enters a loop, continuously reading the phone's state. It waits for the user to press
and hold the 'B1' button in the HEBI Mobile I/O app. Once B1 is pressed, the loop breaks and
returns the phone's pose at that exact moment.
Returns:
A tuple containing the position (np.ndarray) and rotation (Rotation) of the phone at the
moment the trigger was activated.
"""
while True:
has_pose, position, rotation, fb_pose = self._read_current_pose()
if not has_pose:
@@ -126,6 +136,21 @@ class IOSPhone(BasePhone, Teleoperator):
time.sleep(0.01)
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
"""
Reads the instantaneous 6-DoF pose from the connected iOS device via the HEBI SDK.
This method fetches the latest feedback packet from the HEBI group, extracts the ARKit
position and orientation, and converts them into a standard format. It also applies a
configured camera offset to adjust the pose from the camera's frame to the phone's
physical frame.
Returns:
A tuple containing:
- A boolean indicating if a valid pose was successfully read.
- The 3D position as a NumPy array, or None if not available.
- The orientation as a `Rotation` object, or None if not available.
- The raw HEBI feedback object for accessing other data like button presses.
"""
fbk = self._group.get_next_feedback()
pose = fbk[0]
ar_pos = getattr(pose, "ar_position", None)
@@ -228,7 +253,18 @@ class AndroidPhone(BasePhone, Teleoperator):
print("Calibration done\n")
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
"""
Blocks execution until the calibration trigger is detected from the Android device.
This method enters a loop, continuously checking the latest message received from the WebXR
session. It waits for the user to touch and move their finger on the screen, which generates
a `move` event. Once this event is detected, the loop breaks and returns the phone's current
pose.
Returns:
A tuple containing the position (np.ndarray) and rotation (Rotation) of the phone at the
moment the trigger was activated.
"""
while True:
with self._android_lock:
msg = self._latest_message or {}
@@ -241,6 +277,20 @@ class AndroidPhone(BasePhone, Teleoperator):
time.sleep(0.01)
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
"""
Reads the latest 6-DoF pose received from the Android device's WebXR session.
This method accesses the most recent pose data stored by the `_android_callback`. It uses a
thread lock to safely read the shared `_latest_pose` variable. The pose, a 4x4 matrix, is
then decomposed into position and rotation, and the configured camera offset is applied.
Returns:
A tuple containing:
- A boolean indicating if a valid pose was available.
- The 3D position as a NumPy array, or None if no pose has been received yet.
- The orientation as a `Rotation` object, or None if no pose has been received.
- The raw 4x4 pose matrix as received from the teleop stream.
"""
with self._android_lock:
if self._latest_pose is None:
return False, None, None, None
@@ -251,6 +301,19 @@ class AndroidPhone(BasePhone, Teleoperator):
return True, pos, rot, pose
def _android_callback(self, pose: np.ndarray, message: dict) -> None:
"""
Callback function to handle incoming data from the Android teleop stream.
This method is executed by the `teleop` package's subscriber thread whenever a new
pose and message are received from the WebXR session on the Android phone. It updates
the internal state (`_latest_pose` and `_latest_message`) with the new data.
A thread lock is used to ensure that these shared variables are updated atomically,
preventing race conditions with the main thread that reads them.
Args:
pose: A 4x4 NumPy array representing the phone's transformation matrix.
message: A dictionary containing additional data, such as button presses or touch events.
"""
with self._android_lock:
self._latest_pose = pose
self._latest_message = message

View File

@@ -36,6 +36,20 @@ from lerobot.robots import Robot
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
"""
Logs performance metrics for a single step of the robot control loop.
This function formats and prints a single line of log information, including episode/frame counters,
total loop time (dt), and detailed timings for various robot and camera operations. It can also
highlight performance drops in yellow if the actual FPS is lower than the target FPS.
Args:
robot: The `Robot` instance, used to access its internal logs for detailed timings.
dt_s: The total duration of the control loop step in seconds.
episode_index: The index of the current episode.
frame_index: The index of the current frame within the episode.
fps: The target frames per second, used to check for performance degradation.
"""
log_items = []
if episode_index is not None:
log_items.append(f"ep:{episode_index}")
@@ -81,7 +95,16 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
@cache
def is_headless():
"""Detects if python is running without a monitor."""
"""
Detects if the Python script is running in a headless environment (e.g., without a display).
This function attempts to import `pynput`, a library that requires a graphical environment.
If the import fails, it assumes the environment is headless. The result is cached to avoid
re-running the check.
Returns:
True if the environment is determined to be headless, False otherwise.
"""
try:
import pynput # noqa
@@ -108,6 +131,29 @@ def predict_action(
task: str | None = None,
robot_type: str | None = None,
):
"""
Performs a single-step inference to predict a robot action from an observation.
This function encapsulates the full inference pipeline:
1. Prepares the observation by converting it to PyTorch tensors and adding a batch dimension.
2. Runs the preprocessor pipeline on the observation.
3. Feeds the processed observation to the policy to get a raw action.
4. Runs the postprocessor pipeline on the raw action.
5. Formats the final action by removing the batch dimension and moving it to the CPU.
Args:
observation: A dictionary of NumPy arrays representing the robot's current observation.
policy: The `PreTrainedPolicy` model to use for action prediction.
device: The `torch.device` (e.g., 'cuda' or 'cpu') to run inference on.
preprocessor: The `PolicyProcessorPipeline` for preprocessing observations.
postprocessor: The `PolicyProcessorPipeline` for postprocessing actions.
use_amp: A boolean to enable/disable Automatic Mixed Precision for CUDA inference.
task: An optional string identifier for the task.
robot_type: An optional string identifier for the robot type.
Returns:
A `torch.Tensor` containing the predicted action, ready for the robot.
"""
observation = copy(observation)
with (
torch.inference_mode(),
@@ -143,6 +189,18 @@ def predict_action(
def init_keyboard_listener():
"""
Initializes a non-blocking keyboard listener for real-time user interaction.
This function sets up a listener for specific keys (right arrow, left arrow, escape) to control
the program flow during execution, such as stopping recording or exiting loops. It gracefully
handles headless environments where keyboard listening is not possible.
Returns:
A tuple containing:
- The `pynput.keyboard.Listener` instance, or `None` if in a headless environment.
- A dictionary of event flags (e.g., `exit_early`) that are set by key presses.
"""
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
@@ -184,6 +242,19 @@ def init_keyboard_listener():
def sanity_check_dataset_name(repo_id, policy_cfg):
"""
Validates the dataset repository name against the presence of a policy configuration.
This function enforces a naming convention: a dataset repository ID should start with "eval_"
if and only if a policy configuration is provided for evaluation purposes.
Args:
repo_id: The Hugging Face Hub repository ID of the dataset.
policy_cfg: The configuration object for the policy, or `None`.
Raises:
ValueError: If the naming convention is violated.
"""
_, dataset_name = repo_id.split("/")
# either repo_id doesnt start with "eval_" and there is no policy
# or repo_id starts with "eval_" and there is a policy
@@ -204,6 +275,21 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
def sanity_check_dataset_robot_compatibility(
dataset: LeRobotDataset, robot: Robot, fps: int, features: dict
) -> None:
"""
Checks if a dataset's metadata is compatible with the current robot and recording setup.
This function compares key metadata fields (`robot_type`, `fps`, and `features`) from the
dataset against the current configuration to ensure that appended data will be consistent.
Args:
dataset: The `LeRobotDataset` instance to check.
robot: The `Robot` instance representing the current hardware setup.
fps: The current recording frequency (frames per second).
features: The dictionary of features for the current recording session.
Raises:
ValueError: If any of the checked metadata fields do not match.
"""
fields = [
("robot_type", dataset.meta.robot_type, robot.robot_type),
("fps", dataset.fps, fps),

View File

@@ -42,7 +42,23 @@ def log_rerun_data(
observation: dict[str, Any] | None = None,
action: dict[str, Any] | None = None,
) -> None:
"""Log observation and action data to Rerun for visualization."""
"""
Logs observation and action data to Rerun for real-time visualization.
This function iterates through the provided observation and action dictionaries and sends their contents
to the Rerun viewer. It handles different data types appropriately:
- Scalar values (floats, ints) are logged as `rr.Scalar`.
- 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed
from CHW to HWC format and logged as `rr.Image`.
- 1D NumPy arrays are logged as a series of individual scalars, with each element indexed.
- Other multi-dimensional arrays are flattened and logged as individual scalars.
Keys are automatically namespaced with "observation." or "action." if not already present.
Args:
observation: An optional dictionary containing observation data to log.
action: An optional dictionary containing action data to log.
"""
if observation:
for k, v in observation.items():
if v is None: