From 2dc2a3ae550e604eb9f131f19bcbbbe13b9a4bf7 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 25 Nov 2025 22:06:20 +0100 Subject: [PATCH] add subtask init and detection --- src/lerobot/policies/sarm/modeling_sarm.py | 31 +++++++++++++- src/lerobot/scripts/lerobot_train.py | 50 ++++++++++++++-------- 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/policies/sarm/modeling_sarm.py index 9d2319fd3..7cb4b879f 100644 --- a/src/lerobot/policies/sarm/modeling_sarm.py +++ b/src/lerobot/policies/sarm/modeling_sarm.py @@ -208,12 +208,16 @@ class SARMRewardModel(PreTrainedPolicy): name = "sarm" config_class = SARMConfig - def __init__(self, config: SARMConfig, dataset_stats: dict | None = None): + def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None): super().__init__(config, dataset_stats) self.config = config self.dataset_stats = dataset_stats self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu") + # Auto-detect num_stages from dataset annotations before building the model + if dataset_meta is not None: + self._update_num_stages_from_dataset(dataset_meta) + # Initialize CLIP encoder for images logging.info("Loading CLIP encoder...") self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") @@ -276,6 +280,31 @@ class SARMRewardModel(PreTrainedPolicy): logging.info(f"SARM Reward Model initialized on {self.device}") + def _update_num_stages_from_dataset(self, dataset_meta) -> None: + """Update num_stages in config based on dataset subtask annotations.""" + episodes = dataset_meta.episodes + if "annotation.subtask.name" not in episodes: + raise ValueError("No subtask annotations found in dataset annotations") + + all_subtask_names = set() + for i in range(len(episodes["annotation.subtask.name"])): + subtask_names = episodes["annotation.subtask.name"][i] + if subtask_names: + for name in subtask_names: + all_subtask_names.add(name) + + if not all_subtask_names: + raise ValueError("No subtask names found in dataset annotations") + + # Sort subtask names for consistent ordering + subtask_names = sorted(list(all_subtask_names)) + num_stages = len(subtask_names) + + self.config.num_stages = num_stages + self.config.subtask_names = subtask_names + + logging.info(f"Auto-detected {num_stages} subtasks from dataset: {subtask_names}") + def to(self, device): """Override to method to ensure all components move together.""" super().to(device) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index a4fede3d3..2cee5372a 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -211,6 +211,30 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): logging.info("Creating env") eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) + # Create processors first - for SARM/ReWiND, this updates config.num_stages from dataset annotations + # This must happen BEFORE make_policy() so the model is created with the correct number of stages + processor_kwargs = {} + postprocessor_kwargs = {} + if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path: + # Only provide dataset_stats when not resuming from saved processor state + processor_kwargs["dataset_stats"] = dataset.meta.stats + + # For ReWiND and SARM, always provide dataset_meta for progress normalization + if cfg.policy.type in ["rewind", "sarm"]: + processor_kwargs["dataset_meta"] = dataset.meta + + # For pretrained paths, we need to defer some overrides until after policy creation + # But for SARM/ReWiND, we need processor to run first to update num_stages + if cfg.policy.pretrained_path is None or cfg.policy.type in ["rewind", "sarm"]: + if is_main_process: + logging.info("Creating processors") + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path if cfg.policy.type not in ["rewind", "sarm"] else None, + **processor_kwargs, + **postprocessor_kwargs, + ) + if is_main_process: logging.info("Creating policy") policy = make_policy( @@ -222,18 +246,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): # Wait for all processes to finish policy creation before continuing accelerator.wait_for_everyone() - # Create processors - only provide dataset_stats if not resuming from saved processors - processor_kwargs = {} - postprocessor_kwargs = {} - if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path: - # Only provide dataset_stats when not resuming from saved processor state - processor_kwargs["dataset_stats"] = dataset.meta.stats - - # For ReWiND and SARM, always provide dataset_meta for progress normalization - if cfg.policy.type in ["rewind", "sarm"]: - processor_kwargs["dataset_meta"] = dataset.meta - - if cfg.policy.pretrained_path is not None: + # For pretrained paths (non-SARM/ReWiND), create processors now with policy-dependent overrides + if cfg.policy.pretrained_path is not None and cfg.policy.type not in ["rewind", "sarm"]: processor_kwargs["preprocessor_overrides"] = { "device_processor": {"device": device.type}, "normalizer_processor": { @@ -253,12 +267,12 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): }, } - preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg.policy, - pretrained_path=cfg.policy.pretrained_path, - **processor_kwargs, - **postprocessor_kwargs, - ) + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + **processor_kwargs, + **postprocessor_kwargs, + ) if is_main_process: logging.info("Creating optimizer and scheduler")