fix(processor): Preserve stats overrides in normalizer load_state_dict and fix training resumption (#1958)

* feat(processor): enhance normalization handling and state management

- Added support for additional normalization modes including IDENTITY.
- Introduced a new function `clean_state_dict` to remove specific substrings from state dict keys.
- Implemented preservation of explicitly provided normalization statistics during state loading.
- Updated training script to conditionally provide dataset statistics based on resume state.
- Expanded tests to verify the correct behavior of stats override preservation and loading.

* fix(train): remove redundant comment regarding state loading

- Removed a comment that noted the preprocessor and postprocessor state is already loaded when resuming training, as it was deemed unnecessary for clarity.
This commit is contained in:
Adil Zouitine
2025-09-16 16:45:13 +02:00
committed by GitHub
parent 772da63a8e
commit a7d1179aab
4 changed files with 321 additions and 11 deletions

View File

@@ -26,7 +26,6 @@ from torch.optim import Optimizer
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
@@ -177,8 +176,15 @@ def train(cfg: TrainPipelineConfig):
cfg=cfg.policy,
ds_meta=dataset.meta,
)
# Create processors - only provide dataset_stats if not resuming from saved processors
processor_kwargs = {}
if not (cfg.resume and cfg.policy.pretrained_path):
# Only provide dataset_stats when not resuming from saved processor state
processor_kwargs["dataset_stats"] = dataset.meta.stats
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs
)
logging.info("Creating optimizer and scheduler")
@@ -189,12 +195,6 @@ def train(cfg: TrainPipelineConfig):
if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
preprocessor.from_pretrained(
cfg.policy.pretrained_path, config_filename=f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
)
postprocessor.from_pretrained(
cfg.policy.pretrained_path, config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())