mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 03:11:29 +00:00
refactor(pipeline): feature contract now categorizes between OBS or Action (#1867)
* refactor(processor): signature of transform_features * refactor(processor): remove prefixes + processor respect new transform_features signature + update test accordingly * refactor(processor): rename now is only for visual * refactor(processor): update normalize processor * refactor(processor): update vanilla processor features * refactor(processor): feature contract now uses its own enum * chore(processor): rename renameprocessor * chore(processor): minor changes * refactor(processor): add create & change aggregate * refactor(processor): update aggregate * refactor(processor): simplify to functions, fix features contracts and rename function * test(processor): remove to converter tests as now they are very simple * chore(docs): recover docs joint observations processor * fix(processor): update RKP * fix(tests): recv diff test_pipeline * chore(tests): add docs to test * chore(processor): leave obs language constant untouched * fix(processor): correct new shape of feature in crop image processor
This commit is contained in:
@@ -8,7 +8,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_LANGUAGE
|
||||
from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey
|
||||
from tests.utils import require_package
|
||||
@@ -512,23 +512,27 @@ def test_features_basic():
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=128)
|
||||
|
||||
input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))
|
||||
},
|
||||
PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))},
|
||||
}
|
||||
|
||||
output_features = processor.transform_features(input_features)
|
||||
|
||||
# Check that original features are preserved
|
||||
assert "observation.state" in output_features
|
||||
assert "action" in output_features
|
||||
assert "observation.state" in output_features[PipelineFeatureType.OBSERVATION]
|
||||
assert "action" in output_features[PipelineFeatureType.ACTION]
|
||||
|
||||
# Check that tokenized features are added
|
||||
assert f"{OBS_LANGUAGE}.tokens" in output_features
|
||||
assert f"{OBS_LANGUAGE}.attention_mask" in output_features
|
||||
assert f"{OBS_LANGUAGE}.tokens" in output_features[PipelineFeatureType.OBSERVATION]
|
||||
assert f"{OBS_LANGUAGE}.attention_mask" in output_features[PipelineFeatureType.OBSERVATION]
|
||||
|
||||
# Check feature properties
|
||||
tokens_feature = output_features[f"{OBS_LANGUAGE}.tokens"]
|
||||
attention_mask_feature = output_features[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
tokens_feature = output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
|
||||
attention_mask_feature = output_features[PipelineFeatureType.OBSERVATION][
|
||||
f"{OBS_LANGUAGE}.attention_mask"
|
||||
]
|
||||
|
||||
assert tokens_feature.type == FeatureType.LANGUAGE
|
||||
assert tokens_feature.shape == (128,)
|
||||
@@ -542,15 +546,17 @@ def test_features_with_custom_max_length():
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=64)
|
||||
|
||||
input_features = {}
|
||||
input_features = {PipelineFeatureType.OBSERVATION: {}}
|
||||
output_features = processor.transform_features(input_features)
|
||||
|
||||
# Check that features use correct max_length
|
||||
assert f"{OBS_LANGUAGE}.tokens" in output_features
|
||||
assert f"{OBS_LANGUAGE}.attention_mask" in output_features
|
||||
assert f"{OBS_LANGUAGE}.tokens" in output_features[PipelineFeatureType.OBSERVATION]
|
||||
assert f"{OBS_LANGUAGE}.attention_mask" in output_features[PipelineFeatureType.OBSERVATION]
|
||||
|
||||
tokens_feature = output_features[f"{OBS_LANGUAGE}.tokens"]
|
||||
attention_mask_feature = output_features[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
tokens_feature = output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
|
||||
attention_mask_feature = output_features[PipelineFeatureType.OBSERVATION][
|
||||
f"{OBS_LANGUAGE}.attention_mask"
|
||||
]
|
||||
|
||||
assert tokens_feature.shape == (64,)
|
||||
assert attention_mask_feature.shape == (64,)
|
||||
@@ -563,15 +569,19 @@ def test_features_existing_features():
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=256)
|
||||
|
||||
input_features = {
|
||||
f"{OBS_LANGUAGE}.tokens": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)),
|
||||
f"{OBS_LANGUAGE}.attention_mask": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)),
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
f"{OBS_LANGUAGE}.tokens": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)),
|
||||
f"{OBS_LANGUAGE}.attention_mask": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)),
|
||||
}
|
||||
}
|
||||
|
||||
output_features = processor.transform_features(input_features)
|
||||
|
||||
# Should not overwrite existing features
|
||||
assert output_features[f"{OBS_LANGUAGE}.tokens"].shape == (100,) # Original shape preserved
|
||||
assert output_features[f"{OBS_LANGUAGE}.attention_mask"].shape == (100,)
|
||||
assert output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.tokens"].shape == (
|
||||
100,
|
||||
) # Original shape preserved
|
||||
assert output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"].shape == (100,)
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
|
||||
Reference in New Issue
Block a user