mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
raise if no state key is found
This commit is contained in:
@@ -222,7 +222,7 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
# Initialize CLIP encoder for images
|
||||
logging.info("Loading CLIP encoder...")
|
||||
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
|
||||
self.clip_model.to(self.device)
|
||||
self.clip_model.eval()
|
||||
|
||||
@@ -237,8 +237,10 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
self.minilm_model.to(self.device)
|
||||
self.minilm_model.eval()
|
||||
|
||||
# Auto-detect state_dim from dataset_statss
|
||||
# Auto-detect state_dim from dataset_stats
|
||||
if config.state_dim is None:
|
||||
logging.info(f"Attempting to auto-detect state_dim. dataset_stats is None: {dataset_stats is None}")
|
||||
|
||||
if dataset_stats is not None:
|
||||
if "observation.state" in dataset_stats:
|
||||
config.state_dim = dataset_stats["observation.state"]["mean"].shape[0]
|
||||
@@ -246,7 +248,22 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
elif "state" in dataset_stats:
|
||||
config.state_dim = dataset_stats["state"]["mean"].shape[0]
|
||||
logging.info(f"Auto-detected state_dim={config.state_dim} from dataset_stats['state']")
|
||||
else:
|
||||
logging.warning(f"State keys not found in dataset_stats. Available keys: {list(dataset_stats.keys())}")
|
||||
else:
|
||||
logging.warning("dataset_stats is None, cannot auto-detect state_dim")
|
||||
|
||||
# Raise explicit error if still None
|
||||
if config.state_dim is None:
|
||||
raise ValueError(
|
||||
"Could not determine state_dim! "
|
||||
f"dataset_stats={'None' if dataset_stats is None else f'available with keys: {list(dataset_stats.keys())}'}, "
|
||||
"config.state_dim=None. "
|
||||
"Please either:\n"
|
||||
"1. Provide --policy.state_dim=<your_state_dimension> explicitly, or\n"
|
||||
"2. Ensure dataset_stats contains 'observation.state' or 'state' key"
|
||||
)
|
||||
|
||||
# Initialize SARM transformer
|
||||
self.sarm_transformer = SARMTransformer(
|
||||
video_dim=config.image_dim,
|
||||
|
||||
@@ -85,7 +85,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
|
||||
logging.info("Initializing CLIP encoder for SARM...")
|
||||
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
|
||||
self.clip_model.to(device)
|
||||
self.clip_model.eval()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user