add reward output

This commit is contained in:
Pepijn
2025-11-25 13:44:04 +01:00
parent ca67231892
commit d286ea30d4

View File

@@ -91,8 +91,8 @@ class SARMConfig(PreTrainedConfig):
"state_features": PolicyFeature(shape=(9, 14), type=FeatureType.STATE) # Example: 7 DOF × 2 arms
})
output_features: dict = field(default_factory=lambda: {
"stage": PolicyFeature(shape=(1,), type=FeatureType.OTHER),
"progress": PolicyFeature(shape=(1,), type=FeatureType.OTHER)
"stage": PolicyFeature(shape=(1,), type=FeatureType.REWARD),
"progress": PolicyFeature(shape=(1,), type=FeatureType.REWARD)
})
def __post_init__(self):