Files
lerobot-clone/src/lerobot/processor/converters.py
Adil Zouitine fc43246942 feat(record): add transition features to dataset and handle scalar vs array formatting in converters (#1861)
- Introduced new transition features (`next.reward`, `next.done`, `next.truncated`) in the dataset during recording.
- Updated the `transition_to_dataset_frame` function to handle scalar values correctly, ensuring compatibility with expected array formats for reward, done, and truncated features.
2025-09-04 16:17:31 +02:00

482 lines
16 KiB
Python

# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
from copy import deepcopy
from functools import singledispatch
from typing import Any
import numpy as np
import torch
from scipy.spatial.transform import Rotation
from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED
from .core import EnvTransition, TransitionKey
@singledispatch
def to_tensor(
value: Any,
*,
dtype: torch.dtype | None = torch.float32,
device: torch.device | str | None = None,
) -> torch.Tensor:
"""
Convert various data types to PyTorch tensors with configurable options.
This is a unified tensor conversion function using single dispatch to handle
different input types appropriately.
Args:
value: Input value to convert (tensor, array, scalar, sequence, etc.)
dtype: Target tensor dtype. If None, preserves original dtype.
device: Target device for the tensor.
Returns:
PyTorch tensor.
Raises:
TypeError: If the input type is not supported.
"""
raise TypeError(f"Unsupported type for tensor conversion: {type(value)}")
@to_tensor.register(torch.Tensor)
def _(value: torch.Tensor, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
"""Handle existing PyTorch tensors."""
if dtype is not None:
value = value.to(dtype=dtype)
if device is not None:
value = value.to(device=device)
return value
@to_tensor.register(np.ndarray)
def _(
value: np.ndarray,
*,
dtype=torch.float32,
device=None,
**kwargs,
) -> torch.Tensor:
"""Handle numpy arrays."""
# Check for numpy scalars (0-dimensional arrays) and treat them as scalars
if value.ndim == 0:
# Numpy scalars should be converted to 0-dimensional tensors
scalar_value = value.item()
return torch.tensor(scalar_value, dtype=dtype, device=device)
# Create tensor from numpy array (torch.from_numpy handles contiguity automatically)
tensor = torch.from_numpy(value)
# Apply dtype conversion if specified
if dtype is not None:
tensor = tensor.to(dtype=dtype)
if device is not None:
tensor = tensor.to(device=device)
return tensor
@to_tensor.register(int)
@to_tensor.register(float)
@to_tensor.register(np.integer)
@to_tensor.register(np.floating)
def _(value, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
"""Handle scalar values including numpy scalars."""
return torch.tensor(value, dtype=dtype, device=device)
@to_tensor.register(list)
@to_tensor.register(tuple)
def _(value: Sequence, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
"""Handle sequences (lists, tuples)."""
return torch.tensor(value, dtype=dtype, device=device)
@to_tensor.register(dict)
def _(value: dict, *, device=None, **kwargs) -> dict:
"""Handle dictionaries by recursively converting values to tensors."""
if not value:
return {}
result = {}
for key, sub_value in value.items():
if sub_value is None:
continue
if isinstance(sub_value, dict):
# Recursively process nested dictionaries
result[key] = to_tensor(
sub_value,
device=device,
**kwargs,
)
continue
# Convert individual values to tensors
result[key] = to_tensor(
sub_value,
device=device,
**kwargs,
)
return result
def _from_tensor(x: torch.Tensor | Any) -> np.ndarray | float | int | Any:
"""Convert tensor to numpy/scalar if needed."""
if isinstance(x, torch.Tensor):
return x.item() if x.numel() == 1 else x.detach().cpu().numpy()
return x
def _is_image(arr: Any) -> bool:
return isinstance(arr, np.ndarray) and arr.dtype == np.uint8 and arr.ndim == 3
def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
state, images = {}, {}
for k, v in obs.items():
if "image" in k.lower() or _is_image(v):
images[k] = v
else:
state[k] = v
return state, images
# ============================================================================
# Private Helper Functions (Common Logic)
# ============================================================================
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
"""Extract complementary data (pad flags, task, index, task_index)."""
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
return {**pad_keys, **task_key, **index_key, **task_index_key}
def _merge_transitions(base: EnvTransition, other: EnvTransition) -> EnvTransition:
"""Merge two transitions, with other taking precedence."""
out = deepcopy(base)
for key in (
TransitionKey.OBSERVATION,
TransitionKey.ACTION,
TransitionKey.INFO,
TransitionKey.COMPLEMENTARY_DATA,
):
if other.get(key):
out.setdefault(key, {}).update(deepcopy(other[key]))
for k in (TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED):
if k in other:
out[k] = other[k]
return out
# ============================================================================
# Core Conversion Functions
# ============================================================================
def create_transition(
observation: dict[str, Any] | None = None,
action: dict[str, Any] | None = None,
reward: float = 0.0,
done: bool = False,
truncated: bool = False,
info: dict[str, Any] | None = None,
complementary_data: dict[str, Any] | None = None,
) -> EnvTransition:
"""Create an EnvTransition with sensible defaults.
Args:
observation: Observation dictionary.
action: Action dictionary.
reward: Scalar reward value.
done: Episode termination flag.
truncated: Episode truncation flag.
info: Additional info dictionary.
complementary_data: Complementary data dictionary.
Returns:
Complete EnvTransition dictionary.
"""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info if info is not None else {},
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
}
def action_to_transition(action: dict[str, Any]) -> EnvTransition: # action_to_transition
"""
Convert a raw teleop action dict into an EnvTransition under the ACTION TransitionKey.
"""
act_dict: dict[str, Any] = {}
for k, v in action.items():
# Check if the value is a type that should not be converted to a tensor.
if isinstance(v, (Rotation, dict)):
act_dict[f"{ACTION}.{k}"] = v
continue
arr = np.array(v) if np.isscalar(v) else v
act_dict[f"{ACTION}.{k}"] = to_tensor(arr)
return create_transition(observation={}, action=act_dict)
# TODO(Adil, Pepijn): Overtime we can maybe add these converters to pipeline.py itself
def observation_to_transition(observation: dict[str, Any]) -> EnvTransition:
"""
Convert a raw robot observation dict into an EnvTransition under the OBSERVATION TransitionKey.
"""
state, images = _split_obs_to_state_and_images(observation)
obs_dict: dict[str, Any] = {}
for k, v in state.items():
arr = np.array(v) if np.isscalar(v) else v
obs_dict[f"{OBS_STATE}.{k}"] = to_tensor(arr)
for cam, img in images.items():
obs_dict[f"{OBS_IMAGES}.{cam}"] = img
return create_transition(observation=obs_dict, action={})
def transition_to_robot_action(transition: EnvTransition) -> dict[str, Any]:
"""
Converts a EnvTransition under the ACTION TransitionKey to a dict with keys ending in '.pos' for raw robot actions.
"""
out: dict[str, Any] = {}
action_dict = transition.get(TransitionKey.ACTION) or {}
if action_dict is None:
return out
for k, v in action_dict.items():
if isinstance(k, str) and k.startswith(f"{ACTION}.") and k.endswith((".pos", ".vel")):
out_key = k[len(f"{ACTION}.") :] # Strip the 'action.' prefix.
out[out_key] = float(v)
return out
def merge_transitions(transitions: Sequence[EnvTransition] | EnvTransition) -> EnvTransition:
"""Merge multiple transitions or return single transition.
Args:
transitions: Either a single transition or iterable of transitions.
Returns:
Merged EnvTransition.
"""
if not isinstance(transitions, Sequence): # Single transition
return transitions
items = list(transitions)
if not items:
raise ValueError("merge_transitions() requires a non-empty sequence of transitions")
result = items[0]
for t in items[1:]:
result = _merge_transitions(result, t)
return result
def transition_to_dataset_frame(
transitions_or_transition: EnvTransition | Sequence[EnvTransition], features: dict[str, dict]
) -> dict[str, Any]:
"""Convert a single EnvTransition or an iterable of them into a flat, dataset-friendly dictionary for training or evaluation.
Processes transitions according to the provided feature specification and returns
data in the format expected by machine learning models and datasets.
Args:
transitions_or_transition: Either a single EnvTransition dict or an iterable of them
(which will be merged using merge_transitions).
features: Feature specification dictionary with the following structure:
- 'action': dict with 'names': list of action feature names
- 'observation.state': dict with 'names': list of state feature names
- keys starting with 'observation.images.' are passed through as-is
Returns:
Flat dictionary containing:
- numpy arrays for "observation.state" and "action" (vectorized from feature names)
- any image tensors defined in features (passed through unchanged)
- next.{reward,done,truncated} scalar values
- info dict
- *_is_pad flags and task from complementary_data
"""
action_names = features.get(ACTION, {}).get("names", [])
obs_state_names = features.get(OBS_STATE, {}).get("names", [])
image_keys = [k for k in features if k.startswith(OBS_IMAGES)]
tr = merge_transitions(transitions_or_transition)
obs = tr.get(TransitionKey.OBSERVATION, {}) or {}
act = tr.get(TransitionKey.ACTION, {}) or {}
batch: dict[str, Any] = {}
# Images passthrough
for k in image_keys:
if k in obs:
batch[k] = obs[k]
# Observation.state vector
if obs_state_names:
vals = [_from_tensor(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names]
batch[OBS_STATE] = np.asarray(vals, dtype=np.float32)
# Action vector
if action_names:
vals = [_from_tensor(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names]
batch[ACTION] = np.asarray(vals, dtype=np.float32)
# Add transition metadata
if tr.get(TransitionKey.REWARD) is not None:
reward_val = _from_tensor(tr[TransitionKey.REWARD])
# Check if features expect array format, otherwise keep as scalar
if REWARD in features and features[REWARD].get("shape") == (1,):
batch[REWARD] = np.array([reward_val], dtype=np.float32)
else:
batch[REWARD] = reward_val
if tr.get(TransitionKey.DONE) is not None:
done_val = _from_tensor(tr[TransitionKey.DONE])
if DONE in features and features[DONE].get("shape") == (1,):
batch[DONE] = np.array([done_val], dtype=bool)
else:
batch[DONE] = done_val
if tr.get(TransitionKey.TRUNCATED) is not None:
truncated_val = _from_tensor(tr[TransitionKey.TRUNCATED])
if TRUNCATED in features and features[TRUNCATED].get("shape") == (1,):
batch[TRUNCATED] = np.array([truncated_val], dtype=bool)
else:
batch[TRUNCATED] = truncated_val
# Complementary data flags and task
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if comp:
# pad flags
for k, v in comp.items():
if k.endswith("_is_pad"):
batch[k] = v
# task label
if comp.get("task") is not None:
batch["task"] = comp["task"]
return batch
def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
"""Convert a batch dict coming from LeRobot replay/dataset code into an EnvTransition dictionary.
The function maps well known keys to the EnvTransition structure. Missing keys are
filled with sane defaults (None or 0.0/False).
Keys recognised (case-sensitive):
* "observation.*" (keys starting with "observation." are grouped into observation dict)
* "action"
* "next.reward"
* "next.done"
* "next.truncated"
* "info"
* "_is_pad" patterns (padding flags)
* "task", "index", "task_index" (complementary data)
Additional keys are ignored so that existing dataloaders can carry extra
metadata without breaking the processor.
Args:
batch: Batch dictionary from datasets or dataloaders containing the above keys.
Returns:
EnvTransition dictionary with properly structured transition data.
"""
# Validate input type
if not isinstance(batch, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
# Extract observation keys
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
complementary_data = _extract_complementary_data(batch)
return create_transition(
observation=observation_keys if observation_keys else None,
action=batch.get("action"),
reward=batch.get("next.reward", 0.0),
done=batch.get("next.done", False),
truncated=batch.get("next.truncated", False),
info=batch.get("info", {}),
complementary_data=complementary_data if complementary_data else None,
)
def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
"""Inverse of batch_to_transition. Returns a dict with canonical field names used throughout LeRobot.
Converts an EnvTransition back to the batch format expected by datasets, dataloaders,
and other LeRobot components.
Output format:
* "action": Action data from transition
* "next.reward": Reward value (defaults to 0.0)
* "next.done": Done flag (defaults to False)
* "next.truncated": Truncated flag (defaults to False)
* "info": Info dictionary (defaults to {})
* Flattened observation keys (e.g., "observation.state", "observation.images.cam1")
* Complementary data fields ("task", "index", "task_index", padding flags)
Args:
transition: EnvTransition dictionary to convert.
Returns:
Batch dictionary with canonical LeRobot field names suitable for dataloaders.
"""
batch = {
"action": transition.get(TransitionKey.ACTION),
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
"next.done": transition.get(TransitionKey.DONE, False),
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
"info": transition.get(TransitionKey.INFO, {}),
}
# Add complementary data
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if comp_data:
batch.update(comp_data)
# Flatten observation dict
observation = transition.get(TransitionKey.OBSERVATION)
if isinstance(observation, dict):
batch.update(observation)
return batch