mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 11:21:27 +00:00
refactor(tests): streamline transition creation in processor tests
- Replaced custom transition creation functions with a centralized `create_transition` function imported from converters across multiple test files. - Updated test cases to utilize keyword arguments for better readability and maintainability, ensuring consistent transition creation throughout the test suite.
This commit is contained in:
@@ -20,28 +20,7 @@ import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
|
||||
):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
if reward is not None:
|
||||
transition[TransitionKey.REWARD] = reward
|
||||
if done is not None:
|
||||
transition[TransitionKey.DONE] = done
|
||||
if truncated is not None:
|
||||
transition[TransitionKey.TRUNCATED] = truncated
|
||||
if info is not None:
|
||||
transition[TransitionKey.INFO] = info
|
||||
if complementary_data is not None:
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
return transition
|
||||
from lerobot.processor.converters import create_transition
|
||||
|
||||
|
||||
def test_basic_functionality():
|
||||
@@ -147,14 +126,14 @@ def test_none_values():
|
||||
# Test with None observation
|
||||
transition = create_transition(observation=None, action=torch.randn(5))
|
||||
result = processor(transition)
|
||||
assert TransitionKey.OBSERVATION not in result
|
||||
assert result[TransitionKey.OBSERVATION] is None
|
||||
assert result[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
# Test with None action
|
||||
transition = create_transition(observation={"observation.state": torch.randn(10)}, action=None)
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu"
|
||||
assert TransitionKey.ACTION not in result
|
||||
assert result[TransitionKey.ACTION] is None
|
||||
|
||||
|
||||
def test_empty_observation():
|
||||
@@ -822,8 +801,8 @@ def test_complementary_data_none():
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Complementary data should not be in the result (same as input)
|
||||
assert TransitionKey.COMPLEMENTARY_DATA not in result
|
||||
# Complementary data should be an empty dict (standardized behavior)
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
|
||||
Reference in New Issue
Block a user