mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 03:41:25 +00:00
refactor(tests): streamline transition creation in processor tests
- Replaced custom transition creation functions with a centralized `create_transition` function imported from converters across multiple test files. - Updated test cases to utilize keyword arguments for better readability and maintainability, ensuring consistent transition creation throughout the test suite.
This commit is contained in:
@@ -33,19 +33,7 @@ from lerobot.processor import (
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
@@ -112,7 +100,8 @@ def test_act_processor_normalization():
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -122,7 +111,8 @@ def test_act_processor_normalization():
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 4)
|
||||
|
||||
# Process action through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
action_transition = create_transition()
|
||||
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is unnormalized
|
||||
@@ -146,7 +136,8 @@ def test_act_processor_cuda():
|
||||
# Create CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -156,7 +147,8 @@ def test_act_processor_cuda():
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
action_transition = create_transition()
|
||||
action_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is back on CPU
|
||||
@@ -181,7 +173,8 @@ def test_act_processor_accelerate_scenario():
|
||||
device = torch.device("cuda:0")
|
||||
observation = {OBS_STATE: torch.randn(1, 7).to(device)} # Already batched and on GPU
|
||||
action = torch.randn(1, 4).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -204,7 +197,8 @@ def test_act_processor_multi_gpu():
|
||||
device = torch.device("cuda:1")
|
||||
observation = {OBS_STATE: torch.randn(1, 7).to(device)}
|
||||
action = torch.randn(1, 4).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -227,7 +221,8 @@ def test_act_processor_without_stats():
|
||||
# Process should still work (but won't normalize without stats)
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
@@ -257,7 +252,8 @@ def test_act_processor_save_and_load():
|
||||
# Test that loaded processor works
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
|
||||
@@ -281,7 +277,8 @@ def test_act_processor_device_placement_preservation():
|
||||
# Process CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
|
||||
@@ -326,7 +323,8 @@ def test_act_processor_mixed_precision():
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)}
|
||||
action = torch.randn(4, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -351,7 +349,8 @@ def test_act_processor_batch_consistency():
|
||||
# Test single sample (unbatched)
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 1 # Batched
|
||||
@@ -359,7 +358,8 @@ def test_act_processor_batch_consistency():
|
||||
# Test already batched data
|
||||
observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8
|
||||
action_batched = torch.randn(8, 4)
|
||||
transition_batched = create_transition(observation_batched, action_batched)
|
||||
transition_batched = create_transition(observation=observation_batched)
|
||||
transition_batched[TransitionKey.ACTION] = action_batched
|
||||
|
||||
processed_batched = preprocessor(transition_batched)
|
||||
assert processed_batched[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 8
|
||||
@@ -407,7 +407,8 @@ def test_act_processor_bfloat16_device_float32_normalizer():
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)} # Start with float32
|
||||
action = torch.randn(4, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through full pipeline
|
||||
processed = preprocessor(transition)
|
||||
|
||||
Reference in New Issue
Block a user