mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 19:31:25 +00:00
refactor(tests): streamline transition creation in processor tests
- Replaced custom transition creation functions with a centralized `create_transition` function imported from converters across multiple test files. - Updated test cases to utilize keyword arguments for better readability and maintainability, ensuring consistent transition creation throughout the test suite.
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.processor import DataProcessorPipeline, TransitionKey
|
||||
@@ -20,7 +22,7 @@ def _dummy_batch():
|
||||
|
||||
def test_observation_grouping_roundtrip():
|
||||
"""Test that observation.* keys are properly grouped and ungrouped."""
|
||||
proc = DataProcessorPipeline([])
|
||||
proc = DataProcessorPipeline[dict[str, Any]]([])
|
||||
batch_in = _dummy_batch()
|
||||
batch_out = proc(batch_in)
|
||||
|
||||
@@ -45,11 +47,12 @@ def test_observation_grouping_roundtrip():
|
||||
|
||||
def test_batch_to_transition_observation_grouping():
|
||||
"""Test that batch_to_transition correctly groups observation.* keys."""
|
||||
base_batch = _dummy_batch()
|
||||
batch = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": base_batch["observation.image.left"],
|
||||
"observation.state": [1, 2, 3, 4],
|
||||
"action": "action_data",
|
||||
"action": torch.tensor([[0.1, 0.2]]),
|
||||
"next.reward": 1.5,
|
||||
"next.done": True,
|
||||
"next.truncated": False,
|
||||
@@ -74,7 +77,7 @@ def test_batch_to_transition_observation_grouping():
|
||||
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields
|
||||
assert transition[TransitionKey.ACTION] == "action_data"
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([[0.1, 0.2]]))
|
||||
assert transition[TransitionKey.REWARD] == 1.5
|
||||
assert transition[TransitionKey.DONE]
|
||||
assert not transition[TransitionKey.TRUNCATED]
|
||||
@@ -84,15 +87,16 @@ def test_batch_to_transition_observation_grouping():
|
||||
|
||||
def test_transition_to_batch_observation_flattening():
|
||||
"""Test that transition_to_batch correctly flattens observation dict."""
|
||||
base_batch = _dummy_batch()
|
||||
observation_dict = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": base_batch["observation.image.left"],
|
||||
"observation.state": [1, 2, 3, 4],
|
||||
}
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: observation_dict,
|
||||
TransitionKey.ACTION: "action_data",
|
||||
TransitionKey.ACTION: torch.tensor([[0.3, 0.4]]),
|
||||
TransitionKey.REWARD: 1.5,
|
||||
TransitionKey.DONE: True,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
@@ -113,7 +117,7 @@ def test_transition_to_batch_observation_flattening():
|
||||
assert batch["observation.state"] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields are mapped to next.* format
|
||||
assert batch["action"] == "action_data"
|
||||
assert torch.allclose(batch["action"], torch.tensor([[0.3, 0.4]]))
|
||||
assert batch["next.reward"] == 1.5
|
||||
assert batch["next.done"]
|
||||
assert not batch["next.truncated"]
|
||||
@@ -123,7 +127,7 @@ def test_transition_to_batch_observation_flattening():
|
||||
def test_no_observation_keys():
|
||||
"""Test behavior when there are no observation.* keys."""
|
||||
batch = {
|
||||
"action": "action_data",
|
||||
"action": torch.tensor([[0.7, 0.8]]),
|
||||
"next.reward": 2.0,
|
||||
"next.done": False,
|
||||
"next.truncated": True,
|
||||
@@ -136,7 +140,7 @@ def test_no_observation_keys():
|
||||
assert transition[TransitionKey.OBSERVATION] is None
|
||||
|
||||
# Check other fields
|
||||
assert transition[TransitionKey.ACTION] == "action_data"
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([[0.7, 0.8]]))
|
||||
assert transition[TransitionKey.REWARD] == 2.0
|
||||
assert not transition[TransitionKey.DONE]
|
||||
assert transition[TransitionKey.TRUNCATED]
|
||||
@@ -144,7 +148,7 @@ def test_no_observation_keys():
|
||||
|
||||
# Round trip should work
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["action"] == "action_data"
|
||||
assert torch.allclose(reconstructed_batch["action"], torch.tensor([[0.7, 0.8]]))
|
||||
assert reconstructed_batch["next.reward"] == 2.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert reconstructed_batch["next.truncated"]
|
||||
@@ -153,13 +157,13 @@ def test_no_observation_keys():
|
||||
|
||||
def test_minimal_batch():
|
||||
"""Test with minimal batch containing only observation.* and action."""
|
||||
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
|
||||
batch = {"observation.state": "minimal_state", "action": torch.tensor([[0.9]])}
|
||||
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Check observation
|
||||
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
|
||||
assert transition[TransitionKey.ACTION] == "minimal_action"
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([[0.9]]))
|
||||
|
||||
# Check defaults
|
||||
assert transition[TransitionKey.REWARD] == 0.0
|
||||
@@ -171,7 +175,7 @@ def test_minimal_batch():
|
||||
# Round trip
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["observation.state"] == "minimal_state"
|
||||
assert reconstructed_batch["action"] == "minimal_action"
|
||||
assert torch.allclose(reconstructed_batch["action"], torch.tensor([[0.9]]))
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert not reconstructed_batch["next.truncated"]
|
||||
@@ -204,9 +208,10 @@ def test_empty_batch():
|
||||
|
||||
def test_complex_nested_observation():
|
||||
"""Test with complex nested observation data."""
|
||||
base_batch = _dummy_batch()
|
||||
batch = {
|
||||
"observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
|
||||
"observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
|
||||
"observation.image.left": {"image": base_batch["observation.image.left"], "timestamp": 1234567891},
|
||||
"observation.state": torch.randn(7),
|
||||
"action": torch.randn(8),
|
||||
"next.reward": 3.14,
|
||||
|
||||
Reference in New Issue
Block a user