diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/policies/sarm/configuration_sarm.py index 3592f050b..17119fbad 100644 --- a/src/lerobot/policies/sarm/configuration_sarm.py +++ b/src/lerobot/policies/sarm/configuration_sarm.py @@ -17,6 +17,7 @@ from dataclasses import dataclass, field from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import PolicyFeature, FeatureType from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig @@ -85,13 +86,13 @@ class SARMConfig(PreTrainedConfig): # Features (required by PreTrainedPolicy) input_features: dict = field(default_factory=lambda: { - "video_features": {"shape": [9, 512], "dtype": "float32"}, - "text_features": {"shape": [384], "dtype": "float32"}, - "state_features": {"shape": [9, 14], "dtype": "float32"} # Example: 7 DOF × 2 arms + "video_features": PolicyFeature(shape=(9, 512), type=FeatureType.VISUAL), + "text_features": PolicyFeature(shape=(384,), type=FeatureType.LANGUAGE), + "state_features": PolicyFeature(shape=(9, 14), type=FeatureType.STATE) # Example: 7 DOF × 2 arms }) output_features: dict = field(default_factory=lambda: { - "stage": {"shape": [1], "dtype": "int64"}, - "progress": {"shape": [1], "dtype": "float32"} + "stage": PolicyFeature(shape=(1,), type=FeatureType.OTHER), + "progress": PolicyFeature(shape=(1,), type=FeatureType.OTHER) }) def __post_init__(self):