mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 04:11:24 +00:00
refactor(tests): streamline transition creation in processor tests
- Replaced custom transition creation functions with a centralized `create_transition` function imported from converters across multiple test files. - Updated test cases to utilize keyword arguments for better readability and maintainability, ensuring consistent transition creation throughout the test suite.
This commit is contained in:
@@ -37,6 +37,7 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
@@ -55,21 +56,6 @@ class MockTokenizerProcessorStep(ProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
elif key == "complementary_data":
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default SmolVLA configuration for testing."""
|
||||
config = SmolVLAConfig()
|
||||
@@ -228,7 +214,8 @@ def test_smolvla_processor_cuda():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(7)
|
||||
transition = create_transition(observation, action, complementary_data={"task": "test task"})
|
||||
transition = create_transition(observation=observation, complementary_data={"task": "test task"})
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -286,7 +273,8 @@ def test_smolvla_processor_accelerate_scenario():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 7).to(device)
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -344,7 +332,8 @@ def test_smolvla_processor_multi_gpu():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 7).to(device)
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -455,8 +444,9 @@ def test_smolvla_processor_bfloat16_device_float32_normalizer():
|
||||
}
|
||||
action = torch.randn(7, dtype=torch.float32)
|
||||
transition = create_transition(
|
||||
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
|
||||
observation=observation, complementary_data={"task": "test bfloat16 adaptation"}
|
||||
)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
Reference in New Issue
Block a user