[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-07-03 16:35:37 +00:00
committed by Adil Zouitine
parent 8774aec304
commit 8ebf79c494
2 changed files with 18 additions and 18 deletions

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, Mapping
from typing import Any, Mapping
import numpy as np
import torch
@@ -11,9 +11,9 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
def _convert_stats_to_tensors(stats: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Tensor]]:
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
"""Convert numpy arrays and other types to torch tensors."""
tensor_stats: Dict[str, Dict[str, Tensor]] = {}
tensor_stats: dict[str, dict[str, Tensor]] = {}
for key, sub in stats.items():
tensor_stats[key] = {}
for stat_name, value in sub.items():
@@ -50,12 +50,12 @@ class ObservationNormalizer:
Small constant to avoid division by zero.
"""
stats: Dict[str, Dict[str, Any]]
stats: dict[str, dict[str, Any]]
normalize_keys: set[str] | None = None
eps: float = 1e-8
# Cached tensors for performance
_tensor_stats: Dict[str, Dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
@classmethod
def from_lerobot_dataset(
@@ -135,14 +135,14 @@ class ObservationNormalizer:
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
def get_config(self) -> Dict[str, Any]:
def get_config(self) -> dict[str, Any]:
return {
"normalize_keys": list(self.normalize_keys) if self.normalize_keys is not None else None,
"eps": self.eps,
}
def state_dict(self) -> Dict[str, Tensor]:
flat_state: Dict[str, Tensor] = {}
def state_dict(self) -> dict[str, Tensor]:
flat_state: dict[str, Tensor] = {}
for key, sub in self._tensor_stats.items():
for stat_name, tensor in sub.items():
flat_state[f"{key}.{stat_name}"] = tensor
@@ -178,11 +178,11 @@ class ActionUnnormalizer:
Small constant used during normalization (not used in unnormalization).
"""
action_stats: Dict[str, Any]
action_stats: dict[str, Any]
eps: float = 1e-8 # Kept for consistency, not used in unnormalization
# Cached tensors for performance
_tensor_stats: Dict[str, Tensor] = field(default_factory=dict, init=False, repr=False)
_tensor_stats: dict[str, Tensor] = field(default_factory=dict, init=False, repr=False)
@classmethod
def from_lerobot_dataset(
@@ -238,10 +238,10 @@ class ActionUnnormalizer:
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
def get_config(self) -> Dict[str, Any]:
def get_config(self) -> dict[str, Any]:
return {"eps": self.eps}
def state_dict(self) -> Dict[str, Tensor]:
def state_dict(self) -> dict[str, Tensor]:
return dict(self._tensor_stats.items())
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
@@ -273,13 +273,13 @@ class NormalizationProcessor:
Small constant to avoid division by zero.
"""
stats: Dict[str, Dict[str, Any]]
stats: dict[str, dict[str, Any]]
normalize_keys: set[str] | None = None
unnormalize_action: bool = True
eps: float = 1e-8
# Cached tensors for performance
_tensor_stats: Dict[str, Dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
@classmethod
def from_lerobot_dataset(
@@ -365,15 +365,15 @@ class NormalizationProcessor:
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
def get_config(self) -> Dict[str, Any]:
def get_config(self) -> dict[str, Any]:
return {
"normalize_keys": list(self.normalize_keys) if self.normalize_keys is not None else None,
"unnormalize_action": self.unnormalize_action,
"eps": self.eps,
}
def state_dict(self) -> Dict[str, Tensor]:
flat_state: Dict[str, Tensor] = {}
def state_dict(self) -> dict[str, Tensor]:
flat_state: dict[str, Tensor] = {}
for key, sub in self._tensor_stats.items():
for stat_name, tensor in sub.items():
flat_state[f"{key}.{stat_name}"] = tensor