mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
get state input from dataset stats
This commit is contained in:
@@ -237,19 +237,16 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
self.minilm_model.to(self.device)
|
||||
self.minilm_model.eval()
|
||||
|
||||
# Auto-detect state_dim from input_features if not explicitly set
|
||||
# Auto-detect state_dim from dataset_statss
|
||||
if config.state_dim is None:
|
||||
# Look for "observation.state" or "state" in input_features
|
||||
if "observation.state" in config.input_features:
|
||||
config.state_dim = config.input_features["observation.state"].shape[0]
|
||||
logging.info(f"Auto-detected state_dim={config.state_dim} from input_features['observation.state']")
|
||||
elif "state" in config.input_features:
|
||||
config.state_dim = config.input_features["state"].shape[0]
|
||||
logging.info(f"Auto-detected state_dim={config.state_dim} from input_features['state']")
|
||||
else:
|
||||
config.state_dim = 14
|
||||
logging.warning(f"Could not find state in input_features, using default state_dim={config.state_dim}")
|
||||
|
||||
if dataset_stats is not None:
|
||||
if "observation.state" in dataset_stats:
|
||||
config.state_dim = dataset_stats["observation.state"]["mean"].shape[0]
|
||||
logging.info(f"Auto-detected state_dim={config.state_dim} from dataset_stats['observation.state']")
|
||||
elif "state" in dataset_stats:
|
||||
config.state_dim = dataset_stats["state"]["mean"].shape[0]
|
||||
logging.info(f"Auto-detected state_dim={config.state_dim} from dataset_stats['state']")
|
||||
|
||||
# Initialize SARM transformer
|
||||
self.sarm_transformer = SARMTransformer(
|
||||
video_dim=config.image_dim,
|
||||
|
||||
Reference in New Issue
Block a user