refactor(normalization): Remove unused state dict transformation methods and streamline imports

- Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process.
- Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging.
- Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation.
- Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality.
This commit is contained in:
Adil Zouitine
2025-08-01 08:55:38 +02:00
committed by Steven Palma
parent f02ce69df0
commit 8ff95be04c
6 changed files with 398 additions and 98 deletions

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
from collections.abc import Mapping
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any
@@ -10,7 +12,7 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey, RobotProcessor
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
@@ -402,3 +404,13 @@ class UnnormalizerProcessor:
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features
def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor:
robot_processor = deepcopy(robot_processor)
for step in robot_processor.steps:
if isinstance(step, NormalizerProcessor) or isinstance(step, UnnormalizerProcessor):
step: NormalizerProcessor | UnnormalizerProcessor
step.stats = stats
step._tensor_stats = _convert_stats_to_tensors(stats)
return robot_processor