mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 18:31:25 +00:00
update sarm processor
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import PolicyFeature, FeatureType
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
@@ -85,13 +86,13 @@ class SARMConfig(PreTrainedConfig):
|
||||
|
||||
# Features (required by PreTrainedPolicy)
|
||||
input_features: dict = field(default_factory=lambda: {
|
||||
"video_features": {"shape": [9, 512], "dtype": "float32"},
|
||||
"text_features": {"shape": [384], "dtype": "float32"},
|
||||
"state_features": {"shape": [9, 14], "dtype": "float32"} # Example: 7 DOF × 2 arms
|
||||
"video_features": PolicyFeature(shape=(9, 512), type=FeatureType.VISUAL),
|
||||
"text_features": PolicyFeature(shape=(384,), type=FeatureType.LANGUAGE),
|
||||
"state_features": PolicyFeature(shape=(9, 14), type=FeatureType.STATE) # Example: 7 DOF × 2 arms
|
||||
})
|
||||
output_features: dict = field(default_factory=lambda: {
|
||||
"stage": {"shape": [1], "dtype": "int64"},
|
||||
"progress": {"shape": [1], "dtype": "float32"}
|
||||
"stage": PolicyFeature(shape=(1,), type=FeatureType.OTHER),
|
||||
"progress": PolicyFeature(shape=(1,), type=FeatureType.OTHER)
|
||||
})
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
Reference in New Issue
Block a user