mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
* fix(datasets): remove unreachable timestamp branch in add_frame and document caller contract
- Remove dead code: frame.pop("timestamp") branch in add_frame() could never
execute because validate_frame() raises ValueError for any DEFAULT_FEATURES
key (including timestamp) before we reach that line.
- Expand add_frame() docstring: explicitly document that timestamp and
frame_index must NOT be passed by the caller.
- Add explanatory comment in validate_frame(): clarifies why DEFAULT_FEATURES
are excluded from expected_features, preventing future re-introduction of
the dead branch.
The dead branch originated in #1200, which fixed a shape-(1,) mismatch for a
code path that was subsequently made unreachable by a refactor of validate_frame.
* chore(datasets): narrow PR scope
* fix(datasets): move add_frame timestamp cleanup to dataset_writer
557 lines
21 KiB
Python
557 lines
21 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from pprint import pformat
|
|
from typing import Any
|
|
|
|
import datasets
|
|
import numpy as np
|
|
from PIL import Image as PILImage
|
|
|
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
|
from lerobot.datasets.utils import (
|
|
DEFAULT_CHUNK_SIZE,
|
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
|
DEFAULT_DATA_PATH,
|
|
DEFAULT_FEATURES,
|
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
|
DEFAULT_VIDEO_PATH,
|
|
)
|
|
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR
|
|
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
|
|
|
|
|
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
|
"""Convert a LeRobot features dictionary to a `datasets.Features` object.
|
|
|
|
Args:
|
|
features (dict): A LeRobot-style feature dictionary.
|
|
|
|
Returns:
|
|
datasets.Features: The corresponding Hugging Face `datasets.Features` object.
|
|
|
|
Raises:
|
|
ValueError: If a feature has an unsupported shape.
|
|
"""
|
|
hf_features = {}
|
|
for key, ft in features.items():
|
|
if ft["dtype"] == "video":
|
|
continue
|
|
elif ft["dtype"] == "image":
|
|
hf_features[key] = datasets.Image()
|
|
elif ft["shape"] == (1,):
|
|
hf_features[key] = datasets.Value(dtype=ft["dtype"])
|
|
elif len(ft["shape"]) == 1:
|
|
hf_features[key] = datasets.Sequence(
|
|
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
|
)
|
|
elif len(ft["shape"]) == 2:
|
|
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
|
|
elif len(ft["shape"]) == 3:
|
|
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
|
|
elif len(ft["shape"]) == 4:
|
|
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
|
|
elif len(ft["shape"]) == 5:
|
|
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
|
|
else:
|
|
raise ValueError(f"Corresponding feature is not valid: {ft}")
|
|
|
|
return datasets.Features(hf_features)
|
|
|
|
|
|
def _validate_feature_names(features: dict[str, dict]) -> None:
|
|
"""Validate that feature names do not contain invalid characters.
|
|
|
|
Args:
|
|
features (dict): The LeRobot features dictionary.
|
|
|
|
Raises:
|
|
ValueError: If any feature name contains '/'.
|
|
"""
|
|
invalid_features = {name: ft for name, ft in features.items() if "/" in name}
|
|
if invalid_features:
|
|
raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
|
|
|
|
|
|
def hw_to_dataset_features(
|
|
hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
|
|
) -> dict[str, dict]:
|
|
"""Convert hardware-specific features to a LeRobot dataset feature dictionary.
|
|
|
|
This function takes a dictionary describing hardware outputs (like joint states
|
|
or camera image shapes) and formats it into the standard LeRobot feature
|
|
specification.
|
|
|
|
Args:
|
|
hw_features (dict): Dictionary mapping feature names to their type (float for
|
|
joints) or shape (tuple for images).
|
|
prefix (str): The prefix to add to the feature keys (e.g., "observation"
|
|
or "action").
|
|
use_video (bool): If True, image features are marked as "video", otherwise "image".
|
|
|
|
Returns:
|
|
dict: A LeRobot features dictionary.
|
|
"""
|
|
features = {}
|
|
joint_fts = {
|
|
key: ftype
|
|
for key, ftype in hw_features.items()
|
|
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
|
|
}
|
|
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
|
|
|
if joint_fts and prefix == ACTION:
|
|
features[prefix] = {
|
|
"dtype": "float32",
|
|
"shape": (len(joint_fts),),
|
|
"names": list(joint_fts),
|
|
}
|
|
|
|
if joint_fts and prefix == OBS_STR:
|
|
features[f"{prefix}.state"] = {
|
|
"dtype": "float32",
|
|
"shape": (len(joint_fts),),
|
|
"names": list(joint_fts),
|
|
}
|
|
|
|
for key, shape in cam_fts.items():
|
|
features[f"{prefix}.images.{key}"] = {
|
|
"dtype": "video" if use_video else "image",
|
|
"shape": shape,
|
|
"names": ["height", "width", "channels"],
|
|
}
|
|
|
|
_validate_feature_names(features)
|
|
return features
|
|
|
|
|
|
def build_dataset_frame(
|
|
ds_features: dict[str, dict], values: dict[str, Any], prefix: str
|
|
) -> dict[str, np.ndarray]:
|
|
"""Construct a single data frame from raw values based on dataset features.
|
|
|
|
A "frame" is a dictionary containing all the data for a single timestep,
|
|
formatted as numpy arrays according to the feature specification.
|
|
|
|
Args:
|
|
ds_features (dict): The LeRobot dataset features dictionary.
|
|
values (dict): A dictionary of raw values from the hardware/environment.
|
|
prefix (str): The prefix to filter features by (e.g., "observation"
|
|
or "action").
|
|
|
|
Returns:
|
|
dict: A dictionary representing a single frame of data.
|
|
"""
|
|
frame = {}
|
|
for key, ft in ds_features.items():
|
|
if key in DEFAULT_FEATURES or not key.startswith(prefix):
|
|
continue
|
|
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
|
|
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
|
elif ft["dtype"] in ["image", "video"]:
|
|
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
|
|
|
return frame
|
|
|
|
|
|
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
|
|
"""Convert dataset features to policy features.
|
|
|
|
This function transforms the dataset's feature specification into a format
|
|
that a policy can use, classifying features by type (e.g., visual, state,
|
|
action) and ensuring correct shapes (e.g., channel-first for images).
|
|
|
|
Args:
|
|
features (dict): The LeRobot dataset features dictionary.
|
|
|
|
Returns:
|
|
dict: A dictionary mapping feature keys to `PolicyFeature` objects.
|
|
|
|
Raises:
|
|
ValueError: If an image feature does not have a 3D shape.
|
|
"""
|
|
# TODO(aliberts): Implement "type" in dataset features and simplify this
|
|
policy_features = {}
|
|
for key, ft in features.items():
|
|
shape = ft["shape"]
|
|
if ft["dtype"] in ["image", "video"]:
|
|
type = FeatureType.VISUAL
|
|
if len(shape) != 3:
|
|
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
|
|
|
names = ft["names"]
|
|
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
|
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
|
shape = (shape[2], shape[0], shape[1])
|
|
elif key == OBS_ENV_STATE:
|
|
type = FeatureType.ENV
|
|
elif key.startswith(OBS_STR):
|
|
type = FeatureType.STATE
|
|
elif key.startswith(ACTION):
|
|
type = FeatureType.ACTION
|
|
else:
|
|
continue
|
|
|
|
policy_features[key] = PolicyFeature(
|
|
type=type,
|
|
shape=shape,
|
|
)
|
|
|
|
return policy_features
|
|
|
|
|
|
def combine_feature_dicts(*dicts: dict) -> dict:
|
|
"""Merge LeRobot grouped feature dicts.
|
|
|
|
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
|
|
- For others (e.g. `observation.images.*`), the last one wins (if they are identical).
|
|
|
|
Args:
|
|
*dicts: A variable number of LeRobot feature dictionaries to merge.
|
|
|
|
Returns:
|
|
dict: A single merged feature dictionary.
|
|
|
|
Raises:
|
|
ValueError: If there's a dtype mismatch for a feature being merged.
|
|
"""
|
|
out: dict = {}
|
|
for d in dicts:
|
|
for key, value in d.items():
|
|
if not isinstance(value, dict):
|
|
out[key] = value
|
|
continue
|
|
|
|
dtype = value.get("dtype")
|
|
shape = value.get("shape")
|
|
is_vector = (
|
|
dtype not in ("image", "video", "string")
|
|
and isinstance(shape, tuple)
|
|
and len(shape) == 1
|
|
and "names" in value
|
|
)
|
|
|
|
if is_vector:
|
|
# Initialize or retrieve the accumulating dict for this feature key
|
|
target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
|
|
# Ensure consistent data types across merged entries
|
|
if "dtype" in target and dtype != target["dtype"]:
|
|
raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
|
|
|
|
# Merge feature names: append only new ones to preserve order without duplicates
|
|
seen = set(target["names"])
|
|
for n in value["names"]:
|
|
if n not in seen:
|
|
target["names"].append(n)
|
|
seen.add(n)
|
|
# Recompute the shape to reflect the updated number of features
|
|
target["shape"] = (len(target["names"]),)
|
|
else:
|
|
# For images/videos and non-1D entries: override with the latest definition
|
|
out[key] = value
|
|
return out
|
|
|
|
|
|
def create_empty_dataset_info(
|
|
codebase_version: str,
|
|
fps: int,
|
|
features: dict,
|
|
use_videos: bool,
|
|
robot_type: str | None = None,
|
|
chunks_size: int | None = None,
|
|
data_files_size_in_mb: int | None = None,
|
|
video_files_size_in_mb: int | None = None,
|
|
) -> dict:
|
|
"""Create a template dictionary for a new dataset's `info.json`.
|
|
|
|
Args:
|
|
codebase_version (str): The version of the LeRobot codebase.
|
|
fps (int): The frames per second of the data.
|
|
features (dict): The LeRobot features dictionary for the dataset.
|
|
use_videos (bool): Whether the dataset will store videos.
|
|
robot_type (str | None): The type of robot used, if any.
|
|
|
|
Returns:
|
|
dict: A dictionary with the initial dataset metadata.
|
|
"""
|
|
return {
|
|
"codebase_version": codebase_version,
|
|
"robot_type": robot_type,
|
|
"total_episodes": 0,
|
|
"total_frames": 0,
|
|
"total_tasks": 0,
|
|
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
|
|
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
|
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
|
"fps": fps,
|
|
"splits": {},
|
|
"data_path": DEFAULT_DATA_PATH,
|
|
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
|
"features": features,
|
|
}
|
|
|
|
|
|
def check_delta_timestamps(
|
|
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
|
) -> bool:
|
|
"""Check if delta timestamps are multiples of 1/fps +/- tolerance.
|
|
|
|
This ensures that adding these delta timestamps to any existing timestamp in
|
|
the dataset will result in a value that aligns with the dataset's frame rate.
|
|
|
|
Args:
|
|
delta_timestamps (dict): A dictionary where values are lists of time
|
|
deltas in seconds.
|
|
fps (int): The frames per second of the dataset.
|
|
tolerance_s (float): The allowed tolerance in seconds.
|
|
raise_value_error (bool): If True, raises an error on failure.
|
|
|
|
Returns:
|
|
bool: True if all deltas are valid, False otherwise.
|
|
|
|
Raises:
|
|
ValueError: If any delta is outside the tolerance and `raise_value_error` is True.
|
|
"""
|
|
outside_tolerance = {}
|
|
for key, delta_ts in delta_timestamps.items():
|
|
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
|
|
if not all(within_tolerance):
|
|
outside_tolerance[key] = [
|
|
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
|
|
]
|
|
|
|
if len(outside_tolerance) > 0:
|
|
if raise_value_error:
|
|
raise ValueError(
|
|
f"""
|
|
The following delta_timestamps are found outside of tolerance range.
|
|
Please make sure they are multiples of 1/{fps} +/- tolerance and adjust
|
|
their values accordingly.
|
|
\n{pformat(outside_tolerance)}
|
|
"""
|
|
)
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
|
|
"""Convert delta timestamps in seconds to delta indices in frames.
|
|
|
|
Args:
|
|
delta_timestamps (dict): A dictionary of time deltas in seconds.
|
|
fps (int): The frames per second of the dataset.
|
|
|
|
Returns:
|
|
dict: A dictionary of frame delta indices.
|
|
"""
|
|
delta_indices = {}
|
|
for key, delta_ts in delta_timestamps.items():
|
|
delta_indices[key] = [round(d * fps) for d in delta_ts]
|
|
|
|
return delta_indices
|
|
|
|
|
|
def validate_frame(frame: dict, features: dict) -> None:
|
|
# DEFAULT_FEATURES (timestamp, frame_index, episode_index, index, task_index) are
|
|
# auto-populated by the recording pipeline (add_frame / save_episode) and must not
|
|
# be supplied by the caller. Excluding them here means any frame dict that contains
|
|
# these keys will be rejected as extra features.
|
|
expected_features = set(features) - set(DEFAULT_FEATURES)
|
|
actual_features = set(frame)
|
|
|
|
# task is a special required field that's not part of regular features
|
|
if "task" not in actual_features:
|
|
raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n")
|
|
|
|
# Remove task from actual_features for regular feature validation
|
|
actual_features_for_validation = actual_features - {"task"}
|
|
|
|
error_message = validate_features_presence(actual_features_for_validation, expected_features)
|
|
|
|
common_features = actual_features_for_validation & expected_features
|
|
for name in common_features:
|
|
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
|
|
|
if error_message:
|
|
raise ValueError(error_message)
|
|
|
|
|
|
def validate_features_presence(actual_features: set[str], expected_features: set[str]) -> str:
|
|
"""Check for missing or extra features in a frame.
|
|
|
|
Args:
|
|
actual_features (set[str]): The set of feature names present in the frame.
|
|
expected_features (set[str]): The set of feature names expected in the frame.
|
|
|
|
Returns:
|
|
str: An error message string if there's a mismatch, otherwise an empty string.
|
|
"""
|
|
error_message = ""
|
|
missing_features = expected_features - actual_features
|
|
extra_features = actual_features - expected_features
|
|
|
|
if missing_features or extra_features:
|
|
error_message += "Feature mismatch in `frame` dictionary:\n"
|
|
if missing_features:
|
|
error_message += f"Missing features: {missing_features}\n"
|
|
if extra_features:
|
|
error_message += f"Extra features: {extra_features}\n"
|
|
|
|
return error_message
|
|
|
|
|
|
def validate_feature_dtype_and_shape(
|
|
name: str, feature: dict, value: np.ndarray | PILImage.Image | str
|
|
) -> str:
|
|
"""Validate the dtype and shape of a single feature's value.
|
|
|
|
Args:
|
|
name (str): The name of the feature.
|
|
feature (dict): The feature specification from the LeRobot features dictionary.
|
|
value: The value of the feature to validate.
|
|
|
|
Returns:
|
|
str: An error message if validation fails, otherwise an empty string.
|
|
|
|
Raises:
|
|
NotImplementedError: If the feature dtype is not supported for validation.
|
|
"""
|
|
expected_dtype = feature["dtype"]
|
|
expected_shape = feature["shape"]
|
|
if is_valid_numpy_dtype_string(expected_dtype):
|
|
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
|
elif expected_dtype in ["image", "video"]:
|
|
return validate_feature_image_or_video(name, expected_shape, value)
|
|
elif expected_dtype == "string":
|
|
return validate_feature_string(name, value)
|
|
else:
|
|
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
|
|
|
|
|
def validate_feature_numpy_array(
|
|
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
|
|
) -> str:
|
|
"""Validate a feature that is expected to be a numpy array.
|
|
|
|
Args:
|
|
name (str): The name of the feature.
|
|
expected_dtype (str): The expected numpy dtype as a string.
|
|
expected_shape (list[int]): The expected shape.
|
|
value (np.ndarray): The numpy array to validate.
|
|
|
|
Returns:
|
|
str: An error message if validation fails, otherwise an empty string.
|
|
"""
|
|
error_message = ""
|
|
if isinstance(value, np.ndarray):
|
|
actual_dtype = value.dtype
|
|
actual_shape = value.shape
|
|
|
|
if actual_dtype != np.dtype(expected_dtype):
|
|
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
|
|
|
|
if actual_shape != expected_shape:
|
|
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
|
|
else:
|
|
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
|
|
|
|
return error_message
|
|
|
|
|
|
def validate_feature_image_or_video(
|
|
name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image
|
|
) -> str:
|
|
"""Validate a feature that is expected to be an image or video frame.
|
|
|
|
Accepts `np.ndarray` (channel-first or channel-last) or `PIL.Image.Image`.
|
|
|
|
Args:
|
|
name (str): The name of the feature.
|
|
expected_shape (list[str]): The expected shape (C, H, W).
|
|
value: The image data to validate.
|
|
|
|
Returns:
|
|
str: An error message if validation fails, otherwise an empty string.
|
|
"""
|
|
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
|
error_message = ""
|
|
if isinstance(value, np.ndarray):
|
|
actual_shape = value.shape
|
|
c, h, w = expected_shape
|
|
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
|
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
|
elif isinstance(value, PILImage.Image):
|
|
pass
|
|
else:
|
|
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
|
|
|
|
return error_message
|
|
|
|
|
|
def validate_feature_string(name: str, value: str) -> str:
|
|
"""Validate a feature that is expected to be a string.
|
|
|
|
Args:
|
|
name (str): The name of the feature.
|
|
value (str): The value to validate.
|
|
|
|
Returns:
|
|
str: An error message if validation fails, otherwise an empty string.
|
|
"""
|
|
if not isinstance(value, str):
|
|
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
|
return ""
|
|
|
|
|
|
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None:
|
|
"""Validate the episode buffer before it's written to disk.
|
|
|
|
Ensures the buffer has the required keys, contains at least one frame, and
|
|
has features consistent with the dataset's specification.
|
|
|
|
Args:
|
|
episode_buffer (dict): The buffer containing data for a single episode.
|
|
total_episodes (int): The current total number of episodes in the dataset.
|
|
features (dict): The LeRobot features dictionary for the dataset.
|
|
|
|
Raises:
|
|
ValueError: If the buffer is invalid.
|
|
NotImplementedError: If the episode index is manually set and doesn't match.
|
|
"""
|
|
if "size" not in episode_buffer:
|
|
raise ValueError("size key not found in episode_buffer")
|
|
|
|
if "task" not in episode_buffer:
|
|
raise ValueError("task key not found in episode_buffer")
|
|
|
|
if episode_buffer["episode_index"] != total_episodes:
|
|
# TODO(aliberts): Add option to use existing episode_index
|
|
raise NotImplementedError(
|
|
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
|
"match the total number of episodes already in the dataset. This is not supported for now."
|
|
)
|
|
|
|
if episode_buffer["size"] == 0:
|
|
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
|
|
|
|
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
|
|
if not buffer_keys == set(features):
|
|
raise ValueError(
|
|
f"Features from `episode_buffer` don't match the ones in `features`."
|
|
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
|
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
|
)
|