mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 03:11:29 +00:00
revert lerobot_train changes
This commit is contained in:
@@ -211,30 +211,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
logging.info("Creating env")
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
# Create processors first - for SARM/ReWiND, this updates config.num_stages from dataset annotations
|
||||
# This must happen BEFORE make_policy() so the model is created with the correct number of stages
|
||||
processor_kwargs = {}
|
||||
postprocessor_kwargs = {}
|
||||
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
|
||||
# Only provide dataset_stats when not resuming from saved processor state
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
# For ReWiND and SARM, always provide dataset_meta for progress normalization
|
||||
if cfg.policy.type in ["rewind", "sarm"]:
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
# For pretrained paths, we need to defer some overrides until after policy creation
|
||||
# But for SARM/ReWiND, we need processor to run first to update num_stages
|
||||
if cfg.policy.pretrained_path is None or cfg.policy.type in ["rewind", "sarm"]:
|
||||
if is_main_process:
|
||||
logging.info("Creating processors")
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path if cfg.policy.type not in ["rewind", "sarm"] else None,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
logging.info("Creating policy")
|
||||
policy = make_policy(
|
||||
@@ -246,8 +222,18 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# Wait for all processes to finish policy creation before continuing
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# For pretrained paths (non-SARM/ReWiND), create processors now with policy-dependent overrides
|
||||
if cfg.policy.pretrained_path is not None and cfg.policy.type not in ["rewind", "sarm"]:
|
||||
# Create processors - only provide dataset_stats if not resuming from saved processors
|
||||
processor_kwargs = {}
|
||||
postprocessor_kwargs = {}
|
||||
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
|
||||
# Only provide dataset_stats when not resuming from saved processor state
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
# For ReWiND and SARM, always provide dataset_meta for progress normalization
|
||||
if cfg.policy.type in ["rewind", "sarm"]:
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
if cfg.policy.pretrained_path is not None:
|
||||
processor_kwargs["preprocessor_overrides"] = {
|
||||
"device_processor": {"device": device.type},
|
||||
"normalizer_processor": {
|
||||
@@ -267,12 +253,12 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
},
|
||||
}
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
|
||||
Reference in New Issue
Block a user