mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 04:41:24 +00:00
chore(docs): update doctrines pipeline files (#1872)
* docs(processor): update docstrings batch_processor * docs(processor): update docstrings device_processor * docs(processor): update docstrings tokenizer_processor * update docstrings processor_act * update docstrings for pipeline_features * update docstrings for utils * update docstring for processor_diffusion * update docstrings factory * add docstrings to pi0 processor * add docstring to pi0fast processor * add docstring classifier processor * add docstring to sac processor * add docstring smolvla processor * add docstring to tdmpc processor * add docstring to vqbet processor * add docstrings to converters * add docstrings for delta_action_processor * add docstring to gym action processor * update hil processor * add docstring to joint obs processor * add docstring to migrate_normalize_processor * update docstrings normalize processor * update docstring normalize processor * update docstrings observation processor * update docstrings rename_processor * add docstrings robot_kinematic_processor * cleanup rl comments * add docstring to train.py * add docstring to teleoperate.py * add docstrings to phone_processor.py * add docstrings to teleop_phone.py * add docstrings to control_utils.py * add docstrings to visualization_utils.py --------- Co-authored-by: Pepijn <pepijn@huggingface.co>
This commit is contained in:
@@ -27,14 +27,30 @@ def aggregate_pipeline_dataset_features(
|
||||
use_videos: bool = True,
|
||||
patterns: Sequence[str] | None = None,
|
||||
) -> dict[str, dict]:
|
||||
"""
|
||||
Aggregates the pipeline's features and returns a features dict ready for the dataset,
|
||||
filtered to only those keys matching any of the given patterns (for action/state only).
|
||||
"""Aggregates and filters dataset features based on a data processing pipeline.
|
||||
|
||||
- `initial_features`: raw camera specs, e.g. {"front": (h,w,c), ...}
|
||||
- `use_videos`: whether to treat image features as video streams
|
||||
- `patterns`: regexes to filter action & state features; images are included
|
||||
whenever use_videos=True, regardless of patterns.
|
||||
This function determines the final structure of dataset features after applying a series
|
||||
of processing steps defined in a pipeline. It starts with an initial set of hardware
|
||||
features (e.g., camera image shapes), transforms them using the pipeline, and then
|
||||
filters the results.
|
||||
|
||||
Image features are controlled by the `use_videos` flag, while action and state features
|
||||
can be selectively included by matching their keys against the provided regex `patterns`.
|
||||
The final output is formatted to be compatible with Hugging Face Datasets feature dictionaries.
|
||||
|
||||
Args:
|
||||
pipeline (DataProcessorPipeline): The data processing pipeline that defines all
|
||||
feature transformations.
|
||||
initial_features (dict[str, Any]): A dictionary of initial hardware features, where
|
||||
keys are feature names and values are their shapes or types (e.g., camera resolutions).
|
||||
use_videos (bool): If `True`, includes image/video features in the output. Defaults to `True`.
|
||||
patterns (Sequence[str] | None): An optional sequence of regular expression patterns.
|
||||
Only action and state keys that match at least one pattern will be included. If `None`,
|
||||
all action and state keys are kept. Defaults to `None`.
|
||||
|
||||
Returns:
|
||||
dict[str, dict]: A dictionary representing the final dataset features, structured for
|
||||
use with `datasets.Features`.
|
||||
"""
|
||||
import re
|
||||
|
||||
|
||||
@@ -75,13 +75,20 @@ DEFAULT_FEATURES = {
|
||||
|
||||
|
||||
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
"""Flatten a nested dictionary by joining keys with a separator.
|
||||
|
||||
For example:
|
||||
```
|
||||
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
|
||||
>>> print(flatten_dict(dct))
|
||||
{"a/b": 1, "a/c/d": 2, "e": 3}
|
||||
Example:
|
||||
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}
|
||||
>>> print(flatten_dict(dct))
|
||||
{'a/b': 1, 'a/c/d': 2, 'e': 3}
|
||||
|
||||
Args:
|
||||
d (dict): The dictionary to flatten.
|
||||
parent_key (str): The base key to prepend to the keys in this level.
|
||||
sep (str): The separator to use between keys.
|
||||
|
||||
Returns:
|
||||
dict: A flattened dictionary.
|
||||
"""
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
@@ -94,6 +101,20 @@ def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
|
||||
|
||||
def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
"""Unflatten a dictionary with delimited keys into a nested dictionary.
|
||||
|
||||
Example:
|
||||
>>> flat_dct = {"a/b": 1, "a/c/d": 2, "e": 3}
|
||||
>>> print(unflatten_dict(flat_dct))
|
||||
{'a': {'b': 1, 'c': {'d': 2}}, 'e': 3}
|
||||
|
||||
Args:
|
||||
d (dict): A dictionary with flattened keys.
|
||||
sep (str): The separator used in the keys.
|
||||
|
||||
Returns:
|
||||
dict: A nested dictionary.
|
||||
"""
|
||||
outdict = {}
|
||||
for key, value in d.items():
|
||||
parts = key.split(sep)
|
||||
@@ -107,6 +128,16 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
|
||||
|
||||
def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
|
||||
"""Access an item in a nested dictionary using a flattened key.
|
||||
|
||||
Args:
|
||||
obj (DictLike): The nested dictionary-like object.
|
||||
flattened_key (str): A key with parts separated by `sep`.
|
||||
sep (str): The separator used in the flattened key.
|
||||
|
||||
Returns:
|
||||
Any: The value from the nested dictionary.
|
||||
"""
|
||||
split_keys = flattened_key.split(sep)
|
||||
getter = obj[split_keys[0]]
|
||||
if len(split_keys) == 1:
|
||||
@@ -119,6 +150,19 @@ def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
|
||||
|
||||
|
||||
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
"""Serialize a dictionary containing tensors or numpy arrays to be JSON-compatible.
|
||||
|
||||
Converts torch.Tensor, np.ndarray, and np.generic types to lists or native Python types.
|
||||
|
||||
Args:
|
||||
stats (dict): A dictionary that may contain non-serializable numeric types.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with all values converted to JSON-serializable types.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If a value has an unsupported type.
|
||||
"""
|
||||
serialized_dict = {}
|
||||
for key, value in flatten_dict(stats).items():
|
||||
if isinstance(value, (torch.Tensor, np.ndarray)):
|
||||
@@ -133,6 +177,17 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
|
||||
|
||||
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
"""Embed image bytes into the dataset table before saving to Parquet.
|
||||
|
||||
This function prepares a Hugging Face dataset for serialization by converting
|
||||
image objects into an embedded format that can be stored in Arrow/Parquet.
|
||||
|
||||
Args:
|
||||
dataset (datasets.Dataset): The input dataset, possibly containing image features.
|
||||
|
||||
Returns:
|
||||
datasets.Dataset: The dataset with images embedded in the table storage.
|
||||
"""
|
||||
# Embed image bytes into the table before saving to parquet
|
||||
format = dataset.format
|
||||
dataset = dataset.with_format("arrow")
|
||||
@@ -142,38 +197,94 @@ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
|
||||
|
||||
def load_json(fpath: Path) -> Any:
|
||||
"""Load data from a JSON file.
|
||||
|
||||
Args:
|
||||
fpath (Path): Path to the JSON file.
|
||||
|
||||
Returns:
|
||||
Any: The data loaded from the JSON file.
|
||||
"""
|
||||
with open(fpath) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(data: dict, fpath: Path) -> None:
|
||||
"""Write data to a JSON file.
|
||||
|
||||
Creates parent directories if they don't exist.
|
||||
|
||||
Args:
|
||||
data (dict): The dictionary to write.
|
||||
fpath (Path): The path to the output JSON file.
|
||||
"""
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def load_jsonlines(fpath: Path) -> list[Any]:
|
||||
"""Load data from a JSON Lines file.
|
||||
|
||||
Args:
|
||||
fpath (Path): Path to the JSON Lines file.
|
||||
|
||||
Returns:
|
||||
list[Any]: A list of objects loaded from the file.
|
||||
"""
|
||||
with jsonlines.open(fpath, "r") as reader:
|
||||
return list(reader)
|
||||
|
||||
|
||||
def write_jsonlines(data: dict, fpath: Path) -> None:
|
||||
"""Write a list of dictionaries to a JSON Lines file.
|
||||
|
||||
Creates parent directories if they don't exist.
|
||||
|
||||
Args:
|
||||
data (dict): The list of dictionaries to write.
|
||||
fpath (Path): The path to the output JSON Lines file.
|
||||
"""
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(data)
|
||||
|
||||
|
||||
def append_jsonlines(data: dict, fpath: Path) -> None:
|
||||
"""Append a dictionary to a JSON Lines file.
|
||||
|
||||
Creates parent directories if they don't exist.
|
||||
|
||||
Args:
|
||||
data (dict): The dictionary to append.
|
||||
fpath (Path): The path to the JSON Lines file.
|
||||
"""
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with jsonlines.open(fpath, "a") as writer:
|
||||
writer.write(data)
|
||||
|
||||
|
||||
def write_info(info: dict, local_dir: Path):
|
||||
"""Write dataset info metadata to its standard file path.
|
||||
|
||||
Args:
|
||||
info (dict): The dataset information dictionary.
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
"""
|
||||
write_json(info, local_dir / INFO_PATH)
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
"""Load dataset info metadata from its standard file path.
|
||||
|
||||
Also converts shape lists to tuples for consistency.
|
||||
|
||||
Args:
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
|
||||
Returns:
|
||||
dict: The dataset information dictionary.
|
||||
"""
|
||||
info = load_json(local_dir / INFO_PATH)
|
||||
for ft in info["features"].values():
|
||||
ft["shape"] = tuple(ft["shape"])
|
||||
@@ -181,16 +292,40 @@ def load_info(local_dir: Path) -> dict:
|
||||
|
||||
|
||||
def write_stats(stats: dict, local_dir: Path):
|
||||
"""Serialize and write dataset statistics to their standard file path.
|
||||
|
||||
Args:
|
||||
stats (dict): The statistics dictionary (can contain tensors/numpy arrays).
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
"""
|
||||
serialized_stats = serialize_dict(stats)
|
||||
write_json(serialized_stats, local_dir / STATS_PATH)
|
||||
|
||||
|
||||
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Recursively cast numerical values in a stats dictionary to numpy arrays.
|
||||
|
||||
Args:
|
||||
stats (dict): The statistics dictionary.
|
||||
|
||||
Returns:
|
||||
dict: The statistics dictionary with values cast to numpy arrays.
|
||||
"""
|
||||
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Load dataset statistics and cast numerical values to numpy arrays.
|
||||
|
||||
Returns None if the stats file doesn't exist.
|
||||
|
||||
Args:
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
|
||||
Returns:
|
||||
A dictionary of statistics or None if the file is not found.
|
||||
"""
|
||||
if not (local_dir / STATS_PATH).exists():
|
||||
return None
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
@@ -198,6 +333,13 @@ def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
||||
|
||||
|
||||
def write_task(task_index: int, task: dict, local_dir: Path):
|
||||
"""Write a single task to the tasks metadata file.
|
||||
|
||||
Args:
|
||||
task_index (int): The index of the task.
|
||||
task (dict): The task description dictionary.
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
"""
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
@@ -206,6 +348,16 @@ def write_task(task_index: int, task: dict, local_dir: Path):
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
"""Load tasks from the tasks metadata file.
|
||||
|
||||
Args:
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A dictionary mapping task index to task description.
|
||||
- A dictionary mapping task description to task index.
|
||||
"""
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||
@@ -213,15 +365,36 @@ def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
|
||||
|
||||
def write_episode(episode: dict, local_dir: Path):
|
||||
"""Write a single episode's metadata to the episodes metadata file.
|
||||
|
||||
Args:
|
||||
episode (dict): The episode metadata dictionary.
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
"""
|
||||
append_jsonlines(episode, local_dir / EPISODES_PATH)
|
||||
|
||||
|
||||
def load_episodes(local_dir: Path) -> dict:
|
||||
"""Load episode metadata from the episodes metadata file.
|
||||
|
||||
Args:
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary mapping episode index to episode metadata.
|
||||
"""
|
||||
episodes = load_jsonlines(local_dir / EPISODES_PATH)
|
||||
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||
|
||||
|
||||
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
||||
"""Write statistics for a single episode to the episode stats file.
|
||||
|
||||
Args:
|
||||
episode_index (int): The index of the episode.
|
||||
episode_stats (dict): The statistics for the episode.
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
"""
|
||||
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
|
||||
# is a dictionary of stats and not an integer.
|
||||
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
|
||||
@@ -229,6 +402,14 @@ def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path
|
||||
|
||||
|
||||
def load_episodes_stats(local_dir: Path) -> dict:
|
||||
"""Load per-episode statistics from the episode stats file.
|
||||
|
||||
Args:
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary mapping episode index to its statistics dictionary.
|
||||
"""
|
||||
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
|
||||
return {
|
||||
item["episode_index"]: cast_stats_to_numpy(item["stats"])
|
||||
@@ -239,12 +420,35 @@ def load_episodes_stats(local_dir: Path) -> dict:
|
||||
def backward_compatible_episodes_stats(
|
||||
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
||||
) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Create a per-episode stats dictionary from a global stats dictionary.
|
||||
|
||||
This is used for backward compatibility with older datasets that only had global stats.
|
||||
|
||||
Args:
|
||||
stats (dict): The global dataset statistics.
|
||||
episodes (list[int]): A list of episode indices.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary mapping each episode index to the global stats.
|
||||
"""
|
||||
return dict.fromkeys(episodes, stats)
|
||||
|
||||
|
||||
def load_image_as_numpy(
|
||||
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
|
||||
) -> np.ndarray:
|
||||
"""Load an image from a file into a numpy array.
|
||||
|
||||
Args:
|
||||
fpath (str | Path): Path to the image file.
|
||||
dtype (np.dtype): The desired data type of the output array. If floating,
|
||||
pixels are scaled to [0, 1].
|
||||
channel_first (bool): If True, converts the image to (C, H, W) format.
|
||||
Otherwise, it remains in (H, W, C) format.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The image as a numpy array.
|
||||
"""
|
||||
img = PILImage.open(fpath).convert("RGB")
|
||||
img_array = np.array(img, dtype=dtype)
|
||||
if channel_first: # (H, W, C) -> (C, H, W)
|
||||
@@ -255,10 +459,19 @@ def load_image_as_numpy(
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
a channel last representation (h w c) of uint8 type, to a torch image representation
|
||||
with channel first (c h w) of float32 type in range [0,1].
|
||||
"""Convert a batch from a Hugging Face dataset to torch tensors.
|
||||
|
||||
This transform function converts items from Hugging Face dataset format (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
|
||||
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
|
||||
types are converted to torch.tensor.
|
||||
|
||||
Args:
|
||||
items_dict (dict): A dictionary representing a batch of data from a
|
||||
Hugging Face dataset.
|
||||
|
||||
Returns:
|
||||
dict: The batch with items converted to torch tensors.
|
||||
"""
|
||||
for key in items_dict:
|
||||
first_item = items_dict[key][0]
|
||||
@@ -273,6 +486,14 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
|
||||
|
||||
def is_valid_version(version: str) -> bool:
|
||||
"""Check if a string is a valid PEP 440 version.
|
||||
|
||||
Args:
|
||||
version (str): The version string to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the version string is valid, False otherwise.
|
||||
"""
|
||||
try:
|
||||
packaging.version.parse(version)
|
||||
return True
|
||||
@@ -286,6 +507,18 @@ def check_version_compatibility(
|
||||
current_version: str | packaging.version.Version,
|
||||
enforce_breaking_major: bool = True,
|
||||
) -> None:
|
||||
"""Check for version compatibility between a dataset and the current codebase.
|
||||
|
||||
Args:
|
||||
repo_id (str): The repository ID for logging purposes.
|
||||
version_to_check (str | packaging.version.Version): The version of the dataset.
|
||||
current_version (str | packaging.version.Version): The current version of the codebase.
|
||||
enforce_breaking_major (bool): If True, raise an error on major version mismatch.
|
||||
|
||||
Raises:
|
||||
BackwardCompatibilityError: If the dataset version is from a newer, incompatible
|
||||
major version of the codebase.
|
||||
"""
|
||||
v_check = (
|
||||
packaging.version.parse(version_to_check)
|
||||
if not isinstance(version_to_check, packaging.version.Version)
|
||||
@@ -303,7 +536,14 @@ def check_version_compatibility(
|
||||
|
||||
|
||||
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
||||
"""Returns available valid versions (branches and tags) on given repo."""
|
||||
"""Return available valid versions (branches and tags) on a given Hub repo.
|
||||
|
||||
Args:
|
||||
repo_id (str): The repository ID on the Hugging Face Hub.
|
||||
|
||||
Returns:
|
||||
list[packaging.version.Version]: A list of valid versions found.
|
||||
"""
|
||||
api = HfApi()
|
||||
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
|
||||
@@ -316,9 +556,22 @@ def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
||||
|
||||
|
||||
def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str:
|
||||
"""
|
||||
Returns the version if available on repo or the latest compatible one.
|
||||
Otherwise, will throw a `CompatibilityError`.
|
||||
"""Return the specified version if available on repo, or the latest compatible one.
|
||||
|
||||
If the exact version is not found, it looks for the latest version with the
|
||||
same major version number that is less than or equal to the target minor version.
|
||||
|
||||
Args:
|
||||
repo_id (str): The repository ID on the Hugging Face Hub.
|
||||
version (str | packaging.version.Version): The target version.
|
||||
|
||||
Returns:
|
||||
str: The safe version string (e.g., "v1.2.3") to use as a revision.
|
||||
|
||||
Raises:
|
||||
RevisionNotFoundError: If the repo has no version tags.
|
||||
BackwardCompatibilityError: If only older major versions are available.
|
||||
ForwardCompatibilityError: If only newer major versions are available.
|
||||
"""
|
||||
target_version = (
|
||||
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
|
||||
@@ -360,6 +613,17 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
||||
|
||||
|
||||
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
"""Convert a LeRobot features dictionary to a `datasets.Features` object.
|
||||
|
||||
Args:
|
||||
features (dict): A LeRobot-style feature dictionary.
|
||||
|
||||
Returns:
|
||||
datasets.Features: The corresponding Hugging Face `datasets.Features` object.
|
||||
|
||||
Raises:
|
||||
ValueError: If a feature has an unsupported shape.
|
||||
"""
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
@@ -387,6 +651,14 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
|
||||
|
||||
def _validate_feature_names(features: dict[str, dict]) -> None:
|
||||
"""Validate that feature names do not contain invalid characters.
|
||||
|
||||
Args:
|
||||
features (dict): The LeRobot features dictionary.
|
||||
|
||||
Raises:
|
||||
ValueError: If any feature name contains '/'.
|
||||
"""
|
||||
invalid_features = {name: ft for name, ft in features.items() if "/" in name}
|
||||
if invalid_features:
|
||||
raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
|
||||
@@ -395,6 +667,22 @@ def _validate_feature_names(features: dict[str, dict]) -> None:
|
||||
def hw_to_dataset_features(
|
||||
hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
|
||||
) -> dict[str, dict]:
|
||||
"""Convert hardware-specific features to a LeRobot dataset feature dictionary.
|
||||
|
||||
This function takes a dictionary describing hardware outputs (like joint states
|
||||
or camera image shapes) and formats it into the standard LeRobot feature
|
||||
specification.
|
||||
|
||||
Args:
|
||||
hw_features (dict): Dictionary mapping feature names to their type (float for
|
||||
joints) or shape (tuple for images).
|
||||
prefix (str): The prefix to add to the feature keys (e.g., "observation"
|
||||
or "action").
|
||||
use_video (bool): If True, image features are marked as "video", otherwise "image".
|
||||
|
||||
Returns:
|
||||
dict: A LeRobot features dictionary.
|
||||
"""
|
||||
features = {}
|
||||
joint_fts = {key: ftype for key, ftype in hw_features.items() if ftype is float}
|
||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
||||
@@ -427,6 +715,20 @@ def hw_to_dataset_features(
|
||||
def build_dataset_frame(
|
||||
ds_features: dict[str, dict], values: dict[str, Any], prefix: str
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Construct a single data frame from raw values based on dataset features.
|
||||
|
||||
A "frame" is a dictionary containing all the data for a single timestep,
|
||||
formatted as numpy arrays according to the feature specification.
|
||||
|
||||
Args:
|
||||
ds_features (dict): The LeRobot dataset features dictionary.
|
||||
values (dict): A dictionary of raw values from the hardware/environment.
|
||||
prefix (str): The prefix to filter features by (e.g., "observation"
|
||||
or "action").
|
||||
|
||||
Returns:
|
||||
dict: A dictionary representing a single frame of data.
|
||||
"""
|
||||
frame = {}
|
||||
for key, ft in ds_features.items():
|
||||
if key in DEFAULT_FEATURES or not key.startswith(prefix):
|
||||
@@ -440,6 +742,21 @@ def build_dataset_frame(
|
||||
|
||||
|
||||
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
"""Convert dataset features to policy features.
|
||||
|
||||
This function transforms the dataset's feature specification into a format
|
||||
that a policy can use, classifying features by type (e.g., visual, state,
|
||||
action) and ensuring correct shapes (e.g., channel-first for images).
|
||||
|
||||
Args:
|
||||
features (dict): The LeRobot dataset features dictionary.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary mapping feature keys to `PolicyFeature` objects.
|
||||
|
||||
Raises:
|
||||
ValueError: If an image feature does not have a 3D shape.
|
||||
"""
|
||||
# TODO(aliberts): Implement "type" in dataset features and simplify this
|
||||
policy_features = {}
|
||||
for key, ft in features.items():
|
||||
@@ -471,11 +788,19 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
|
||||
|
||||
def combine_feature_dicts(*dicts: dict) -> dict:
|
||||
"""
|
||||
Merge LeRobot grouped feature dicts.
|
||||
"""Merge LeRobot grouped feature dicts.
|
||||
|
||||
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
|
||||
- For others (observation.images.*), last one wins (if they are identical).
|
||||
- For others (e.g. `observation.images.*`), the last one wins (if they are identical).
|
||||
|
||||
Args:
|
||||
*dicts: A variable number of LeRobot feature dictionaries to merge.
|
||||
|
||||
Returns:
|
||||
dict: A single merged feature dictionary.
|
||||
|
||||
Raises:
|
||||
ValueError: If there's a dtype mismatch for a feature being merged.
|
||||
"""
|
||||
out: dict = {}
|
||||
for d in dicts:
|
||||
@@ -521,6 +846,18 @@ def create_empty_dataset_info(
|
||||
use_videos: bool,
|
||||
robot_type: str | None = None,
|
||||
) -> dict:
|
||||
"""Create a template dictionary for a new dataset's `info.json`.
|
||||
|
||||
Args:
|
||||
codebase_version (str): The version of the LeRobot codebase.
|
||||
fps (int): The frames per second of the data.
|
||||
features (dict): The LeRobot features dictionary for the dataset.
|
||||
use_videos (bool): Whether the dataset will store videos.
|
||||
robot_type (str | None): The type of robot used, if any.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with the initial dataset metadata.
|
||||
"""
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"robot_type": robot_type,
|
||||
@@ -541,6 +878,18 @@ def create_empty_dataset_info(
|
||||
def get_episode_data_index(
|
||||
episode_dicts: dict[dict], episodes: list[int] | None = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Calculate the start and end indices for each episode in a flattened dataset.
|
||||
|
||||
Args:
|
||||
episode_dicts (dict): A dictionary mapping episode index to episode metadata,
|
||||
which must contain a "length" key.
|
||||
episodes (list[int] | None): An optional list of episode indices to consider.
|
||||
If None, all episodes are used.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with "from" and "to" keys, containing torch tensors
|
||||
with the start and end indices for each episode.
|
||||
"""
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
|
||||
if episodes is not None:
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||
@@ -560,16 +909,19 @@ def check_timestamps_sync(
|
||||
tolerance_s: float,
|
||||
raise_value_error: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
|
||||
to account for possible numerical error.
|
||||
"""Check if timestamps are separated by (1/fps) +/- tolerance.
|
||||
|
||||
This check ensures that consecutive timestamps within an episode are spaced
|
||||
correctly, accounting for possible numerical errors. It ignores the boundaries
|
||||
between episodes.
|
||||
|
||||
Args:
|
||||
timestamps (np.ndarray): Array of timestamps in seconds.
|
||||
episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
|
||||
episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
|
||||
episode_data_index (dict): A dictionary that includes 'to',
|
||||
which identifies indices for the end of each episode.
|
||||
fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
|
||||
fps (int): Frames per second. Used to check the expected difference between
|
||||
consecutive timestamps.
|
||||
tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
|
||||
raise_value_error (bool): Whether to raise a ValueError if the check fails.
|
||||
|
||||
@@ -577,7 +929,8 @@ def check_timestamps_sync(
|
||||
bool: True if all checked timestamp differences lie within tolerance, False otherwise.
|
||||
|
||||
Raises:
|
||||
ValueError: If the check fails and `raise_value_error` is True.
|
||||
ValueError: If `timestamps` and `episode_indices` shapes do not match, or if
|
||||
the check fails and `raise_value_error` is True.
|
||||
"""
|
||||
if timestamps.shape != episode_indices.shape:
|
||||
raise ValueError(
|
||||
@@ -628,9 +981,23 @@ def check_timestamps_sync(
|
||||
def check_delta_timestamps(
|
||||
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||
) -> bool:
|
||||
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
|
||||
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
|
||||
actual timestamps from the dataset.
|
||||
"""Check if delta timestamps are multiples of 1/fps +/- tolerance.
|
||||
|
||||
This ensures that adding these delta timestamps to any existing timestamp in
|
||||
the dataset will result in a value that aligns with the dataset's frame rate.
|
||||
|
||||
Args:
|
||||
delta_timestamps (dict): A dictionary where values are lists of time
|
||||
deltas in seconds.
|
||||
fps (int): The frames per second of the dataset.
|
||||
tolerance_s (float): The allowed tolerance in seconds.
|
||||
raise_value_error (bool): If True, raises an error on failure.
|
||||
|
||||
Returns:
|
||||
bool: True if all deltas are valid, False otherwise.
|
||||
|
||||
Raises:
|
||||
ValueError: If any delta is outside the tolerance and `raise_value_error` is True.
|
||||
"""
|
||||
outside_tolerance = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
@@ -656,6 +1023,15 @@ def check_delta_timestamps(
|
||||
|
||||
|
||||
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
|
||||
"""Convert delta timestamps in seconds to delta indices in frames.
|
||||
|
||||
Args:
|
||||
delta_timestamps (dict): A dictionary of time deltas in seconds.
|
||||
fps (int): The frames per second of the dataset.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of frame delta indices.
|
||||
"""
|
||||
delta_indices = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
delta_indices[key] = [round(d * fps) for d in delta_ts]
|
||||
@@ -664,9 +1040,17 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic
|
||||
|
||||
|
||||
def cycle(iterable):
|
||||
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
|
||||
"""Create a dataloader-safe cyclical iterator.
|
||||
|
||||
See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe.
|
||||
This is an equivalent of `itertools.cycle` but is safe for use with
|
||||
PyTorch DataLoaders with multiple workers.
|
||||
See https://github.com/pytorch/pytorch/issues/23900 for details.
|
||||
|
||||
Args:
|
||||
iterable: The iterable to cycle over.
|
||||
|
||||
Yields:
|
||||
Items from the iterable, restarting from the beginning when exhausted.
|
||||
"""
|
||||
iterator = iter(iterable)
|
||||
while True:
|
||||
@@ -677,8 +1061,14 @@ def cycle(iterable):
|
||||
|
||||
|
||||
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
|
||||
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
|
||||
exists before creating it.
|
||||
"""Create a branch on an existing Hugging Face repo.
|
||||
|
||||
Deletes the branch if it already exists before creating it.
|
||||
|
||||
Args:
|
||||
repo_id (str): The ID of the repository.
|
||||
branch (str): The name of the branch to create.
|
||||
repo_type (str | None): The type of the repository (e.g., "dataset").
|
||||
"""
|
||||
api = HfApi()
|
||||
|
||||
@@ -696,9 +1086,20 @@ def create_lerobot_dataset_card(
|
||||
dataset_info: dict | None = None,
|
||||
**kwargs,
|
||||
) -> DatasetCard:
|
||||
"""
|
||||
Keyword arguments will be used to replace values in src/lerobot/datasets/card_template.md.
|
||||
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
|
||||
"""Create a `DatasetCard` for a LeRobot dataset.
|
||||
|
||||
Keyword arguments are used to replace values in the card template.
|
||||
Note: If specified, `license` must be a valid license identifier from
|
||||
https://huggingface.co/docs/hub/repositories-licenses.
|
||||
|
||||
Args:
|
||||
tags (list | None): A list of tags to add to the dataset card.
|
||||
dataset_info (dict | None): The dataset's info dictionary, which will
|
||||
be displayed on the card.
|
||||
**kwargs: Additional keyword arguments to populate the card template.
|
||||
|
||||
Returns:
|
||||
DatasetCard: The generated dataset card object.
|
||||
"""
|
||||
card_tags = ["LeRobot"]
|
||||
|
||||
@@ -730,19 +1131,16 @@ def create_lerobot_dataset_card(
|
||||
|
||||
|
||||
class IterableNamespace(SimpleNamespace):
|
||||
"""
|
||||
A namespace object that supports both dictionary-like iteration and dot notation access.
|
||||
Automatically converts nested dictionaries into IterableNamespaces.
|
||||
"""A namespace object that supports both dictionary-like iteration and dot notation.
|
||||
|
||||
This class extends SimpleNamespace to provide:
|
||||
- Dictionary-style iteration over keys
|
||||
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
|
||||
- Dictionary-like methods: items(), keys(), values()
|
||||
- Recursive conversion of nested dictionaries
|
||||
This class extends `SimpleNamespace` to provide dictionary-style iteration,
|
||||
access to items via brackets (`obj["key"]`), and dictionary-like methods
|
||||
(`items()`, `keys()`, `values()`). Nested dictionaries are recursively
|
||||
converted to `IterableNamespace` objects.
|
||||
|
||||
Args:
|
||||
dictionary: Optional dictionary to initialize the namespace
|
||||
**kwargs: Additional keyword arguments passed to SimpleNamespace
|
||||
dictionary (dict, optional): A dictionary to initialize the namespace with.
|
||||
**kwargs: Additional keyword arguments to initialize the namespace.
|
||||
|
||||
Examples:
|
||||
>>> data = {"name": "Alice", "details": {"age": 25}}
|
||||
@@ -756,10 +1154,16 @@ class IterableNamespace(SimpleNamespace):
|
||||
>>> for key, value in ns.items():
|
||||
... print(f"{key}: {value}")
|
||||
name: Alice
|
||||
details: IterableNamespace(age=25)
|
||||
details: <__main__.IterableNamespace object at ...>
|
||||
"""
|
||||
|
||||
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
|
||||
"""Initialize the IterableNamespace.
|
||||
|
||||
Args:
|
||||
dictionary (dict, optional): Dictionary to populate the namespace.
|
||||
**kwargs: Keyword arguments to populate the namespace.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if dictionary is not None:
|
||||
for key, value in dictionary.items():
|
||||
@@ -769,22 +1173,46 @@ class IterableNamespace(SimpleNamespace):
|
||||
setattr(self, key, value)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
"""Return an iterator over the keys of the namespace."""
|
||||
return iter(vars(self))
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""Allow bracket-style access to attributes.
|
||||
|
||||
Args:
|
||||
key (str): The name of the attribute.
|
||||
|
||||
Returns:
|
||||
Any: The value of the attribute.
|
||||
"""
|
||||
return vars(self)[key]
|
||||
|
||||
def items(self):
|
||||
"""Return a view of the namespace's (key, value) pairs."""
|
||||
return vars(self).items()
|
||||
|
||||
def values(self):
|
||||
"""Return a view of the namespace's values."""
|
||||
return vars(self).values()
|
||||
|
||||
def keys(self):
|
||||
"""Return a view of the namespace's keys."""
|
||||
return vars(self).keys()
|
||||
|
||||
|
||||
def validate_frame(frame: dict, features: dict):
|
||||
"""Validate a single data frame against the dataset's feature specification.
|
||||
|
||||
Checks for missing/extra features, and validates the dtype and shape of each
|
||||
provided feature.
|
||||
|
||||
Args:
|
||||
frame (dict): The data frame to validate.
|
||||
features (dict): The LeRobot features dictionary for the dataset.
|
||||
|
||||
Raises:
|
||||
ValueError: If the frame does not match the feature specification.
|
||||
"""
|
||||
expected_features = set(features) - set(DEFAULT_FEATURES)
|
||||
actual_features = set(frame)
|
||||
|
||||
@@ -799,6 +1227,15 @@ def validate_frame(frame: dict, features: dict):
|
||||
|
||||
|
||||
def validate_features_presence(actual_features: set[str], expected_features: set[str]):
|
||||
"""Check for missing or extra features in a frame.
|
||||
|
||||
Args:
|
||||
actual_features (set[str]): The set of feature names present in the frame.
|
||||
expected_features (set[str]): The set of feature names expected in the frame.
|
||||
|
||||
Returns:
|
||||
str: An error message string if there's a mismatch, otherwise an empty string.
|
||||
"""
|
||||
error_message = ""
|
||||
missing_features = expected_features - actual_features
|
||||
extra_features = actual_features - expected_features
|
||||
@@ -814,6 +1251,19 @@ def validate_features_presence(actual_features: set[str], expected_features: set
|
||||
|
||||
|
||||
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
|
||||
"""Validate the dtype and shape of a single feature's value.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
feature (dict): The feature specification from the LeRobot features dictionary.
|
||||
value: The value of the feature to validate.
|
||||
|
||||
Returns:
|
||||
str: An error message if validation fails, otherwise an empty string.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the feature dtype is not supported for validation.
|
||||
"""
|
||||
expected_dtype = feature["dtype"]
|
||||
expected_shape = feature["shape"]
|
||||
if is_valid_numpy_dtype_string(expected_dtype):
|
||||
@@ -829,6 +1279,17 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray
|
||||
def validate_feature_numpy_array(
|
||||
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
|
||||
):
|
||||
"""Validate a feature that is expected to be a numpy array.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
expected_dtype (str): The expected numpy dtype as a string.
|
||||
expected_shape (list[int]): The expected shape.
|
||||
value (np.ndarray): The numpy array to validate.
|
||||
|
||||
Returns:
|
||||
str: An error message if validation fails, otherwise an empty string.
|
||||
"""
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_dtype = value.dtype
|
||||
@@ -846,6 +1307,18 @@ def validate_feature_numpy_array(
|
||||
|
||||
|
||||
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
|
||||
"""Validate a feature that is expected to be an image or video frame.
|
||||
|
||||
Accepts `np.ndarray` (channel-first or channel-last) or `PIL.Image.Image`.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
expected_shape (list[str]): The expected shape (C, H, W).
|
||||
value: The image data to validate.
|
||||
|
||||
Returns:
|
||||
str: An error message if validation fails, otherwise an empty string.
|
||||
"""
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
@@ -862,12 +1335,35 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value:
|
||||
|
||||
|
||||
def validate_feature_string(name: str, value: str):
|
||||
"""Validate a feature that is expected to be a string.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
value (str): The value to validate.
|
||||
|
||||
Returns:
|
||||
str: An error message if validation fails, otherwise an empty string.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
||||
return ""
|
||||
|
||||
|
||||
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
|
||||
"""Validate the episode buffer before it's written to disk.
|
||||
|
||||
Ensures the buffer has the required keys, contains at least one frame, and
|
||||
has features consistent with the dataset's specification.
|
||||
|
||||
Args:
|
||||
episode_buffer (dict): The buffer containing data for a single episode.
|
||||
total_episodes (int): The current total number of episodes in the dataset.
|
||||
features (dict): The LeRobot features dictionary for the dataset.
|
||||
|
||||
Raises:
|
||||
ValueError: If the buffer is invalid.
|
||||
NotImplementedError: If the episode index is manually set and doesn't match.
|
||||
"""
|
||||
if "size" not in episode_buffer:
|
||||
raise ValueError("size key not found in episode_buffer")
|
||||
|
||||
|
||||
@@ -34,6 +34,24 @@ def make_act_pre_post_processors(
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""Creates the pre- and post-processing pipelines for the ACT policy.
|
||||
|
||||
The pre-processing pipeline handles normalization, batching, and device placement for the model inputs.
|
||||
The post-processing pipeline handles unnormalization and moves the model outputs back to the CPU.
|
||||
|
||||
Args:
|
||||
config (ACTConfig): The ACT policy configuration object.
|
||||
dataset_stats (dict[str, dict[str, torch.Tensor]] | None): A dictionary containing dataset
|
||||
statistics (e.g., mean and std) used for normalization. Defaults to None.
|
||||
preprocessor_kwargs (ProcessorKwargs | None): Extra keyword arguments to pass to the
|
||||
preprocessor pipeline's constructor. Defaults to None.
|
||||
postprocessor_kwargs (ProcessorKwargs | None): Extra keyword arguments to pass to the
|
||||
postprocessor pipeline's constructor. Defaults to None.
|
||||
|
||||
Returns:
|
||||
tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]: A tuple containing the
|
||||
pre-processor pipeline and the post-processor pipeline.
|
||||
"""
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
|
||||
@@ -35,6 +35,32 @@ def make_diffusion_pre_post_processors(
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for a diffusion policy.
|
||||
|
||||
The pre-processing pipeline prepares the input data for the model by:
|
||||
1. Renaming features (if a `rename_map` is provided in `preprocessor_kwargs`).
|
||||
2. Normalizing the input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Moving the data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving the data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the diffusion policy,
|
||||
containing feature definitions, normalization mappings, and device information.
|
||||
dataset_stats: A dictionary of statistics used for normalization.
|
||||
Defaults to None.
|
||||
preprocessor_kwargs: Additional keyword arguments
|
||||
for the pre-processor pipeline. Defaults to an empty dictionary.
|
||||
postprocessor_kwargs: Additional keyword arguments
|
||||
for the post-processor pipeline. Defaults to an empty dictionary.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
|
||||
@@ -43,7 +43,22 @@ from lerobot.processor import PolicyProcessorPipeline, ProcessorKwargs
|
||||
|
||||
|
||||
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
|
||||
"""
|
||||
Retrieves a policy class by its registered name.
|
||||
|
||||
This function uses dynamic imports to avoid loading all policy classes into memory
|
||||
at once, improving startup time and reducing dependencies.
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla".
|
||||
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the policy name is not recognized.
|
||||
"""
|
||||
if name == "tdmpc":
|
||||
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
|
||||
@@ -85,6 +100,24 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
|
||||
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
"""
|
||||
Instantiates a policy configuration object based on the policy type.
|
||||
|
||||
This factory function simplifies the creation of policy configuration objects by
|
||||
mapping a string identifier to the corresponding config class.
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"diffusion", "act", "vqbet", "pi0", "pi0fast", "sac", "smolvla",
|
||||
"reward_classifier".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
An instance of a `PreTrainedConfig` subclass.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `policy_type` is not recognized.
|
||||
"""
|
||||
if policy_type == "tdmpc":
|
||||
return TDMPCConfig(**kwargs)
|
||||
elif policy_type == "diffusion":
|
||||
@@ -108,7 +141,21 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
|
||||
class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
"""Keyword arguments for the processor config."""
|
||||
"""
|
||||
A TypedDict defining the keyword arguments for processor configuration.
|
||||
|
||||
This provides type hints for the optional arguments passed to `make_pre_post_processors`,
|
||||
improving code clarity and enabling static analysis.
|
||||
|
||||
Attributes:
|
||||
preprocessor_config_filename: The filename for the preprocessor configuration.
|
||||
postprocessor_config_filename: The filename for the postprocessor configuration.
|
||||
preprocessor_overrides: A dictionary of overrides for the preprocessor configuration.
|
||||
postprocessor_overrides: A dictionary of overrides for the postprocessor configuration.
|
||||
dataset_stats: Dataset statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the `PolicyProcessorPipeline`.
|
||||
postprocessor_kwargs: Additional arguments for the `PolicyProcessorPipeline`.
|
||||
"""
|
||||
|
||||
preprocessor_config_filename: str | None
|
||||
postprocessor_config_filename: str | None
|
||||
@@ -124,22 +171,27 @@ def make_pre_post_processors(
|
||||
pretrained_path: str | None = None,
|
||||
**kwargs: Unpack[ProcessorConfigKwargs],
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""Make a processor instance for a given policy type.
|
||||
"""
|
||||
Create or load pre- and post-processor pipelines for a given policy.
|
||||
|
||||
This function creates the appropriate processor configuration based on the policy type.
|
||||
Each policy type has its own processor with specific preprocessing steps.
|
||||
This function acts as a factory. It can either load existing processor pipelines
|
||||
from a pretrained path or create new ones from scratch based on the policy
|
||||
configuration. Each policy type has a dedicated factory function for its
|
||||
processors (e.g., `make_tdmpc_pre_post_processors`).
|
||||
|
||||
Args:
|
||||
policy_cfg: The config of the policy to create a processor for (e.g., "act", "diffusion", etc.)
|
||||
pretrained_path: Optional path to load a pretrained processor from. If provided, loads
|
||||
the processor from this path instead of creating a new one.
|
||||
**kwargs: Additional keyword arguments passed to the processor creation.
|
||||
policy_cfg: The configuration of the policy for which to create processors.
|
||||
pretrained_path: An optional path to load pretrained processor pipelines from.
|
||||
If provided, pipelines are loaded from this path.
|
||||
**kwargs: Keyword arguments for processor configuration, as defined in
|
||||
`ProcessorConfigKwargs`.
|
||||
|
||||
Returns:
|
||||
Tuple of (input_processor, output_processor) for the policy.
|
||||
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the policy type doesn't have a processor implemented.
|
||||
NotImplementedError: If a processor factory is not implemented for the given
|
||||
policy configuration type.
|
||||
"""
|
||||
if pretrained_path:
|
||||
# Extract preprocessor and postprocessor kwargs
|
||||
@@ -269,25 +321,29 @@ def make_policy(
|
||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||
env_cfg: EnvConfig | None = None,
|
||||
) -> PreTrainedPolicy:
|
||||
"""Make an instance of a policy class.
|
||||
"""
|
||||
Instantiate a policy model.
|
||||
|
||||
This function exists because (for now) we need to parse features from either a dataset or an environment
|
||||
in order to properly dimension and instantiate a policy for that dataset or environment.
|
||||
This factory function handles the logic of creating a policy, which requires
|
||||
determining the input and output feature shapes. These shapes can be derived
|
||||
either from a `LeRobotDatasetMetadata` object or an `EnvConfig` object. The function
|
||||
can either initialize a new policy from scratch or load a pretrained one.
|
||||
|
||||
Args:
|
||||
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
|
||||
be loaded with the weights from that path.
|
||||
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
|
||||
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
|
||||
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
|
||||
provided if ds_meta is not. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: Either ds_meta or env and env_cfg must be provided.
|
||||
NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
|
||||
cfg: The configuration for the policy to be created. If `cfg.pretrained_path` is
|
||||
set, the policy will be loaded with weights from that path.
|
||||
ds_meta: Dataset metadata used to infer feature shapes and types. Also provides
|
||||
statistics for normalization layers.
|
||||
env_cfg: Environment configuration used to infer feature shapes and types.
|
||||
One of `ds_meta` or `env_cfg` must be provided.
|
||||
|
||||
Returns:
|
||||
PreTrainedPolicy: _description_
|
||||
An instantiated and device-placed policy model.
|
||||
|
||||
Raises:
|
||||
ValueError: If both or neither of `ds_meta` and `env_cfg` are provided.
|
||||
NotImplementedError: If attempting to use an unsupported policy-backend
|
||||
combination (e.g., VQBeT with 'mps').
|
||||
"""
|
||||
if bool(ds_meta) == bool(env_cfg):
|
||||
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
|
||||
|
||||
@@ -37,11 +37,25 @@ from lerobot.processor import (
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
|
||||
class Pi0NewLineProcessor(ComplementaryDataProcessorStep):
|
||||
"""Add a new line to the end of the task if it doesn't have one.
|
||||
This is required for the PaliGemma tokenizer.
|
||||
"""
|
||||
Ensures that the task description string ends with a newline character.
|
||||
|
||||
This processing step is required for compatibility with the PaliGemma tokenizer,
|
||||
which expects a newline at the end of the text prompt. It handles both single
|
||||
strings and lists of strings for the 'task' key in complementary data.
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
"""
|
||||
Adds a newline to the 'task' field if it doesn't already have one.
|
||||
|
||||
Args:
|
||||
complementary_data: A dictionary that may contain a 'task' key with a
|
||||
string or list of strings.
|
||||
|
||||
Returns:
|
||||
A new dictionary with the modified 'task' field.
|
||||
"""
|
||||
if "task" not in complementary_data:
|
||||
return complementary_data
|
||||
|
||||
@@ -64,6 +78,15 @@ class Pi0NewLineProcessor(ComplementaryDataProcessorStep):
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""
|
||||
This step does not alter the feature definitions.
|
||||
|
||||
Args:
|
||||
features: The input feature dictionary.
|
||||
|
||||
Returns:
|
||||
The unchanged feature dictionary.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
@@ -73,6 +96,30 @@ def make_pi0_pre_post_processors(
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the PI0 policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Appending a newline character to the task description for tokenizer compatibility.
|
||||
5. Tokenizing the text prompt using the PaliGemma tokenizer.
|
||||
6. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the PI0 policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
@@ -30,11 +30,33 @@ from lerobot.processor import (
|
||||
|
||||
|
||||
def make_pi0fast_pre_post_processors(
|
||||
config: PI0Config,
|
||||
config: PI0FASTConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the PI0Fast policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the PI0Fast policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
|
||||
@@ -36,6 +36,28 @@ def make_sac_pre_post_processors(
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the SAC policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the SAC policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
|
||||
@@ -31,6 +31,26 @@ def make_classifier_processor(
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the reward classifier.
|
||||
|
||||
The pre-processing pipeline prepares input data for the classifier by:
|
||||
1. Normalizing both input and output features based on dataset statistics.
|
||||
2. Moving the data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the classifier's output by:
|
||||
1. Moving the data to the CPU.
|
||||
2. Applying an identity step, as no unnormalization is needed for the output logits.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the RewardClassifier.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
|
||||
@@ -39,6 +39,30 @@ def make_smolvla_pre_post_processors(
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the SmolVLA policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Ensuring the language task description ends with a newline character.
|
||||
5. Tokenizing the language task description.
|
||||
6. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output actions to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the SmolVLA policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
@@ -83,7 +107,13 @@ def make_smolvla_pre_post_processors(
|
||||
|
||||
@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
|
||||
class SmolVLANewLineProcessor(ComplementaryDataProcessorStep):
|
||||
"""Add a new line to the end of the task if it doesn't have one."""
|
||||
"""
|
||||
A processor step that ensures the 'task' description ends with a newline character.
|
||||
|
||||
This step is necessary for certain tokenizers (e.g., PaliGemma) that expect a
|
||||
newline at the end of the prompt. It handles both single string tasks and lists
|
||||
of string tasks.
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
if "task" not in complementary_data:
|
||||
|
||||
@@ -35,6 +35,28 @@ def make_tdmpc_pre_post_processors(
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the TDMPC policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the TDMPC policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
|
||||
@@ -36,6 +36,28 @@ def make_vqbet_pre_post_processors(
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the VQ-BeT policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features, allowing customization to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the VQ-BeT policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -11,6 +13,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script defines processor steps for adding a batch dimension to various components of an environment transition.
|
||||
|
||||
These steps are designed to process actions, observations, and complementary data, making them suitable for batch processing by adding a leading dimension. This is a common requirement before feeding data into a neural network model.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from torch import Tensor
|
||||
@@ -31,24 +40,63 @@ from .pipeline import (
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_action")
|
||||
class AddBatchDimensionActionStep(ActionProcessorStep):
|
||||
"""Process action component in-place, adding batch dimension if needed."""
|
||||
"""
|
||||
Processor step to add a batch dimension to a 1D tensor action.
|
||||
|
||||
def action(self, action):
|
||||
This is useful for creating a batch of size 1 from a single action sample.
|
||||
"""
|
||||
|
||||
def action(self, action: Tensor) -> Tensor:
|
||||
"""
|
||||
Adds a batch dimension to the action if it's a 1D tensor.
|
||||
|
||||
Args:
|
||||
action: The action tensor.
|
||||
|
||||
Returns:
|
||||
The action tensor with an added batch dimension.
|
||||
"""
|
||||
if not isinstance(action, Tensor) or action.dim() != 1:
|
||||
return action
|
||||
|
||||
return action.unsqueeze(0)
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""
|
||||
Returns the input features unchanged.
|
||||
|
||||
Adding a batch dimension does not alter the feature definition.
|
||||
|
||||
Args:
|
||||
features: A dictionary of policy features.
|
||||
|
||||
Returns:
|
||||
The original dictionary of policy features.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_observation")
|
||||
class AddBatchDimensionObservationStep(ObservationProcessorStep):
|
||||
"""Process observation component in-place, adding batch dimensions where needed."""
|
||||
"""
|
||||
Processor step to add a batch dimension to observations.
|
||||
|
||||
def observation(self, observation):
|
||||
It handles different types of observations:
|
||||
- State vectors (1D tensors).
|
||||
- Single images (3D tensors).
|
||||
- Dictionaries of multiple images (3D tensors).
|
||||
"""
|
||||
|
||||
def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""
|
||||
Adds a batch dimension to tensor-based observations in the observation dictionary.
|
||||
|
||||
Args:
|
||||
observation: The observation dictionary.
|
||||
|
||||
Returns:
|
||||
The observation dictionary with batch dimensions added to tensors.
|
||||
"""
|
||||
# Process state observations - add batch dim if 1D
|
||||
for state_key in [OBS_STATE, OBS_ENV_STATE]:
|
||||
if state_key in observation:
|
||||
@@ -69,15 +117,41 @@ class AddBatchDimensionObservationStep(ObservationProcessorStep):
|
||||
return observation
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""
|
||||
Returns the input features unchanged.
|
||||
|
||||
Adding a batch dimension does not alter the feature definition.
|
||||
|
||||
Args:
|
||||
features: A dictionary of policy features.
|
||||
|
||||
Returns:
|
||||
The original dictionary of policy features.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data")
|
||||
class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
|
||||
"""Process complementary data in-place, handling task field batching."""
|
||||
"""
|
||||
Processor step to add a batch dimension to complementary data fields.
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
Handles specific keys like 'task', 'index', and 'task_index' to make them batched.
|
||||
- 'task' (str) is wrapped in a list.
|
||||
- 'index' and 'task_index' (0D tensors) get a batch dimension.
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
"""
|
||||
Adds a batch dimension to specific fields in the complementary data dictionary.
|
||||
|
||||
Args:
|
||||
complementary_data: The complementary data dictionary.
|
||||
|
||||
Returns:
|
||||
The complementary data dictionary with batch dimensions added.
|
||||
"""
|
||||
# Process task field - wrap string in list to add batch dimension
|
||||
if "task" in complementary_data:
|
||||
task_value = complementary_data["task"]
|
||||
@@ -98,44 +172,33 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
|
||||
return complementary_data
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""
|
||||
Returns the input features unchanged.
|
||||
|
||||
Adding a batch dimension does not alter the feature definition.
|
||||
|
||||
Args:
|
||||
features: A dictionary of policy features.
|
||||
|
||||
Returns:
|
||||
The original dictionary of policy features.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor")
|
||||
class AddBatchDimensionProcessorStep(ProcessorStep):
|
||||
"""Processor that adds batch dimensions to observations and actions when needed.
|
||||
"""
|
||||
A composite processor step that adds a batch dimension to the entire environment transition.
|
||||
|
||||
This processor ensures that observations and actions have proper batch dimensions for model processing:
|
||||
This step combines individual processors for actions, observations, and complementary data
|
||||
to create a batched transition (batch size 1) from a single-instance transition.
|
||||
|
||||
- For state observations (observation.state, observation.environment_state):
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
|
||||
|
||||
- For image observations (observation.image, observation.images.*):
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C)
|
||||
|
||||
- For actions:
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
|
||||
|
||||
- For task field in complementary data:
|
||||
Wraps string task in a list to add batch dimension
|
||||
(task must be a string or list of strings)
|
||||
|
||||
This is useful when processing single transitions that need to be batched for
|
||||
model inference or when converting from unbatched environment outputs to
|
||||
batched model inputs.
|
||||
|
||||
The processor only modifies tensors that need batching and leaves already
|
||||
batched tensors unchanged.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# State: (7,) -> (1, 7)
|
||||
# Image: (224, 224, 3) -> (1, 224, 224, 3)
|
||||
# Action: (4,) -> (1, 4)
|
||||
# Task: "pick_cube" -> ["pick_cube"]
|
||||
# Already batched: (1, 7) -> (1, 7) [unchanged]
|
||||
```
|
||||
Attributes:
|
||||
to_batch_action_processor: Processor for the action component.
|
||||
to_batch_observation_processor: Processor for the observation component.
|
||||
to_batch_complementary_data_processor: Processor for the complementary data component.
|
||||
"""
|
||||
|
||||
to_batch_action_processor: AddBatchDimensionActionStep = field(
|
||||
@@ -149,11 +212,31 @@ class AddBatchDimensionProcessorStep(ProcessorStep):
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Applies the batching process to all relevant parts of an environment transition.
|
||||
|
||||
Args:
|
||||
transition: The environment transition to process.
|
||||
|
||||
Returns:
|
||||
The environment transition with a batch dimension added.
|
||||
"""
|
||||
transition = self.to_batch_action_processor(transition)
|
||||
transition = self.to_batch_observation_processor(transition)
|
||||
transition = self.to_batch_complementary_data_processor(transition)
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""
|
||||
Returns the input features unchanged.
|
||||
|
||||
Adding a batch dimension does not alter the feature definition.
|
||||
|
||||
Args:
|
||||
features: A dictionary of policy features.
|
||||
|
||||
Returns:
|
||||
The original dictionary of policy features.
|
||||
"""
|
||||
# NOTE: We ignore the batch dimension when transforming features
|
||||
return features
|
||||
|
||||
@@ -44,12 +44,12 @@ def to_tensor(
|
||||
different input types appropriately.
|
||||
|
||||
Args:
|
||||
value: Input value to convert (tensor, array, scalar, sequence, etc.)
|
||||
value: Input value to convert (tensor, array, scalar, sequence, etc.).
|
||||
dtype: Target tensor dtype. If None, preserves original dtype.
|
||||
device: Target device for the tensor.
|
||||
|
||||
Returns:
|
||||
PyTorch tensor.
|
||||
A PyTorch tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If the input type is not supported.
|
||||
@@ -59,7 +59,7 @@ def to_tensor(
|
||||
|
||||
@to_tensor.register(torch.Tensor)
|
||||
def _(value: torch.Tensor, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle existing PyTorch tensors."""
|
||||
"""Handle conversion for existing PyTorch tensors."""
|
||||
if dtype is not None:
|
||||
value = value.to(dtype=dtype)
|
||||
if device is not None:
|
||||
@@ -75,17 +75,17 @@ def _(
|
||||
device=None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Handle numpy arrays."""
|
||||
# Check for numpy scalars (0-dimensional arrays) and treat them as scalars
|
||||
"""Handle conversion for numpy arrays."""
|
||||
# Check for numpy scalars (0-dimensional arrays) and treat them as scalars.
|
||||
if value.ndim == 0:
|
||||
# Numpy scalars should be converted to 0-dimensional tensors
|
||||
# Numpy scalars should be converted to 0-dimensional tensors.
|
||||
scalar_value = value.item()
|
||||
return torch.tensor(scalar_value, dtype=dtype, device=device)
|
||||
|
||||
# Create tensor from numpy array (torch.from_numpy handles contiguity automatically)
|
||||
# Create tensor from numpy array.
|
||||
tensor = torch.from_numpy(value)
|
||||
|
||||
# Apply dtype conversion if specified
|
||||
# Apply dtype and device conversion if specified.
|
||||
if dtype is not None:
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
if device is not None:
|
||||
@@ -99,20 +99,20 @@ def _(
|
||||
@to_tensor.register(np.integer)
|
||||
@to_tensor.register(np.floating)
|
||||
def _(value, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle scalar values including numpy scalars."""
|
||||
"""Handle conversion for scalar values including numpy scalars."""
|
||||
return torch.tensor(value, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@to_tensor.register(list)
|
||||
@to_tensor.register(tuple)
|
||||
def _(value: Sequence, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle sequences (lists, tuples)."""
|
||||
"""Handle conversion for sequences (lists, tuples)."""
|
||||
return torch.tensor(value, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@to_tensor.register(dict)
|
||||
def _(value: dict, *, device=None, **kwargs) -> dict:
|
||||
"""Handle dictionaries by recursively converting values to tensors."""
|
||||
"""Handle conversion for dictionaries by recursively converting their values to tensors."""
|
||||
if not value:
|
||||
return {}
|
||||
|
||||
@@ -122,7 +122,7 @@ def _(value: dict, *, device=None, **kwargs) -> dict:
|
||||
continue
|
||||
|
||||
if isinstance(sub_value, dict):
|
||||
# Recursively process nested dictionaries
|
||||
# Recursively process nested dictionaries.
|
||||
result[key] = to_tensor(
|
||||
sub_value,
|
||||
device=device,
|
||||
@@ -130,7 +130,7 @@ def _(value: dict, *, device=None, **kwargs) -> dict:
|
||||
)
|
||||
continue
|
||||
|
||||
# Convert individual values to tensors
|
||||
# Convert individual values to tensors.
|
||||
result[key] = to_tensor(
|
||||
sub_value,
|
||||
device=device,
|
||||
@@ -140,17 +140,45 @@ def _(value: dict, *, device=None, **kwargs) -> dict:
|
||||
|
||||
|
||||
def _from_tensor(x: torch.Tensor | Any) -> np.ndarray | float | int | Any:
|
||||
"""Convert tensor to numpy/scalar if needed."""
|
||||
"""
|
||||
Convert a PyTorch tensor to a numpy array or scalar if applicable.
|
||||
|
||||
If the input is not a tensor, it is returned unchanged.
|
||||
|
||||
Args:
|
||||
x: The input, which can be a tensor or any other type.
|
||||
|
||||
Returns:
|
||||
A numpy array, a scalar, or the original input.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.item() if x.numel() == 1 else x.detach().cpu().numpy()
|
||||
return x
|
||||
|
||||
|
||||
def _is_image(arr: Any) -> bool:
|
||||
"""
|
||||
Check if a given array is likely an image (uint8, 3D).
|
||||
|
||||
Args:
|
||||
arr: The array to check.
|
||||
|
||||
Returns:
|
||||
True if the array matches the image criteria, False otherwise.
|
||||
"""
|
||||
return isinstance(arr, np.ndarray) and arr.dtype == np.uint8 and arr.ndim == 3
|
||||
|
||||
|
||||
def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""
|
||||
Separate an observation dictionary into state and image components.
|
||||
|
||||
Args:
|
||||
obs: The observation dictionary.
|
||||
|
||||
Returns:
|
||||
A tuple containing two dictionaries: one for state and one for images.
|
||||
"""
|
||||
state, images = {}, {}
|
||||
for k, v in obs.items():
|
||||
if "image" in k.lower() or _is_image(v):
|
||||
@@ -160,13 +188,21 @@ def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any],
|
||||
return state, images
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Private Helper Functions (Common Logic)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract complementary data (pad flags, task, index, task_index)."""
|
||||
"""
|
||||
Extract complementary data from a batch dictionary.
|
||||
|
||||
This includes padding flags, task description, and indices.
|
||||
|
||||
Args:
|
||||
batch: The batch dictionary.
|
||||
|
||||
Returns:
|
||||
A dictionary with the extracted complementary data.
|
||||
"""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
@@ -176,7 +212,16 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
def _merge_transitions(base: EnvTransition, other: EnvTransition) -> EnvTransition:
|
||||
"""Merge two transitions, with other taking precedence."""
|
||||
"""
|
||||
Merge two transitions, with the second one taking precedence in case of conflicts.
|
||||
|
||||
Args:
|
||||
base: The base transition.
|
||||
other: The transition to merge, which will overwrite base values.
|
||||
|
||||
Returns:
|
||||
The merged transition dictionary.
|
||||
"""
|
||||
out = deepcopy(base)
|
||||
|
||||
for key in (
|
||||
@@ -194,9 +239,7 @@ def _merge_transitions(base: EnvTransition, other: EnvTransition) -> EnvTransiti
|
||||
return out
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Core Conversion Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def create_transition(
|
||||
@@ -208,7 +251,8 @@ def create_transition(
|
||||
info: dict[str, Any] | None = None,
|
||||
complementary_data: dict[str, Any] | None = None,
|
||||
) -> EnvTransition:
|
||||
"""Create an EnvTransition with sensible defaults.
|
||||
"""
|
||||
Create an `EnvTransition` dictionary with sensible defaults.
|
||||
|
||||
Args:
|
||||
observation: Observation dictionary.
|
||||
@@ -220,7 +264,7 @@ def create_transition(
|
||||
complementary_data: Complementary data dictionary.
|
||||
|
||||
Returns:
|
||||
Complete EnvTransition dictionary.
|
||||
A complete `EnvTransition` dictionary.
|
||||
"""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
@@ -233,9 +277,19 @@ def create_transition(
|
||||
}
|
||||
|
||||
|
||||
def action_to_transition(action: dict[str, Any]) -> EnvTransition: # action_to_transition
|
||||
def action_to_transition(action: dict[str, Any]) -> EnvTransition:
|
||||
"""
|
||||
Convert a raw teleop action dict into an EnvTransition under the ACTION TransitionKey.
|
||||
Convert a raw action dictionary into a standardized `EnvTransition`.
|
||||
|
||||
The keys in the action dictionary are prefixed with "action." and stored under
|
||||
the `ACTION` key in the transition. Values are converted to tensors, except for
|
||||
special types like `Rotation`.
|
||||
|
||||
Args:
|
||||
action: The raw action dictionary from a teleoperation device or controller.
|
||||
|
||||
Returns:
|
||||
An `EnvTransition` containing the formatted action.
|
||||
"""
|
||||
act_dict: dict[str, Any] = {}
|
||||
for k, v in action.items():
|
||||
@@ -250,10 +304,19 @@ def action_to_transition(action: dict[str, Any]) -> EnvTransition: # action_to_
|
||||
return create_transition(observation={}, action=act_dict)
|
||||
|
||||
|
||||
# TODO(Adil, Pepijn): Overtime we can maybe add these converters to pipeline.py itself
|
||||
def observation_to_transition(observation: dict[str, Any]) -> EnvTransition:
|
||||
"""
|
||||
Convert a raw robot observation dict into an EnvTransition under the OBSERVATION TransitionKey.
|
||||
Convert a raw robot observation dictionary into a standardized `EnvTransition`.
|
||||
|
||||
The observation is split into state and image components. State keys are prefixed
|
||||
with "observation.state." and image keys with "observation.images.". The result is
|
||||
stored under the `OBSERVATION` key in the transition.
|
||||
|
||||
Args:
|
||||
observation: The raw observation dictionary from the environment.
|
||||
|
||||
Returns:
|
||||
An `EnvTransition` containing the formatted observation.
|
||||
"""
|
||||
state, images = _split_obs_to_state_and_images(observation)
|
||||
|
||||
@@ -270,7 +333,16 @@ def observation_to_transition(observation: dict[str, Any]) -> EnvTransition:
|
||||
|
||||
def transition_to_robot_action(transition: EnvTransition) -> dict[str, Any]:
|
||||
"""
|
||||
Converts a EnvTransition under the ACTION TransitionKey to a dict with keys ending in '.pos' for raw robot actions.
|
||||
Extract a raw action dictionary for a robot from an `EnvTransition`.
|
||||
|
||||
This function searches for keys in the format "action.*.pos" or "action.*.vel"
|
||||
and converts them into a flat dictionary suitable for sending to a robot controller.
|
||||
|
||||
Args:
|
||||
transition: The `EnvTransition` containing the action.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the raw robot action.
|
||||
"""
|
||||
out: dict[str, Any] = {}
|
||||
action_dict = transition.get(TransitionKey.ACTION) or {}
|
||||
@@ -287,13 +359,21 @@ def transition_to_robot_action(transition: EnvTransition) -> dict[str, Any]:
|
||||
|
||||
|
||||
def merge_transitions(transitions: Sequence[EnvTransition] | EnvTransition) -> EnvTransition:
|
||||
"""Merge multiple transitions or return single transition.
|
||||
"""
|
||||
Merge a sequence of transitions into a single one.
|
||||
|
||||
If a single transition is provided, it is returned as is. For a sequence,
|
||||
transitions are merged sequentially, with later transitions in the sequence
|
||||
overwriting earlier ones.
|
||||
|
||||
Args:
|
||||
transitions: Either a single transition or iterable of transitions.
|
||||
transitions: A single transition or a sequence of them.
|
||||
|
||||
Returns:
|
||||
Merged EnvTransition.
|
||||
A single merged `EnvTransition`.
|
||||
|
||||
Raises:
|
||||
ValueError: If an empty sequence of transitions is provided.
|
||||
"""
|
||||
|
||||
if not isinstance(transitions, Sequence): # Single transition
|
||||
@@ -312,26 +392,18 @@ def merge_transitions(transitions: Sequence[EnvTransition] | EnvTransition) -> E
|
||||
def transition_to_dataset_frame(
|
||||
transitions_or_transition: EnvTransition | Sequence[EnvTransition], features: dict[str, dict]
|
||||
) -> dict[str, Any]:
|
||||
"""Convert a single EnvTransition or an iterable of them into a flat, dataset-friendly dictionary for training or evaluation.
|
||||
"""
|
||||
Convert one or more transitions into a flat dictionary suitable for a dataset frame.
|
||||
|
||||
Processes transitions according to the provided feature specification and returns
|
||||
data in the format expected by machine learning models and datasets.
|
||||
This function processes `EnvTransition` objects according to a feature
|
||||
specification, producing a format ready for training or evaluation.
|
||||
|
||||
Args:
|
||||
transitions_or_transition: Either a single EnvTransition dict or an iterable of them
|
||||
(which will be merged using merge_transitions).
|
||||
features: Feature specification dictionary with the following structure:
|
||||
- 'action': dict with 'names': list of action feature names
|
||||
- 'observation.state': dict with 'names': list of state feature names
|
||||
- keys starting with 'observation.images.' are passed through as-is
|
||||
transitions_or_transition: A single `EnvTransition` or a sequence to be merged.
|
||||
features: A feature specification dictionary.
|
||||
|
||||
Returns:
|
||||
Flat dictionary containing:
|
||||
- numpy arrays for "observation.state" and "action" (vectorized from feature names)
|
||||
- any image tensors defined in features (passed through unchanged)
|
||||
- next.{reward,done,truncated} scalar values
|
||||
- info dict
|
||||
- *_is_pad flags and task from complementary_data
|
||||
A flat dictionary representing a single frame of data for a dataset.
|
||||
"""
|
||||
action_names = features.get(ACTION, {}).get("names", [])
|
||||
obs_state_names = features.get(OBS_STATE, {}).get("names", [])
|
||||
@@ -342,25 +414,25 @@ def transition_to_dataset_frame(
|
||||
act = tr.get(TransitionKey.ACTION, {}) or {}
|
||||
batch: dict[str, Any] = {}
|
||||
|
||||
# Images passthrough
|
||||
# Passthrough for images.
|
||||
for k in image_keys:
|
||||
if k in obs:
|
||||
batch[k] = obs[k]
|
||||
|
||||
# Observation.state vector
|
||||
# Create observation.state vector.
|
||||
if obs_state_names:
|
||||
vals = [_from_tensor(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names]
|
||||
batch[OBS_STATE] = np.asarray(vals, dtype=np.float32)
|
||||
|
||||
# Action vector
|
||||
# Create action vector.
|
||||
if action_names:
|
||||
vals = [_from_tensor(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names]
|
||||
batch[ACTION] = np.asarray(vals, dtype=np.float32)
|
||||
|
||||
# Add transition metadata
|
||||
# Add transition metadata.
|
||||
if tr.get(TransitionKey.REWARD) is not None:
|
||||
reward_val = _from_tensor(tr[TransitionKey.REWARD])
|
||||
# Check if features expect array format, otherwise keep as scalar
|
||||
# Check if features expect array format, otherwise keep as scalar.
|
||||
if REWARD in features and features[REWARD].get("shape") == (1,):
|
||||
batch[REWARD] = np.array([reward_val], dtype=np.float32)
|
||||
else:
|
||||
@@ -380,14 +452,14 @@ def transition_to_dataset_frame(
|
||||
else:
|
||||
batch[TRUNCATED] = truncated_val
|
||||
|
||||
# Complementary data flags and task
|
||||
# Add complementary data flags and task.
|
||||
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
if comp:
|
||||
# pad flags
|
||||
# Padding flags.
|
||||
for k, v in comp.items():
|
||||
if k.endswith("_is_pad"):
|
||||
batch[k] = v
|
||||
# task label
|
||||
# Task label.
|
||||
if comp.get("task") is not None:
|
||||
batch["task"] = comp["task"]
|
||||
|
||||
@@ -395,36 +467,27 @@ def transition_to_dataset_frame(
|
||||
|
||||
|
||||
def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
|
||||
"""Convert a batch dict coming from LeRobot replay/dataset code into an EnvTransition dictionary.
|
||||
"""
|
||||
Convert a batch dictionary from a dataset/dataloader into an `EnvTransition`.
|
||||
|
||||
The function maps well known keys to the EnvTransition structure. Missing keys are
|
||||
filled with sane defaults (None or 0.0/False).
|
||||
|
||||
Keys recognised (case-sensitive):
|
||||
* "observation.*" (keys starting with "observation." are grouped into observation dict)
|
||||
* "action"
|
||||
* "next.reward"
|
||||
* "next.done"
|
||||
* "next.truncated"
|
||||
* "info"
|
||||
* "_is_pad" patterns (padding flags)
|
||||
* "task", "index", "task_index" (complementary data)
|
||||
|
||||
Additional keys are ignored so that existing dataloaders can carry extra
|
||||
metadata without breaking the processor.
|
||||
This function maps recognized keys from a batch to the `EnvTransition` structure,
|
||||
filling in missing keys with sensible defaults.
|
||||
|
||||
Args:
|
||||
batch: Batch dictionary from datasets or dataloaders containing the above keys.
|
||||
batch: A batch dictionary.
|
||||
|
||||
Returns:
|
||||
EnvTransition dictionary with properly structured transition data.
|
||||
An `EnvTransition` dictionary.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input is not a dictionary.
|
||||
"""
|
||||
|
||||
# Validate input type
|
||||
# Validate input type.
|
||||
if not isinstance(batch, dict):
|
||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
|
||||
|
||||
# Extract observation keys
|
||||
# Extract observation and complementary data keys.
|
||||
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
complementary_data = _extract_complementary_data(batch)
|
||||
|
||||
@@ -440,25 +503,16 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
|
||||
|
||||
|
||||
def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
"""Inverse of batch_to_transition. Returns a dict with canonical field names used throughout LeRobot.
|
||||
"""
|
||||
Convert an `EnvTransition` back to the canonical batch format used in LeRobot.
|
||||
|
||||
Converts an EnvTransition back to the batch format expected by datasets, dataloaders,
|
||||
and other LeRobot components.
|
||||
|
||||
Output format:
|
||||
* "action": Action data from transition
|
||||
* "next.reward": Reward value (defaults to 0.0)
|
||||
* "next.done": Done flag (defaults to False)
|
||||
* "next.truncated": Truncated flag (defaults to False)
|
||||
* "info": Info dictionary (defaults to {})
|
||||
* Flattened observation keys (e.g., "observation.state", "observation.images.cam1")
|
||||
* Complementary data fields ("task", "index", "task_index", padding flags)
|
||||
This is the inverse of `batch_to_transition`.
|
||||
|
||||
Args:
|
||||
transition: EnvTransition dictionary to convert.
|
||||
transition: The `EnvTransition` to convert.
|
||||
|
||||
Returns:
|
||||
Batch dictionary with canonical LeRobot field names suitable for dataloaders.
|
||||
A batch dictionary with canonical LeRobot field names.
|
||||
"""
|
||||
batch = {
|
||||
"action": transition.get(TransitionKey.ACTION),
|
||||
@@ -468,12 +522,12 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
"info": transition.get(TransitionKey.INFO, {}),
|
||||
}
|
||||
|
||||
# Add complementary data
|
||||
# Add complementary data.
|
||||
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if comp_data:
|
||||
batch.update(comp_data)
|
||||
|
||||
# Flatten observation dict
|
||||
# Flatten observation dictionary.
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if isinstance(observation, dict):
|
||||
batch.update(observation)
|
||||
@@ -482,4 +536,15 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
|
||||
|
||||
def identity_transition(tr: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
An identity function for transitions, returning the input unchanged.
|
||||
|
||||
Useful as a default or placeholder in processing pipelines.
|
||||
|
||||
Args:
|
||||
tr: An `EnvTransition`.
|
||||
|
||||
Returns:
|
||||
The same `EnvTransition`.
|
||||
"""
|
||||
return tr
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# !/usr/bin/env python
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@@ -28,7 +28,15 @@ from .pipeline import ActionProcessorStep, ProcessorStepRegistry
|
||||
@dataclass
|
||||
class MapTensorToDeltaActionDictStep(ActionProcessorStep):
|
||||
"""
|
||||
Map a tensor to a delta action dictionary.
|
||||
Maps a flat action tensor from a policy to a structured delta action dictionary.
|
||||
|
||||
This step is typically used after a policy outputs a continuous action vector.
|
||||
It decomposes the vector into named components for delta movements of the
|
||||
end-effector (x, y, z) and optionally the gripper.
|
||||
|
||||
Attributes:
|
||||
use_gripper: If True, assumes the 4th element of the tensor is the
|
||||
gripper action.
|
||||
"""
|
||||
|
||||
use_gripper: bool = True
|
||||
@@ -60,28 +68,17 @@ class MapTensorToDeltaActionDictStep(ActionProcessorStep):
|
||||
@dataclass
|
||||
class MapDeltaActionToRobotActionStep(ActionProcessorStep):
|
||||
"""
|
||||
Map delta actions from teleoperators (gamepad, keyboard) to robot target actions
|
||||
for use with inverse kinematics processors.
|
||||
Maps delta actions from teleoperators to robot target actions for inverse kinematics.
|
||||
|
||||
Expected input ACTION keys:
|
||||
{
|
||||
"action.delta_x": float,
|
||||
"action.delta_y": float,
|
||||
"action.delta_z": float,
|
||||
"action.gripper": float (optional),
|
||||
}
|
||||
This step converts a dictionary of delta movements (e.g., from a gamepad)
|
||||
into a target action format that includes an "enabled" flag and target
|
||||
end-effector positions. It also handles scaling and noise filtering.
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.enabled": bool,
|
||||
"action.target_x": float,
|
||||
"action.target_y": float,
|
||||
"action.target_z": float,
|
||||
"action.target_wx": float,
|
||||
"action.target_wy": float,
|
||||
"action.target_wz": float,
|
||||
"action.gripper": float,
|
||||
}
|
||||
Attributes:
|
||||
position_scale: A factor to scale the delta position inputs.
|
||||
rotation_scale: A factor to scale the delta rotation inputs (currently unused).
|
||||
noise_threshold: The magnitude below which delta inputs are considered noise
|
||||
and do not trigger an "enabled" state.
|
||||
"""
|
||||
|
||||
# Scale factors for delta movements
|
||||
|
||||
@@ -13,6 +13,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script defines a processor step for moving environment transition data to a specific torch device and casting
|
||||
its floating-point precision.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
@@ -28,12 +34,16 @@ from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
@ProcessorStepRegistry.register("device_processor")
|
||||
@dataclass
|
||||
class DeviceProcessorStep(ProcessorStep):
|
||||
"""Processes transitions by moving tensors to the specified device and optionally converting float dtypes.
|
||||
"""
|
||||
Processor step to move all tensors within an `EnvTransition` to a specified device and optionally cast their
|
||||
floating-point data type.
|
||||
|
||||
This processor ensures that all tensors in the transition are moved to the
|
||||
specified device (CPU or GPU) before they are returned. It can also convert
|
||||
floating-point tensors to a specified dtype while preserving non-float types
|
||||
(int, long, bool, etc.).
|
||||
This is crucial for preparing data for model training or inference on hardware like GPUs.
|
||||
|
||||
Attributes:
|
||||
device: The target device for tensors (e.g., "cpu", "cuda", "cuda:0").
|
||||
float_dtype: The target floating-point dtype as a string (e.g., "float32", "float16", "bfloat16").
|
||||
If None, the dtype is not changed.
|
||||
"""
|
||||
|
||||
device: str = "cpu"
|
||||
@@ -50,8 +60,15 @@ class DeviceProcessorStep(ProcessorStep):
|
||||
}
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Initializes the processor by converting string configurations to torch objects.
|
||||
|
||||
This method sets up the `torch.device`, determines if transfers can be non-blocking, and validates the
|
||||
`float_dtype` string, converting it to a `torch.dtype` object.
|
||||
"""
|
||||
self.tensor_device: torch.device = get_safe_torch_device(self.device)
|
||||
self.device = self.tensor_device.type # cuda might have changed to cuda:1
|
||||
# Update device string in case a specific GPU was selected (e.g., "cuda" -> "cuda:0")
|
||||
self.device = self.tensor_device.type
|
||||
self.non_blocking = "cuda" in str(self.device)
|
||||
|
||||
# Validate and convert float_dtype string to torch dtype
|
||||
@@ -60,27 +77,32 @@ class DeviceProcessorStep(ProcessorStep):
|
||||
raise ValueError(
|
||||
f"Invalid float_dtype '{self.float_dtype}'. Available options: {list(self.DTYPE_MAPPING.keys())}"
|
||||
)
|
||||
|
||||
self._target_float_dtype = self.DTYPE_MAPPING[self.float_dtype]
|
||||
else:
|
||||
self._target_float_dtype = None
|
||||
|
||||
def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Process a tensor by moving to device and optionally converting float dtype.
|
||||
"""
|
||||
Moves a single tensor to the target device and casts its dtype.
|
||||
|
||||
If the tensor is already on a GPU and we're configured for a GPU, it preserves
|
||||
that GPU placement (useful for multi-GPU training with Accelerate).
|
||||
Otherwise, it moves to the configured device.
|
||||
Handles multi-GPU scenarios by not moving a tensor if it's already on a different CUDA device than
|
||||
the target, which is useful when using frameworks like Accelerate.
|
||||
|
||||
Args:
|
||||
tensor: The input torch.Tensor.
|
||||
|
||||
Returns:
|
||||
The processed tensor on the correct device and with the correct dtype.
|
||||
"""
|
||||
# Determine target device
|
||||
if tensor.is_cuda and self.tensor_device.type == "cuda":
|
||||
# Both tensor and target are on GPU - preserve tensor's GPU placement
|
||||
# Both tensor and target are on GPU - preserve tensor's GPU placement.
|
||||
# This handles multi-GPU scenarios where Accelerate has already placed
|
||||
# tensors on the correct GPU for each process
|
||||
# tensors on the correct GPU for each process.
|
||||
target_device = tensor.device
|
||||
else:
|
||||
# Either tensor is on CPU, or we're configured for CPU
|
||||
# In both cases, use the configured device
|
||||
# Either tensor is on CPU, or we're configured for CPU.
|
||||
# In both cases, use the configured device.
|
||||
target_device = self.tensor_device
|
||||
|
||||
# Only move if necessary
|
||||
@@ -94,6 +116,18 @@ class DeviceProcessorStep(ProcessorStep):
|
||||
return tensor
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Applies device and dtype conversion to all tensors in an environment transition.
|
||||
|
||||
It iterates through the transition, finds all `torch.Tensor` objects (including those nested in
|
||||
dictionaries like `observation`), and processes them.
|
||||
|
||||
Args:
|
||||
transition: The input `EnvTransition` object.
|
||||
|
||||
Returns:
|
||||
A new `EnvTransition` object with all tensors moved to the target device and dtype.
|
||||
"""
|
||||
new_transition = transition.copy()
|
||||
|
||||
simple_tensor_keys = [
|
||||
@@ -108,13 +142,13 @@ class DeviceProcessorStep(ProcessorStep):
|
||||
TransitionKey.COMPLEMENTARY_DATA,
|
||||
]
|
||||
|
||||
# Process simple tensors
|
||||
# Process simple, top-level tensors
|
||||
for key in simple_tensor_keys:
|
||||
value = transition.get(key)
|
||||
if isinstance(value, torch.Tensor):
|
||||
new_transition[key] = self._process_tensor(value)
|
||||
|
||||
# Process dictionary-like tensors
|
||||
# Process tensors nested within dictionaries
|
||||
for key in dict_tensor_keys:
|
||||
data_dict = transition.get(key)
|
||||
if data_dict is not None:
|
||||
@@ -127,8 +161,24 @@ class DeviceProcessorStep(ProcessorStep):
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
"""
|
||||
Returns the serializable configuration of the processor.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the device and float_dtype settings.
|
||||
"""
|
||||
return {"device": self.device, "float_dtype": self.float_dtype}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""
|
||||
Returns the input features unchanged.
|
||||
|
||||
Device and dtype transformations do not alter the fundamental definition of the features (e.g., shape).
|
||||
|
||||
Args:
|
||||
features: A dictionary of policy features.
|
||||
|
||||
Returns:
|
||||
The original dictionary of policy features.
|
||||
"""
|
||||
return features
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#! /usr/bin/env python
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@@ -10,6 +10,9 @@
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -25,7 +28,17 @@ from .pipeline import ActionProcessorStep, ProcessorStepRegistry
|
||||
@ProcessorStepRegistry.register("torch2numpy_action_processor")
|
||||
@dataclass
|
||||
class Torch2NumpyActionProcessorStep(ActionProcessorStep):
|
||||
"""Convert PyTorch tensor actions to NumPy arrays."""
|
||||
"""
|
||||
Converts a PyTorch tensor action to a NumPy array.
|
||||
|
||||
This step is useful when the output of a policy (typically a torch.Tensor)
|
||||
needs to be passed to an environment or component that expects a NumPy array.
|
||||
|
||||
Attributes:
|
||||
squeeze_batch_dim: If True, removes the first dimension of the array
|
||||
if it is of size 1. This is useful for converting a
|
||||
batched action of size (1, D) to a single action of size (D,).
|
||||
"""
|
||||
|
||||
squeeze_batch_dim: bool = True
|
||||
|
||||
@@ -38,8 +51,8 @@ class Torch2NumpyActionProcessorStep(ActionProcessorStep):
|
||||
|
||||
numpy_action = action.detach().cpu().numpy()
|
||||
|
||||
# Remove batch dimensions but preserve action dimensions
|
||||
# Only squeeze if there's a batch dimension (first dim == 1)
|
||||
# Remove batch dimensions but preserve action dimensions.
|
||||
# Only squeeze if there's a batch dimension (first dim == 1).
|
||||
if (
|
||||
self.squeeze_batch_dim
|
||||
and numpy_action.shape
|
||||
@@ -57,7 +70,13 @@ class Torch2NumpyActionProcessorStep(ActionProcessorStep):
|
||||
@ProcessorStepRegistry.register("numpy2torch_action_processor")
|
||||
@dataclass
|
||||
class Numpy2TorchActionProcessorStep(ActionProcessorStep):
|
||||
"""Convert NumPy array action to PyTorch tensor."""
|
||||
"""
|
||||
Converts a NumPy array action to a PyTorch tensor.
|
||||
|
||||
This step is useful for converting actions from environments or hardware,
|
||||
which are often NumPy arrays, into PyTorch tensors that can be processed
|
||||
by a policy or model.
|
||||
"""
|
||||
|
||||
def action(self, action: np.ndarray) -> torch.Tensor:
|
||||
if not isinstance(action, np.ndarray):
|
||||
|
||||
@@ -1,3 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
@@ -29,21 +46,25 @@ TELEOP_ACTION_KEY = "teleop_action"
|
||||
|
||||
@runtime_checkable
|
||||
class HasTeleopEvents(Protocol):
|
||||
"""Minimal protocol for objects that provide teleoperation events.
|
||||
"""
|
||||
Minimal protocol for objects that provide teleoperation events.
|
||||
|
||||
This protocol only defines the additional get_teleop_events() method,
|
||||
avoiding duplication of the entire Teleoperator interface.
|
||||
This protocol defines the `get_teleop_events()` method, allowing processor
|
||||
steps to interact with teleoperators that support event-based controls
|
||||
(like episode termination or success flagging) without needing to know the
|
||||
teleoperator's specific class.
|
||||
"""
|
||||
|
||||
def get_teleop_events(self) -> dict[str, Any]:
|
||||
"""Get extra control events from the teleoperator.
|
||||
"""
|
||||
Get extra control events from the teleoperator.
|
||||
|
||||
Returns:
|
||||
Dictionary containing control events such as:
|
||||
- is_intervention: bool - Whether human is currently intervening
|
||||
- terminate_episode: bool - Whether to terminate the current episode
|
||||
- success: bool - Whether the episode was successful
|
||||
- rerecord_episode: bool - Whether to rerecord the episode
|
||||
A dictionary containing control events such as:
|
||||
- `is_intervention`: bool - Whether the human is currently intervening.
|
||||
- `terminate_episode`: bool - Whether to terminate the current episode.
|
||||
- `success`: bool - Whether the episode was successful.
|
||||
- `rerecord_episode`: bool - Whether to rerecord the episode.
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -53,7 +74,15 @@ TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator)
|
||||
|
||||
|
||||
def _check_teleop_with_events(teleop: Teleoperator) -> None:
|
||||
"""Runtime check that a teleoperator implements get_teleop_events."""
|
||||
"""
|
||||
Runtime check that a teleoperator implements the `HasTeleopEvents` protocol.
|
||||
|
||||
Args:
|
||||
teleop: The teleoperator instance to check.
|
||||
|
||||
Raises:
|
||||
TypeError: If the teleoperator does not have a `get_teleop_events` method.
|
||||
"""
|
||||
if not isinstance(teleop, HasTeleopEvents):
|
||||
raise TypeError(
|
||||
f"Teleoperator {type(teleop).__name__} must implement get_teleop_events() method. "
|
||||
@@ -64,11 +93,30 @@ def _check_teleop_with_events(teleop: Teleoperator) -> None:
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
|
||||
@dataclass
|
||||
class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
|
||||
"""Add teleoperator action to transition complementary data."""
|
||||
"""
|
||||
Adds the raw action from a teleoperator to the transition's complementary data.
|
||||
|
||||
This is useful for human-in-the-loop scenarios where the human's input needs to
|
||||
be available to downstream processors, for example, to override a policy's action
|
||||
during an intervention.
|
||||
|
||||
Attributes:
|
||||
teleop_device: The teleoperator instance to get the action from.
|
||||
"""
|
||||
|
||||
teleop_device: Teleoperator
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
"""
|
||||
Retrieves the teleoperator's action and adds it to the complementary data.
|
||||
|
||||
Args:
|
||||
complementary_data: The incoming complementary data dictionary.
|
||||
|
||||
Returns:
|
||||
A new dictionary with the teleoperator action added under the
|
||||
`teleop_action` key.
|
||||
"""
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
|
||||
return new_complementary_data
|
||||
@@ -80,26 +128,33 @@ class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_info")
|
||||
@dataclass
|
||||
class AddTeleopEventsAsInfoStep(InfoProcessorStep):
|
||||
"""Add teleoperator control events to transition info.
|
||||
"""
|
||||
Adds teleoperator control events (e.g., terminate, success) to the transition's info.
|
||||
|
||||
This processor step extracts control events from teleoperators that support
|
||||
event-based interaction (intervention detection, episode termination, etc.).
|
||||
This step extracts control events from teleoperators that support event-based
|
||||
interaction, making these signals available to other parts of the system.
|
||||
|
||||
Works with any teleoperator that inherits from Teleoperator and implements the
|
||||
get_teleop_events() method, including custom user-defined teleoperators.
|
||||
|
||||
Built-in compatible teleoperators:
|
||||
- GamepadTeleop: Uses gamepad buttons for control events
|
||||
- KeyboardEndEffectorTeleop: Uses keyboard keys for control events
|
||||
Attributes:
|
||||
teleop_device: An instance of a teleoperator that implements the
|
||||
`HasTeleopEvents` protocol.
|
||||
"""
|
||||
|
||||
teleop_device: TeleopWithEvents
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate that the teleoperator supports events."""
|
||||
"""Validates that the provided teleoperator supports events after initialization."""
|
||||
_check_teleop_with_events(self.teleop_device)
|
||||
|
||||
def info(self, info: dict) -> dict:
|
||||
"""
|
||||
Retrieves teleoperator events and updates the info dictionary.
|
||||
|
||||
Args:
|
||||
info: The incoming info dictionary.
|
||||
|
||||
Returns:
|
||||
A new dictionary including the teleoperator events.
|
||||
"""
|
||||
new_info = dict(info)
|
||||
|
||||
teleop_events = self.teleop_device.get_teleop_events()
|
||||
@@ -113,12 +168,32 @@ class AddTeleopEventsAsInfoStep(InfoProcessorStep):
|
||||
@ProcessorStepRegistry.register("image_crop_resize_processor")
|
||||
@dataclass
|
||||
class ImageCropResizeProcessorStep(ObservationProcessorStep):
|
||||
"""Crop and resize image observations."""
|
||||
"""
|
||||
Crops and/or resizes image observations.
|
||||
|
||||
This step iterates through all image keys in an observation dictionary and applies
|
||||
the specified transformations. It handles device placement, moving tensors to the
|
||||
CPU if necessary for operations not supported on certain accelerators like MPS.
|
||||
|
||||
Attributes:
|
||||
crop_params_dict: A dictionary mapping image keys to cropping parameters
|
||||
(top, left, height, width).
|
||||
resize_size: A tuple (height, width) to resize all images to.
|
||||
"""
|
||||
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
"""
|
||||
Applies cropping and resizing to all images in the observation dictionary.
|
||||
|
||||
Args:
|
||||
observation: The observation dictionary, potentially containing image tensors.
|
||||
|
||||
Returns:
|
||||
A new observation dictionary with transformed images.
|
||||
"""
|
||||
if self.resize_size is None and not self.crop_params_dict:
|
||||
return observation
|
||||
|
||||
@@ -146,12 +221,27 @@ class ImageCropResizeProcessorStep(ObservationProcessorStep):
|
||||
return new_observation
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary with the crop parameters and resize dimensions.
|
||||
"""
|
||||
return {
|
||||
"crop_params_dict": self.crop_params_dict,
|
||||
"resize_size": self.resize_size,
|
||||
}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""
|
||||
Updates the image feature shapes in the policy features dictionary if resizing is applied.
|
||||
|
||||
Args:
|
||||
features: The policy features dictionary.
|
||||
|
||||
Returns:
|
||||
The updated policy features dictionary with new image shapes.
|
||||
"""
|
||||
if self.resize_size is None:
|
||||
return features
|
||||
for key in features:
|
||||
@@ -163,12 +253,27 @@ class ImageCropResizeProcessorStep(ObservationProcessorStep):
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("time_limit_processor")
|
||||
class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
"""Track episode steps and enforce time limits."""
|
||||
"""
|
||||
Tracks episode steps and enforces a time limit by truncating the episode.
|
||||
|
||||
Attributes:
|
||||
max_episode_steps: The maximum number of steps allowed per episode.
|
||||
current_step: The current step count for the active episode.
|
||||
"""
|
||||
|
||||
max_episode_steps: int
|
||||
current_step: int = 0
|
||||
|
||||
def truncated(self, truncated):
|
||||
def truncated(self, truncated: bool) -> bool:
|
||||
"""
|
||||
Increments the step counter and sets the truncated flag if the time limit is reached.
|
||||
|
||||
Args:
|
||||
truncated: The incoming truncated flag.
|
||||
|
||||
Returns:
|
||||
True if the episode step limit is reached, otherwise the incoming value.
|
||||
"""
|
||||
self.current_step += 1
|
||||
if self.current_step >= self.max_episode_steps:
|
||||
truncated = True
|
||||
@@ -176,11 +281,18 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
return truncated
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the `max_episode_steps`.
|
||||
"""
|
||||
return {
|
||||
"max_episode_steps": self.max_episode_steps,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the step counter, typically called at the start of a new episode."""
|
||||
self.current_step = 0
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
@@ -190,13 +302,31 @@ class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
"""Apply penalty for inappropriate gripper usage."""
|
||||
"""
|
||||
Applies a penalty for inefficient gripper usage.
|
||||
|
||||
This step penalizes actions that attempt to close an already closed gripper or
|
||||
open an already open one, based on position thresholds.
|
||||
|
||||
Attributes:
|
||||
penalty: The negative reward value to apply.
|
||||
max_gripper_pos: The maximum position value for the gripper, used for normalization.
|
||||
"""
|
||||
|
||||
penalty: float = -0.01
|
||||
max_gripper_pos: float = 30.0
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
"""Calculate gripper penalty and add to complementary data."""
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
"""
|
||||
Calculates the gripper penalty and adds it to the complementary data.
|
||||
|
||||
Args:
|
||||
complementary_data: The incoming complementary data, which should contain
|
||||
raw joint positions.
|
||||
|
||||
Returns:
|
||||
A new complementary data dictionary with the `discrete_penalty` key added.
|
||||
"""
|
||||
action = self.transition.get(TransitionKey.ACTION)
|
||||
|
||||
current_gripper_pos = complementary_data.get("raw_joint_positions", None).get(GRIPPER_KEY, None)
|
||||
@@ -223,14 +353,20 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
return new_complementary_data
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the penalty value and max gripper position.
|
||||
"""
|
||||
return {
|
||||
"penalty": self.penalty,
|
||||
"max_gripper_pos": self.max_gripper_pos,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the processor state."""
|
||||
self.last_gripper_state = None
|
||||
"""Resets the processor's internal state."""
|
||||
pass
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -239,12 +375,33 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("intervention_action_processor")
|
||||
class InterventionActionProcessorStep(ProcessorStep):
|
||||
"""Handle human intervention actions and episode termination."""
|
||||
"""
|
||||
Handles human intervention, overriding policy actions and managing episode termination.
|
||||
|
||||
When an intervention is detected (via teleoperator events in the `info` dict),
|
||||
this step replaces the policy's action with the human's teleoperated action.
|
||||
It also processes signals to terminate the episode or flag success.
|
||||
|
||||
Attributes:
|
||||
use_gripper: Whether to include the gripper in the teleoperated action.
|
||||
terminate_on_success: If True, automatically sets the `done` flag when a
|
||||
`success` event is received.
|
||||
"""
|
||||
|
||||
use_gripper: bool = False
|
||||
terminate_on_success: bool = True
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Processes the transition to handle interventions.
|
||||
|
||||
Args:
|
||||
transition: The incoming environment transition.
|
||||
|
||||
Returns:
|
||||
The modified transition, potentially with an overridden action, updated
|
||||
reward, and termination status.
|
||||
"""
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return transition
|
||||
@@ -300,6 +457,12 @@ class InterventionActionProcessorStep(ProcessorStep):
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the step's configuration attributes.
|
||||
"""
|
||||
return {
|
||||
"use_gripper": self.use_gripper,
|
||||
"terminate_on_success": self.terminate_on_success,
|
||||
@@ -312,7 +475,20 @@ class InterventionActionProcessorStep(ProcessorStep):
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("reward_classifier_processor")
|
||||
class RewardClassifierProcessorStep(ProcessorStep):
|
||||
"""Apply reward classification to image observations."""
|
||||
"""
|
||||
Applies a pretrained reward classifier to image observations to predict success.
|
||||
|
||||
This step uses a model to determine if the current state is successful, updating
|
||||
the reward and potentially terminating the episode.
|
||||
|
||||
Attributes:
|
||||
pretrained_path: Path to the pretrained reward classifier model.
|
||||
device: The device to run the classifier on.
|
||||
success_threshold: The probability threshold to consider a prediction as successful.
|
||||
success_reward: The reward value to assign on success.
|
||||
terminate_on_success: If True, terminates the episode upon successful classification.
|
||||
reward_classifier: The loaded classifier model instance.
|
||||
"""
|
||||
|
||||
pretrained_path: str | None = None
|
||||
device: str = "cpu"
|
||||
@@ -323,7 +499,7 @@ class RewardClassifierProcessorStep(ProcessorStep):
|
||||
reward_classifier: Any = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize the reward classifier after dataclass initialization."""
|
||||
"""Initializes the reward classifier model after the dataclass is created."""
|
||||
if self.pretrained_path is not None:
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
@@ -332,6 +508,16 @@ class RewardClassifierProcessorStep(ProcessorStep):
|
||||
self.reward_classifier.eval()
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Processes a transition, applying the reward classifier to its image observations.
|
||||
|
||||
Args:
|
||||
transition: The incoming environment transition.
|
||||
|
||||
Returns:
|
||||
The modified transition with an updated reward and done flag based on the
|
||||
classifier's prediction.
|
||||
"""
|
||||
new_transition = transition.copy()
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None or self.reward_classifier is None:
|
||||
@@ -371,6 +557,12 @@ class RewardClassifierProcessorStep(ProcessorStep):
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the step's configuration attributes.
|
||||
"""
|
||||
return {
|
||||
"device": self.device,
|
||||
"success_threshold": self.success_threshold,
|
||||
|
||||
@@ -1,3 +1,19 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
@@ -15,17 +31,42 @@ from lerobot.robots import Robot
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("joint_velocity_processor")
|
||||
class JointVelocityProcessorStep(ObservationProcessorStep):
|
||||
"""Add joint velocity information to observations."""
|
||||
"""
|
||||
Calculates and appends joint velocity information to the observation state.
|
||||
|
||||
This step computes the velocity of each joint by calculating the finite
|
||||
difference between the current and the last observed joint positions. The
|
||||
resulting velocity vector is then concatenated to the original state vector.
|
||||
|
||||
Attributes:
|
||||
dt: The time step (delta time) in seconds between observations, used for
|
||||
calculating velocity.
|
||||
last_joint_positions: Stores the joint positions from the previous step
|
||||
to enable velocity calculation.
|
||||
"""
|
||||
|
||||
dt: float = 0.1
|
||||
|
||||
last_joint_positions: torch.Tensor | None = None
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
"""
|
||||
Computes joint velocities and adds them to the observation state.
|
||||
|
||||
Args:
|
||||
observation: The input observation dictionary, expected to contain
|
||||
an `observation.state` key with joint positions.
|
||||
|
||||
Returns:
|
||||
A new observation dictionary with the `observation.state` tensor
|
||||
extended to include joint velocities.
|
||||
|
||||
Raises:
|
||||
ValueError: If `observation.state` is not found in the observation.
|
||||
"""
|
||||
# Get current joint positions (assuming they're in observation.state)
|
||||
current_positions = observation.get(OBS_STATE)
|
||||
if current_positions is None:
|
||||
# TODO(steven): if we get here, then the transform_features method will not hold
|
||||
raise ValueError(f"{OBS_STATE} is not in observation")
|
||||
|
||||
# Initialize last joint positions if not already set
|
||||
@@ -48,14 +89,33 @@ class JointVelocityProcessorStep(ObservationProcessorStep):
|
||||
return new_observation
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the time step `dt`.
|
||||
"""
|
||||
return {
|
||||
"dt": self.dt,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the internal state, clearing the last known joint positions."""
|
||||
self.last_joint_positions = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""
|
||||
Updates the `observation.state` feature to reflect the added velocities.
|
||||
|
||||
This method doubles the size of the first dimension of the `observation.state`
|
||||
shape to account for the concatenation of position and velocity vectors.
|
||||
|
||||
Args:
|
||||
features: The policy features dictionary.
|
||||
|
||||
Returns:
|
||||
The updated policy features dictionary.
|
||||
"""
|
||||
if OBS_STATE in features:
|
||||
original_feature = features[OBS_STATE]
|
||||
# Double the shape to account for positions + velocities
|
||||
@@ -68,11 +128,33 @@ class JointVelocityProcessorStep(ObservationProcessorStep):
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("current_processor")
|
||||
class MotorCurrentProcessorStep(ObservationProcessorStep):
|
||||
"""Add motor current information to observations."""
|
||||
"""
|
||||
Reads motor currents from a robot and appends them to the observation state.
|
||||
|
||||
This step queries the robot's hardware interface to get the present current
|
||||
for each motor and concatenates this information to the existing state vector.
|
||||
|
||||
Attributes:
|
||||
robot: An instance of a `lerobot` Robot class that provides access to
|
||||
the hardware bus.
|
||||
"""
|
||||
|
||||
robot: Robot | None = None
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
"""
|
||||
Fetches motor currents and adds them to the observation state.
|
||||
|
||||
Args:
|
||||
observation: The input observation dictionary.
|
||||
|
||||
Returns:
|
||||
A new observation dictionary with the `observation.state` tensor
|
||||
extended to include motor currents.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `robot` attribute has not been set.
|
||||
"""
|
||||
# Get current values from robot state
|
||||
if self.robot is None:
|
||||
raise ValueError("Robot is not set")
|
||||
@@ -96,6 +178,18 @@ class MotorCurrentProcessorStep(ObservationProcessorStep):
|
||||
return new_observation
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""
|
||||
Updates the `observation.state` feature to reflect the added motor currents.
|
||||
|
||||
This method increases the size of the first dimension of the `observation.state`
|
||||
shape by the number of motors in the robot.
|
||||
|
||||
Args:
|
||||
features: The policy features dictionary.
|
||||
|
||||
Returns:
|
||||
The updated policy features dictionary.
|
||||
"""
|
||||
if OBS_STATE in features and self.robot is not None:
|
||||
original_feature = features[OBS_STATE]
|
||||
# Add motor current dimensions to the original state shape
|
||||
|
||||
@@ -15,16 +15,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Generic script to migrate any policy model with normalization layers to the new pipeline-based system.
|
||||
A generic script to migrate LeRobot policies with built-in normalization layers to the new
|
||||
pipeline-based processor system.
|
||||
|
||||
This script:
|
||||
1. Loads an existing pretrained policy model
|
||||
2. Extracts normalization statistics from the model
|
||||
3. Creates both preprocessor and postprocessor:
|
||||
- Preprocessor: normalizes both inputs (observations) and outputs (actions) for training
|
||||
- Postprocessor: unnormalizes outputs (actions) for inference
|
||||
4. Removes normalization layers from the model state_dict
|
||||
5. Saves the new model and both processors
|
||||
This script performs the following steps:
|
||||
1. Loads a pretrained policy model and its configuration from a local path or the
|
||||
Hugging Face Hub.
|
||||
2. Scans the model's state dictionary to extract normalization statistics (e.g., mean,
|
||||
std, min, max) for all features.
|
||||
3. Creates two new processor pipelines:
|
||||
- A preprocessor that normalizes inputs (observations) and outputs (actions).
|
||||
- A postprocessor that unnormalizes outputs (actions) for inference.
|
||||
4. Removes the original normalization layers from the model's state dictionary,
|
||||
creating a "clean" model.
|
||||
5. Saves the new clean model, the preprocessor, the postprocessor, and a generated
|
||||
model card to a new directory.
|
||||
6. Optionally pushes all the new artifacts to the Hugging Face Hub.
|
||||
|
||||
Usage:
|
||||
python src/lerobot/processor/migrate_policy_normalization.py \
|
||||
@@ -68,7 +74,21 @@ POLICY_CLASSES = {
|
||||
|
||||
|
||||
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Extract normalization statistics from model state_dict."""
|
||||
"""
|
||||
Scans a model's state_dict to find and extract normalization statistics.
|
||||
|
||||
This function identifies keys corresponding to normalization layers (e.g., those
|
||||
for mean, std, min, max) based on a set of predefined patterns and organizes
|
||||
them into a nested dictionary.
|
||||
|
||||
Args:
|
||||
state_dict: The state dictionary of a pretrained policy model.
|
||||
|
||||
Returns:
|
||||
A nested dictionary where outer keys are feature names (e.g.,
|
||||
'observation.state') and inner keys are statistic types ('mean', 'std'),
|
||||
mapping to their corresponding tensor values.
|
||||
"""
|
||||
stats = {}
|
||||
|
||||
# Define patterns to match and their prefixes to remove
|
||||
@@ -112,7 +132,25 @@ def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str
|
||||
def detect_features_and_norm_modes(
|
||||
config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]]
|
||||
) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]:
|
||||
"""Detect features and normalization modes from config and stats."""
|
||||
"""
|
||||
Infers policy features and normalization modes from the model config and stats.
|
||||
|
||||
This function first attempts to find feature definitions and normalization
|
||||
mappings directly from the policy's configuration file. If this information is
|
||||
not present, it infers it from the extracted normalization statistics, using
|
||||
tensor shapes to determine feature shapes and the presence of specific stat
|
||||
keys (e.g., 'mean'/'std' vs 'min'/'max') to determine the normalization mode.
|
||||
It applies sensible defaults if inference is not possible.
|
||||
|
||||
Args:
|
||||
config: The policy's configuration dictionary from `config.json`.
|
||||
stats: The normalization statistics extracted from the model's state_dict.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A dictionary mapping feature names to `PolicyFeature` objects.
|
||||
- A dictionary mapping `FeatureType` enums to `NormalizationMode` enums.
|
||||
"""
|
||||
features = {}
|
||||
norm_modes = {}
|
||||
|
||||
@@ -204,7 +242,19 @@ def detect_features_and_norm_modes(
|
||||
|
||||
|
||||
def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""Remove normalization layers from state_dict."""
|
||||
"""
|
||||
Creates a new state_dict with all normalization-related layers removed.
|
||||
|
||||
This function filters the original state dictionary, excluding any keys that
|
||||
match a set of predefined patterns associated with normalization modules.
|
||||
|
||||
Args:
|
||||
state_dict: The original model state dictionary.
|
||||
|
||||
Returns:
|
||||
A new state dictionary containing only the core model weights, without
|
||||
any normalization parameters.
|
||||
"""
|
||||
new_state_dict = {}
|
||||
|
||||
# Patterns to remove
|
||||
@@ -228,7 +278,16 @@ def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str
|
||||
|
||||
|
||||
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
"""Convert features from old format to PolicyFeature objects."""
|
||||
"""
|
||||
Converts a feature dictionary from the old config format to the new `PolicyFeature` format.
|
||||
|
||||
Args:
|
||||
features_dict: The feature dictionary in the old format, where values are
|
||||
simple dictionaries (e.g., `{"shape": [7]}`).
|
||||
|
||||
Returns:
|
||||
A dictionary mapping feature names to `PolicyFeature` dataclass objects.
|
||||
"""
|
||||
converted_features = {}
|
||||
|
||||
for key, feature_dict in features_dict.items():
|
||||
@@ -254,8 +313,18 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
|
||||
def load_model_from_hub(
|
||||
repo_id: str, revision: str = None
|
||||
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
|
||||
"""Load model state_dict and config from hub."""
|
||||
# Download files
|
||||
"""
|
||||
Downloads and loads a model's state_dict and configs from the Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
repo_id: The repository ID on the Hub (e.g., 'lerobot/aloha').
|
||||
revision: The specific git revision (branch, tag, or commit hash) to use.
|
||||
|
||||
Returns:
|
||||
A tuple containing the model's state dictionary, the policy configuration,
|
||||
and the training configuration.
|
||||
"""
|
||||
# Download files.
|
||||
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
|
||||
|
||||
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
|
||||
|
||||
@@ -1,3 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
@@ -20,9 +37,26 @@ class _NormalizationMixin:
|
||||
"""
|
||||
A mixin class providing core functionality for normalization and unnormalization.
|
||||
|
||||
This class manages normalization statistics, their conversion to tensors, device placement,
|
||||
and the application of normalization transformations. It is designed to be inherited by
|
||||
concrete ProcessorStep implementations.
|
||||
This class manages normalization statistics (`stats`), converts them to tensors for
|
||||
efficient computation, handles device placement, and implements the logic for
|
||||
applying normalization transformations (mean/std and min/max). It is designed to
|
||||
be inherited by concrete `ProcessorStep` implementations and should not be used
|
||||
directly.
|
||||
|
||||
Attributes:
|
||||
features: A dictionary mapping feature names to `PolicyFeature` objects, defining
|
||||
the data structure to be processed.
|
||||
norm_map: A dictionary mapping `FeatureType` to `NormalizationMode`, specifying
|
||||
which normalization method to use for each type of feature.
|
||||
stats: A dictionary containing the normalization statistics (e.g., mean, std,
|
||||
min, max) for each feature.
|
||||
device: The PyTorch device on which to store and perform tensor operations.
|
||||
eps: A small epsilon value to prevent division by zero in normalization
|
||||
calculations.
|
||||
normalize_observation_keys: An optional set of keys to selectively apply
|
||||
normalization to specific observation features.
|
||||
_tensor_stats: An internal dictionary holding the normalization statistics as
|
||||
PyTorch tensors.
|
||||
"""
|
||||
|
||||
features: dict[str, PolicyFeature]
|
||||
@@ -36,7 +70,15 @@ class _NormalizationMixin:
|
||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
# Robust JSON deserialization handling (guard empty maps)
|
||||
"""
|
||||
Initializes the mixin after dataclass construction.
|
||||
|
||||
This method handles the robust deserialization of `features` and `norm_map`
|
||||
from JSON-compatible formats (where enums become strings and tuples become
|
||||
lists) and converts the provided `stats` dictionary into a dictionary of
|
||||
tensors (`_tensor_stats`) on the specified device.
|
||||
"""
|
||||
# Robust JSON deserialization handling (guard empty maps).
|
||||
if self.features:
|
||||
first_val = next(iter(self.features.values()))
|
||||
if isinstance(first_val, dict):
|
||||
@@ -65,7 +107,15 @@ class _NormalizationMixin:
|
||||
def to(
|
||||
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
||||
) -> _NormalizationMixin:
|
||||
"""Moves the processor's normalization stats to the specified device and returns self."""
|
||||
"""
|
||||
Moves the processor's normalization stats to the specified device.
|
||||
|
||||
Args:
|
||||
device: The target PyTorch device.
|
||||
|
||||
Returns:
|
||||
The instance of the class, allowing for method chaining.
|
||||
"""
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
@@ -74,6 +124,16 @@ class _NormalizationMixin:
|
||||
return self
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
"""
|
||||
Returns the normalization statistics as a flat state dictionary.
|
||||
|
||||
All tensors are moved to the CPU before being returned, which is standard practice
|
||||
for saving state dictionaries.
|
||||
|
||||
Returns:
|
||||
A flat dictionary mapping from `'feature_name.stat_name'` to the
|
||||
corresponding statistics tensor on the CPU.
|
||||
"""
|
||||
flat: dict[str, Tensor] = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
@@ -81,6 +141,15 @@ class _NormalizationMixin:
|
||||
return flat
|
||||
|
||||
def load_state_dict(self, state: dict[str, Tensor]) -> None:
|
||||
"""
|
||||
Loads normalization statistics from a state dictionary.
|
||||
|
||||
The loaded tensors are moved to the processor's configured device.
|
||||
|
||||
Args:
|
||||
state: A flat state dictionary with keys in the format
|
||||
`'feature_name.stat_name'`.
|
||||
"""
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
@@ -90,6 +159,15 @@ class _NormalizationMixin:
|
||||
)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns a serializable dictionary of the processor's configuration.
|
||||
|
||||
This method is used when saving the processor to disk, ensuring that its
|
||||
configuration can be reconstructed later.
|
||||
|
||||
Returns:
|
||||
A JSON-serializable dictionary containing the configuration.
|
||||
"""
|
||||
config = {
|
||||
"eps": self.eps,
|
||||
"features": {
|
||||
@@ -102,6 +180,16 @@ class _NormalizationMixin:
|
||||
return config
|
||||
|
||||
def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]:
|
||||
"""
|
||||
Applies (un)normalization to all relevant features in an observation dictionary.
|
||||
|
||||
Args:
|
||||
observation: The observation dictionary to process.
|
||||
inverse: If `True`, applies unnormalization; otherwise, applies normalization.
|
||||
|
||||
Returns:
|
||||
A new observation dictionary with the transformed tensor values.
|
||||
"""
|
||||
new_observation = dict(observation)
|
||||
for key, feature in self.features.items():
|
||||
if self.normalize_observation_keys is not None and key not in self.normalize_observation_keys:
|
||||
@@ -114,6 +202,16 @@ class _NormalizationMixin:
|
||||
|
||||
def _normalize_action(self, action: Any, inverse: bool) -> Tensor:
|
||||
# Convert to tensor but preserve original dtype for adaptation logic
|
||||
"""
|
||||
Applies (un)normalization to an action tensor.
|
||||
|
||||
Args:
|
||||
action: The action tensor to process.
|
||||
inverse: If `True`, applies unnormalization; otherwise, applies normalization.
|
||||
|
||||
Returns:
|
||||
The transformed action tensor.
|
||||
"""
|
||||
tensor = torch.as_tensor(action)
|
||||
processed_action = self._apply_transform(tensor, "action", FeatureType.ACTION, inverse=inverse)
|
||||
return processed_action
|
||||
@@ -121,7 +219,24 @@ class _NormalizationMixin:
|
||||
def _apply_transform(
|
||||
self, tensor: Tensor, key: str, feature_type: FeatureType, *, inverse: bool = False
|
||||
) -> Tensor:
|
||||
"""Core logic to apply normalization or unnormalization."""
|
||||
"""
|
||||
Core logic to apply a normalization or unnormalization transformation to a tensor.
|
||||
|
||||
This method selects the appropriate normalization mode (e.g., mean/std, min/max)
|
||||
based on the feature type and applies the corresponding mathematical operation.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor to transform.
|
||||
key: The feature key corresponding to the tensor.
|
||||
feature_type: The `FeatureType` of the tensor.
|
||||
inverse: If `True`, applies the inverse transformation (unnormalization).
|
||||
|
||||
Returns:
|
||||
The transformed tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: If an unsupported normalization mode is encountered.
|
||||
"""
|
||||
norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY)
|
||||
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats:
|
||||
return tensor
|
||||
@@ -168,11 +283,11 @@ class _NormalizationMixin:
|
||||
@ProcessorStepRegistry.register(name="normalizer_processor")
|
||||
class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
|
||||
"""
|
||||
A processor that applies normalization to observations and actions in a transition.
|
||||
A processor step that applies normalization to observations and actions in a transition.
|
||||
|
||||
This class directly implements the normalization logic for both observation and action
|
||||
components of an `EnvTransition`, using statistics (mean/std or min/max) provided at
|
||||
initialization.
|
||||
This class uses the logic from `_NormalizationMixin` to perform forward normalization
|
||||
(e.g., scaling data to have zero mean and unit variance, or to the range [-1, 1]).
|
||||
It is typically used in the pre-processing pipeline before feeding data to a policy.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@@ -186,6 +301,20 @@ class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
|
||||
eps: float = 1e-8,
|
||||
device: torch.device | str | None = None,
|
||||
) -> NormalizerProcessorStep:
|
||||
"""
|
||||
Creates a `NormalizerProcessorStep` instance using statistics from a `LeRobotDataset`.
|
||||
|
||||
Args:
|
||||
dataset: The dataset from which to extract normalization statistics.
|
||||
features: The feature definition for the processor.
|
||||
norm_map: The mapping from feature types to normalization modes.
|
||||
normalize_observation_keys: An optional set of observation keys to normalize.
|
||||
eps: A small epsilon value for numerical stability.
|
||||
device: The target device for the processor.
|
||||
|
||||
Returns:
|
||||
A new instance of `NormalizerProcessorStep`.
|
||||
"""
|
||||
return cls(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
@@ -220,11 +349,12 @@ class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
|
||||
@ProcessorStepRegistry.register(name="unnormalizer_processor")
|
||||
class UnnormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
|
||||
"""
|
||||
A processor that applies unnormalization (the inverse of normalization) to
|
||||
observations and actions in a transition.
|
||||
A processor step that applies unnormalization to observations and actions.
|
||||
|
||||
This is typically used to transform actions from a normalized policy output back into
|
||||
the original scale for execution in an environment.
|
||||
This class inverts the normalization process, scaling data back to its original
|
||||
range. It is typically used in the post-processing pipeline to convert a policy's
|
||||
normalized action output into a format that can be executed by a robot or
|
||||
environment.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@@ -236,6 +366,18 @@ class UnnormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
|
||||
*,
|
||||
device: torch.device | str | None = None,
|
||||
) -> UnnormalizerProcessorStep:
|
||||
"""
|
||||
Creates an `UnnormalizerProcessorStep` using statistics from a `LeRobotDataset`.
|
||||
|
||||
Args:
|
||||
dataset: The dataset from which to extract normalization statistics.
|
||||
features: The feature definition for the processor.
|
||||
norm_map: The mapping from feature types to normalization modes.
|
||||
device: The target device for the processor.
|
||||
|
||||
Returns:
|
||||
A new instance of `UnnormalizerProcessorStep`.
|
||||
"""
|
||||
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, device=device)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
@@ -261,12 +403,20 @@ def hotswap_stats(
|
||||
policy_processor: PolicyProcessorPipeline, stats: dict[str, dict[str, Any]]
|
||||
) -> PolicyProcessorPipeline:
|
||||
"""
|
||||
Replaces normalization statistics in a PolicyProcessor pipeline.
|
||||
Replaces normalization statistics in an existing `PolicyProcessorPipeline` instance.
|
||||
|
||||
This function creates a deep copy of the provided `PolicyProcessorPipeline` and updates the
|
||||
statistics of any `NormalizerProcessorStep` or `UnnormalizerProcessorStep` steps within it.
|
||||
It's useful for adapting a trained policy to a new environment or dataset with
|
||||
different data distributions.
|
||||
This function creates a deep copy of the provided pipeline and updates the
|
||||
statistics of any `NormalizerProcessorStep` or `UnnormalizerProcessorStep` it
|
||||
contains. This is useful for adapting a trained policy to a new environment or
|
||||
dataset with different data distributions without having to reconstruct the entire
|
||||
pipeline.
|
||||
|
||||
Args:
|
||||
policy_processor: The policy processor pipeline to modify.
|
||||
stats: The new dictionary of normalization statistics to apply.
|
||||
|
||||
Returns:
|
||||
A new `PolicyProcessorPipeline` instance with the updated statistics.
|
||||
"""
|
||||
rp = deepcopy(policy_processor)
|
||||
for step in rp.steps:
|
||||
|
||||
@@ -30,23 +30,44 @@ from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
@ProcessorStepRegistry.register(name="observation_processor")
|
||||
class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processes environment observations into the LeRobot format by handling both images and states.
|
||||
Processes standard Gymnasium observations into the LeRobot format.
|
||||
|
||||
Image processing:
|
||||
- Converts channel-last (H, W, C) images to channel-first (C, H, W)
|
||||
- Normalizes uint8 images ([0, 255]) to float32 ([0, 1])
|
||||
- Adds a batch dimension if missing
|
||||
- Supports single images and image dictionaries
|
||||
This step handles both image and state data from a typical observation dictionary,
|
||||
preparing it for use in a LeRobot policy.
|
||||
|
||||
State processing:
|
||||
- Maps 'environment_state' to observation.environment_state
|
||||
- Maps 'agent_pos' to observation.state
|
||||
- Converts numpy arrays to tensors
|
||||
- Adds a batch dimension if missing
|
||||
**Image Processing:**
|
||||
- Converts channel-last (H, W, C), `uint8` images to channel-first (C, H, W),
|
||||
`float32` tensors.
|
||||
- Normalizes pixel values from the [0, 255] range to [0, 1].
|
||||
- Adds a batch dimension if one is not already present.
|
||||
- Recognizes a single image under the key `"pixels"` and maps it to
|
||||
`"observation.image"`.
|
||||
- Recognizes a dictionary of images under the key `"pixels"` and maps them
|
||||
to `"observation.images.{camera_name}"`.
|
||||
|
||||
**State Processing:**
|
||||
- Maps the `"environment_state"` key to `"observation.environment_state"`.
|
||||
- Maps the `"agent_pos"` key to `"observation.state"`.
|
||||
- Converts NumPy arrays to PyTorch tensors.
|
||||
- Adds a batch dimension if one is not already present.
|
||||
"""
|
||||
|
||||
def _process_single_image(self, img: np.ndarray) -> Tensor:
|
||||
"""Process a single image array."""
|
||||
"""
|
||||
Processes a single NumPy image array into a channel-first, normalized tensor.
|
||||
|
||||
Args:
|
||||
img: A NumPy array representing the image, expected to be in channel-last
|
||||
(H, W, C) format with a `uint8` dtype.
|
||||
|
||||
Returns:
|
||||
A `float32` PyTorch tensor in channel-first (B, C, H, W) format, with
|
||||
pixel values normalized to the [0, 1] range.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input image does not appear to be in channel-last
|
||||
format or is not of `uint8` dtype.
|
||||
"""
|
||||
# Convert to tensor
|
||||
img_tensor = torch.from_numpy(img)
|
||||
|
||||
@@ -108,16 +129,24 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
return self._process_observation(observation)
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms feature keys to a standardized contract.
|
||||
This method handles several renaming patterns:
|
||||
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
|
||||
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').
|
||||
- Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1').
|
||||
- Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1').
|
||||
- environment_state -> OBS_ENV_STATE,
|
||||
- agent_pos -> OBS_STATE,
|
||||
- observation.environment_state -> OBS_ENV_STATE,
|
||||
- observation.agent_pos -> OBS_STATE
|
||||
"""
|
||||
Transforms feature keys from the Gym standard to the LeRobot standard.
|
||||
|
||||
This method standardizes the feature dictionary by renaming keys according
|
||||
to LeRobot's conventions, ensuring that policies can be constructed correctly.
|
||||
It handles various raw key formats, including those with an "observation." prefix.
|
||||
|
||||
**Renaming Rules:**
|
||||
- `pixels` or `observation.pixels` -> `observation.image`
|
||||
- `pixels.{cam}` or `observation.pixels.{cam}` -> `observation.images.{cam}`
|
||||
- `environment_state` or `observation.environment_state` -> `observation.environment_state`
|
||||
- `agent_pos` or `observation.agent_pos` -> `observation.state`
|
||||
|
||||
Args:
|
||||
features: The policy features dictionary with Gym-style keys.
|
||||
|
||||
Returns:
|
||||
The policy features dictionary with standardized LeRobot keys.
|
||||
"""
|
||||
exact_pairs = {
|
||||
"pixels": OBS_IMAGE,
|
||||
|
||||
@@ -25,7 +25,18 @@ from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="rename_processor")
|
||||
class RenameProcessorStep(ObservationProcessorStep):
|
||||
"""Rename processor that renames keys in the observation."""
|
||||
"""
|
||||
A processor step that renames keys in an observation dictionary.
|
||||
|
||||
This step is useful for creating a standardized data interface by mapping keys
|
||||
from an environment's format to the format expected by a LeRobot policy or
|
||||
other downstream components.
|
||||
|
||||
Attributes:
|
||||
rename_map: A dictionary mapping from old key names to new key names.
|
||||
Keys present in an observation that are not in this map will
|
||||
be kept with their original names.
|
||||
"""
|
||||
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@@ -51,7 +62,22 @@ class RenameProcessorStep(ObservationProcessorStep):
|
||||
|
||||
|
||||
def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]:
|
||||
"""Rename keys in the stats dictionary according to rename_map (defensive copy)."""
|
||||
"""
|
||||
Renames the top-level keys in a statistics dictionary using a provided mapping.
|
||||
|
||||
This is a helper function typically used to keep normalization statistics
|
||||
consistent with renamed observation or action features. It performs a defensive
|
||||
deep copy to avoid modifying the original `stats` dictionary.
|
||||
|
||||
Args:
|
||||
stats: A nested dictionary of statistics, where top-level keys are
|
||||
feature names (e.g., `{"observation.state": {"mean": 0.5}}`).
|
||||
rename_map: A dictionary mapping old feature names to new feature names.
|
||||
|
||||
Returns:
|
||||
A new statistics dictionary with its top-level keys renamed. Returns an
|
||||
empty dictionary if the input `stats` is empty.
|
||||
"""
|
||||
if not stats:
|
||||
return {}
|
||||
renamed: dict[str, dict[str, Any]] = {}
|
||||
|
||||
@@ -1,5 +1,24 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Tokenizer processor for handling text tokenization in robot transitions.
|
||||
This script defines a processor for tokenizing natural language instructions from an environment transition.
|
||||
|
||||
It uses a tokenizer from the Hugging Face `transformers` library to convert task descriptions (text) into
|
||||
token IDs and attention masks, which are then added to the observation dictionary.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -16,6 +35,7 @@ from lerobot.utils.import_utils import _transformers_available
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoTokenizer
|
||||
else:
|
||||
@@ -25,54 +45,48 @@ else:
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="tokenizer_processor")
|
||||
class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
"""Tokenizes text tasks in complementary data using a huggingface tokenizer.
|
||||
"""
|
||||
Processor step to tokenize a natural language task description.
|
||||
|
||||
This processor handles tokenization of task strings found in the complementary_data
|
||||
using a specified pretrained tokenizer from Hugging Face. It adds tokenized versions
|
||||
to the observation data for model processing while preserving the original task string.
|
||||
This step extracts a task string from the `complementary_data` of an `EnvTransition`,
|
||||
tokenizes it using a Hugging Face `transformers` tokenizer, and adds the resulting
|
||||
token IDs and attention mask to the `observation` dictionary.
|
||||
|
||||
The processor supports both single strings and lists of strings as task inputs.
|
||||
Requires the `transformers` library to be installed.
|
||||
|
||||
Args:
|
||||
tokenizer_name: Name of the pretrained tokenizer to load from Hugging Face Hub
|
||||
(e.g., "bert-base-uncased", "microsoft/DialoGPT-medium"). This will be used
|
||||
with AutoTokenizer.from_pretrained(). If tokenizer is provided, this is ignored.
|
||||
tokenizer: A tokenizer object (e.g., from transformers library) that implements
|
||||
the __call__ method. If provided, tokenizer_name is ignored. This parameter
|
||||
is not serialized and must be provided via overrides when loading.
|
||||
max_length: Maximum sequence length for tokenization. Defaults to 512.
|
||||
task_key: Key in complementary_data containing the task text. Defaults to "task".
|
||||
padding: Padding strategy for tokenization. Defaults to "max_length".
|
||||
truncation: Whether to truncate sequences longer than max_length. Defaults to True.
|
||||
|
||||
Examples:
|
||||
Using tokenizer name (auto-loaded):
|
||||
```python
|
||||
processor = TokenizerProcessorStep(tokenizer_name="bert-base-uncased", max_length=128)
|
||||
```
|
||||
|
||||
Using custom tokenizer object:
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
custom_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
processor = TokenizerProcessorStep(tokenizer=custom_tokenizer, max_length=128)
|
||||
```
|
||||
Attributes:
|
||||
tokenizer_name: The name of a pretrained tokenizer from the Hugging Face Hub (e.g., "bert-base-uncased").
|
||||
tokenizer: A pre-initialized tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||
max_length: The maximum length to pad or truncate sequences to.
|
||||
task_key: The key in `complementary_data` where the task string is stored.
|
||||
padding_side: The side to pad on ('left' or 'right').
|
||||
padding: The padding strategy ('max_length', 'longest', etc.).
|
||||
truncation: Whether to truncate sequences longer than `max_length`.
|
||||
input_tokenizer: The internal tokenizer instance, loaded during initialization.
|
||||
"""
|
||||
|
||||
tokenizer_name: str | None = None
|
||||
tokenizer: Any | None = None # Otherwise transformers is not available in the core dependencies
|
||||
tokenizer: Any | None = None # Use `Any` for compatibility without a hard dependency
|
||||
max_length: int = 512
|
||||
task_key: str = "task"
|
||||
padding_side: str = "right"
|
||||
padding: str = "max_length"
|
||||
truncation: bool = True
|
||||
|
||||
# Internal tokenizer instance (not serialized)
|
||||
# Internal tokenizer instance (not part of the config)
|
||||
input_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize the tokenizer from the provided tokenizer or tokenizer name."""
|
||||
"""
|
||||
Initializes the tokenizer after the dataclass is created.
|
||||
|
||||
It checks for the availability of the `transformers` library and loads the tokenizer
|
||||
either from a provided object or by name from the Hugging Face Hub.
|
||||
|
||||
Raises:
|
||||
ImportError: If the `transformers` library is not installed.
|
||||
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
|
||||
"""
|
||||
if not _transformers_available:
|
||||
raise ImportError(
|
||||
"The 'transformers' library is not installed. "
|
||||
@@ -93,13 +107,14 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
)
|
||||
|
||||
def get_task(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""Extract and normalize task from complementary data.
|
||||
"""
|
||||
Extracts the task description(s) from the transition's complementary data.
|
||||
|
||||
Args:
|
||||
transition: Input transition containing complementary_data.
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
List of task strings if task is present, None otherwise.
|
||||
A list of task strings, or None if the task key is not found or the value is None.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
@@ -112,7 +127,7 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
if task is None:
|
||||
return None
|
||||
|
||||
# Convert to list of strings
|
||||
# Standardize to a list of strings for the tokenizer
|
||||
if isinstance(task, str):
|
||||
return [task]
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
@@ -120,78 +135,80 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
|
||||
return None
|
||||
|
||||
def observation(self, observation):
|
||||
"""Process the transition by tokenizing the task text.
|
||||
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Tokenizes the task description and adds it to the observation dictionary.
|
||||
|
||||
This method retrieves the task, tokenizes it, moves the resulting tensors to the
|
||||
same device as other data in the transition, and updates the observation.
|
||||
|
||||
Args:
|
||||
transition: Input transition containing complementary_data with task text.
|
||||
observation: The original observation dictionary.
|
||||
|
||||
Returns:
|
||||
Modified transition with tokenized task added to observation.
|
||||
|
||||
Raises:
|
||||
ValueError: If tokenizer initialization failed.
|
||||
The updated observation dictionary including token IDs and an attention mask.
|
||||
"""
|
||||
task = self.get_task(self.transition)
|
||||
if task is None:
|
||||
return observation
|
||||
|
||||
# Tokenize the task (creates CPU tensors)
|
||||
# Tokenize the task (this will create CPU tensors)
|
||||
tokenized_prompt = self._tokenize_text(task)
|
||||
|
||||
# Detect device from existing tensors in the transition
|
||||
# Detect the device from existing tensors in the transition to ensure consistency
|
||||
target_device = self._detect_device(self.transition)
|
||||
|
||||
# Move tokenized tensors to match the device of other data
|
||||
# Move new tokenized tensors to the detected device
|
||||
if target_device is not None:
|
||||
tokenized_prompt = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_prompt.items()
|
||||
}
|
||||
|
||||
# Get or create observation dict
|
||||
# Create a new observation dict to avoid modifying the original in place
|
||||
new_observation = dict(observation)
|
||||
|
||||
# Add tokenized data to observation
|
||||
# Add tokenized data to the observation
|
||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
return new_observation
|
||||
|
||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||
"""Detect device from existing tensors in the transition.
|
||||
"""
|
||||
Detects the torch.device from existing tensors in the transition.
|
||||
|
||||
This allows the tokenized tensors to match the device of other data,
|
||||
which is especially important for multi-GPU training with Accelerate.
|
||||
It checks tensors in the observation dictionary first, then the action tensor.
|
||||
|
||||
Args:
|
||||
transition: The transition to search for existing tensors.
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
The device of the first tensor found, or None if no tensors exist.
|
||||
The detected `torch.device`, or None if no tensors are found.
|
||||
"""
|
||||
# Check observation tensors first (most likely to exist)
|
||||
# Check observation tensors first (most likely place to find tensors)
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation:
|
||||
for value in observation.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
|
||||
# Check action tensor
|
||||
# Fallback to checking the action tensor
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if isinstance(action, torch.Tensor):
|
||||
return action.device
|
||||
|
||||
return None # No tensors found, keep on CPU
|
||||
return None # No tensors found, default will be CPU
|
||||
|
||||
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
|
||||
"""Tokenize text using the configured tokenizer.
|
||||
"""
|
||||
A wrapper around the tokenizer call.
|
||||
|
||||
Args:
|
||||
text: Text string or list of strings to tokenize.
|
||||
text: A string or list of strings to tokenize.
|
||||
|
||||
Returns:
|
||||
Dictionary containing tokenized output with keys like 'input_ids', 'attention_mask'.
|
||||
A dictionary containing tokenized 'input_ids' and 'attention_mask' as PyTorch tensors.
|
||||
"""
|
||||
return self.input_tokenizer(
|
||||
text,
|
||||
@@ -203,10 +220,14 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization.
|
||||
"""
|
||||
Returns the serializable configuration of the processor.
|
||||
|
||||
Note: Only tokenizer_name is saved, not the tokenizer object itself.
|
||||
When loading, provide the tokenizer via overrides if needed.
|
||||
Note: The tokenizer object itself is not serialized. If the processor was initialized
|
||||
with a tokenizer name, that name will be included in the config.
|
||||
|
||||
Returns:
|
||||
A dictionary with the processor's configuration parameters.
|
||||
"""
|
||||
config = {
|
||||
"max_length": self.max_length,
|
||||
@@ -216,28 +237,30 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
"truncation": self.truncation,
|
||||
}
|
||||
|
||||
# Only include tokenizer_name if it was used (not when tokenizer object was provided)
|
||||
# TODO(steven): Consider saving the name of the _tokenizer if it was loaded
|
||||
# Only save tokenizer_name if it was used to create the tokenizer
|
||||
if self.tokenizer_name is not None and self.tokenizer is None:
|
||||
config["tokenizer_name"] = self.tokenizer_name
|
||||
|
||||
return config
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Add tokenized task features to the feature contract.
|
||||
"""
|
||||
Adds feature definitions for the language tokens and attention mask.
|
||||
|
||||
This updates the policy features dictionary to include the new data added to the
|
||||
observation, ensuring downstream components are aware of their shape and type.
|
||||
|
||||
Args:
|
||||
features: Input feature dictionary.
|
||||
features: The dictionary of existing policy features.
|
||||
|
||||
Returns:
|
||||
Updated feature dictionary with tokenized task features added.
|
||||
The updated dictionary of policy features.
|
||||
"""
|
||||
# Add features for tokenized output if they don't exist
|
||||
# Standard tokenizer output includes tokens and attention_mask
|
||||
|
||||
# Add a feature for the token IDs if it doesn't already exist
|
||||
if OBS_LANGUAGE_TOKENS not in features:
|
||||
features[OBS_LANGUAGE_TOKENS] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
|
||||
|
||||
# Add a feature for the attention mask if it doesn't already exist
|
||||
if OBS_LANGUAGE_ATTENTION_MASK not in features:
|
||||
features[OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# !/usr/bin/env python
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@@ -38,18 +38,27 @@ from lerobot.utils.rotation import Rotation
|
||||
@dataclass
|
||||
class EEReferenceAndDelta(ActionProcessorStep):
|
||||
"""
|
||||
Compute the desired end-effector pose from the target pose and the current pose.
|
||||
Computes a target end-effector pose from a relative delta command.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"complementary_data.raw_joint_positions": dict,
|
||||
}
|
||||
This step takes a desired change in position and orientation (`target_*`) and applies it to a
|
||||
reference end-effector pose to calculate an absolute target pose. The reference pose is derived
|
||||
from the current robot joint positions using forward kinematics.
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
The processor can operate in two modes:
|
||||
1. `use_latched_reference=True`: The reference pose is "latched" or saved at the moment the action
|
||||
is first enabled. Subsequent commands are relative to this fixed reference.
|
||||
2. `use_latched_reference=False`: The reference pose is updated to the robot's current pose at
|
||||
every step.
|
||||
|
||||
Attributes:
|
||||
kinematics: The robot's kinematic model for forward kinematics.
|
||||
end_effector_step_sizes: A dictionary scaling the input delta commands.
|
||||
motor_names: A list of motor names required for forward kinematics.
|
||||
use_latched_reference: If True, latch the reference pose on enable; otherwise, always use the
|
||||
current pose as the reference.
|
||||
reference_ee_pose: Internal state storing the latched reference pose.
|
||||
_prev_enabled: Internal state to detect the rising edge of the enable signal.
|
||||
_command_when_disabled: Internal state to hold the last command while disabled.
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
@@ -135,6 +144,7 @@ class EEReferenceAndDelta(ActionProcessorStep):
|
||||
return new_action
|
||||
|
||||
def reset(self):
|
||||
"""Resets the internal state of the processor."""
|
||||
self._prev_enabled = False
|
||||
self.reference_ee_pose = None
|
||||
self._command_when_disabled = None
|
||||
@@ -161,17 +171,17 @@ class EEReferenceAndDelta(ActionProcessorStep):
|
||||
@dataclass
|
||||
class EEBoundsAndSafety(ActionProcessorStep):
|
||||
"""
|
||||
Clip the end-effector pose to the bounds and check for jumps.
|
||||
Clips the end-effector pose to predefined bounds and checks for unsafe jumps.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
This step ensures that the target end-effector pose remains within a safe operational workspace.
|
||||
It also moderates the command to prevent large, sudden movements between consecutive steps.
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
Attributes:
|
||||
end_effector_bounds: A dictionary with "min" and "max" keys for position clipping.
|
||||
max_ee_step_m: The maximum allowed change in position (in meters) between steps.
|
||||
max_ee_twist_step_rad: The maximum allowed change in orientation (in radians) between steps.
|
||||
_last_pos: Internal state storing the last commanded position.
|
||||
_last_twist: Internal state storing the last commanded orientation.
|
||||
"""
|
||||
|
||||
end_effector_bounds: dict
|
||||
@@ -219,6 +229,7 @@ class EEBoundsAndSafety(ActionProcessorStep):
|
||||
return act
|
||||
|
||||
def reset(self):
|
||||
"""Resets the last known position and orientation."""
|
||||
self._last_pos = None
|
||||
self._last_twist = None
|
||||
|
||||
@@ -232,21 +243,17 @@ class EEBoundsAndSafety(ActionProcessorStep):
|
||||
@dataclass
|
||||
class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
"""
|
||||
Compute the desired joint positions from the desired end-effector pose.
|
||||
Computes desired joint positions from a target end-effector pose using inverse kinematics (IK).
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"complementary_data.raw_joint_positions": dict,
|
||||
}
|
||||
This step translates a Cartesian command (position and orientation of the end-effector) into
|
||||
the corresponding joint-space commands for each motor.
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.joint_name_1.pos": float,
|
||||
"action.joint_name_2.pos": float,
|
||||
...
|
||||
"action.joint_name_n.pos": float,
|
||||
}
|
||||
Attributes:
|
||||
kinematics: The robot's kinematic model for inverse kinematics.
|
||||
motor_names: A list of motor names for which to compute joint positions.
|
||||
q_curr: Internal state storing the last joint positions, used as an initial guess for the IK solver.
|
||||
initial_guess_current_joints: If True, use the robot's current joint state as the IK guess.
|
||||
If False, use the solution from the previous step.
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
@@ -312,6 +319,7 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
return features
|
||||
|
||||
def reset(self):
|
||||
"""Resets the initial guess for the IK solver."""
|
||||
self.q_curr = None
|
||||
|
||||
|
||||
@@ -319,17 +327,18 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
@dataclass
|
||||
class GripperVelocityToJoint(ProcessorStep):
|
||||
"""
|
||||
Convert the gripper velocity to a joint velocity.
|
||||
Converts a gripper velocity command into a target gripper joint position.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.gripper": float,
|
||||
}
|
||||
This step integrates a normalized velocity command over time to produce a position command,
|
||||
taking the current gripper position as a starting point. It also supports a discrete mode
|
||||
where integer actions map to open, close, or no-op.
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.gripper.pos": float,
|
||||
}
|
||||
Attributes:
|
||||
motor_names: A list of motor names, which must include 'gripper'.
|
||||
speed_factor: A scaling factor to convert the normalized velocity command to a position change.
|
||||
clip_min: The minimum allowed gripper joint position.
|
||||
clip_max: The maximum allowed gripper joint position.
|
||||
discrete_gripper: If True, treat the input action as discrete (0: open, 1: close, 2: stay).
|
||||
"""
|
||||
|
||||
motor_names: list[str]
|
||||
@@ -365,7 +374,7 @@ class GripperVelocityToJoint(ProcessorStep):
|
||||
raw = comp.get("raw_joint_positions") or {}
|
||||
curr_pos = float(raw.get("gripper"))
|
||||
|
||||
# Compute desired gripper velocity
|
||||
# Compute desired gripper position
|
||||
u = float(act.get(f"{ACTION}.gripper", 0.0))
|
||||
delta = u * float(self.speed_factor)
|
||||
gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max))
|
||||
@@ -391,17 +400,14 @@ class GripperVelocityToJoint(ProcessorStep):
|
||||
@dataclass
|
||||
class ForwardKinematicsJointsToEE(ObservationProcessorStep):
|
||||
"""
|
||||
Compute the end-effector pose from the joint positions.
|
||||
Computes the end-effector pose from joint positions using forward kinematics (FK).
|
||||
|
||||
Input OBSERVATION keys:
|
||||
{
|
||||
"observation.state.{joint_name_1,joint_name_2,...,joint_name_n}.pos": float,
|
||||
}
|
||||
This step is typically used to add the robot's Cartesian pose to the observation space,
|
||||
which can be useful for visualization or as an input to a policy.
|
||||
|
||||
Output OBSERVATION keys:
|
||||
{
|
||||
"observation.state.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
Attributes:
|
||||
kinematics: The robot's kinematic model.
|
||||
motor_names: A list of motor names whose joint positions are used for FK.
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
@@ -435,10 +441,14 @@ class ForwardKinematicsJointsToEE(ObservationProcessorStep):
|
||||
@dataclass
|
||||
class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
Read the robot's current observation and insert it into the transition as complementary data.
|
||||
Reads the robot's current observation and adds it to the transition's complementary data.
|
||||
|
||||
- Joint positions are added under complementary_data["raw_joint_positions"] as a dict:
|
||||
{ "<motor_name>": <float position>, ... }
|
||||
This step acts as a bridge to the physical robot, injecting its real-time sensor readings
|
||||
(like raw joint positions) into the data processing pipeline. This data is then available
|
||||
for other processing steps.
|
||||
|
||||
Attributes:
|
||||
robot: An instance of a `Robot` class used to get observations from hardware.
|
||||
"""
|
||||
|
||||
robot: Robot
|
||||
|
||||
@@ -98,9 +98,7 @@ from lerobot.utils.utils import (
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
#################################################
|
||||
# Main entry point #
|
||||
#################################################
|
||||
# Main entry point
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
@@ -207,9 +205,7 @@ def actor_cli(cfg: TrainRLServerPipelineConfig):
|
||||
logging.info("[ACTOR] queues closed")
|
||||
|
||||
|
||||
#################################################
|
||||
# Core algorithm functions #
|
||||
#################################################
|
||||
# Core algorithm functions
|
||||
|
||||
|
||||
def act_with_policy(
|
||||
@@ -406,9 +402,7 @@ def act_with_policy(
|
||||
busy_wait(1 / cfg.env.fps - dt_time)
|
||||
|
||||
|
||||
#################################################
|
||||
# Communication Functions - Group all gRPC/messaging functions #
|
||||
#################################################
|
||||
# Communication Functions - Group all gRPC/messaging functions
|
||||
|
||||
|
||||
def establish_learner_connection(
|
||||
@@ -653,9 +647,7 @@ def interactions_stream(
|
||||
return services_pb2.Empty()
|
||||
|
||||
|
||||
#################################################
|
||||
# Policy functions #
|
||||
#################################################
|
||||
# Policy functions
|
||||
|
||||
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
|
||||
@@ -687,9 +679,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
|
||||
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
|
||||
|
||||
|
||||
#################################################
|
||||
# Utilities functions #
|
||||
#################################################
|
||||
# Utilities functions
|
||||
|
||||
|
||||
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
||||
|
||||
@@ -103,11 +103,6 @@ from lerobot.utils.wandb_utils import WandBLogger
|
||||
LOG_PREFIX = "[LEARNER]"
|
||||
|
||||
|
||||
#################################################
|
||||
# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS #
|
||||
#################################################
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train_cli(cfg: TrainRLServerPipelineConfig):
|
||||
if not use_threads(cfg):
|
||||
@@ -250,9 +245,7 @@ def start_learner_threads(
|
||||
logging.info("[LEARNER] queues closed")
|
||||
|
||||
|
||||
#################################################
|
||||
# Core algorithm functions #
|
||||
#################################################
|
||||
# Core algorithm functions
|
||||
|
||||
|
||||
def add_actor_information_and_train(
|
||||
@@ -820,9 +813,7 @@ def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.M
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
#################################################
|
||||
# Training setup functions #
|
||||
#################################################
|
||||
# Training setup functions
|
||||
|
||||
|
||||
def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipelineConfig:
|
||||
@@ -1023,9 +1014,7 @@ def initialize_offline_replay_buffer(
|
||||
return offline_replay_buffer
|
||||
|
||||
|
||||
#################################################
|
||||
# Utilities/Helpers functions #
|
||||
#################################################
|
||||
# Utilities/Helpers functions
|
||||
|
||||
|
||||
def get_observation_features(
|
||||
|
||||
@@ -65,6 +65,28 @@ def update_policy(
|
||||
use_amp: bool = False,
|
||||
lock=None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
"""
|
||||
Performs a single training step to update the policy's weights.
|
||||
|
||||
This function executes the forward and backward passes, clips gradients, and steps the optimizer and
|
||||
learning rate scheduler. It also handles mixed-precision training via a GradScaler.
|
||||
|
||||
Args:
|
||||
train_metrics: A MetricsTracker instance to record training statistics.
|
||||
policy: The policy model to be trained.
|
||||
batch: A batch of training data.
|
||||
optimizer: The optimizer used to update the policy's parameters.
|
||||
grad_clip_norm: The maximum norm for gradient clipping.
|
||||
grad_scaler: The GradScaler for automatic mixed-precision training.
|
||||
lr_scheduler: An optional learning rate scheduler.
|
||||
use_amp: A boolean indicating whether to use automatic mixed precision.
|
||||
lock: An optional lock for thread-safe optimizer updates.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- The updated MetricsTracker with new statistics for this step.
|
||||
- A dictionary of outputs from the policy's forward pass, for logging purposes.
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
device = get_device_from_parameters(policy)
|
||||
policy.train()
|
||||
@@ -108,6 +130,20 @@ def update_policy(
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
"""
|
||||
Main function to train a policy.
|
||||
|
||||
This function orchestrates the entire training pipeline, including:
|
||||
- Setting up logging, seeding, and device configuration.
|
||||
- Creating the dataset, evaluation environment (if applicable), policy, and optimizer.
|
||||
- Handling resumption from a checkpoint.
|
||||
- Running the main training loop, which involves fetching data batches and calling `update_policy`.
|
||||
- Periodically logging metrics, saving model checkpoints, and evaluating the policy.
|
||||
- Pushing the final trained model to the Hugging Face Hub if configured.
|
||||
|
||||
Args:
|
||||
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
||||
"""
|
||||
cfg.validate()
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
|
||||
|
||||
@@ -121,6 +121,21 @@ def teleop_loop(
|
||||
robot_action_processor: RobotProcessorPipeline[dict[str, Any]] | None = None,
|
||||
robot_observation_processor: RobotProcessorPipeline[EnvTransition] | None = None,
|
||||
):
|
||||
"""
|
||||
This function continuously reads actions from a teleoperation device, processes them through optional
|
||||
pipelines, sends them to a robot, and optionally displays the robot's state. The loop runs at a
|
||||
specified frequency until a set duration is reached or it is manually interrupted.
|
||||
|
||||
Args:
|
||||
teleop: The teleoperator device instance providing control actions.
|
||||
robot: The robot instance being controlled.
|
||||
fps: The target frequency for the control loop in frames per second.
|
||||
display_data: If True, fetches robot observations and displays them in the console and Rerun.
|
||||
duration: The maximum duration of the teleoperation loop in seconds. If None, the loop runs indefinitely.
|
||||
teleop_action_processor: An optional pipeline to process raw actions from the teleoperator.
|
||||
robot_action_processor: An optional pipeline to process actions before they are sent to the robot.
|
||||
robot_observation_processor: An optional pipeline to process raw observations from the robot.
|
||||
"""
|
||||
# Initialize processors with defaults if not provided
|
||||
teleop_action_processor: RobotProcessorPipeline[EnvTransition] = (
|
||||
teleop_action_processor
|
||||
|
||||
@@ -26,28 +26,35 @@ from lerobot.teleoperators.phone.config_phone import PhoneOS
|
||||
@dataclass
|
||||
class MapPhoneActionToRobotAction(ActionProcessorStep):
|
||||
"""
|
||||
Map calibrated phone pose (actions) to the inputs for robot actions
|
||||
Maps calibrated phone pose actions to standardized robot action inputs.
|
||||
|
||||
Expected input ACTION keys:
|
||||
{
|
||||
"action.phone.enabled": bool,
|
||||
"action.phone.pos": np.ndarray,
|
||||
"action.phone.rot": Rotation,
|
||||
"action.phone.raw_inputs": dict,
|
||||
}
|
||||
This processor step acts as a bridge between the phone teleoperator's output
|
||||
and the robot's expected action format. It remaps the phone's 6-DoF pose
|
||||
(position and rotation) to the robot's target end-effector pose, applying
|
||||
necessary axis inversions and swaps. It also interprets platform-specific
|
||||
button presses to generate a gripper command.
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.enabled": bool,
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"action.gripper": float,
|
||||
}
|
||||
Attributes:
|
||||
platform: The operating system of the phone (iOS or Android), used
|
||||
to determine the correct button mappings for the gripper.
|
||||
"""
|
||||
|
||||
platform: PhoneOS
|
||||
_enabled_prev: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def action(self, act: dict) -> dict:
|
||||
"""
|
||||
Processes the phone action dictionary to create a robot action dictionary.
|
||||
|
||||
Args:
|
||||
act: The input action dictionary from the phone teleoperator.
|
||||
|
||||
Returns:
|
||||
A new action dictionary formatted for the robot controller.
|
||||
|
||||
Raises:
|
||||
ValueError: If 'pos' or 'rot' keys are missing from the input action.
|
||||
"""
|
||||
# Pop them from the action
|
||||
enabled = bool(act.pop(f"{ACTION}.phone.enabled", 0))
|
||||
pos = act.pop(f"{ACTION}.phone.pos", None)
|
||||
|
||||
@@ -108,7 +108,17 @@ class IOSPhone(BasePhone, Teleoperator):
|
||||
print("Calibration done\n")
|
||||
|
||||
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
|
||||
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
|
||||
"""
|
||||
Blocks execution until the calibration trigger is detected from the iOS device.
|
||||
|
||||
This method enters a loop, continuously reading the phone's state. It waits for the user to press
|
||||
and hold the 'B1' button in the HEBI Mobile I/O app. Once B1 is pressed, the loop breaks and
|
||||
returns the phone's pose at that exact moment.
|
||||
|
||||
Returns:
|
||||
A tuple containing the position (np.ndarray) and rotation (Rotation) of the phone at the
|
||||
moment the trigger was activated.
|
||||
"""
|
||||
while True:
|
||||
has_pose, position, rotation, fb_pose = self._read_current_pose()
|
||||
if not has_pose:
|
||||
@@ -126,6 +136,21 @@ class IOSPhone(BasePhone, Teleoperator):
|
||||
time.sleep(0.01)
|
||||
|
||||
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
|
||||
"""
|
||||
Reads the instantaneous 6-DoF pose from the connected iOS device via the HEBI SDK.
|
||||
|
||||
This method fetches the latest feedback packet from the HEBI group, extracts the ARKit
|
||||
position and orientation, and converts them into a standard format. It also applies a
|
||||
configured camera offset to adjust the pose from the camera's frame to the phone's
|
||||
physical frame.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A boolean indicating if a valid pose was successfully read.
|
||||
- The 3D position as a NumPy array, or None if not available.
|
||||
- The orientation as a `Rotation` object, or None if not available.
|
||||
- The raw HEBI feedback object for accessing other data like button presses.
|
||||
"""
|
||||
fbk = self._group.get_next_feedback()
|
||||
pose = fbk[0]
|
||||
ar_pos = getattr(pose, "ar_position", None)
|
||||
@@ -228,7 +253,18 @@ class AndroidPhone(BasePhone, Teleoperator):
|
||||
print("Calibration done\n")
|
||||
|
||||
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
|
||||
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
|
||||
"""
|
||||
Blocks execution until the calibration trigger is detected from the Android device.
|
||||
|
||||
This method enters a loop, continuously checking the latest message received from the WebXR
|
||||
session. It waits for the user to touch and move their finger on the screen, which generates
|
||||
a `move` event. Once this event is detected, the loop breaks and returns the phone's current
|
||||
pose.
|
||||
|
||||
Returns:
|
||||
A tuple containing the position (np.ndarray) and rotation (Rotation) of the phone at the
|
||||
moment the trigger was activated.
|
||||
"""
|
||||
while True:
|
||||
with self._android_lock:
|
||||
msg = self._latest_message or {}
|
||||
@@ -241,6 +277,20 @@ class AndroidPhone(BasePhone, Teleoperator):
|
||||
time.sleep(0.01)
|
||||
|
||||
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
|
||||
"""
|
||||
Reads the latest 6-DoF pose received from the Android device's WebXR session.
|
||||
|
||||
This method accesses the most recent pose data stored by the `_android_callback`. It uses a
|
||||
thread lock to safely read the shared `_latest_pose` variable. The pose, a 4x4 matrix, is
|
||||
then decomposed into position and rotation, and the configured camera offset is applied.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A boolean indicating if a valid pose was available.
|
||||
- The 3D position as a NumPy array, or None if no pose has been received yet.
|
||||
- The orientation as a `Rotation` object, or None if no pose has been received.
|
||||
- The raw 4x4 pose matrix as received from the teleop stream.
|
||||
"""
|
||||
with self._android_lock:
|
||||
if self._latest_pose is None:
|
||||
return False, None, None, None
|
||||
@@ -251,6 +301,19 @@ class AndroidPhone(BasePhone, Teleoperator):
|
||||
return True, pos, rot, pose
|
||||
|
||||
def _android_callback(self, pose: np.ndarray, message: dict) -> None:
|
||||
"""
|
||||
Callback function to handle incoming data from the Android teleop stream.
|
||||
|
||||
This method is executed by the `teleop` package's subscriber thread whenever a new
|
||||
pose and message are received from the WebXR session on the Android phone. It updates
|
||||
the internal state (`_latest_pose` and `_latest_message`) with the new data.
|
||||
A thread lock is used to ensure that these shared variables are updated atomically,
|
||||
preventing race conditions with the main thread that reads them.
|
||||
|
||||
Args:
|
||||
pose: A 4x4 NumPy array representing the phone's transformation matrix.
|
||||
message: A dictionary containing additional data, such as button presses or touch events.
|
||||
"""
|
||||
with self._android_lock:
|
||||
self._latest_pose = pose
|
||||
self._latest_message = message
|
||||
|
||||
@@ -36,6 +36,20 @@ from lerobot.robots import Robot
|
||||
|
||||
|
||||
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None):
|
||||
"""
|
||||
Logs performance metrics for a single step of the robot control loop.
|
||||
|
||||
This function formats and prints a single line of log information, including episode/frame counters,
|
||||
total loop time (dt), and detailed timings for various robot and camera operations. It can also
|
||||
highlight performance drops in yellow if the actual FPS is lower than the target FPS.
|
||||
|
||||
Args:
|
||||
robot: The `Robot` instance, used to access its internal logs for detailed timings.
|
||||
dt_s: The total duration of the control loop step in seconds.
|
||||
episode_index: The index of the current episode.
|
||||
frame_index: The index of the current frame within the episode.
|
||||
fps: The target frames per second, used to check for performance degradation.
|
||||
"""
|
||||
log_items = []
|
||||
if episode_index is not None:
|
||||
log_items.append(f"ep:{episode_index}")
|
||||
@@ -81,7 +95,16 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
|
||||
|
||||
@cache
|
||||
def is_headless():
|
||||
"""Detects if python is running without a monitor."""
|
||||
"""
|
||||
Detects if the Python script is running in a headless environment (e.g., without a display).
|
||||
|
||||
This function attempts to import `pynput`, a library that requires a graphical environment.
|
||||
If the import fails, it assumes the environment is headless. The result is cached to avoid
|
||||
re-running the check.
|
||||
|
||||
Returns:
|
||||
True if the environment is determined to be headless, False otherwise.
|
||||
"""
|
||||
try:
|
||||
import pynput # noqa
|
||||
|
||||
@@ -108,6 +131,29 @@ def predict_action(
|
||||
task: str | None = None,
|
||||
robot_type: str | None = None,
|
||||
):
|
||||
"""
|
||||
Performs a single-step inference to predict a robot action from an observation.
|
||||
|
||||
This function encapsulates the full inference pipeline:
|
||||
1. Prepares the observation by converting it to PyTorch tensors and adding a batch dimension.
|
||||
2. Runs the preprocessor pipeline on the observation.
|
||||
3. Feeds the processed observation to the policy to get a raw action.
|
||||
4. Runs the postprocessor pipeline on the raw action.
|
||||
5. Formats the final action by removing the batch dimension and moving it to the CPU.
|
||||
|
||||
Args:
|
||||
observation: A dictionary of NumPy arrays representing the robot's current observation.
|
||||
policy: The `PreTrainedPolicy` model to use for action prediction.
|
||||
device: The `torch.device` (e.g., 'cuda' or 'cpu') to run inference on.
|
||||
preprocessor: The `PolicyProcessorPipeline` for preprocessing observations.
|
||||
postprocessor: The `PolicyProcessorPipeline` for postprocessing actions.
|
||||
use_amp: A boolean to enable/disable Automatic Mixed Precision for CUDA inference.
|
||||
task: An optional string identifier for the task.
|
||||
robot_type: An optional string identifier for the robot type.
|
||||
|
||||
Returns:
|
||||
A `torch.Tensor` containing the predicted action, ready for the robot.
|
||||
"""
|
||||
observation = copy(observation)
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
@@ -143,6 +189,18 @@ def predict_action(
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
"""
|
||||
Initializes a non-blocking keyboard listener for real-time user interaction.
|
||||
|
||||
This function sets up a listener for specific keys (right arrow, left arrow, escape) to control
|
||||
the program flow during execution, such as stopping recording or exiting loops. It gracefully
|
||||
handles headless environments where keyboard listening is not possible.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- The `pynput.keyboard.Listener` instance, or `None` if in a headless environment.
|
||||
- A dictionary of event flags (e.g., `exit_early`) that are set by key presses.
|
||||
"""
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
@@ -184,6 +242,19 @@ def init_keyboard_listener():
|
||||
|
||||
|
||||
def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
"""
|
||||
Validates the dataset repository name against the presence of a policy configuration.
|
||||
|
||||
This function enforces a naming convention: a dataset repository ID should start with "eval_"
|
||||
if and only if a policy configuration is provided for evaluation purposes.
|
||||
|
||||
Args:
|
||||
repo_id: The Hugging Face Hub repository ID of the dataset.
|
||||
policy_cfg: The configuration object for the policy, or `None`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the naming convention is violated.
|
||||
"""
|
||||
_, dataset_name = repo_id.split("/")
|
||||
# either repo_id doesnt start with "eval_" and there is no policy
|
||||
# or repo_id starts with "eval_" and there is a policy
|
||||
@@ -204,6 +275,21 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
def sanity_check_dataset_robot_compatibility(
|
||||
dataset: LeRobotDataset, robot: Robot, fps: int, features: dict
|
||||
) -> None:
|
||||
"""
|
||||
Checks if a dataset's metadata is compatible with the current robot and recording setup.
|
||||
|
||||
This function compares key metadata fields (`robot_type`, `fps`, and `features`) from the
|
||||
dataset against the current configuration to ensure that appended data will be consistent.
|
||||
|
||||
Args:
|
||||
dataset: The `LeRobotDataset` instance to check.
|
||||
robot: The `Robot` instance representing the current hardware setup.
|
||||
fps: The current recording frequency (frames per second).
|
||||
features: The dictionary of features for the current recording session.
|
||||
|
||||
Raises:
|
||||
ValueError: If any of the checked metadata fields do not match.
|
||||
"""
|
||||
fields = [
|
||||
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||||
("fps", dataset.fps, fps),
|
||||
|
||||
@@ -42,7 +42,23 @@ def log_rerun_data(
|
||||
observation: dict[str, Any] | None = None,
|
||||
action: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Log observation and action data to Rerun for visualization."""
|
||||
"""
|
||||
Logs observation and action data to Rerun for real-time visualization.
|
||||
|
||||
This function iterates through the provided observation and action dictionaries and sends their contents
|
||||
to the Rerun viewer. It handles different data types appropriately:
|
||||
- Scalar values (floats, ints) are logged as `rr.Scalar`.
|
||||
- 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed
|
||||
from CHW to HWC format and logged as `rr.Image`.
|
||||
- 1D NumPy arrays are logged as a series of individual scalars, with each element indexed.
|
||||
- Other multi-dimensional arrays are flattened and logged as individual scalars.
|
||||
|
||||
Keys are automatically namespaced with "observation." or "action." if not already present.
|
||||
|
||||
Args:
|
||||
observation: An optional dictionary containing observation data to log.
|
||||
action: An optional dictionary containing action data to log.
|
||||
"""
|
||||
if observation:
|
||||
for k, v in observation.items():
|
||||
if v is None:
|
||||
|
||||
Reference in New Issue
Block a user