add subtask init and detection

This commit is contained in:
Pepijn
2025-11-25 22:06:20 +01:00
parent 0c99b768f4
commit 2dc2a3ae55
2 changed files with 62 additions and 19 deletions

View File

@@ -208,12 +208,16 @@ class SARMRewardModel(PreTrainedPolicy):
name = "sarm"
config_class = SARMConfig
def __init__(self, config: SARMConfig, dataset_stats: dict | None = None):
def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None):
super().__init__(config, dataset_stats)
self.config = config
self.dataset_stats = dataset_stats
self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu")
# Auto-detect num_stages from dataset annotations before building the model
if dataset_meta is not None:
self._update_num_stages_from_dataset(dataset_meta)
# Initialize CLIP encoder for images
logging.info("Loading CLIP encoder...")
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
@@ -276,6 +280,31 @@ class SARMRewardModel(PreTrainedPolicy):
logging.info(f"SARM Reward Model initialized on {self.device}")
def _update_num_stages_from_dataset(self, dataset_meta) -> None:
"""Update num_stages in config based on dataset subtask annotations."""
episodes = dataset_meta.episodes
if "annotation.subtask.name" not in episodes:
raise ValueError("No subtask annotations found in dataset annotations")
all_subtask_names = set()
for i in range(len(episodes["annotation.subtask.name"])):
subtask_names = episodes["annotation.subtask.name"][i]
if subtask_names:
for name in subtask_names:
all_subtask_names.add(name)
if not all_subtask_names:
raise ValueError("No subtask names found in dataset annotations")
# Sort subtask names for consistent ordering
subtask_names = sorted(list(all_subtask_names))
num_stages = len(subtask_names)
self.config.num_stages = num_stages
self.config.subtask_names = subtask_names
logging.info(f"Auto-detected {num_stages} subtasks from dataset: {subtask_names}")
def to(self, device):
"""Override to method to ensure all components move together."""
super().to(device)

View File

@@ -211,6 +211,30 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
# Create processors first - for SARM/ReWiND, this updates config.num_stages from dataset annotations
# This must happen BEFORE make_policy() so the model is created with the correct number of stages
processor_kwargs = {}
postprocessor_kwargs = {}
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
# Only provide dataset_stats when not resuming from saved processor state
processor_kwargs["dataset_stats"] = dataset.meta.stats
# For ReWiND and SARM, always provide dataset_meta for progress normalization
if cfg.policy.type in ["rewind", "sarm"]:
processor_kwargs["dataset_meta"] = dataset.meta
# For pretrained paths, we need to defer some overrides until after policy creation
# But for SARM/ReWiND, we need processor to run first to update num_stages
if cfg.policy.pretrained_path is None or cfg.policy.type in ["rewind", "sarm"]:
if is_main_process:
logging.info("Creating processors")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path if cfg.policy.type not in ["rewind", "sarm"] else None,
**processor_kwargs,
**postprocessor_kwargs,
)
if is_main_process:
logging.info("Creating policy")
policy = make_policy(
@@ -222,18 +246,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# Wait for all processes to finish policy creation before continuing
accelerator.wait_for_everyone()
# Create processors - only provide dataset_stats if not resuming from saved processors
processor_kwargs = {}
postprocessor_kwargs = {}
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
# Only provide dataset_stats when not resuming from saved processor state
processor_kwargs["dataset_stats"] = dataset.meta.stats
# For ReWiND and SARM, always provide dataset_meta for progress normalization
if cfg.policy.type in ["rewind", "sarm"]:
processor_kwargs["dataset_meta"] = dataset.meta
if cfg.policy.pretrained_path is not None:
# For pretrained paths (non-SARM/ReWiND), create processors now with policy-dependent overrides
if cfg.policy.pretrained_path is not None and cfg.policy.type not in ["rewind", "sarm"]:
processor_kwargs["preprocessor_overrides"] = {
"device_processor": {"device": device.type},
"normalizer_processor": {
@@ -253,12 +267,12 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
},
}
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
**processor_kwargs,
**postprocessor_kwargs,
)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
**processor_kwargs,
**postprocessor_kwargs,
)
if is_main_process:
logging.info("Creating optimizer and scheduler")