update sarm processor

This commit is contained in:
Pepijn
2025-11-25 13:40:04 +01:00
parent 5245332e36
commit ca67231892

View File

@@ -17,6 +17,7 @@
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import PolicyFeature, FeatureType
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@@ -85,13 +86,13 @@ class SARMConfig(PreTrainedConfig):
# Features (required by PreTrainedPolicy)
input_features: dict = field(default_factory=lambda: {
"video_features": {"shape": [9, 512], "dtype": "float32"},
"text_features": {"shape": [384], "dtype": "float32"},
"state_features": {"shape": [9, 14], "dtype": "float32"} # Example: 7 DOF × 2 arms
"video_features": PolicyFeature(shape=(9, 512), type=FeatureType.VISUAL),
"text_features": PolicyFeature(shape=(384,), type=FeatureType.LANGUAGE),
"state_features": PolicyFeature(shape=(9, 14), type=FeatureType.STATE) # Example: 7 DOF × 2 arms
})
output_features: dict = field(default_factory=lambda: {
"stage": {"shape": [1], "dtype": "int64"},
"progress": {"shape": [1], "dtype": "float32"}
"stage": PolicyFeature(shape=(1,), type=FeatureType.OTHER),
"progress": PolicyFeature(shape=(1,), type=FeatureType.OTHER)
})
def __post_init__(self):