mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
chore (batch handling): Enhance processing components with batch conversion utilities
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Mapping
|
||||
from typing import Any, Mapping, Optional, Set
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
|
||||
|
||||
@@ -45,8 +46,18 @@ class NormalizerProcessor:
|
||||
the normalize_keys parameter.
|
||||
"""
|
||||
|
||||
stats: dict[str, dict[str, Any]]
|
||||
normalize_keys: set[str] | None = None
|
||||
# Features and normalisation map are mandatory to match the design of normalize.py
|
||||
features: dict[str, PolicyFeature]
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
|
||||
# Pre-computed statistics coming from dataset.meta.stats for instance.
|
||||
stats: Optional[dict[str, dict[str, Any]]] = None
|
||||
|
||||
# Explicit subset of keys to normalise. If ``None`` every key (except
|
||||
# "action") found in ``stats`` will be normalised. Using a ``set`` makes
|
||||
# membership checks O(1).
|
||||
normalize_keys: Optional[Set[str]] = None
|
||||
|
||||
eps: float = 1e-8
|
||||
|
||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
@@ -55,24 +66,48 @@ class NormalizerProcessor:
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
dataset: LeRobotDataset,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
*,
|
||||
normalize_keys: set[str] | None = None,
|
||||
normalize_keys: Optional[Set[str]] = None,
|
||||
eps: float = 1e-8,
|
||||
) -> NormalizerProcessor:
|
||||
return cls(stats=dataset.meta.stats, normalize_keys=normalize_keys, eps=eps)
|
||||
) -> "NormalizerProcessor":
|
||||
"""Factory helper that pulls statistics from a :class:`LeRobotDataset`.
|
||||
|
||||
The features and norm_map parameters are mandatory to match the design
|
||||
pattern used in normalize.py.
|
||||
"""
|
||||
|
||||
return cls(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=dataset.meta.stats,
|
||||
normalize_keys=normalize_keys,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Convert statistics once so we avoid repeated numpy→Tensor conversions
|
||||
# during runtime.
|
||||
self.stats = self.stats or {}
|
||||
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
||||
|
||||
# Ensure *normalize_keys* is a set for fast look-ups and compare by
|
||||
# value later when returning the configuration.
|
||||
if self.normalize_keys is not None and not isinstance(self.normalize_keys, set):
|
||||
self.normalize_keys = set(self.normalize_keys)
|
||||
|
||||
def _normalize_obs(self, observation):
|
||||
if observation is None:
|
||||
return None
|
||||
|
||||
keys_to_norm = (
|
||||
self.normalize_keys
|
||||
if self.normalize_keys is not None
|
||||
else {k for k in self._tensor_stats if k != "action"}
|
||||
)
|
||||
# Decide which keys should be normalised for this call.
|
||||
if self.normalize_keys is not None:
|
||||
keys_to_norm = self.normalize_keys
|
||||
else:
|
||||
# Use feature map to skip action keys.
|
||||
keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION}
|
||||
|
||||
processed = dict(observation)
|
||||
for key in keys_to_norm:
|
||||
if key not in processed or key not in self._tensor_stats:
|
||||
@@ -126,7 +161,11 @@ class NormalizerProcessor:
|
||||
)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"normalize_keys": list(self.normalize_keys) if self.normalize_keys else None, "eps": self.eps}
|
||||
config = {"eps": self.eps}
|
||||
if self.normalize_keys is not None:
|
||||
# Serialise as a list for YAML / JSON friendliness
|
||||
config["normalize_keys"] = sorted(self.normalize_keys)
|
||||
return config
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
flat = {}
|
||||
@@ -154,8 +193,9 @@ class UnnormalizerProcessor:
|
||||
transform.
|
||||
"""
|
||||
|
||||
stats: dict[str, dict[str, Any]]
|
||||
unnormalize_keys: set[str] | None = None
|
||||
features: dict[str, PolicyFeature]
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
stats: Optional[dict[str, dict[str, Any]]] = None
|
||||
eps: float = 1e-8
|
||||
|
||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
@@ -164,23 +204,21 @@ class UnnormalizerProcessor:
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
dataset: LeRobotDataset,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
*,
|
||||
unnormalize_keys: set[str] | None = None,
|
||||
eps: float = 1e-8,
|
||||
) -> UnnormalizerProcessor:
|
||||
return cls(stats=dataset.meta.stats, unnormalize_keys=unnormalize_keys, eps=eps)
|
||||
) -> "UnnormalizerProcessor":
|
||||
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, eps=eps)
|
||||
|
||||
def __post_init__(self):
|
||||
self.stats = self.stats or {}
|
||||
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
||||
|
||||
def _unnormalize_obs(self, observation):
|
||||
if observation is None:
|
||||
return None
|
||||
keys = (
|
||||
self.unnormalize_keys
|
||||
if self.unnormalize_keys is not None
|
||||
else {k for k in self._tensor_stats if k != "action"}
|
||||
)
|
||||
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
|
||||
processed = dict(observation)
|
||||
for key in keys:
|
||||
if key not in processed or key not in self._tensor_stats:
|
||||
@@ -231,10 +269,7 @@ class UnnormalizerProcessor:
|
||||
)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"unnormalize_keys": list(self.unnormalize_keys) if self.unnormalize_keys else None,
|
||||
"eps": self.eps,
|
||||
}
|
||||
return {"eps": self.eps}
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
flat = {}
|
||||
|
||||
Reference in New Issue
Block a user