raise if no state key is found

This commit is contained in:
Pepijn
2025-11-25 16:21:29 +01:00
parent 7beb20819e
commit 6b6a82bbdf
2 changed files with 20 additions and 3 deletions

View File

@@ -222,7 +222,7 @@ class SARMRewardModel(PreTrainedPolicy):
# Initialize CLIP encoder for images
logging.info("Loading CLIP encoder...")
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
self.clip_model.to(self.device)
self.clip_model.eval()
@@ -237,8 +237,10 @@ class SARMRewardModel(PreTrainedPolicy):
self.minilm_model.to(self.device)
self.minilm_model.eval()
# Auto-detect state_dim from dataset_statss
# Auto-detect state_dim from dataset_stats
if config.state_dim is None:
logging.info(f"Attempting to auto-detect state_dim. dataset_stats is None: {dataset_stats is None}")
if dataset_stats is not None:
if "observation.state" in dataset_stats:
config.state_dim = dataset_stats["observation.state"]["mean"].shape[0]
@@ -246,7 +248,22 @@ class SARMRewardModel(PreTrainedPolicy):
elif "state" in dataset_stats:
config.state_dim = dataset_stats["state"]["mean"].shape[0]
logging.info(f"Auto-detected state_dim={config.state_dim} from dataset_stats['state']")
else:
logging.warning(f"State keys not found in dataset_stats. Available keys: {list(dataset_stats.keys())}")
else:
logging.warning("dataset_stats is None, cannot auto-detect state_dim")
# Raise explicit error if still None
if config.state_dim is None:
raise ValueError(
"Could not determine state_dim! "
f"dataset_stats={'None' if dataset_stats is None else f'available with keys: {list(dataset_stats.keys())}'}, "
"config.state_dim=None. "
"Please either:\n"
"1. Provide --policy.state_dim=<your_state_dimension> explicitly, or\n"
"2. Ensure dataset_stats contains 'observation.state' or 'state' key"
)
# Initialize SARM transformer
self.sarm_transformer = SARMTransformer(
video_dim=config.image_dim,

View File

@@ -85,7 +85,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
logging.info("Initializing CLIP encoder for SARM...")
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
self.clip_model.to(device)
self.clip_model.eval()