change expected features

This commit is contained in:
Pepijn
2025-11-25 13:51:01 +01:00
parent d286ea30d4
commit 8d2fb5d298

View File

@@ -83,11 +83,9 @@ class SARMConfig(PreTrainedConfig):
encode_on_the_fly: bool = True # Encode images/text during training
use_dataset_task: bool = True # Use task descriptions from dataset
use_subtask_annotations: bool = True # Use subtask annotations for stage-aware training if available
# Features (required by PreTrainedPolicy)
# Video_features and text_features are generated by the processor from raw images/text, we don't declare them as VISUAL/LANGUAGE here to avoid validation errors
input_features: dict = field(default_factory=lambda: {
"video_features": PolicyFeature(shape=(9, 512), type=FeatureType.VISUAL),
"text_features": PolicyFeature(shape=(384,), type=FeatureType.LANGUAGE),
"state_features": PolicyFeature(shape=(9, 14), type=FeatureType.STATE) # Example: 7 DOF × 2 arms
})
output_features: dict = field(default_factory=lambda: {