diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/policies/sarm/configuration_sarm.py index 17119fbad..258531166 100644 --- a/src/lerobot/policies/sarm/configuration_sarm.py +++ b/src/lerobot/policies/sarm/configuration_sarm.py @@ -91,8 +91,8 @@ class SARMConfig(PreTrainedConfig): "state_features": PolicyFeature(shape=(9, 14), type=FeatureType.STATE) # Example: 7 DOF × 2 arms }) output_features: dict = field(default_factory=lambda: { - "stage": PolicyFeature(shape=(1,), type=FeatureType.OTHER), - "progress": PolicyFeature(shape=(1,), type=FeatureType.OTHER) + "stage": PolicyFeature(shape=(1,), type=FeatureType.REWARD), + "progress": PolicyFeature(shape=(1,), type=FeatureType.REWARD) }) def __post_init__(self):