diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 2cee5372a..a4fede3d3 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -211,30 +211,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): logging.info("Creating env") eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) - # Create processors first - for SARM/ReWiND, this updates config.num_stages from dataset annotations - # This must happen BEFORE make_policy() so the model is created with the correct number of stages - processor_kwargs = {} - postprocessor_kwargs = {} - if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path: - # Only provide dataset_stats when not resuming from saved processor state - processor_kwargs["dataset_stats"] = dataset.meta.stats - - # For ReWiND and SARM, always provide dataset_meta for progress normalization - if cfg.policy.type in ["rewind", "sarm"]: - processor_kwargs["dataset_meta"] = dataset.meta - - # For pretrained paths, we need to defer some overrides until after policy creation - # But for SARM/ReWiND, we need processor to run first to update num_stages - if cfg.policy.pretrained_path is None or cfg.policy.type in ["rewind", "sarm"]: - if is_main_process: - logging.info("Creating processors") - preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg.policy, - pretrained_path=cfg.policy.pretrained_path if cfg.policy.type not in ["rewind", "sarm"] else None, - **processor_kwargs, - **postprocessor_kwargs, - ) - if is_main_process: logging.info("Creating policy") policy = make_policy( @@ -246,8 +222,18 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): # Wait for all processes to finish policy creation before continuing accelerator.wait_for_everyone() - # For pretrained paths (non-SARM/ReWiND), create processors now with policy-dependent overrides - if cfg.policy.pretrained_path is not None and cfg.policy.type not in ["rewind", "sarm"]: + # Create processors - only provide dataset_stats if not resuming from saved processors + processor_kwargs = {} + postprocessor_kwargs = {} + if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path: + # Only provide dataset_stats when not resuming from saved processor state + processor_kwargs["dataset_stats"] = dataset.meta.stats + + # For ReWiND and SARM, always provide dataset_meta for progress normalization + if cfg.policy.type in ["rewind", "sarm"]: + processor_kwargs["dataset_meta"] = dataset.meta + + if cfg.policy.pretrained_path is not None: processor_kwargs["preprocessor_overrides"] = { "device_processor": {"device": device.type}, "normalizer_processor": { @@ -267,12 +253,12 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): }, } - preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg.policy, - pretrained_path=cfg.policy.pretrained_path, - **processor_kwargs, - **postprocessor_kwargs, - ) + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + **processor_kwargs, + **postprocessor_kwargs, + ) if is_main_process: logging.info("Creating optimizer and scheduler")