refactor(tests): streamline transition creation in processor tests

- Replaced custom transition creation functions with a centralized `create_transition` function imported from converters across multiple test files.
- Updated test cases to utilize keyword arguments for better readability and maintainability, ensuring consistent transition creation throughout the test suite.
This commit is contained in:
AdilZouitine
2025-09-10 13:08:44 +02:00
parent f286eb059c
commit 6f1e49dbc4
10 changed files with 165 additions and 218 deletions

View File

@@ -34,6 +34,7 @@ from lerobot.processor import (
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import create_transition
class MockTokenizerProcessorStep(ProcessorStep):
@@ -52,21 +53,6 @@ class MockTokenizerProcessorStep(ProcessorStep):
return features
def create_transition(observation=None, action=None, **kwargs):
"""Helper function to create a transition dictionary."""
transition = {}
if observation is not None:
transition[TransitionKey.OBSERVATION] = observation
if action is not None:
transition[TransitionKey.ACTION] = action
for key, value in kwargs.items():
if hasattr(TransitionKey, key.upper()):
transition[getattr(TransitionKey, key.upper())] = value
elif key == "complementary_data":
transition[TransitionKey.COMPLEMENTARY_DATA] = value
return transition
def create_default_config():
"""Create a default PI0 configuration for testing."""
config = PI0Config()
@@ -219,7 +205,8 @@ def test_pi0_processor_cuda():
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(6)
transition = create_transition(observation, action, complementary_data={"task": "test task"})
transition = create_transition(observation=observation, complementary_data={"task": "test task"})
transition[TransitionKey.ACTION] = action
# Process through preprocessor
processed = preprocessor(transition)
@@ -275,7 +262,8 @@ def test_pi0_processor_accelerate_scenario():
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
}
action = torch.randn(1, 6).to(device)
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
transition[TransitionKey.ACTION] = action
# Process through preprocessor
processed = preprocessor(transition)
@@ -331,7 +319,8 @@ def test_pi0_processor_multi_gpu():
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
}
action = torch.randn(1, 6).to(device)
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
transition[TransitionKey.ACTION] = action
# Process through preprocessor
processed = preprocessor(transition)
@@ -426,8 +415,9 @@ def test_pi0_processor_bfloat16_device_float32_normalizer():
}
action = torch.randn(6, dtype=torch.float32) # PI0 expects size 6
transition = create_transition(
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
observation=observation, complementary_data={"task": "test bfloat16 adaptation"}
)
transition[TransitionKey.ACTION] = action
# Process through full pipeline
processed = preprocessor(transition)