refactor: more changes

This commit is contained in:
Steven Palma
2026-04-11 11:13:15 +02:00
parent 4767f51971
commit 964acd0151
27 changed files with 338 additions and 292 deletions

View File

@@ -23,7 +23,6 @@ from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
from lerobot.policies import ( # noqa: F401
@@ -36,6 +35,7 @@ from lerobot.policies import ( # noqa: F401
)
from lerobot.robots.robot import Robot
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
from lerobot.utils.utils import init_logging
Action = torch.Tensor

View File

@@ -746,7 +746,7 @@ def save_annotations_to_dataset(
dataset_path: Path, annotations: dict[int, SubtaskAnnotation], fps: int, prefix: str = "sparse"
):
"""Save annotations to LeRobot dataset parquet format."""
from lerobot.datasets.io_utils import load_episodes
from lerobot.datasets import load_episodes
from lerobot.datasets.utils import DEFAULT_EPISODES_PATH
episodes_dataset = load_episodes(dataset_path)
@@ -841,7 +841,7 @@ def generate_auto_sparse_annotations(
def load_annotations_from_dataset(dataset_path: Path, prefix: str = "sparse") -> dict[int, SubtaskAnnotation]:
"""Load annotations from LeRobot dataset parquet files."""
from lerobot.datasets.io_utils import load_episodes
from lerobot.datasets import load_episodes
episodes_dataset = load_episodes(dataset_path)
if not episodes_dataset or len(episodes_dataset) == 0:

View File

@@ -20,10 +20,15 @@ from lerobot.utils.import_utils import require_package
require_package("datasets", extra="dataset")
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.io_utils import load_episodes, write_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.video_utils import VideoEncodingManager
__all__ = [
"EpisodeAwareSampler",
@@ -31,4 +36,11 @@ __all__ = [
"LeRobotDatasetMetadata",
"MultiLeRobotDataset",
"StreamingLeRobotDataset",
"VideoEncodingManager",
"aggregate_pipeline_dataset_features",
"create_initial_features",
"load_episodes",
"make_dataset",
"safe_stop_image_writer",
"write_stats",
]

View File

@@ -24,7 +24,7 @@ import pyarrow.parquet as pq
from huggingface_hub import snapshot_download
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.feature_utils import _validate_feature_names, create_empty_dataset_info
from lerobot.datasets.feature_utils import create_empty_dataset_info
from lerobot.datasets.io_utils import (
get_file_size_in_mb,
load_episodes,
@@ -39,7 +39,6 @@ from lerobot.datasets.io_utils import (
)
from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
DEFAULT_FEATURES,
INFO_PATH,
check_version_compatibility,
flatten_dict,
@@ -49,7 +48,8 @@ from lerobot.datasets.utils import (
update_chunk_file_indices,
)
from lerobot.datasets.video_utils import get_video_info
from lerobot.utils.constants import HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
from lerobot.utils.constants import DEFAULT_FEATURES, HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
from lerobot.utils.feature_utils import _validate_feature_names
CODEBASE_VERSION = "v3.0"

View File

@@ -25,12 +25,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.transforms import ImageTransforms
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
IMAGENET_STATS = {
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
}
from lerobot.utils.constants import ACTION, IMAGENET_STATS, OBS_PREFIX, REWARD
def resolve_delta_timestamps(

View File

@@ -14,22 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pprint import pformat
from typing import Any
import datasets
import numpy as np
from PIL import Image as PILImage
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
DEFAULT_FEATURES,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
)
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR
from lerobot.utils.constants import DEFAULT_FEATURES
from lerobot.utils.utils import is_valid_numpy_dtype_string
@@ -71,199 +68,6 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
return datasets.Features(hf_features)
def _validate_feature_names(features: dict[str, dict]) -> None:
"""Validate that feature names do not contain invalid characters.
Args:
features (dict): The LeRobot features dictionary.
Raises:
ValueError: If any feature name contains '/'.
"""
invalid_features = {name: ft for name, ft in features.items() if "/" in name}
if invalid_features:
raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
def hw_to_dataset_features(
hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
) -> dict[str, dict]:
"""Convert hardware-specific features to a LeRobot dataset feature dictionary.
This function takes a dictionary describing hardware outputs (like joint states
or camera image shapes) and formats it into the standard LeRobot feature
specification.
Args:
hw_features (dict): Dictionary mapping feature names to their type (float for
joints) or shape (tuple for images).
prefix (str): The prefix to add to the feature keys (e.g., "observation"
or "action").
use_video (bool): If True, image features are marked as "video", otherwise "image".
Returns:
dict: A LeRobot features dictionary.
"""
features = {}
joint_fts = {
key: ftype
for key, ftype in hw_features.items()
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
}
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
if joint_fts and prefix == ACTION:
features[prefix] = {
"dtype": "float32",
"shape": (len(joint_fts),),
"names": list(joint_fts),
}
if joint_fts and prefix == OBS_STR:
features[f"{prefix}.state"] = {
"dtype": "float32",
"shape": (len(joint_fts),),
"names": list(joint_fts),
}
for key, shape in cam_fts.items():
features[f"{prefix}.images.{key}"] = {
"dtype": "video" if use_video else "image",
"shape": shape,
"names": ["height", "width", "channels"],
}
_validate_feature_names(features)
return features
def build_dataset_frame(
ds_features: dict[str, dict], values: dict[str, Any], prefix: str
) -> dict[str, np.ndarray]:
"""Construct a single data frame from raw values based on dataset features.
A "frame" is a dictionary containing all the data for a single timestep,
formatted as numpy arrays according to the feature specification.
Args:
ds_features (dict): The LeRobot dataset features dictionary.
values (dict): A dictionary of raw values from the hardware/environment.
prefix (str): The prefix to filter features by (e.g., "observation"
or "action").
Returns:
dict: A dictionary representing a single frame of data.
"""
frame = {}
for key, ft in ds_features.items():
if key in DEFAULT_FEATURES or not key.startswith(prefix):
continue
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
elif ft["dtype"] in ["image", "video"]:
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
return frame
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
"""Convert dataset features to policy features.
This function transforms the dataset's feature specification into a format
that a policy can use, classifying features by type (e.g., visual, state,
action) and ensuring correct shapes (e.g., channel-first for images).
Args:
features (dict): The LeRobot dataset features dictionary.
Returns:
dict: A dictionary mapping feature keys to `PolicyFeature` objects.
Raises:
ValueError: If an image feature does not have a 3D shape.
"""
# TODO(aliberts): Implement "type" in dataset features and simplify this
policy_features = {}
for key, ft in features.items():
shape = ft["shape"]
if ft["dtype"] in ["image", "video"]:
type = FeatureType.VISUAL
if len(shape) != 3:
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
names = ft["names"]
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif key == OBS_ENV_STATE:
type = FeatureType.ENV
elif key.startswith(OBS_STR):
type = FeatureType.STATE
elif key.startswith(ACTION):
type = FeatureType.ACTION
else:
continue
policy_features[key] = PolicyFeature(
type=type,
shape=shape,
)
return policy_features
def combine_feature_dicts(*dicts: dict) -> dict:
"""Merge LeRobot grouped feature dicts.
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
- For others (e.g. `observation.images.*`), the last one wins (if they are identical).
Args:
*dicts: A variable number of LeRobot feature dictionaries to merge.
Returns:
dict: A single merged feature dictionary.
Raises:
ValueError: If there's a dtype mismatch for a feature being merged.
"""
out: dict = {}
for d in dicts:
for key, value in d.items():
if not isinstance(value, dict):
out[key] = value
continue
dtype = value.get("dtype")
shape = value.get("shape")
is_vector = (
dtype not in ("image", "video", "string")
and isinstance(shape, tuple)
and len(shape) == 1
and "names" in value
)
if is_vector:
# Initialize or retrieve the accumulating dict for this feature key
target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
# Ensure consistent data types across merged entries
if "dtype" in target and dtype != target["dtype"]:
raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
# Merge feature names: append only new ones to preserve order without duplicates
seen = set(target["names"])
for n in value["names"]:
if n not in seen:
target["names"].append(n)
seen.add(n)
# Recompute the shape to reflect the updated number of features
target["shape"] = (len(target["names"]),)
else:
# For images/videos and non-1D entries: override with the latest definition
out[key] = value
return out
def create_empty_dataset_info(
codebase_version: str,
fps: int,

View File

@@ -17,10 +17,10 @@ from collections.abc import Sequence
from typing import Any
from lerobot.configs.types import PipelineFeatureType
from lerobot.datasets.feature_utils import hw_to_dataset_features
from lerobot.processor import DataProcessorPipeline
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
from lerobot.utils.feature_utils import hw_to_dataset_features
def create_initial_features(

View File

@@ -93,14 +93,6 @@ LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_FEATURES = {
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
"index": {"dtype": "int64", "shape": (1,), "names": None},
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
}
def has_legacy_hub_download_metadata(root: Path) -> bool:
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.

View File

@@ -29,24 +29,17 @@ import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from torch import Tensor, nn
from lerobot.utils.import_utils import require_package
require_package("diffusers", extra="training")
from diffusers.schedulers.scheduling_ddim import DDIMScheduler # noqa: E402
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler # noqa: E402
from torch import Tensor, nn # noqa: E402
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig # noqa: E402
from lerobot.policies.pretrained import PreTrainedPolicy # noqa: E402
from lerobot.policies.utils import ( # noqa: E402
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import (
get_device_from_parameters,
get_dtype_from_parameters,
get_output_shape,
populate_queues,
)
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE # noqa: E402
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
class DiffusionPolicy(PreTrainedPolicy):
@@ -156,11 +149,17 @@ class DiffusionPolicy(PreTrainedPolicy):
return loss, None
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
def _make_noise_scheduler(name: str, **kwargs: dict):
"""
Factory for noise scheduler instances of the requested type. All kwargs are passed
to the scheduler.
"""
from lerobot.utils.import_utils import require_package
require_package("diffusers", extra="training")
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
if name == "DDPM":
return DDPMScheduler(**kwargs)
elif name == "DDIM":

View File

@@ -495,7 +495,7 @@ def make_policy(
kwargs = {}
if ds_meta is not None:
from lerobot.datasets.feature_utils import dataset_to_policy_features
from lerobot.utils.feature_utils import dataset_to_policy_features
features = dataset_to_policy_features(ds_meta.features)
else:

View File

@@ -34,17 +34,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torchvision
from torch import Tensor
from lerobot.utils.import_utils import require_package
require_package("diffusers", extra="training")
from diffusers.schedulers.scheduling_ddim import DDIMScheduler # noqa: E402
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler # noqa: E402
from torch import Tensor # noqa: E402
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig # noqa: E402
from lerobot.utils.import_utils import _transformers_available # noqa: E402
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
@@ -52,9 +45,9 @@ if TYPE_CHECKING or _transformers_available:
else:
CLIPTextModel = None
CLIPVisionModel = None
from lerobot.policies.pretrained import PreTrainedPolicy # noqa: E402
from lerobot.policies.utils import populate_queues # noqa: E402
from lerobot.utils.constants import ( # noqa: E402
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import (
ACTION,
OBS_IMAGES,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -648,6 +641,12 @@ class DiffusionObjective(nn.Module):
"prediction_type": config.prediction_type,
}
from lerobot.utils.import_utils import require_package
require_package("diffusers", extra="training")
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
if config.noise_scheduler_type == "DDPM":
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
elif config.noise_scheduler_type == "DDIM":

View File

@@ -162,7 +162,7 @@ def build_inference_frame(
Returns:
A dictionary of preprocessed tensors ready for model inference.
"""
from lerobot.datasets.feature_utils import build_dataset_frame
from lerobot.utils.feature_utils import build_dataset_frame
# Extracts the correct keys from the incoming raw observation
observation = build_dataset_frame(ds_features, observation, prefix=OBS_STR)

View File

@@ -21,7 +21,6 @@ import numpy as np
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.datasets.factory import IMAGENET_STATS
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
from lerobot.processor import (
@@ -40,6 +39,7 @@ from lerobot.processor import (
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import (
IMAGENET_STATS,
OBS_IMAGES,
OBS_PREFIX,
OBS_STATE,

View File

@@ -62,8 +62,7 @@ from torch.optim.optimizer import Optimizer
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets import LeRobotDataset, make_dataset
from lerobot.policies.factory import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions

View File

@@ -44,10 +44,9 @@ from huggingface_hub import HfApi
from requests import HTTPError
from tqdm import tqdm
from lerobot.datasets import LeRobotDataset
from lerobot.datasets import LeRobotDataset, write_stats
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION
from lerobot.datasets.io_utils import write_stats
from lerobot.utils.utils import init_logging

View File

@@ -85,11 +85,13 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets import LeRobotDataset
from lerobot.datasets.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.datasets import (
LeRobotDataset,
VideoEncodingManager,
aggregate_pipeline_dataset_features,
create_initial_features,
safe_stop_image_writer,
)
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc import ActionInterpolator
@@ -143,6 +145,7 @@ from lerobot.utils.control_utils import (
sanity_check_dataset_robot_compatibility,
)
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import (

View File

@@ -13,45 +13,44 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import dataclasses
import logging
import time
from contextlib import nullcontext
from pprint import pformat
from typing import Any
from typing import TYPE_CHECKING, Any
from lerobot.utils.import_utils import require_package
if TYPE_CHECKING:
from accelerate import Accelerator
require_package("accelerate", extra="training")
import torch
from termcolor import colored
from torch.optim import Optimizer
from tqdm import tqdm
import torch # noqa: E402
from accelerate import Accelerator # noqa: E402
from termcolor import colored # noqa: E402
from torch.optim import Optimizer # noqa: E402
from tqdm import tqdm # noqa: E402
from lerobot.configs import parser # noqa: E402
from lerobot.configs.train import TrainPipelineConfig # noqa: E402
from lerobot.datasets import EpisodeAwareSampler # noqa: E402
from lerobot.datasets.factory import make_dataset # noqa: E402
from lerobot.envs.factory import make_env, make_env_pre_post_processors # noqa: E402
from lerobot.envs.utils import close_envs # noqa: E402
from lerobot.optim.factory import make_optimizer_and_scheduler # noqa: E402
from lerobot.policies.factory import make_policy, make_pre_post_processors # noqa: E402
from lerobot.policies.pretrained import PreTrainedPolicy # noqa: E402
from lerobot.rl.wandb_utils import WandBLogger # noqa: E402
from lerobot.scripts.lerobot_eval import eval_policy_all # noqa: E402
from lerobot.utils.import_utils import register_third_party_plugins # noqa: E402
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker # noqa: E402
from lerobot.utils.random_utils import set_seed # noqa: E402
from lerobot.utils.train_utils import ( # noqa: E402
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets import EpisodeAwareSampler, make_dataset
from lerobot.envs.factory import make_env, make_env_pre_post_processors
from lerobot.envs.utils import close_envs
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.rl.wandb_utils import WandBLogger
from lerobot.scripts.lerobot_eval import eval_policy_all
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
from lerobot.utils.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state,
save_checkpoint,
update_last_checkpoint,
)
from lerobot.utils.utils import ( # noqa: E402
from lerobot.utils.utils import (
cycle,
format_big_number,
has_method,
@@ -171,6 +170,11 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
cfg: A `TrainPipelineConfig` object containing all training configurations.
accelerator: Optional Accelerator instance. If None, one will be created automatically.
"""
from lerobot.utils.import_utils import require_package
require_package("accelerate", extra="training")
from accelerate import Accelerator
cfg.validate()
# Create Accelerator if not provided

View File

@@ -75,6 +75,21 @@ default_calibration_path = HF_LEROBOT_HOME / "calibration"
HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser()
# Dataset meta-features (auto-populated by the recording pipeline)
DEFAULT_FEATURES = {
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
"index": {"dtype": "int64", "shape": (1,), "names": None},
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
}
# ImageNet normalization constants
IMAGENET_STATS = {
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
}
# streaming datasets
LOOKBACK_BACKTRACKTABLE = 100
LOOKAHEAD_BACKTRACKTABLE = 100

View File

@@ -223,7 +223,7 @@ def sanity_check_dataset_robot_compatibility(
require_package("deepdiff", extra="hardware")
from deepdiff import DeepDiff
from lerobot.datasets.utils import DEFAULT_FEATURES
from lerobot.utils.constants import DEFAULT_FEATURES
fields = [
("robot_type", dataset.meta.robot_type, robot.robot_type),

View File

@@ -0,0 +1,222 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Lightweight feature-manipulation utilities.
These functions are intentionally kept free of heavy dependencies (e.g. the
HuggingFace ``datasets`` library) so that they can be imported from anywhere
in the codebase including modules that are part of the *minimal* install
without triggering the ``lerobot.datasets`` package guard.
"""
from typing import Any
import numpy as np
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.utils.constants import ACTION, DEFAULT_FEATURES, OBS_ENV_STATE, OBS_STR
def _validate_feature_names(features: dict[str, dict]) -> None:
"""Validate that feature names do not contain invalid characters.
Args:
features (dict): The LeRobot features dictionary.
Raises:
ValueError: If any feature name contains '/'.
"""
invalid_features = {name: ft for name, ft in features.items() if "/" in name}
if invalid_features:
raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
def hw_to_dataset_features(
hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
) -> dict[str, dict]:
"""Convert hardware-specific features to a LeRobot dataset feature dictionary.
This function takes a dictionary describing hardware outputs (like joint states
or camera image shapes) and formats it into the standard LeRobot feature
specification.
Args:
hw_features (dict): Dictionary mapping feature names to their type (float for
joints) or shape (tuple for images).
prefix (str): The prefix to add to the feature keys (e.g., "observation"
or "action").
use_video (bool): If True, image features are marked as "video", otherwise "image".
Returns:
dict: A LeRobot features dictionary.
"""
features = {}
joint_fts = {
key: ftype
for key, ftype in hw_features.items()
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
}
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
if joint_fts and prefix == ACTION:
features[prefix] = {
"dtype": "float32",
"shape": (len(joint_fts),),
"names": list(joint_fts),
}
if joint_fts and prefix == OBS_STR:
features[f"{prefix}.state"] = {
"dtype": "float32",
"shape": (len(joint_fts),),
"names": list(joint_fts),
}
for key, shape in cam_fts.items():
features[f"{prefix}.images.{key}"] = {
"dtype": "video" if use_video else "image",
"shape": shape,
"names": ["height", "width", "channels"],
}
_validate_feature_names(features)
return features
def build_dataset_frame(
ds_features: dict[str, dict], values: dict[str, Any], prefix: str
) -> dict[str, np.ndarray]:
"""Construct a single data frame from raw values based on dataset features.
A "frame" is a dictionary containing all the data for a single timestep,
formatted as numpy arrays according to the feature specification.
Args:
ds_features (dict): The LeRobot dataset features dictionary.
values (dict): A dictionary of raw values from the hardware/environment.
prefix (str): The prefix to filter features by (e.g., "observation"
or "action").
Returns:
dict: A dictionary representing a single frame of data.
"""
frame = {}
for key, ft in ds_features.items():
if key in DEFAULT_FEATURES or not key.startswith(prefix):
continue
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
elif ft["dtype"] in ["image", "video"]:
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
return frame
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
"""Convert dataset features to policy features.
This function transforms the dataset's feature specification into a format
that a policy can use, classifying features by type (e.g., visual, state,
action) and ensuring correct shapes (e.g., channel-first for images).
Args:
features (dict): The LeRobot dataset features dictionary.
Returns:
dict: A dictionary mapping feature keys to `PolicyFeature` objects.
Raises:
ValueError: If an image feature does not have a 3D shape.
"""
# TODO(aliberts): Implement "type" in dataset features and simplify this
policy_features = {}
for key, ft in features.items():
shape = ft["shape"]
if ft["dtype"] in ["image", "video"]:
type = FeatureType.VISUAL
if len(shape) != 3:
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
names = ft["names"]
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif key == OBS_ENV_STATE:
type = FeatureType.ENV
elif key.startswith(OBS_STR):
type = FeatureType.STATE
elif key.startswith(ACTION):
type = FeatureType.ACTION
else:
continue
policy_features[key] = PolicyFeature(
type=type,
shape=shape,
)
return policy_features
def combine_feature_dicts(*dicts: dict) -> dict:
"""Merge LeRobot grouped feature dicts.
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
- For others (e.g. `observation.images.*`), the last one wins (if they are identical).
Args:
*dicts: A variable number of LeRobot feature dictionaries to merge.
Returns:
dict: A single merged feature dictionary.
Raises:
ValueError: If there's a dtype mismatch for a feature being merged.
"""
out: dict = {}
for d in dicts:
for key, value in d.items():
if not isinstance(value, dict):
out[key] = value
continue
dtype = value.get("dtype")
shape = value.get("shape")
is_vector = (
dtype not in ("image", "video", "string")
and isinstance(shape, tuple)
and len(shape) == 1
and "names" in value
)
if is_vector:
# Initialize or retrieve the accumulating dict for this feature key
target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
# Ensure consistent data types across merged entries
if "dtype" in target and dtype != target["dtype"]:
raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
# Merge feature names: append only new ones to preserve order without duplicates
seen = set(target["names"])
for n in value["names"]:
if n not in seen:
target["names"].append(n)
seen.add(n)
# Recompute the shape to reflect the updated number of features
target["shape"] = (len(target["names"]),)
else:
# For images/videos and non-1D entries: override with the latest definition
out[key] = value
return out

View File

@@ -90,7 +90,8 @@ def require_package(pkg_name: str, extra: str, import_name: str | None = None) -
_require_package_cache[cache_key] = is_package_available(pkg_name, import_name)
if not _require_package_cache[cache_key]:
raise ImportError(
f"'{pkg_name}' is required but not installed. Install it with: pip install 'lerobot[{extra}]'"
f"'{pkg_name}' is required but not installed. Install it with: "
f"pip install 'lerobot[{extra}]' (or uv pip install 'lerobot[{extra}]')"
)

View File

@@ -81,6 +81,8 @@ def write_video(video_path: str | Path, stacked_frames: list, fps: int) -> None:
stream.height = height
stream.pix_fmt = "yuv420p"
for frame_array in stacked_frames:
if height != orig_height or width != orig_width:
frame_array = frame_array[:height, :width]
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)

View File

@@ -292,9 +292,8 @@ class SuppressProgressBars:
disable_progress_bar()
except ImportError:
logging.getLogger(__name__).info(
"SuppressProgressBars is a no-op because 'datasets' is not installed. "
"Install it with: pip install 'lerobot[dataset]'"
logging.getLogger(__name__).debug(
"SuppressProgressBars is a no-op because 'datasets' is not installed."
)
def __exit__(self, exc_type, exc_val, exc_tb):

View File

@@ -21,7 +21,7 @@ from safetensors.torch import save_file
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets import make_dataset
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors
from lerobot.utils.constants import OBS_STR

View File

@@ -19,10 +19,10 @@ import torch
from datasets import Dataset
from huggingface_hub import DatasetCard
from lerobot.datasets.feature_utils import combine_feature_dicts
from lerobot.datasets.io_utils import hf_transform_to_torch
from lerobot.datasets.utils import create_lerobot_dataset_card
from lerobot.utils.constants import ACTION, OBS_IMAGES
from lerobot.utils.feature_utils import combine_feature_dicts
def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]:

View File

@@ -29,8 +29,8 @@ from torchvision.transforms import v2
import lerobot
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.feature_utils import get_hf_features_from_features, hw_to_dataset_features
from lerobot.datasets import make_dataset
from lerobot.datasets.feature_utils import get_hf_features_from_features
from lerobot.datasets.image_writer import image_array_to_pil_image
from lerobot.datasets.io_utils import hf_transform_to_torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
@@ -47,6 +47,7 @@ from lerobot.policies.factory import make_policy_config
from lerobot.robots import make_robot_from_config
from lerobot.transforms import ImageTransforms, ImageTransformsConfig
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
from lerobot.utils.feature_utils import hw_to_dataset_features
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
from tests.mocks.mock_robot import MockRobotConfig
from tests.utils import require_x86_64_kernel

View File

@@ -27,8 +27,7 @@ from lerobot import available_policies
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.feature_utils import dataset_to_policy_features
from lerobot.datasets import make_dataset
from lerobot.envs.factory import make_env, make_env_config
from lerobot.envs.utils import close_envs, preprocess_observation
from lerobot.optim.factory import make_optimizer_and_scheduler
@@ -44,6 +43,7 @@ from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.vqbet.modeling_vqbet import VQBeTHead
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.utils.feature_utils import dataset_to_policy_features
from lerobot.utils.random_utils import seeded_context
from lerobot.utils.utils import cycle
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats