get state input from dataset stats

This commit is contained in:
Pepijn
2025-11-25 16:17:28 +01:00
parent 9a5a0ad575
commit 7beb20819e

View File

@@ -237,19 +237,16 @@ class SARMRewardModel(PreTrainedPolicy):
self.minilm_model.to(self.device)
self.minilm_model.eval()
# Auto-detect state_dim from input_features if not explicitly set
# Auto-detect state_dim from dataset_statss
if config.state_dim is None:
# Look for "observation.state" or "state" in input_features
if "observation.state" in config.input_features:
config.state_dim = config.input_features["observation.state"].shape[0]
logging.info(f"Auto-detected state_dim={config.state_dim} from input_features['observation.state']")
elif "state" in config.input_features:
config.state_dim = config.input_features["state"].shape[0]
logging.info(f"Auto-detected state_dim={config.state_dim} from input_features['state']")
else:
config.state_dim = 14
logging.warning(f"Could not find state in input_features, using default state_dim={config.state_dim}")
if dataset_stats is not None:
if "observation.state" in dataset_stats:
config.state_dim = dataset_stats["observation.state"]["mean"].shape[0]
logging.info(f"Auto-detected state_dim={config.state_dim} from dataset_stats['observation.state']")
elif "state" in dataset_stats:
config.state_dim = dataset_stats["state"]["mean"].shape[0]
logging.info(f"Auto-detected state_dim={config.state_dim} from dataset_stats['state']")
# Initialize SARM transformer
self.sarm_transformer = SARMTransformer(
video_dim=config.image_dim,