chore (batch handling): Enhance processing components with batch conversion utilities

This commit is contained in:
Adil Zouitine
2025-07-06 21:29:51 +02:00
parent c227107f60
commit b08149a113
6 changed files with 606 additions and 53 deletions

View File

@@ -1,12 +1,13 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Mapping
from typing import Any, Mapping, Optional, Set
import numpy as np
import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
@@ -45,8 +46,18 @@ class NormalizerProcessor:
the normalize_keys parameter.
"""
stats: dict[str, dict[str, Any]]
normalize_keys: set[str] | None = None
# Features and normalisation map are mandatory to match the design of normalize.py
features: dict[str, PolicyFeature]
norm_map: dict[FeatureType, NormalizationMode]
# Pre-computed statistics coming from dataset.meta.stats for instance.
stats: Optional[dict[str, dict[str, Any]]] = None
# Explicit subset of keys to normalise. If ``None`` every key (except
# "action") found in ``stats`` will be normalised. Using a ``set`` makes
# membership checks O(1).
normalize_keys: Optional[Set[str]] = None
eps: float = 1e-8
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
@@ -55,24 +66,48 @@ class NormalizerProcessor:
def from_lerobot_dataset(
cls,
dataset: LeRobotDataset,
features: dict[str, PolicyFeature],
norm_map: dict[FeatureType, NormalizationMode],
*,
normalize_keys: set[str] | None = None,
normalize_keys: Optional[Set[str]] = None,
eps: float = 1e-8,
) -> NormalizerProcessor:
return cls(stats=dataset.meta.stats, normalize_keys=normalize_keys, eps=eps)
) -> "NormalizerProcessor":
"""Factory helper that pulls statistics from a :class:`LeRobotDataset`.
The features and norm_map parameters are mandatory to match the design
pattern used in normalize.py.
"""
return cls(
features=features,
norm_map=norm_map,
stats=dataset.meta.stats,
normalize_keys=normalize_keys,
eps=eps,
)
def __post_init__(self):
# Convert statistics once so we avoid repeated numpy→Tensor conversions
# during runtime.
self.stats = self.stats or {}
self._tensor_stats = _convert_stats_to_tensors(self.stats)
# Ensure *normalize_keys* is a set for fast look-ups and compare by
# value later when returning the configuration.
if self.normalize_keys is not None and not isinstance(self.normalize_keys, set):
self.normalize_keys = set(self.normalize_keys)
def _normalize_obs(self, observation):
if observation is None:
return None
keys_to_norm = (
self.normalize_keys
if self.normalize_keys is not None
else {k for k in self._tensor_stats if k != "action"}
)
# Decide which keys should be normalised for this call.
if self.normalize_keys is not None:
keys_to_norm = self.normalize_keys
else:
# Use feature map to skip action keys.
keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION}
processed = dict(observation)
for key in keys_to_norm:
if key not in processed or key not in self._tensor_stats:
@@ -126,7 +161,11 @@ class NormalizerProcessor:
)
def get_config(self) -> dict[str, Any]:
return {"normalize_keys": list(self.normalize_keys) if self.normalize_keys else None, "eps": self.eps}
config = {"eps": self.eps}
if self.normalize_keys is not None:
# Serialise as a list for YAML / JSON friendliness
config["normalize_keys"] = sorted(self.normalize_keys)
return config
def state_dict(self) -> dict[str, Tensor]:
flat = {}
@@ -154,8 +193,9 @@ class UnnormalizerProcessor:
transform.
"""
stats: dict[str, dict[str, Any]]
unnormalize_keys: set[str] | None = None
features: dict[str, PolicyFeature]
norm_map: dict[FeatureType, NormalizationMode]
stats: Optional[dict[str, dict[str, Any]]] = None
eps: float = 1e-8
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
@@ -164,23 +204,21 @@ class UnnormalizerProcessor:
def from_lerobot_dataset(
cls,
dataset: LeRobotDataset,
features: dict[str, PolicyFeature],
norm_map: dict[FeatureType, NormalizationMode],
*,
unnormalize_keys: set[str] | None = None,
eps: float = 1e-8,
) -> UnnormalizerProcessor:
return cls(stats=dataset.meta.stats, unnormalize_keys=unnormalize_keys, eps=eps)
) -> "UnnormalizerProcessor":
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps)
def __post_init__(self):
self.stats = self.stats or {}
self._tensor_stats = _convert_stats_to_tensors(self.stats)
def _unnormalize_obs(self, observation):
if observation is None:
return None
keys = (
self.unnormalize_keys
if self.unnormalize_keys is not None
else {k for k in self._tensor_stats if k != "action"}
)
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
processed = dict(observation)
for key in keys:
if key not in processed or key not in self._tensor_stats:
@@ -231,10 +269,7 @@ class UnnormalizerProcessor:
)
def get_config(self) -> dict[str, Any]:
return {
"unnormalize_keys": list(self.unnormalize_keys) if self.unnormalize_keys else None,
"eps": self.eps,
}
return {"eps": self.eps}
def state_dict(self) -> dict[str, Tensor]:
flat = {}