refactor(tests): streamline transition creation in processor tests

- Replaced custom transition creation functions with a centralized `create_transition` function imported from converters across multiple test files.
- Updated test cases to utilize keyword arguments for better readability and maintainability, ensuring consistent transition creation throughout the test suite.
This commit is contained in:
AdilZouitine
2025-09-10 13:08:44 +02:00
parent f286eb059c
commit 6f1e49dbc4
10 changed files with 165 additions and 218 deletions

View File

@@ -33,19 +33,7 @@ from lerobot.processor import (
TransitionKey,
UnnormalizerProcessorStep,
)
def create_transition(observation=None, action=None, **kwargs):
"""Helper function to create a transition dictionary."""
transition = {}
if observation is not None:
transition[TransitionKey.OBSERVATION] = observation
if action is not None:
transition[TransitionKey.ACTION] = action
for key, value in kwargs.items():
if hasattr(TransitionKey, key.upper()):
transition[getattr(TransitionKey, key.upper())] = value
return transition
from lerobot.processor.converters import create_transition
def create_default_config():
@@ -117,7 +105,8 @@ def test_sac_processor_normalization_modes():
# Create test data
observation = {OBS_STATE: torch.randn(10) * 2} # Larger values to test normalization
action = torch.rand(5) * 2 - 1 # Range [-1, 1]
transition = create_transition(observation, action)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
# Process through preprocessor
processed = preprocessor(transition)
@@ -129,7 +118,8 @@ def test_sac_processor_normalization_modes():
assert processed[TransitionKey.ACTION].shape == (1, 5)
# Process action through postprocessor
action_transition = create_transition(action=processed[TransitionKey.ACTION])
action_transition = create_transition()
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
postprocessed = postprocessor(action_transition)
# Check that action is unnormalized (but still batched)
@@ -153,7 +143,8 @@ def test_sac_processor_cuda():
# Create CPU data
observation = {OBS_STATE: torch.randn(10)}
action = torch.randn(5)
transition = create_transition(observation, action)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
# Process through preprocessor
processed = preprocessor(transition)
@@ -163,7 +154,8 @@ def test_sac_processor_cuda():
assert processed[TransitionKey.ACTION].device.type == "cuda"
# Process through postprocessor
action_transition = create_transition(action=processed[TransitionKey.ACTION])
action_transition = create_transition()
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
postprocessed = postprocessor(action_transition)
# Check that action is back on CPU
@@ -188,7 +180,8 @@ def test_sac_processor_accelerate_scenario():
device = torch.device("cuda:0")
observation = {OBS_STATE: torch.randn(10).to(device)}
action = torch.randn(5).to(device)
transition = create_transition(observation, action)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
# Process through preprocessor
processed = preprocessor(transition)
@@ -216,7 +209,8 @@ def test_sac_processor_multi_gpu():
device = torch.device("cuda:1")
observation = {OBS_STATE: torch.randn(10).to(device)}
action = torch.randn(5).to(device)
transition = create_transition(observation, action)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
# Process through preprocessor
processed = preprocessor(transition)
@@ -254,7 +248,8 @@ def test_sac_processor_without_stats():
# Process should still work
observation = {OBS_STATE: torch.randn(10)}
action = torch.randn(5)
transition = create_transition(observation, action)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
processed = preprocessor(transition)
assert processed is not None
@@ -284,7 +279,8 @@ def test_sac_processor_save_and_load():
# Test that loaded processor works
observation = {OBS_STATE: torch.randn(10)}
action = torch.randn(5)
transition = create_transition(observation, action)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
processed = loaded_preprocessor(transition)
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10)
@@ -329,7 +325,8 @@ def test_sac_processor_mixed_precision():
# Create test data
observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)}
action = torch.randn(5, dtype=torch.float32)
transition = create_transition(observation, action)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
# Process through preprocessor
processed = preprocessor(transition)
@@ -355,7 +352,8 @@ def test_sac_processor_batch_data():
batch_size = 32
observation = {OBS_STATE: torch.randn(batch_size, 10)}
action = torch.randn(batch_size, 5)
transition = create_transition(observation, action)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
# Process through preprocessor
processed = preprocessor(transition)
@@ -378,13 +376,14 @@ def test_sac_processor_edge_cases():
)
# Test with empty observation
transition = create_transition(observation={}, action=torch.randn(5))
transition = create_transition(observation={})
transition[TransitionKey.ACTION] = torch.randn(5)
processed = preprocessor(transition)
assert processed[TransitionKey.OBSERVATION] == {}
assert processed[TransitionKey.ACTION].shape == (1, 5)
# Test with None action
transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=None)
transition = create_transition(observation={OBS_STATE: torch.randn(10)})
processed = preprocessor(transition)
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10)
# When action is None, it may still be present with None value
@@ -433,7 +432,8 @@ def test_sac_processor_bfloat16_device_float32_normalizer():
# Create test data
observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)} # Start with float32
action = torch.randn(5, dtype=torch.float32)
transition = create_transition(observation, action)
transition = create_transition(observation=observation)
transition[TransitionKey.ACTION] = action
# Process through full pipeline
processed = preprocessor(transition)