Add Quantile stats to LeRobotDataset (#1985)

* - Add RunningQuantileStats class for efficient histogram-based quantile computation
- Integrate quantile parameters (compute_quantiles, quantiles) into LeRobotDataset
- Support quantile computation during episode collection and aggregation
- Add comprehensive function-based test suite (24 tests) for quantile functionality
- Maintain full backward compatibility with existing stats computation
- Enable configurable quantiles (default: [0.01, 0.99]) for robust normalization

* style fixes, make quantiles computation by default to new datasets

* fix tests

* - Added DEFAULT_QUANTILES=[0.01, 0.10, 0.50, 0.90, 0.99] to be computed for each features instead of being chosen by the user
- Fortified tests.

* - add helper functions to reshape stats
- add missing test for quantiles

* - Add QUANTILE normalization mode to normalize the data with the 1st and 99th percentiles.
- Add QUANTILE10 normalization mode to normalize the data with the 10th and 90th percentiles.

* style fixes

* Added missing lisence

* Simplify compute_stats

* - added script `augment_dataset_quantile_stats.py` so that we can add quantile stats to existing v3 datasets that dont have quatniles
- modified quantile computation instead of using the edge for the value, interpolate the values in the bin
This commit is contained in:
Michel Aractingi
2025-09-22 17:57:32 +02:00
committed by GitHub
parent 5d9acf9d51
commit d691d1e4fe
7 changed files with 1689 additions and 34 deletions

View File

@@ -165,6 +165,229 @@ def test_min_max_normalization(observation_normalizer):
assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6)
def test_quantile_normalization():
"""Test QUANTILES mode using 1st-99th percentiles."""
features = {
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
norm_map = {
FeatureType.STATE: NormalizationMode.QUANTILES,
}
stats = {
"observation.state": {
"q01": np.array([0.1, -0.8]), # 1st percentile
"q99": np.array([0.9, 0.8]), # 99th percentile
},
}
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Check quantile normalization to [0, 1]
# For state[0]: (0.5 - 0.1) / (0.9 - 0.1) = 0.4 / 0.8 = 0.5
# For state[1]: (0.0 - (-0.8)) / (0.8 - (-0.8)) = 0.8 / 1.6 = 0.5
expected_state = torch.tensor([0.5, 0.5])
assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6)
def test_quantile10_normalization():
"""Test QUANTILE10 mode using 10th-90th percentiles."""
features = {
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
norm_map = {
FeatureType.STATE: NormalizationMode.QUANTILE10,
}
stats = {
"observation.state": {
"q10": np.array([0.2, -0.6]), # 10th percentile
"q90": np.array([0.8, 0.6]), # 90th percentile
},
}
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Check quantile normalization to [0, 1]
# For state[0]: (0.5 - 0.2) / (0.8 - 0.2) = 0.3 / 0.6 = 0.5
# For state[1]: (0.0 - (-0.6)) / (0.6 - (-0.6)) = 0.6 / 1.2 = 0.5
expected_state = torch.tensor([0.5, 0.5])
assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6)
def test_quantile_unnormalization():
"""Test that quantile normalization can be reversed properly."""
features = {
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.ACTION: NormalizationMode.QUANTILES,
}
stats = {
"action": {
"q01": np.array([0.1, -0.8]),
"q99": np.array([0.9, 0.8]),
},
}
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
# Test round-trip normalization
original_action = torch.tensor([0.5, 0.0])
transition = create_transition(action=original_action)
# Normalize then unnormalize
normalized = normalizer(transition)
unnormalized = unnormalizer(normalized)
# Should recover original values
recovered_action = unnormalized[TransitionKey.ACTION]
assert torch.allclose(recovered_action, original_action, atol=1e-6)
def test_quantile_division_by_zero():
"""Test quantile normalization handles edge case where q01 == q99."""
features = {
"observation.state": PolicyFeature(FeatureType.STATE, (1,)),
}
norm_map = {
FeatureType.STATE: NormalizationMode.QUANTILES,
}
stats = {
"observation.state": {
"q01": np.array([0.5]), # Same value
"q99": np.array([0.5]), # Same value -> division by zero case
},
}
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.state": torch.tensor([0.5]),
}
transition = create_transition(observation=observation)
# Should not crash and should handle gracefully
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# When quantiles are identical, should normalize to 0 (due to epsilon handling)
assert torch.isfinite(normalized_obs["observation.state"]).all()
def test_quantile_partial_stats():
"""Test that quantile normalization handles missing quantile stats gracefully."""
features = {
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
norm_map = {
FeatureType.STATE: NormalizationMode.QUANTILES,
}
# Missing q99 - should pass through unchanged
stats_partial = {
"observation.state": {
"q01": np.array([0.1, -0.8]), # Only q01, missing q99
},
}
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats_partial)
observation = {
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Should pass through unchanged when stats are incomplete
assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"])
def test_quantile_mixed_with_other_modes():
"""Test quantile normalization mixed with other normalization modes."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD, # Standard normalization
FeatureType.STATE: NormalizationMode.QUANTILES, # Quantile normalization
FeatureType.ACTION: NormalizationMode.QUANTILE10, # Different quantile mode
}
stats = {
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
"observation.state": {"q01": [0.1, -0.8], "q99": [0.9, 0.8]},
"action": {"q10": [0.2, -0.6], "q90": [0.8, 0.6]},
}
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]), # Should use QUANTILES
}
action = torch.tensor([0.5, 0.0]) # Should use QUANTILE10
transition = create_transition(observation=observation, action=action)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
normalized_action = normalized_transition[TransitionKey.ACTION]
# Image should be mean/std normalized: (0.7 - 0.5) / 0.2 = 1.0, etc.
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
assert torch.allclose(normalized_obs["observation.image"], expected_image)
# State should be quantile normalized: (0.5 - 0.1) / (0.9 - 0.1) = 0.5, etc.
expected_state = torch.tensor([0.5, 0.5])
assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6)
# Action should be quantile10 normalized: (0.5 - 0.2) / (0.8 - 0.2) = 0.5, etc.
expected_action = torch.tensor([0.5, 0.5])
assert torch.allclose(normalized_action, expected_action, atol=1e-6)
def test_quantile_with_missing_stats():
"""Test that quantile normalization handles completely missing stats gracefully."""
features = {
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
norm_map = {
FeatureType.STATE: NormalizationMode.QUANTILES,
}
stats = {} # No stats provided
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Should pass through unchanged when no stats available
assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"])
def test_selective_normalization(observation_stats):
features = _create_observation_features()
norm_map = _create_observation_norm_map()