From 6b6a82bbdf981d8ba055ef49cc6067074b2bf09b Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 25 Nov 2025 16:21:29 +0100 Subject: [PATCH] raise if no state key is found --- src/lerobot/policies/sarm/modeling_sarm.py | 21 +++++++++++++++++++-- src/lerobot/policies/sarm/processor_sarm.py | 2 +- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/policies/sarm/modeling_sarm.py index bc8371897..05c7d8fb7 100644 --- a/src/lerobot/policies/sarm/modeling_sarm.py +++ b/src/lerobot/policies/sarm/modeling_sarm.py @@ -222,7 +222,7 @@ class SARMRewardModel(PreTrainedPolicy): # Initialize CLIP encoder for images logging.info("Loading CLIP encoder...") self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") - self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") + self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True) self.clip_model.to(self.device) self.clip_model.eval() @@ -237,8 +237,10 @@ class SARMRewardModel(PreTrainedPolicy): self.minilm_model.to(self.device) self.minilm_model.eval() - # Auto-detect state_dim from dataset_statss + # Auto-detect state_dim from dataset_stats if config.state_dim is None: + logging.info(f"Attempting to auto-detect state_dim. dataset_stats is None: {dataset_stats is None}") + if dataset_stats is not None: if "observation.state" in dataset_stats: config.state_dim = dataset_stats["observation.state"]["mean"].shape[0] @@ -246,7 +248,22 @@ class SARMRewardModel(PreTrainedPolicy): elif "state" in dataset_stats: config.state_dim = dataset_stats["state"]["mean"].shape[0] logging.info(f"Auto-detected state_dim={config.state_dim} from dataset_stats['state']") + else: + logging.warning(f"State keys not found in dataset_stats. Available keys: {list(dataset_stats.keys())}") + else: + logging.warning("dataset_stats is None, cannot auto-detect state_dim") + # Raise explicit error if still None + if config.state_dim is None: + raise ValueError( + "Could not determine state_dim! " + f"dataset_stats={'None' if dataset_stats is None else f'available with keys: {list(dataset_stats.keys())}'}, " + "config.state_dim=None. " + "Please either:\n" + "1. Provide --policy.state_dim= explicitly, or\n" + "2. Ensure dataset_stats contains 'observation.state' or 'state' key" + ) + # Initialize SARM transformer self.sarm_transformer = SARMTransformer( video_dim=config.image_dim, diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py index c3a6e915e..74ed5fafb 100644 --- a/src/lerobot/policies/sarm/processor_sarm.py +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -85,7 +85,7 @@ class SARMEncodingProcessorStep(ProcessorStep): logging.info("Initializing CLIP encoder for SARM...") self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") - self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") + self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True) self.clip_model.to(device) self.clip_model.eval()