refactor(tests): streamline transition creation in processor tests

- Replaced custom transition creation functions with a centralized `create_transition` function imported from converters across multiple test files.
- Updated test cases to utilize keyword arguments for better readability and maintainability, ensuring consistent transition creation throughout the test suite.
This commit is contained in:
AdilZouitine
2025-09-10 13:08:44 +02:00
parent f286eb059c
commit 6f1e49dbc4
10 changed files with 165 additions and 218 deletions

View File

@@ -37,6 +37,7 @@ from lerobot.processor import (
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import create_transition
class MockTokenizerProcessorStep(ProcessorStep):
@@ -55,21 +56,6 @@ class MockTokenizerProcessorStep(ProcessorStep):
return features
def create_transition(observation=None, action=None, **kwargs):
"""Helper function to create a transition dictionary."""
transition = {}
if observation is not None:
transition[TransitionKey.OBSERVATION] = observation
if action is not None:
transition[TransitionKey.ACTION] = action
for key, value in kwargs.items():
if hasattr(TransitionKey, key.upper()):
transition[getattr(TransitionKey, key.upper())] = value
elif key == "complementary_data":
transition[TransitionKey.COMPLEMENTARY_DATA] = value
return transition
def create_default_config():
"""Create a default SmolVLA configuration for testing."""
config = SmolVLAConfig()
@@ -228,7 +214,8 @@ def test_smolvla_processor_cuda():
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(7)
transition = create_transition(observation, action, complementary_data={"task": "test task"})
transition = create_transition(observation=observation, complementary_data={"task": "test task"})
transition[TransitionKey.ACTION] = action
# Process through preprocessor
processed = preprocessor(transition)
@@ -286,7 +273,8 @@ def test_smolvla_processor_accelerate_scenario():
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
}
action = torch.randn(1, 7).to(device)
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
transition[TransitionKey.ACTION] = action
# Process through preprocessor
processed = preprocessor(transition)
@@ -344,7 +332,8 @@ def test_smolvla_processor_multi_gpu():
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
}
action = torch.randn(1, 7).to(device)
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
transition[TransitionKey.ACTION] = action
# Process through preprocessor
processed = preprocessor(transition)
@@ -455,8 +444,9 @@ def test_smolvla_processor_bfloat16_device_float32_normalizer():
}
action = torch.randn(7, dtype=torch.float32)
transition = create_transition(
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
observation=observation, complementary_data={"task": "test bfloat16 adaptation"}
)
transition[TransitionKey.ACTION] = action
# Process through full pipeline
processed = preprocessor(transition)