revert lerobot_train changes

This commit is contained in:
Pepijn
2025-11-25 22:09:27 +01:00
parent 2dc2a3ae55
commit 006185ff4a

View File

@@ -211,30 +211,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
# Create processors first - for SARM/ReWiND, this updates config.num_stages from dataset annotations
# This must happen BEFORE make_policy() so the model is created with the correct number of stages
processor_kwargs = {}
postprocessor_kwargs = {}
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
# Only provide dataset_stats when not resuming from saved processor state
processor_kwargs["dataset_stats"] = dataset.meta.stats
# For ReWiND and SARM, always provide dataset_meta for progress normalization
if cfg.policy.type in ["rewind", "sarm"]:
processor_kwargs["dataset_meta"] = dataset.meta
# For pretrained paths, we need to defer some overrides until after policy creation
# But for SARM/ReWiND, we need processor to run first to update num_stages
if cfg.policy.pretrained_path is None or cfg.policy.type in ["rewind", "sarm"]:
if is_main_process:
logging.info("Creating processors")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path if cfg.policy.type not in ["rewind", "sarm"] else None,
**processor_kwargs,
**postprocessor_kwargs,
)
if is_main_process:
logging.info("Creating policy")
policy = make_policy(
@@ -246,8 +222,18 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# Wait for all processes to finish policy creation before continuing
accelerator.wait_for_everyone()
# For pretrained paths (non-SARM/ReWiND), create processors now with policy-dependent overrides
if cfg.policy.pretrained_path is not None and cfg.policy.type not in ["rewind", "sarm"]:
# Create processors - only provide dataset_stats if not resuming from saved processors
processor_kwargs = {}
postprocessor_kwargs = {}
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
# Only provide dataset_stats when not resuming from saved processor state
processor_kwargs["dataset_stats"] = dataset.meta.stats
# For ReWiND and SARM, always provide dataset_meta for progress normalization
if cfg.policy.type in ["rewind", "sarm"]:
processor_kwargs["dataset_meta"] = dataset.meta
if cfg.policy.pretrained_path is not None:
processor_kwargs["preprocessor_overrides"] = {
"device_processor": {"device": device.type},
"normalizer_processor": {
@@ -267,12 +253,12 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
},
}
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
**processor_kwargs,
**postprocessor_kwargs,
)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
**processor_kwargs,
**postprocessor_kwargs,
)
if is_main_process:
logging.info("Creating optimizer and scheduler")