mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
refactor(normalization): Remove unused state dict transformation methods and streamline imports
- Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process. - Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging. - Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation. - Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality.
This commit is contained in:
committed by
Steven Palma
parent
f02ce69df0
commit
8ff95be04c
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
@@ -10,7 +12,7 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey, RobotProcessor
|
||||
|
||||
|
||||
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
|
||||
@@ -402,3 +404,13 @@ class UnnormalizerProcessor:
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, Any]]) -> RobotProcessor:
|
||||
robot_processor = deepcopy(robot_processor)
|
||||
for step in robot_processor.steps:
|
||||
if isinstance(step, NormalizerProcessor) or isinstance(step, UnnormalizerProcessor):
|
||||
step: NormalizerProcessor | UnnormalizerProcessor
|
||||
step.stats = stats
|
||||
step._tensor_stats = _convert_stats_to_tensors(stats)
|
||||
return robot_processor
|
||||
|
||||
Reference in New Issue
Block a user