From 7beb20819e61094bfcc2a8c5d9cfbe5b6aeae478 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 25 Nov 2025 16:17:28 +0100 Subject: [PATCH] get state input from dataset stats --- src/lerobot/policies/sarm/modeling_sarm.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/policies/sarm/modeling_sarm.py index 084db750c..bc8371897 100644 --- a/src/lerobot/policies/sarm/modeling_sarm.py +++ b/src/lerobot/policies/sarm/modeling_sarm.py @@ -237,19 +237,16 @@ class SARMRewardModel(PreTrainedPolicy): self.minilm_model.to(self.device) self.minilm_model.eval() - # Auto-detect state_dim from input_features if not explicitly set + # Auto-detect state_dim from dataset_statss if config.state_dim is None: - # Look for "observation.state" or "state" in input_features - if "observation.state" in config.input_features: - config.state_dim = config.input_features["observation.state"].shape[0] - logging.info(f"Auto-detected state_dim={config.state_dim} from input_features['observation.state']") - elif "state" in config.input_features: - config.state_dim = config.input_features["state"].shape[0] - logging.info(f"Auto-detected state_dim={config.state_dim} from input_features['state']") - else: - config.state_dim = 14 - logging.warning(f"Could not find state in input_features, using default state_dim={config.state_dim}") - + if dataset_stats is not None: + if "observation.state" in dataset_stats: + config.state_dim = dataset_stats["observation.state"]["mean"].shape[0] + logging.info(f"Auto-detected state_dim={config.state_dim} from dataset_stats['observation.state']") + elif "state" in dataset_stats: + config.state_dim = dataset_stats["state"]["mean"].shape[0] + logging.info(f"Auto-detected state_dim={config.state_dim} from dataset_stats['state']") + # Initialize SARM transformer self.sarm_transformer = SARMTransformer( video_dim=config.image_dim,