mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 12:21:27 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Adil Zouitine
parent
8774aec304
commit
8ebf79c494
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Mapping
|
||||
from typing import Any, Mapping
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -11,9 +11,9 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionIndex
|
||||
|
||||
|
||||
def _convert_stats_to_tensors(stats: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Tensor]]:
|
||||
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
|
||||
"""Convert numpy arrays and other types to torch tensors."""
|
||||
tensor_stats: Dict[str, Dict[str, Tensor]] = {}
|
||||
tensor_stats: dict[str, dict[str, Tensor]] = {}
|
||||
for key, sub in stats.items():
|
||||
tensor_stats[key] = {}
|
||||
for stat_name, value in sub.items():
|
||||
@@ -50,12 +50,12 @@ class ObservationNormalizer:
|
||||
Small constant to avoid division by zero.
|
||||
"""
|
||||
|
||||
stats: Dict[str, Dict[str, Any]]
|
||||
stats: dict[str, dict[str, Any]]
|
||||
normalize_keys: set[str] | None = None
|
||||
eps: float = 1e-8
|
||||
|
||||
# Cached tensors for performance
|
||||
_tensor_stats: Dict[str, Dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
@@ -135,14 +135,14 @@ class ObservationNormalizer:
|
||||
transition[TransitionIndex.COMPLEMENTARY_DATA],
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"normalize_keys": list(self.normalize_keys) if self.normalize_keys is not None else None,
|
||||
"eps": self.eps,
|
||||
}
|
||||
|
||||
def state_dict(self) -> Dict[str, Tensor]:
|
||||
flat_state: Dict[str, Tensor] = {}
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
flat_state: dict[str, Tensor] = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
flat_state[f"{key}.{stat_name}"] = tensor
|
||||
@@ -178,11 +178,11 @@ class ActionUnnormalizer:
|
||||
Small constant used during normalization (not used in unnormalization).
|
||||
"""
|
||||
|
||||
action_stats: Dict[str, Any]
|
||||
action_stats: dict[str, Any]
|
||||
eps: float = 1e-8 # Kept for consistency, not used in unnormalization
|
||||
|
||||
# Cached tensors for performance
|
||||
_tensor_stats: Dict[str, Tensor] = field(default_factory=dict, init=False, repr=False)
|
||||
_tensor_stats: dict[str, Tensor] = field(default_factory=dict, init=False, repr=False)
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
@@ -238,10 +238,10 @@ class ActionUnnormalizer:
|
||||
transition[TransitionIndex.COMPLEMENTARY_DATA],
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"eps": self.eps}
|
||||
|
||||
def state_dict(self) -> Dict[str, Tensor]:
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
return dict(self._tensor_stats.items())
|
||||
|
||||
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
||||
@@ -273,13 +273,13 @@ class NormalizationProcessor:
|
||||
Small constant to avoid division by zero.
|
||||
"""
|
||||
|
||||
stats: Dict[str, Dict[str, Any]]
|
||||
stats: dict[str, dict[str, Any]]
|
||||
normalize_keys: set[str] | None = None
|
||||
unnormalize_action: bool = True
|
||||
eps: float = 1e-8
|
||||
|
||||
# Cached tensors for performance
|
||||
_tensor_stats: Dict[str, Dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
@@ -365,15 +365,15 @@ class NormalizationProcessor:
|
||||
transition[TransitionIndex.COMPLEMENTARY_DATA],
|
||||
)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"normalize_keys": list(self.normalize_keys) if self.normalize_keys is not None else None,
|
||||
"unnormalize_action": self.unnormalize_action,
|
||||
"eps": self.eps,
|
||||
}
|
||||
|
||||
def state_dict(self) -> Dict[str, Tensor]:
|
||||
flat_state: Dict[str, Tensor] = {}
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
flat_state: dict[str, Tensor] = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
flat_state[f"{key}.{stat_name}"] = tensor
|
||||
|
||||
Reference in New Issue
Block a user