mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 11:51:25 +00:00
refactor(tests): streamline transition creation in processor tests
- Replaced custom transition creation functions with a centralized `create_transition` function imported from converters across multiple test files. - Updated test cases to utilize keyword arguments for better readability and maintainability, ensuring consistent transition creation throughout the test suite.
This commit is contained in:
@@ -31,19 +31,7 @@ from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
TransitionKey,
|
||||
)
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
@@ -115,7 +103,8 @@ def test_classifier_processor_normalization():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1) # Dummy action/reward
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -146,7 +135,8 @@ def test_classifier_processor_cuda():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -157,7 +147,8 @@ def test_classifier_processor_cuda():
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
reward_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
reward_transition = create_transition()
|
||||
reward_transition[TransitionKey.ACTION] = processed[TransitionKey.ACTION]
|
||||
postprocessed = postprocessor(reward_transition)
|
||||
|
||||
# Check that output is back on CPU
|
||||
@@ -185,7 +176,8 @@ def test_classifier_processor_accelerate_scenario():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -212,7 +204,8 @@ def test_classifier_processor_multi_gpu():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224).to(device),
|
||||
}
|
||||
action = torch.randn(1).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -239,7 +232,8 @@ def test_classifier_processor_without_stats():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
@@ -273,7 +267,8 @@ def test_classifier_processor_save_and_load():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224),
|
||||
}
|
||||
action = torch.randn(1)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (10,)
|
||||
@@ -308,7 +303,8 @@ def test_classifier_processor_mixed_precision():
|
||||
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
|
||||
}
|
||||
action = torch.randn(1, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -338,7 +334,8 @@ def test_classifier_processor_batch_data():
|
||||
OBS_IMAGE: torch.randn(batch_size, 3, 224, 224),
|
||||
}
|
||||
action = torch.randn(batch_size, 1)
|
||||
transition = create_transition(observation, action)
|
||||
transition = create_transition(observation=observation)
|
||||
transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
@@ -363,7 +360,8 @@ def test_classifier_processor_postprocessor_identity():
|
||||
|
||||
# Create test data for postprocessor
|
||||
reward = torch.tensor([[0.8], [0.3], [0.9]]) # Batch of rewards/predictions
|
||||
transition = create_transition(action=reward)
|
||||
transition = create_transition()
|
||||
transition[TransitionKey.ACTION] = reward
|
||||
|
||||
# Process through postprocessor
|
||||
processed = postprocessor(transition)
|
||||
|
||||
Reference in New Issue
Block a user