mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 11:51:25 +00:00
refactor(tests): streamline transition creation in processor tests
- Replaced custom transition creation functions with a centralized `create_transition` function imported from converters across multiple test files. - Updated test cases to utilize keyword arguments for better readability and maintainability, ensuring consistent transition creation throughout the test suite.
This commit is contained in:
@@ -34,6 +34,7 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
class MockTokenizerProcessorStep(ProcessorStep):
|
||||
@@ -52,21 +53,6 @@ class MockTokenizerProcessorStep(ProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
elif key == "complementary_data":
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default PI0 configuration for testing."""
|
||||
config = PI0Config()
|
||||
@@ -219,7 +205,8 @@ def test_pi0_processor_cuda():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(6)
|
||||
transition = create_transition(observation, action, complementary_data={"task": "test task"})
|
||||
transition = create_transition(observation=observation, complementary_data={"task": "test task"})
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -275,7 +262,8 @@ def test_pi0_processor_accelerate_scenario():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -331,7 +319,8 @@ def test_pi0_processor_multi_gpu():
|
||||
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1, 6).to(device)
|
||||
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
|
||||
transition = create_transition(observation=observation, complementary_data={"task": ["test task"]})
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -426,8 +415,9 @@ def test_pi0_processor_bfloat16_device_float32_normalizer():
|
||||
}
|
||||
action = torch.randn(6, dtype=torch.float32) # PI0 expects size 6
|
||||
transition = create_transition(
|
||||
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
|
||||
observation=observation, complementary_data={"task": "test bfloat16 adaptation"}
|
||||
)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
Reference in New Issue
Block a user