refactor(pipeline): Transition from tuple to dictionary format for EnvTransition

- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability.
- Replaced instances of TransitionIndex with TransitionKey for accessing transition components.
- Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.
This commit is contained in:
Adil Zouitine
2025-07-21 14:54:31 +02:00
parent 14c2ece004
commit f2b79656eb
16 changed files with 828 additions and 650 deletions

View File

@@ -2,7 +2,7 @@ import torch
from lerobot.processor.pipeline import (
RobotProcessor,
TransitionIndex,
TransitionKey,
_default_batch_to_transition,
_default_transition_to_batch,
)
@@ -63,27 +63,27 @@ def test_batch_to_transition_observation_grouping():
transition = _default_batch_to_transition(batch)
# Check observation is a dict with all observation.* keys
assert isinstance(transition[TransitionIndex.OBSERVATION], dict)
assert "observation.image.top" in transition[TransitionIndex.OBSERVATION]
assert "observation.image.left" in transition[TransitionIndex.OBSERVATION]
assert "observation.state" in transition[TransitionIndex.OBSERVATION]
assert isinstance(transition[TransitionKey.OBSERVATION], dict)
assert "observation.image.top" in transition[TransitionKey.OBSERVATION]
assert "observation.image.left" in transition[TransitionKey.OBSERVATION]
assert "observation.state" in transition[TransitionKey.OBSERVATION]
# Check values are preserved
assert torch.allclose(
transition[TransitionIndex.OBSERVATION]["observation.image.top"], batch["observation.image.top"]
transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"]
)
assert torch.allclose(
transition[TransitionIndex.OBSERVATION]["observation.image.left"], batch["observation.image.left"]
transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"]
)
assert transition[TransitionIndex.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
# Check other fields
assert transition[TransitionIndex.ACTION] == "action_data"
assert transition[TransitionIndex.REWARD] == 1.5
assert transition[TransitionIndex.DONE]
assert not transition[TransitionIndex.TRUNCATED]
assert transition[TransitionIndex.INFO] == {"episode": 42}
assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {}
assert transition[TransitionKey.ACTION] == "action_data"
assert transition[TransitionKey.REWARD] == 1.5
assert transition[TransitionKey.DONE]
assert not transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {"episode": 42}
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
def test_transition_to_batch_observation_flattening():
@@ -94,15 +94,15 @@ def test_transition_to_batch_observation_flattening():
"observation.state": [1, 2, 3, 4],
}
transition = (
observation_dict, # observation
"action_data", # action
1.5, # reward
True, # done
False, # truncated
{"episode": 42}, # info
{}, # complementary_data
)
transition = {
TransitionKey.OBSERVATION: observation_dict,
TransitionKey.ACTION: "action_data",
TransitionKey.REWARD: 1.5,
TransitionKey.DONE: True,
TransitionKey.TRUNCATED: False,
TransitionKey.INFO: {"episode": 42},
TransitionKey.COMPLEMENTARY_DATA: {},
}
batch = _default_transition_to_batch(transition)
@@ -137,14 +137,14 @@ def test_no_observation_keys():
transition = _default_batch_to_transition(batch)
# Observation should be None when no observation.* keys
assert transition[TransitionIndex.OBSERVATION] is None
assert transition[TransitionKey.OBSERVATION] is None
# Check other fields
assert transition[TransitionIndex.ACTION] == "action_data"
assert transition[TransitionIndex.REWARD] == 2.0
assert not transition[TransitionIndex.DONE]
assert transition[TransitionIndex.TRUNCATED]
assert transition[TransitionIndex.INFO] == {"test": "no_obs"}
assert transition[TransitionKey.ACTION] == "action_data"
assert transition[TransitionKey.REWARD] == 2.0
assert not transition[TransitionKey.DONE]
assert transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {"test": "no_obs"}
# Round trip should work
reconstructed_batch = _default_transition_to_batch(transition)
@@ -162,15 +162,15 @@ def test_minimal_batch():
transition = _default_batch_to_transition(batch)
# Check observation
assert transition[TransitionIndex.OBSERVATION] == {"observation.state": "minimal_state"}
assert transition[TransitionIndex.ACTION] == "minimal_action"
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
assert transition[TransitionKey.ACTION] == "minimal_action"
# Check defaults
assert transition[TransitionIndex.REWARD] == 0.0
assert not transition[TransitionIndex.DONE]
assert not transition[TransitionIndex.TRUNCATED]
assert transition[TransitionIndex.INFO] == {}
assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {}
assert transition[TransitionKey.REWARD] == 0.0
assert not transition[TransitionKey.DONE]
assert not transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {}
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
# Round trip
reconstructed_batch = _default_transition_to_batch(transition)
@@ -189,13 +189,13 @@ def test_empty_batch():
transition = _default_batch_to_transition(batch)
# All fields should have defaults
assert transition[TransitionIndex.OBSERVATION] is None
assert transition[TransitionIndex.ACTION] is None
assert transition[TransitionIndex.REWARD] == 0.0
assert not transition[TransitionIndex.DONE]
assert not transition[TransitionIndex.TRUNCATED]
assert transition[TransitionIndex.INFO] == {}
assert transition[TransitionIndex.COMPLEMENTARY_DATA] == {}
assert transition[TransitionKey.OBSERVATION] is None
assert transition[TransitionKey.ACTION] is None
assert transition[TransitionKey.REWARD] == 0.0
assert not transition[TransitionKey.DONE]
assert not transition[TransitionKey.TRUNCATED]
assert transition[TransitionKey.INFO] == {}
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
# Round trip
reconstructed_batch = _default_transition_to_batch(transition)
@@ -256,33 +256,27 @@ def test_custom_converter():
# Custom converter that modifies the reward
tr = _default_batch_to_transition(batch)
# Double the reward
reward = tr[TransitionIndex.REWARD] * 2 if tr[TransitionIndex.REWARD] is not None else 0.0
return (
tr[TransitionIndex.OBSERVATION],
tr[TransitionIndex.ACTION],
reward,
tr[TransitionIndex.DONE],
tr[TransitionIndex.TRUNCATED],
tr[TransitionIndex.INFO],
tr[TransitionIndex.COMPLEMENTARY_DATA],
)
reward = tr.get(TransitionKey.REWARD, 0.0)
new_tr = tr.copy()
new_tr[TransitionKey.REWARD] = reward * 2 if reward is not None else 0.0
return new_tr
def to_batch(tr):
# Custom converter that adds a custom field
batch = _default_transition_to_batch(tr)
batch["custom_field"] = "custom_value"
return batch
proc = RobotProcessor([], to_transition=to_tr, to_output=to_batch)
batch = _dummy_batch()
out = proc(batch)
processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch)
# Check that custom modifications were applied
assert out["next.reward"] == batch["next.reward"] * 2
assert out["custom_field"] == "custom_value"
batch = {
"observation.state": torch.randn(1, 4),
"action": torch.randn(1, 2),
"next.reward": 1.0,
"next.done": False,
}
# Check that observation.* keys are still preserved
original_obs_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
output_obs_keys = {k: v for k, v in out.items() if k.startswith("observation.")}
result = processor(batch)
assert set(original_obs_keys.keys()) == set(output_obs_keys.keys())
# Check the reward was doubled by our custom converter
assert result["next.reward"] == 2.0
assert torch.allclose(result["observation.state"], batch["observation.state"])
assert torch.allclose(result["action"], batch["action"])

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock
import numpy as np
@@ -10,7 +25,22 @@ from lerobot.processor.normalize_processor import (
UnnormalizerProcessor,
_convert_stats_to_tensors,
)
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
from lerobot.processor.pipeline import RobotProcessor, TransitionKey
def create_transition(
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
def test_numpy_conversion():
@@ -120,10 +150,10 @@ def test_mean_std_normalization(observation_normalizer):
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
normalized_transition = observation_normalizer(transition)
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Check mean/std normalization
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
@@ -134,10 +164,10 @@ def test_min_max_normalization(observation_normalizer):
observation = {
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
normalized_transition = observation_normalizer(transition)
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Check min/max normalization to [-1, 1]
# For state[0]: 2 * (0.5 - 0.0) / (1.0 - 0.0) - 1 = 0.0
@@ -157,10 +187,10 @@ def test_selective_normalization(observation_stats):
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Only image should be normalized
assert torch.allclose(normalized_obs["observation.image"], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2)
@@ -176,10 +206,10 @@ def test_device_compatibility(observation_stats):
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(),
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionIndex.OBSERVATION]
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
assert normalized_obs["observation.image"].device.type == "cuda"
@@ -220,10 +250,10 @@ def test_state_dict_save_load(observation_normalizer):
# Test that it works the same
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result1 = observation_normalizer(transition)[0]
result2 = new_normalizer(transition)[0]
result1 = observation_normalizer(transition)[TransitionKey.OBSERVATION]
result2 = new_normalizer(transition)[TransitionKey.OBSERVATION]
assert torch.allclose(result1["observation.image"], result2["observation.image"])
@@ -271,10 +301,10 @@ def test_mean_std_unnormalization(action_stats_mean_std):
)
normalized_action = torch.tensor([1.0, -0.5, 2.0])
transition = (None, normalized_action, None, None, None, None, None)
transition = create_transition(action=normalized_action)
unnormalized_transition = unnormalizer(transition)
unnormalized_action = unnormalized_transition[TransitionIndex.ACTION]
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
# action * std + mean
expected = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0, 2.0 * 0.5 + 0.0])
@@ -290,10 +320,10 @@ def test_min_max_unnormalization(action_stats_min_max):
# Actions in [-1, 1]
normalized_action = torch.tensor([0.0, -1.0, 1.0])
transition = (None, normalized_action, None, None, None, None, None)
transition = create_transition(action=normalized_action)
unnormalized_transition = unnormalizer(transition)
unnormalized_action = unnormalized_transition[TransitionIndex.ACTION]
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
# Map from [-1, 1] to [min, max]
# (action + 1) / 2 * (max - min) + min
@@ -315,10 +345,10 @@ def test_numpy_action_input(action_stats_mean_std):
)
normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32)
transition = (None, normalized_action, None, None, None, None, None)
transition = create_transition(action=normalized_action)
unnormalized_transition = unnormalizer(transition)
unnormalized_action = unnormalized_transition[TransitionIndex.ACTION]
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
assert isinstance(unnormalized_action, torch.Tensor)
expected = torch.tensor([1.0, -1.0, 1.0])
@@ -332,7 +362,7 @@ def test_none_action(action_stats_mean_std):
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
)
transition = (None, None, None, None, None, None, None)
transition = create_transition()
result = unnormalizer(transition)
# Should return transition unchanged
@@ -396,23 +426,31 @@ def test_combined_normalization(normalizer_processor):
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = (observation, action, 1.0, False, False, {}, {})
transition = create_transition(
observation=observation,
action=action,
reward=1.0,
done=False,
truncated=False,
info={},
complementary_data={},
)
processed_transition = normalizer_processor(transition)
# Check normalized observations
processed_obs = processed_transition[TransitionIndex.OBSERVATION]
processed_obs = processed_transition[TransitionKey.OBSERVATION]
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
assert torch.allclose(processed_obs["observation.image"], expected_image)
# Check normalized action
processed_action = processed_transition[TransitionIndex.ACTION]
processed_action = processed_transition[TransitionKey.ACTION]
expected_action = torch.tensor([(1.0 - 0.0) / 1.0, (-0.5 - 0.0) / 2.0])
assert torch.allclose(processed_action, expected_action)
# Check other fields remain unchanged
assert processed_transition[TransitionIndex.REWARD] == 1.0
assert not processed_transition[TransitionIndex.DONE]
assert processed_transition[TransitionKey.REWARD] == 1.0
assert not processed_transition[TransitionKey.DONE]
def test_processor_from_lerobot_dataset(full_stats):
@@ -466,13 +504,21 @@ def test_integration_with_robot_processor(normalizer_processor):
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = (observation, action, 1.0, False, False, {}, {})
transition = create_transition(
observation=observation,
action=action,
reward=1.0,
done=False,
truncated=False,
info={},
complementary_data={},
)
processed_transition = robot_processor(transition)
# Verify the processing worked
assert isinstance(processed_transition[TransitionIndex.OBSERVATION], dict)
assert isinstance(processed_transition[TransitionIndex.ACTION], torch.Tensor)
assert isinstance(processed_transition[TransitionKey.OBSERVATION], dict)
assert isinstance(processed_transition[TransitionKey.ACTION], torch.Tensor)
# Edge case tests
@@ -482,7 +528,7 @@ def test_empty_observation():
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
transition = (None, None, None, None, None, None, None)
transition = create_transition()
result = normalizer(transition)
assert result == transition
@@ -493,11 +539,13 @@ def test_empty_stats():
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
observation = {"observation.image": torch.tensor([0.5])}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = normalizer(transition)
# Should return observation unchanged since no stats are available
assert torch.allclose(result[0]["observation.image"], observation["observation.image"])
assert torch.allclose(
result[TransitionKey.OBSERVATION]["observation.image"], observation["observation.image"]
)
def test_partial_stats():
@@ -507,9 +555,9 @@ def test_partial_stats():
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {"observation.image": torch.tensor([0.7])}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
processed = normalizer(transition)[TransitionIndex.OBSERVATION]
processed = normalizer(transition)[TransitionKey.OBSERVATION]
assert torch.allclose(processed["observation.image"], observation["observation.image"])
@@ -551,14 +599,25 @@ def test_serialization_roundtrip(full_stats):
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = (observation, action, 1.0, False, False, {}, {})
transition = create_transition(
observation=observation,
action=action,
reward=1.0,
done=False,
truncated=False,
info={},
complementary_data={},
)
result1 = original_processor(transition)
result2 = new_processor(transition)
# Compare results
assert torch.allclose(result1[0]["observation.image"], result2[0]["observation.image"])
assert torch.allclose(result1[1], result2[1])
assert torch.allclose(
result1[TransitionKey.OBSERVATION]["observation.image"],
result2[TransitionKey.OBSERVATION]["observation.image"],
)
assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION])
# Verify features and norm_map are correctly reconstructed
assert new_processor.features.keys() == original_processor.features.keys()

View File

@@ -23,6 +23,22 @@ from lerobot.processor import (
StateProcessor,
VanillaObservationProcessor,
)
from lerobot.processor.pipeline import TransitionKey
def create_transition(
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
def test_process_single_image():
@@ -33,10 +49,10 @@ def test_process_single_image():
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that the image was processed correctly
assert "observation.image" in processed_obs
@@ -60,10 +76,10 @@ def test_process_image_dict():
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
observation = {"pixels": {"camera1": image1, "camera2": image2}}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that both images were processed
assert "observation.images.camera1" in processed_obs
@@ -82,10 +98,10 @@ def test_process_batched_image():
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that batch dimension is preserved
assert processed_obs["observation.image"].shape == (2, 3, 64, 64)
@@ -98,7 +114,7 @@ def test_invalid_image_format():
# Test wrong channel order (channels first)
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
with pytest.raises(ValueError, match="Expected channel-last images"):
processor(transition)
@@ -111,7 +127,7 @@ def test_invalid_image_dtype():
# Test wrong dtype
image = np.random.rand(64, 64, 3).astype(np.float32)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
with pytest.raises(ValueError, match="Expected torch.uint8 images"):
processor(transition)
@@ -122,10 +138,10 @@ def test_no_pixels_in_observation():
processor = ImageProcessor()
observation = {"other_data": np.array([1, 2, 3])}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Should preserve other data unchanged
assert "other_data" in processed_obs
@@ -136,7 +152,7 @@ def test_none_observation():
"""Test processor with None observation."""
processor = ImageProcessor()
transition = (None, None, None, None, None, None, None)
transition = create_transition()
result = processor(transition)
assert result == transition
@@ -167,10 +183,10 @@ def test_process_environment_state():
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
observation = {"environment_state": env_state}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that environment_state was renamed and processed
assert "observation.environment_state" in processed_obs
@@ -188,10 +204,10 @@ def test_process_agent_pos():
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that agent_pos was renamed and processed
assert "observation.state" in processed_obs
@@ -211,10 +227,10 @@ def test_process_batched_states():
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
observation = {"environment_state": env_state, "agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that batch dimensions are preserved
assert processed_obs["observation.environment_state"].shape == (2, 2)
@@ -229,10 +245,10 @@ def test_process_both_states():
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
observation = {"environment_state": env_state, "agent_pos": agent_pos, "other_data": "keep_me"}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that both states were processed
assert "observation.environment_state" in processed_obs
@@ -251,10 +267,10 @@ def test_no_states_in_observation():
processor = StateProcessor()
observation = {"other_data": np.array([1, 2, 3])}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Should preserve data unchanged
np.testing.assert_array_equal(processed_obs, observation)
@@ -275,10 +291,10 @@ def test_complete_observation_processing():
"agent_pos": agent_pos,
"other_data": "preserve_me",
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that image was processed
assert "observation.image" in processed_obs
@@ -303,10 +319,10 @@ def test_image_only_processing():
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.image" in processed_obs
assert len(processed_obs) == 1
@@ -318,10 +334,10 @@ def test_state_only_processing():
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.state" in processed_obs
assert "agent_pos" not in processed_obs
@@ -332,10 +348,10 @@ def test_empty_observation():
processor = VanillaObservationProcessor()
observation = {}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[0]
processed_obs = result[TransitionKey.OBSERVATION]
assert processed_obs == {}
@@ -369,8 +385,8 @@ def test_equivalent_to_original_function():
original_result = preprocess_observation(observation)
# Process with new processor
transition = (observation, None, None, None, None, None, None)
processor_result = processor(transition)[0]
transition = create_transition(observation=observation)
processor_result = processor(transition)[TransitionKey.OBSERVATION]
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())
@@ -396,8 +412,8 @@ def test_equivalent_with_image_dict():
original_result = preprocess_observation(observation)
# Process with new processor
transition = (observation, None, None, None, None, None, None)
processor_result = processor(transition)[0]
transition = create_transition(observation=observation)
processor_result = processor(transition)[TransitionKey.OBSERVATION]
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())

View File

@@ -18,7 +18,7 @@ import json
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict
from typing import Any
import numpy as np
import pytest
@@ -26,6 +26,22 @@ import torch
import torch.nn as nn
from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor
from lerobot.processor.pipeline import TransitionKey
def create_transition(
observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info if info is not None else {},
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
}
@dataclass
@@ -45,14 +61,16 @@ class MockStep:
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Add a counter to the complementary_data."""
obs, action, reward, done, truncated, info, comp_data = transition
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp_data = {} if comp_data is None else dict(comp_data) # Make a copy
comp_data[f"{self.name}_counter"] = self.counter
self.counter += 1
return (obs, action, reward, done, truncated, info, comp_data)
# Create a new transition with updated complementary_data
new_transition = transition.copy()
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self) -> dict[str, Any]:
# Return all JSON-serializable attributes that should be persisted
@@ -79,12 +97,14 @@ class MockStepWithoutOptionalMethods:
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Multiply reward by multiplier."""
obs, action, reward, done, truncated, info, comp_data = transition
reward = transition.get(TransitionKey.REWARD)
if reward is not None:
reward = reward * self.multiplier
new_transition = transition.copy()
new_transition[TransitionKey.REWARD] = reward * self.multiplier
return new_transition
return (obs, action, reward, done, truncated, info, comp_data)
return transition
@dataclass
@@ -105,7 +125,7 @@ class MockStepWithTensorState:
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Update running statistics."""
obs, action, reward, done, truncated, info, comp_data = transition
reward = transition.get(TransitionKey.REWARD)
if reward is not None:
# Update running mean
@@ -143,7 +163,7 @@ def test_empty_pipeline():
"""Test pipeline with no steps."""
pipeline = RobotProcessor()
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
result = pipeline(transition)
assert result == transition
@@ -155,15 +175,15 @@ def test_single_step_pipeline():
step = MockStep("test_step")
pipeline = RobotProcessor([step])
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
result = pipeline(transition)
assert len(pipeline) == 1
assert result[6]["test_step_counter"] == 0 # complementary_data
assert result[TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 0
# Call again to test counter increment
result = pipeline(transition)
assert result[6]["test_step_counter"] == 1
assert result[TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 1
def test_multiple_steps_pipeline():
@@ -172,46 +192,46 @@ def test_multiple_steps_pipeline():
step2 = MockStep("step2")
pipeline = RobotProcessor([step1, step2])
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
result = pipeline(transition)
assert len(pipeline) == 2
assert result[6]["step1_counter"] == 0
assert result[6]["step2_counter"] == 0
assert result[TransitionKey.COMPLEMENTARY_DATA]["step1_counter"] == 0
assert result[TransitionKey.COMPLEMENTARY_DATA]["step2_counter"] == 0
def test_invalid_transition_format():
"""Test pipeline with invalid transition format."""
pipeline = RobotProcessor([MockStep()])
# Test with wrong number of elements
with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"):
pipeline((None, None, 0.0)) # Only 3 elements
# Test with wrong type (tuple instead of dict)
with pytest.raises(ValueError, match="EnvTransition must be a dictionary"):
pipeline((None, None, 0.0, False, False, {}, {})) # Tuple instead of dict
# Test with wrong type
with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"):
pipeline("not a tuple")
# Test with wrong type (string)
with pytest.raises(ValueError, match="EnvTransition must be a dictionary"):
pipeline("not a dict")
def test_step_through():
"""Test step_through method with tuple input."""
"""Test step_through method with dict input."""
step1 = MockStep("step1")
step2 = MockStep("step2")
pipeline = RobotProcessor([step1, step2])
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
results = list(pipeline.step_through(transition))
assert len(results) == 3 # Original + 2 steps
assert results[0] == transition # Original
assert "step1_counter" in results[1][6] # After step1
assert "step2_counter" in results[2][6] # After step2
assert "step1_counter" in results[1][TransitionKey.COMPLEMENTARY_DATA] # After step1
assert "step2_counter" in results[2][TransitionKey.COMPLEMENTARY_DATA] # After step2
# Ensure all results are tuples (same format as input)
# Ensure all results are dicts (same format as input)
for result in results:
assert isinstance(result, tuple)
assert len(result) == 7
assert isinstance(result, dict)
assert all(isinstance(k, TransitionKey) for k in result.keys())
def test_step_through_with_dict():
@@ -279,7 +299,7 @@ def test_hooks():
pipeline.register_before_step_hook(before_hook)
pipeline.register_after_step_hook(after_hook)
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
pipeline(transition)
assert before_calls == [0]
@@ -292,15 +312,16 @@ def test_hook_modification():
pipeline = RobotProcessor([step])
def modify_reward_hook(idx: int, transition: EnvTransition):
obs, action, reward, done, truncated, info, comp_data = transition
return (obs, action, 42.0, done, truncated, info, comp_data)
new_transition = transition.copy()
new_transition[TransitionKey.REWARD] = 42.0
return new_transition
pipeline.register_before_step_hook(modify_reward_hook)
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
result = pipeline(transition)
assert result[2] == 42.0 # reward modified by hook
assert result[TransitionKey.REWARD] == 42.0 # reward modified by hook
def test_reset():
@@ -316,7 +337,7 @@ def test_reset():
pipeline.register_reset_hook(reset_hook)
# Make some calls to increment counter
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
pipeline(transition)
pipeline(transition)
@@ -335,7 +356,7 @@ def test_profile_steps():
step2 = MockStep("step2")
pipeline = RobotProcessor([step1, step2])
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
profile_results = pipeline.profile_steps(transition, num_runs=10)
@@ -397,10 +418,10 @@ def test_step_without_optional_methods():
step = MockStepWithoutOptionalMethods(multiplier=3.0)
pipeline = RobotProcessor([step])
transition = (None, None, 2.0, False, False, {}, {})
transition = create_transition(reward=2.0)
result = pipeline(transition)
assert result[2] == 6.0 # 2.0 * 3.0
assert result[TransitionKey.REWARD] == 6.0 # 2.0 * 3.0
# Reset should work even if step doesn't implement reset
pipeline.reset()
@@ -419,7 +440,7 @@ def test_mixed_json_and_tensor_state():
# Process some transitions with rewards
for i in range(10):
transition = (None, None, float(i), False, False, {}, {})
transition = create_transition(reward=float(i))
pipeline(transition)
# Check state
@@ -466,7 +487,7 @@ class MockModuleStep(nn.Module):
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Process transition and update running mean."""
obs, action, reward, done, truncated, info, comp_data = transition
obs = transition.get(TransitionKey.OBSERVATION)
if obs is not None and isinstance(obs, torch.Tensor):
# Process observation through linear layer
@@ -509,7 +530,7 @@ def test_to_device_with_state_dict():
# Process some transitions to populate state
for i in range(10):
transition = (None, None, float(i), False, False, {}, {})
transition = create_transition(reward=float(i))
pipeline(transition)
# Check initial device (should be CPU)
@@ -551,7 +572,7 @@ def test_to_device_with_module():
# Process some data
obs = torch.randn(2, 5)
transition = (obs, None, 1.0, False, False, {}, {})
transition = create_transition(observation=obs, reward=1.0)
pipeline(transition)
# Check initial device
@@ -575,7 +596,7 @@ def test_to_device_with_module():
# Verify the module still works after transfer
obs_cuda = torch.randn(2, 5, device="cuda:0")
transition = (obs_cuda, None, 1.0, False, False, {}, {})
transition = create_transition(observation=obs_cuda, reward=1.0)
pipeline(transition) # Should not raise an error
@@ -589,7 +610,7 @@ def test_to_device_mixed_steps():
# Process some data
for i in range(5):
transition = (torch.randn(2, 10), None, float(i), False, False, {}, {})
transition = create_transition(observation=torch.randn(2, 10), reward=float(i))
pipeline(transition)
# Check initial state
@@ -630,7 +651,7 @@ def test_to_device_preserves_functionality():
# Process initial data
rewards = [1.0, 2.0, 3.0]
for r in rewards:
transition = (None, None, r, False, False, {}, {})
transition = create_transition(reward=r)
pipeline(transition)
# Check state before transfer
@@ -645,7 +666,7 @@ def test_to_device_preserves_functionality():
assert step.running_count == initial_count
# Process more data to ensure functionality
transition = (None, None, 4.0, False, False, {}, {})
transition = create_transition(reward=4.0)
_ = pipeline(transition)
assert step.running_count == 4
@@ -700,7 +721,8 @@ class MockNonModuleStepWithState:
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Process transition using tensor operations."""
obs, action, reward, done, truncated, info, comp_data = transition
obs = transition.get(TransitionKey.OBSERVATION)
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if obs is not None and isinstance(obs, torch.Tensor) and obs.numel() >= self.feature_dim:
# Perform some tensor operations
@@ -718,7 +740,12 @@ class MockNonModuleStepWithState:
comp_data[f"{self.name}_mean_output"] = output.mean().item()
comp_data[f"{self.name}_steps"] = self.step_count.item()
return (obs, action, reward, done, truncated, info, comp_data)
# Return updated transition
new_transition = transition.copy()
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
return transition
def get_config(self) -> dict[str, Any]:
return {
@@ -763,9 +790,9 @@ def test_to_device_non_module_class():
# Process some data to populate state
for i in range(3):
obs = torch.randn(2, 5)
transition = (obs, None, float(i), False, False, {}, {})
transition = create_transition(observation=obs, reward=float(i))
result = pipeline(transition)
comp_data = result[6]
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert f"{non_module_step.name}_steps" in comp_data
# Verify all tensors are on CPU initially
@@ -811,9 +838,9 @@ def test_to_device_non_module_class():
# Test that step still works on GPU
obs_gpu = torch.randn(2, 5, device="cuda")
transition = (obs_gpu, None, 1.0, False, False, {}, {})
transition = create_transition(observation=obs_gpu, reward=1.0)
result = pipeline(transition)
comp_data = result[6]
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
# Verify processing worked
assert comp_data[f"{non_module_step.name}_steps"] == 4
@@ -835,7 +862,7 @@ def test_to_device_module_vs_non_module():
# Process some data
obs = torch.randn(2, 5)
transition = (obs, None, 1.0, False, False, {}, {})
transition = create_transition(observation=obs, reward=1.0)
_ = pipeline(transition)
# Check initial devices
@@ -860,7 +887,7 @@ def test_to_device_module_vs_non_module():
# Process data on GPU
obs_gpu = torch.randn(2, 5, device="cuda")
transition = (obs_gpu, None, 2.0, False, False, {}, {})
transition = create_transition(observation=obs_gpu, reward=2.0)
_ = pipeline(transition)
# Verify both steps processed the data
@@ -889,7 +916,8 @@ class MockStepWithNonSerializableParam:
self.env = env # Non-serializable parameter (like gym.Env)
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs, action, reward, done, truncated, info, comp_data = transition
reward = transition.get(TransitionKey.REWARD)
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
# Use the env parameter if provided
if self.env is not None:
@@ -897,10 +925,14 @@ class MockStepWithNonSerializableParam:
comp_data[f"{self.name}_env_info"] = str(self.env)
# Apply multiplier to reward
new_transition = transition.copy()
if reward is not None:
reward = reward * self.multiplier
new_transition[TransitionKey.REWARD] = reward * self.multiplier
return (obs, action, reward, done, truncated, info, comp_data)
if comp_data:
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self) -> dict[str, Any]:
# Note: env is intentionally NOT included here as it's not serializable
@@ -928,13 +960,15 @@ class RegisteredMockStep:
device: str = "cpu"
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs, action, reward, done, truncated, info, comp_data = transition
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp_data = {} if comp_data is None else dict(comp_data)
comp_data["registered_step_value"] = self.value
comp_data["registered_step_device"] = self.device
return (obs, action, reward, done, truncated, info, comp_data)
new_transition = transition.copy()
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self) -> dict[str, Any]:
return {
@@ -993,18 +1027,18 @@ def test_from_pretrained_with_overrides():
assert loaded_pipeline.name == "TestOverrides"
# Test the loaded steps
transition = (None, None, 1.0, False, False, {}, {})
transition = create_transition(reward=1.0)
result = loaded_pipeline(transition)
# Check that overrides were applied
comp_data = result[6]
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert "env_step_env_info" in comp_data
assert comp_data["env_step_env_info"] == "MockEnvironment(test_env)"
assert comp_data["registered_step_value"] == 200
assert comp_data["registered_step_device"] == "cuda"
# Check that multiplier override was applied
assert result[2] == 3.0 # 1.0 * 3.0 (overridden multiplier)
assert result[TransitionKey.REWARD] == 3.0 # 1.0 * 3.0 (overridden multiplier)
def test_from_pretrained_with_partial_overrides():
@@ -1024,13 +1058,13 @@ def test_from_pretrained_with_partial_overrides():
# Both steps will get the override
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
transition = (None, None, 1.0, False, False, {}, {})
transition = create_transition(reward=1.0)
result = loaded_pipeline(transition)
# The reward should be affected by both steps, both getting the override
# First step: 1.0 * 5.0 = 5.0 (overridden)
# Second step: 5.0 * 5.0 = 25.0 (also overridden)
assert result[2] == 25.0
assert result[TransitionKey.REWARD] == 25.0
def test_from_pretrained_invalid_override_key():
@@ -1082,10 +1116,10 @@ def test_from_pretrained_registered_step_override():
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
# Test that overrides were applied
transition = (None, None, 0.0, False, False, {}, {})
transition = create_transition()
result = loaded_pipeline(transition)
comp_data = result[6]
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert comp_data["registered_step_value"] == 999
assert comp_data["registered_step_device"] == "cuda"
@@ -1110,13 +1144,13 @@ def test_from_pretrained_mixed_registered_and_unregistered():
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
# Test both steps
transition = (None, None, 2.0, False, False, {}, {})
transition = create_transition(reward=2.0)
result = loaded_pipeline(transition)
comp_data = result[6]
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert comp_data["unregistered_env_info"] == "MockEnvironment(mixed_test)"
assert comp_data["registered_step_value"] == 777
assert result[2] == 8.0 # 2.0 * 4.0
assert result[TransitionKey.REWARD] == 8.0 # 2.0 * 4.0
def test_from_pretrained_no_overrides():
@@ -1133,10 +1167,10 @@ def test_from_pretrained_no_overrides():
assert len(loaded_pipeline) == 1
# Test that the step works (env will be None)
transition = (None, None, 1.0, False, False, {}, {})
transition = create_transition(reward=1.0)
result = loaded_pipeline(transition)
assert result[2] == 3.0 # 1.0 * 3.0
assert result[TransitionKey.REWARD] == 3.0 # 1.0 * 3.0
def test_from_pretrained_empty_overrides():
@@ -1153,10 +1187,10 @@ def test_from_pretrained_empty_overrides():
assert len(loaded_pipeline) == 1
# Test that the step works normally
transition = (None, None, 1.0, False, False, {}, {})
transition = create_transition(reward=1.0)
result = loaded_pipeline(transition)
assert result[2] == 2.0
assert result[TransitionKey.REWARD] == 2.0
def test_from_pretrained_override_instantiation_error():
@@ -1185,7 +1219,7 @@ def test_from_pretrained_with_state_and_overrides():
# Process some data to create state
for i in range(10):
transition = (None, None, float(i), False, False, {}, {})
transition = create_transition(reward=float(i))
pipeline(transition)
with tempfile.TemporaryDirectory() as tmp_dir:

View File

@@ -13,14 +13,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
from pathlib import Path
import numpy as np
import torch
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionIndex
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey
def create_transition(
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
):
"""Helper to create an EnvTransition dictionary."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
def test_basic_renaming():
@@ -36,10 +50,10 @@ def test_basic_renaming():
"old_key2": np.array([3.0, 4.0]),
"unchanged_key": "keep_me",
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# Check renamed keys
assert "new_key1" in processed_obs
@@ -63,10 +77,10 @@ def test_empty_rename_map():
"key1": torch.tensor([1.0]),
"key2": "value2",
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# All keys should be unchanged
assert processed_obs.keys() == observation.keys()
@@ -78,7 +92,7 @@ def test_none_observation():
"""Test processor with None observation."""
processor = RenameProcessor(rename_map={"old": "new"})
transition = (None, None, None, None, None, None, None)
transition = create_transition()
result = processor(transition)
# Should return transition unchanged
@@ -98,10 +112,10 @@ def test_overlapping_rename():
"b": 2,
"x": 3,
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that renaming happens correctly
assert "a" not in processed_obs
@@ -124,10 +138,10 @@ def test_partial_rename():
"reward": 1.0,
"info": {"episode": 1},
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# Check renamed keys
assert "observation.proprio_state" in processed_obs
@@ -178,10 +192,12 @@ def test_integration_with_robot_processor():
"pixels": np.zeros((32, 32, 3), dtype=np.uint8),
"other_data": "preserve_me",
}
transition = (observation, None, 0.5, False, False, {}, {})
transition = create_transition(
observation=observation, reward=0.5, done=False, truncated=False, info={}, complementary_data={}
)
result = pipeline(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# Check renaming worked through pipeline
assert "observation.state" in processed_obs
@@ -191,8 +207,8 @@ def test_integration_with_robot_processor():
assert processed_obs["other_data"] == "preserve_me"
# Check other transition elements unchanged
assert result[TransitionIndex.REWARD] == 0.5
assert result[TransitionIndex.DONE] is False
assert result[TransitionKey.REWARD] == 0.5
assert result[TransitionKey.DONE] is False
def test_save_and_load_pretrained():
@@ -229,10 +245,10 @@ def test_save_and_load_pretrained():
# Test functionality after loading
observation = {"old_state": [1, 2, 3], "old_image": "image_data"}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = loaded_pipeline(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
assert "observation.state" in processed_obs
assert "observation.image" in processed_obs
@@ -306,17 +322,17 @@ def test_chained_rename_processors():
"img": "image_data",
"extra": "keep_me",
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
# Step through to see intermediate results
results = list(pipeline.step_through(transition))
# After first processor
assert "agent_position" in results[1][TransitionIndex.OBSERVATION]
assert "camera_image" in results[1][TransitionIndex.OBSERVATION]
assert "agent_position" in results[1][TransitionKey.OBSERVATION]
assert "camera_image" in results[1][TransitionKey.OBSERVATION]
# After second processor
final_obs = results[2][TransitionIndex.OBSERVATION]
final_obs = results[2][TransitionKey.OBSERVATION]
assert "observation.state" in final_obs
assert "observation.image" in final_obs
assert final_obs["extra"] == "keep_me"
@@ -343,10 +359,10 @@ def test_nested_observation_rename():
"observation.proprio": torch.randn(7),
"observation.gripper": torch.tensor([0.0]), # Not renamed
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# Check renames
assert "observation.camera.left_view" in processed_obs
@@ -378,10 +394,10 @@ def test_value_types_preserved():
"old_dict": {"nested": "value"},
"old_list": [1, 2, 3],
}
transition = (observation, None, None, None, None, None, None)
transition = create_transition(observation=observation)
result = processor(transition)
processed_obs = result[TransitionIndex.OBSERVATION]
processed_obs = result[TransitionKey.OBSERVATION]
# Check that values and types are preserved
assert torch.equal(processed_obs["new_tensor"], tensor_value)