mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 03:41:25 +00:00
feat(processor): multiple improvements to the pipeline porting (#1749)
* [Port codebase pipeline] General fixes for RL and scripts (#1748) * Refactor dataset configuration in documentation and codebase - Updated dataset configuration keys from `dataset_root` to `root` and `num_episodes` to `num_episodes_to_record` for consistency. - Adjusted replay episode handling by renaming `episode` to `replay_episode`. - Enhanced documentation - added specific processor to transform from policy actions to delta actions * Added Robot action to tensor processor Added new processor script for dealing with gym specific action processing * removed RobotAction2Tensor processor; imrpoved choosing observations in actor * nit in delta action * added missing reset functions to kinematics * Adapt teleoperate and replay to pipeline similar to record * refactor(processors): move to inheritance (#1750) * fix(teleoperator): improvements phone implementation (#1752) * fix(teleoperator): protect shared state in phone implementation * refactor(teleop): separate classes in phone * fix: solve breaking changes (#1753) * refactor(policies): multiple improvements (#1754) * refactor(processor): simpler logic in device processor (#1755) * refactor(processor): euclidean distance in delta action processor (#1757) * refactor(processor): improvements to joint observations processor migration (#1758) * refactor(processor): improvements to tokenizer migration (#1759) * refactor(processor): improvements to tokenizer migration * fix(tests): tokenizer tests regression from #1750 * fix(processors): fix float comparison and config in hil processors (#1760) * chore(teleop): remove unnecessary callbacks in KeyboardEndEffectorTeleop (#1761) * refactor(processor): improvements normalize pipeline migration (#1756) * refactor(processor): several improvements normalize processor step * refactor(processor): more improvements normalize processor * refactor(processor): more changes to normalizer * refactor(processor): take a different approach to DRY * refactor(processor): final design * chore(record): revert comment and continue deleted (#1764) * refactor(examples): pipeline phone examples (#1769) * refactor(examples): phone teleop + teleop script * refactor(examples): phone replay + replay * chore(examples): rename phone example files & folders * feat(processor): fix improvements to the pipeline porting (#1796) * refactor(processor): enhance tensor device handling in normalization process (#1795) * refactor(tests): remove unsupported device detection test for complementary data (#1797) * chore(tests): update ToBatchProcessor test (#1798) * refactor(tests): remove in-place mutation tests for actions and complementary data in batch processor * test(tests): add tests for action and task processing in batch processor * add names for android and ios phone (#1799) * use _tensor_stats in normalize processor (#1800) * fix(normalize_processor): correct device reference for tensor epsilon handling (#1801) * add point 5 add missing feature contracts (#1806) * Fix PR comments 1452 (#1807) * use key to determine image * Address rest of PR comments * use PolicyFeatures in transform_features --------- Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
This commit is contained in:
@@ -25,7 +25,6 @@ from lerobot.processor.normalize_processor import (
|
||||
UnnormalizerProcessor,
|
||||
_convert_stats_to_tensors,
|
||||
hotswap_stats,
|
||||
rename_stats,
|
||||
)
|
||||
from lerobot.processor.pipeline import IdentityProcessor, RobotProcessor, TransitionKey
|
||||
|
||||
@@ -182,7 +181,10 @@ def test_selective_normalization(observation_stats):
|
||||
features = _create_observation_features()
|
||||
norm_map = _create_observation_norm_map()
|
||||
normalizer = NormalizerProcessor(
|
||||
features=features, norm_map=norm_map, stats=observation_stats, normalize_keys={"observation.image"}
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=observation_stats,
|
||||
normalize_observation_keys={"observation.image"},
|
||||
)
|
||||
|
||||
observation = {
|
||||
@@ -243,6 +245,7 @@ def test_from_lerobot_dataset():
|
||||
def test_state_dict_save_load(observation_normalizer):
|
||||
# Save state
|
||||
state_dict = observation_normalizer.state_dict()
|
||||
print("State dict:", state_dict)
|
||||
|
||||
# Create new normalizer and load state
|
||||
features = _create_observation_features()
|
||||
@@ -464,10 +467,10 @@ def test_processor_from_lerobot_dataset(full_stats):
|
||||
norm_map = _create_full_norm_map()
|
||||
|
||||
processor = NormalizerProcessor.from_lerobot_dataset(
|
||||
mock_dataset, features, norm_map, normalize_keys={"observation.image"}
|
||||
mock_dataset, features, norm_map, normalize_observation_keys={"observation.image"}
|
||||
)
|
||||
|
||||
assert processor.normalize_keys == {"observation.image"}
|
||||
assert processor.normalize_observation_keys == {"observation.image"}
|
||||
assert "observation.image" in processor._tensor_stats
|
||||
assert "action" in processor._tensor_stats
|
||||
|
||||
@@ -476,12 +479,16 @@ def test_get_config(full_stats):
|
||||
features = _create_full_features()
|
||||
norm_map = _create_full_norm_map()
|
||||
processor = NormalizerProcessor(
|
||||
features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=full_stats,
|
||||
normalize_observation_keys={"observation.image"},
|
||||
eps=1e-6,
|
||||
)
|
||||
|
||||
config = processor.get_config()
|
||||
expected_config = {
|
||||
"normalize_keys": ["observation.image"],
|
||||
"normalize_observation_keys": ["observation.image"],
|
||||
"eps": 1e-6,
|
||||
"features": {
|
||||
"observation.image": {"type": "VISUAL", "shape": (3, 96, 96)},
|
||||
@@ -580,7 +587,11 @@ def test_serialization_roundtrip(full_stats):
|
||||
features = _create_full_features()
|
||||
norm_map = _create_full_norm_map()
|
||||
original_processor = NormalizerProcessor(
|
||||
features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=full_stats,
|
||||
normalize_observation_keys={"observation.image"},
|
||||
eps=1e-6,
|
||||
)
|
||||
|
||||
# Get config (serialization)
|
||||
@@ -591,7 +602,7 @@ def test_serialization_roundtrip(full_stats):
|
||||
features=config["features"],
|
||||
norm_map=config["norm_map"],
|
||||
stats=full_stats,
|
||||
normalize_keys=set(config["normalize_keys"]),
|
||||
normalize_observation_keys=set(config["normalize_observation_keys"]),
|
||||
eps=config["eps"],
|
||||
)
|
||||
|
||||
@@ -939,31 +950,31 @@ def test_identity_config_serialization():
|
||||
assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION])
|
||||
|
||||
|
||||
def test_unsupported_normalization_mode_error():
|
||||
"""Test that unsupported normalization modes raise appropriate errors."""
|
||||
features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))}
|
||||
# def test_unsupported_normalization_mode_error():
|
||||
# """Test that unsupported normalization modes raise appropriate errors."""
|
||||
# features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))}
|
||||
|
||||
# Create an invalid norm_map (this would never happen in practice, but tests error handling)
|
||||
from enum import Enum
|
||||
# # Create an invalid norm_map (this would never happen in practice, but tests error handling)
|
||||
# from enum import Enum
|
||||
|
||||
class InvalidMode(str, Enum):
|
||||
INVALID = "INVALID"
|
||||
# class InvalidMode(str, Enum):
|
||||
# INVALID = "INVALID"
|
||||
|
||||
# We can't actually pass an invalid enum to the processor due to type checking,
|
||||
# but we can test the error by manipulating the norm_map after creation
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
|
||||
stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}}
|
||||
# # We can't actually pass an invalid enum to the processor due to type checking,
|
||||
# # but we can test the error by manipulating the norm_map after creation
|
||||
# norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
|
||||
# stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
# normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
# Manually inject an invalid mode to test error handling
|
||||
normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE"
|
||||
# # Manually inject an invalid mode to test error handling
|
||||
# normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE"
|
||||
|
||||
observation = {"observation.state": torch.tensor([1.0, -0.5])}
|
||||
transition = create_transition(observation=observation)
|
||||
# observation = {"observation.state": torch.tensor([1.0, -0.5])}
|
||||
# transition = create_transition(observation=observation)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported normalization mode"):
|
||||
normalizer(transition)
|
||||
# with pytest.raises(ValueError, match="Unsupported normalization mode"):
|
||||
# normalizer(transition)
|
||||
|
||||
|
||||
def test_hotswap_stats_basic_functionality():
|
||||
@@ -1149,11 +1160,15 @@ def test_hotswap_stats_preserves_other_attributes():
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
normalize_keys = {"observation.image"}
|
||||
normalize_observation_keys = {"observation.image"}
|
||||
eps = 1e-6
|
||||
|
||||
normalizer = NormalizerProcessor(
|
||||
features=features, norm_map=norm_map, stats=initial_stats, normalize_keys=normalize_keys, eps=eps
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=initial_stats,
|
||||
normalize_observation_keys=normalize_observation_keys,
|
||||
eps=eps,
|
||||
)
|
||||
robot_processor = RobotProcessor(steps=[normalizer])
|
||||
|
||||
@@ -1164,7 +1179,7 @@ def test_hotswap_stats_preserves_other_attributes():
|
||||
new_normalizer = new_processor.steps[0]
|
||||
assert new_normalizer.features == features
|
||||
assert new_normalizer.norm_map == norm_map
|
||||
assert new_normalizer.normalize_keys == normalize_keys
|
||||
assert new_normalizer.normalize_observation_keys == normalize_observation_keys
|
||||
assert new_normalizer.eps == eps
|
||||
|
||||
# But stats should be updated
|
||||
@@ -1270,273 +1285,6 @@ def test_hotswap_stats_with_different_data_types():
|
||||
torch.testing.assert_close(tensor_stats["observation.image"]["max"], torch.tensor(1.0))
|
||||
|
||||
|
||||
def test_normalization_info_tracking():
|
||||
"""Test that normalization info is tracked in complementary_data."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.STATE: NormalizationMode.MIN_MAX,
|
||||
FeatureType.ACTION: NormalizationMode.IDENTITY,
|
||||
}
|
||||
|
||||
stats = {
|
||||
"observation.image": {
|
||||
"mean": np.array([0.5, 0.5, 0.5]),
|
||||
"std": np.array([0.2, 0.2, 0.2]),
|
||||
},
|
||||
"observation.state": {
|
||||
"min": np.array([0.0, -1.0]),
|
||||
"max": np.array([1.0, 1.0]),
|
||||
},
|
||||
"action": {
|
||||
"mean": np.array([0.0, 0.0]),
|
||||
"std": np.array([1.0, 1.0]),
|
||||
},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
action = torch.tensor([1.0, -0.5])
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
|
||||
# Process the transition
|
||||
normalized_transition = normalizer(transition)
|
||||
|
||||
# Check that normalization info is added
|
||||
comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
assert comp_data is not None
|
||||
assert "normalized_keys" in comp_data
|
||||
|
||||
norm_info = comp_data["normalized_keys"]
|
||||
assert norm_info["observation.image"] == "MEAN_STD"
|
||||
assert norm_info["observation.state"] == "MIN_MAX"
|
||||
assert norm_info["action"] == "IDENTITY"
|
||||
|
||||
|
||||
def test_unnormalization_info_tracking():
|
||||
"""Test that unnormalization info is tracked in complementary_data."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
stats = {
|
||||
"observation.image": {
|
||||
"mean": np.array([0.5, 0.5, 0.5]),
|
||||
"std": np.array([0.2, 0.2, 0.2]),
|
||||
},
|
||||
"action": {
|
||||
"min": np.array([-1.0, -1.0]),
|
||||
"max": np.array([1.0, 1.0]),
|
||||
},
|
||||
}
|
||||
|
||||
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
|
||||
action = torch.tensor([0.0, -0.5])
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
|
||||
# Process the transition
|
||||
unnormalized_transition = unnormalizer(transition)
|
||||
|
||||
# Check that unnormalization info is added
|
||||
comp_data = unnormalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
assert comp_data is not None
|
||||
assert "unnormalized_keys" in comp_data
|
||||
|
||||
unnorm_info = comp_data["unnormalized_keys"]
|
||||
assert unnorm_info["observation.image"] == "MEAN_STD"
|
||||
assert unnorm_info["action"] == "MIN_MAX"
|
||||
|
||||
|
||||
def test_normalization_info_with_missing_stats():
|
||||
"""Test normalization info when stats are missing for some keys."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
}
|
||||
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.STATE: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
# Only provide stats for image, not state
|
||||
stats = {
|
||||
"observation.image": {
|
||||
"mean": np.array([0.5, 0.5, 0.5]),
|
||||
"std": np.array([0.2, 0.2, 0.2]),
|
||||
},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
# Process the transition
|
||||
normalized_transition = normalizer(transition)
|
||||
|
||||
# Check that only keys with stats are in normalization info
|
||||
comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
assert comp_data is not None
|
||||
assert "normalized_keys" in comp_data
|
||||
|
||||
norm_info = comp_data["normalized_keys"]
|
||||
assert norm_info["observation.image"] == "MEAN_STD"
|
||||
# State should not be in the normalization info since it has no stats
|
||||
assert "observation.state" not in norm_info
|
||||
|
||||
|
||||
def test_normalization_info_with_selective_keys():
|
||||
"""Test normalization info with selective normalization."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
}
|
||||
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.STATE: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
stats = {
|
||||
"observation.image": {
|
||||
"mean": np.array([0.5, 0.5, 0.5]),
|
||||
"std": np.array([0.2, 0.2, 0.2]),
|
||||
},
|
||||
"observation.state": {
|
||||
"min": np.array([0.0, -1.0]),
|
||||
"max": np.array([1.0, 1.0]),
|
||||
},
|
||||
}
|
||||
|
||||
# Only normalize image
|
||||
normalizer = NormalizerProcessor(
|
||||
features=features, norm_map=norm_map, stats=stats, normalize_keys={"observation.image"}
|
||||
)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
# Process the transition
|
||||
normalized_transition = normalizer(transition)
|
||||
|
||||
# Check that only selected keys are in normalization info
|
||||
comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
assert comp_data is not None
|
||||
assert "normalized_keys" in comp_data
|
||||
|
||||
norm_info = comp_data["normalized_keys"]
|
||||
assert norm_info["observation.image"] == "MEAN_STD"
|
||||
# State should not be in the normalization info since it wasn't in normalize_keys
|
||||
assert "observation.state" not in norm_info
|
||||
|
||||
|
||||
def test_normalization_info_preserved_in_pipeline():
|
||||
"""Test that normalization info is preserved when using RobotProcessor pipeline."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
stats = {
|
||||
"observation.image": {
|
||||
"mean": np.array([0.5, 0.5, 0.5]),
|
||||
"std": np.array([0.2, 0.2, 0.2]),
|
||||
},
|
||||
"action": {
|
||||
"min": np.array([-1.0, -1.0]),
|
||||
"max": np.array([1.0, 1.0]),
|
||||
},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
# Create pipeline
|
||||
pipeline = RobotProcessor([normalizer, unnormalizer])
|
||||
|
||||
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
|
||||
action = torch.tensor([0.5, -0.5])
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
|
||||
# Process through pipeline
|
||||
result = pipeline(transition)
|
||||
|
||||
# Check that both normalization and unnormalization info are present
|
||||
comp_data = result.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
assert comp_data is not None
|
||||
assert "normalized_keys" in comp_data
|
||||
assert "unnormalized_keys" in comp_data
|
||||
|
||||
# Check normalization info
|
||||
norm_info = comp_data["normalized_keys"]
|
||||
assert norm_info["observation.image"] == "MEAN_STD"
|
||||
assert norm_info["action"] == "MIN_MAX"
|
||||
|
||||
# Check unnormalization info
|
||||
unnorm_info = comp_data["unnormalized_keys"]
|
||||
assert unnorm_info["observation.image"] == "MEAN_STD"
|
||||
assert unnorm_info["action"] == "MIN_MAX"
|
||||
|
||||
|
||||
def test_normalization_info_empty_transition():
|
||||
"""Test that no normalization info is added for empty transitions."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
stats = {
|
||||
"observation.image": {"mean": [0.5], "std": [0.2]},
|
||||
"action": {"min": [-1.0], "max": [1.0]},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
# Empty transition
|
||||
transition = create_transition()
|
||||
|
||||
# Process the transition
|
||||
normalized_transition = normalizer(transition)
|
||||
|
||||
# Check that no normalization info is added
|
||||
comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
assert comp_data is None or "normalized_keys" not in comp_data
|
||||
|
||||
|
||||
def test_hotswap_stats_functional_test():
|
||||
"""Test that hotswapped processor actually works functionally."""
|
||||
# Create test data
|
||||
@@ -1631,8 +1379,8 @@ def test_min_equals_max_maps_to_minus_one():
|
||||
assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.state"], torch.tensor([-1.0]))
|
||||
|
||||
|
||||
def test_action_normalized_despite_normalize_keys():
|
||||
"""Action normalization is independent of normalize_keys filter for observations."""
|
||||
def test_action_normalized_despite_normalize_observation_keys():
|
||||
"""Action normalization is independent of normalize_observation_keys filter for observations."""
|
||||
features = {
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (1,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
@@ -1640,7 +1388,7 @@ def test_action_normalized_despite_normalize_keys():
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.IDENTITY, FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||
stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}}
|
||||
normalizer = NormalizerProcessor(
|
||||
features=features, norm_map=norm_map, stats=stats, normalize_keys={"observation.state"}
|
||||
features=features, norm_map=norm_map, stats=stats, normalize_observation_keys={"observation.state"}
|
||||
)
|
||||
|
||||
transition = create_transition(
|
||||
@@ -1680,19 +1428,6 @@ def test_unnormalize_observations_mean_std_and_min_max():
|
||||
assert torch.allclose(out_mm, torch.tensor([1.0, 0.0])) # mid of [0,2] and [-2,2]
|
||||
|
||||
|
||||
def test_rename_stats_basic():
|
||||
orig = {
|
||||
"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])},
|
||||
"action": {"mean": np.array([0.0])},
|
||||
}
|
||||
mapping = {"observation.state": "observation.robot_state"}
|
||||
renamed = rename_stats(orig, mapping)
|
||||
assert "observation.robot_state" in renamed and "observation.state" not in renamed
|
||||
# Ensure deep copy: mutate original and verify renamed unaffected
|
||||
orig["observation.state"]["mean"][0] = 42.0
|
||||
assert renamed["observation.robot_state"]["mean"][0] != 42.0
|
||||
|
||||
|
||||
def test_unknown_observation_keys_ignored():
|
||||
features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))}
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
|
||||
@@ -1705,8 +1440,6 @@ def test_unknown_observation_keys_ignored():
|
||||
|
||||
# Unknown key should pass through unchanged and not be tracked
|
||||
assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.unknown"], obs["observation.unknown"])
|
||||
comp = out.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
assert "normalized_keys" in comp and "observation.unknown" not in comp["normalized_keys"]
|
||||
|
||||
|
||||
def test_batched_action_normalization():
|
||||
@@ -1731,7 +1464,7 @@ def test_complementary_data_preservation():
|
||||
tr = create_transition(observation={"observation.state": torch.tensor([1.0])}, complementary_data=comp)
|
||||
out = normalizer(tr)
|
||||
new_comp = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert new_comp["existing"] == 123 and "normalized_keys" in new_comp
|
||||
assert new_comp["existing"] == 123
|
||||
|
||||
|
||||
def test_roundtrip_normalize_unnormalize_non_identity():
|
||||
|
||||
Reference in New Issue
Block a user