mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
Add Quantile stats to LeRobotDataset (#1985)
* - Add RunningQuantileStats class for efficient histogram-based quantile computation - Integrate quantile parameters (compute_quantiles, quantiles) into LeRobotDataset - Support quantile computation during episode collection and aggregation - Add comprehensive function-based test suite (24 tests) for quantile functionality - Maintain full backward compatibility with existing stats computation - Enable configurable quantiles (default: [0.01, 0.99]) for robust normalization * style fixes, make quantiles computation by default to new datasets * fix tests * - Added DEFAULT_QUANTILES=[0.01, 0.10, 0.50, 0.90, 0.99] to be computed for each features instead of being chosen by the user - Fortified tests. * - add helper functions to reshape stats - add missing test for quantiles * - Add QUANTILE normalization mode to normalize the data with the 1st and 99th percentiles. - Add QUANTILE10 normalization mode to normalize the data with the 10th and 90th percentiles. * style fixes * Added missing lisence * Simplify compute_stats * - added script `augment_dataset_quantile_stats.py` so that we can add quantile stats to existing v3 datasets that dont have quatniles - modified quantile computation instead of using the edge for the value, interpolate the values in the bin
This commit is contained in:
@@ -36,6 +36,8 @@ class NormalizationMode(str, Enum):
|
||||
MIN_MAX = "MIN_MAX"
|
||||
MEAN_STD = "MEAN_STD"
|
||||
IDENTITY = "IDENTITY"
|
||||
QUANTILES = "QUANTILES"
|
||||
QUANTILE10 = "QUANTILE10"
|
||||
|
||||
|
||||
class DictLike(Protocol):
|
||||
|
||||
@@ -17,6 +17,171 @@ import numpy as np
|
||||
|
||||
from lerobot.datasets.utils import load_image_as_numpy
|
||||
|
||||
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
|
||||
|
||||
|
||||
class RunningQuantileStats:
|
||||
"""Compute running statistics including quantiles for a batch of vectors."""
|
||||
|
||||
def __init__(self, quantile_list: list[float] | None = None, num_quantile_bins: int = 5000):
|
||||
self._count = 0
|
||||
self._mean = None
|
||||
self._mean_of_squares = None
|
||||
self._min = None
|
||||
self._max = None
|
||||
self._histograms = None
|
||||
self._bin_edges = None
|
||||
self._num_quantile_bins = num_quantile_bins
|
||||
|
||||
self._quantile_list = quantile_list
|
||||
if self._quantile_list is None:
|
||||
self._quantile_list = DEFAULT_QUANTILES
|
||||
self._quantile_keys = [f"q{int(q * 100):02d}" for q in self._quantile_list]
|
||||
|
||||
def update(self, batch: np.ndarray) -> None:
|
||||
"""Update the running statistics with a batch of vectors.
|
||||
|
||||
Args:
|
||||
batch: An array where all dimensions except the last are batch dimensions.
|
||||
"""
|
||||
batch = batch.reshape(-1, batch.shape[-1])
|
||||
num_elements, vector_length = batch.shape
|
||||
|
||||
if self._count == 0:
|
||||
self._mean = np.mean(batch, axis=0)
|
||||
self._mean_of_squares = np.mean(batch**2, axis=0)
|
||||
self._min = np.min(batch, axis=0)
|
||||
self._max = np.max(batch, axis=0)
|
||||
self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)]
|
||||
self._bin_edges = [
|
||||
np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1)
|
||||
for i in range(vector_length)
|
||||
]
|
||||
else:
|
||||
if vector_length != self._mean.size:
|
||||
raise ValueError("The length of new vectors does not match the initialized vector length.")
|
||||
|
||||
new_max = np.max(batch, axis=0)
|
||||
new_min = np.min(batch, axis=0)
|
||||
max_changed = np.any(new_max > self._max)
|
||||
min_changed = np.any(new_min < self._min)
|
||||
self._max = np.maximum(self._max, new_max)
|
||||
self._min = np.minimum(self._min, new_min)
|
||||
|
||||
if max_changed or min_changed:
|
||||
self._adjust_histograms()
|
||||
|
||||
self._count += num_elements
|
||||
|
||||
batch_mean = np.mean(batch, axis=0)
|
||||
batch_mean_of_squares = np.mean(batch**2, axis=0)
|
||||
|
||||
# Update running mean and mean of squares
|
||||
self._mean += (batch_mean - self._mean) * (num_elements / self._count)
|
||||
self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * (
|
||||
num_elements / self._count
|
||||
)
|
||||
|
||||
self._update_histograms(batch)
|
||||
|
||||
def get_statistics(self) -> dict[str, np.ndarray]:
|
||||
"""Compute and return the statistics of the vectors processed so far.
|
||||
|
||||
Args:
|
||||
quantiles: List of quantiles to compute (e.g., [0.01, 0.10, 0.50, 0.90, 0.99]). If None, no quantiles computed.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the computed statistics.
|
||||
"""
|
||||
if self._count < 2:
|
||||
raise ValueError("Cannot compute statistics for less than 2 vectors.")
|
||||
|
||||
variance = self._mean_of_squares - self._mean**2
|
||||
stddev = np.sqrt(np.maximum(0, variance))
|
||||
|
||||
stats = {
|
||||
"min": self._min.copy(),
|
||||
"max": self._max.copy(),
|
||||
"mean": self._mean.copy(),
|
||||
"std": stddev,
|
||||
"count": np.array([self._count]),
|
||||
}
|
||||
|
||||
quantile_results = self._compute_quantiles()
|
||||
for i, q in enumerate(self._quantile_keys):
|
||||
stats[q] = quantile_results[i]
|
||||
|
||||
return stats
|
||||
|
||||
def _adjust_histograms(self):
|
||||
"""Adjust histograms when min or max changes."""
|
||||
for i in range(len(self._histograms)):
|
||||
old_edges = self._bin_edges[i]
|
||||
old_hist = self._histograms[i]
|
||||
|
||||
# Create new edges with small padding to ensure range coverage
|
||||
padding = (self._max[i] - self._min[i]) * 1e-10
|
||||
new_edges = np.linspace(
|
||||
self._min[i] - padding, self._max[i] + padding, self._num_quantile_bins + 1
|
||||
)
|
||||
|
||||
# Redistribute existing histogram counts to new bins
|
||||
# We need to map each old bin center to the new bins
|
||||
old_centers = (old_edges[:-1] + old_edges[1:]) / 2
|
||||
new_hist = np.zeros(self._num_quantile_bins)
|
||||
|
||||
for old_center, count in zip(old_centers, old_hist, strict=False):
|
||||
if count > 0:
|
||||
# Find which new bin this old center belongs to
|
||||
bin_idx = np.searchsorted(new_edges, old_center) - 1
|
||||
bin_idx = max(0, min(bin_idx, self._num_quantile_bins - 1))
|
||||
new_hist[bin_idx] += count
|
||||
|
||||
self._histograms[i] = new_hist
|
||||
self._bin_edges[i] = new_edges
|
||||
|
||||
def _update_histograms(self, batch: np.ndarray) -> None:
|
||||
"""Update histograms with new vectors."""
|
||||
for i in range(batch.shape[1]):
|
||||
hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i])
|
||||
self._histograms[i] += hist
|
||||
|
||||
def _compute_quantiles(self) -> list[np.ndarray]:
|
||||
"""Compute quantiles based on histograms."""
|
||||
results = []
|
||||
for q in self._quantile_list:
|
||||
target_count = q * self._count
|
||||
q_values = []
|
||||
|
||||
for hist, edges in zip(self._histograms, self._bin_edges, strict=True):
|
||||
q_value = self._compute_single_quantile(hist, edges, target_count)
|
||||
q_values.append(q_value)
|
||||
|
||||
results.append(np.array(q_values))
|
||||
return results
|
||||
|
||||
def _compute_single_quantile(self, hist: np.ndarray, edges: np.ndarray, target_count: float) -> float:
|
||||
"""Compute a single quantile value from histogram and bin edges."""
|
||||
cumsum = np.cumsum(hist)
|
||||
idx = np.searchsorted(cumsum, target_count)
|
||||
|
||||
if idx == 0:
|
||||
return edges[0]
|
||||
if idx >= len(cumsum):
|
||||
return edges[-1]
|
||||
|
||||
# If not edge case, interpolate within the bin
|
||||
count_before = cumsum[idx - 1]
|
||||
count_in_bin = cumsum[idx] - count_before
|
||||
|
||||
# If no samples in this bin, use the bin edge
|
||||
if count_in_bin == 0:
|
||||
return edges[idx]
|
||||
|
||||
# Linear interpolation within the bin
|
||||
fraction = (target_count - count_before) / count_in_bin
|
||||
return edges[idx] + fraction * (edges[idx + 1] - edges[idx])
|
||||
|
||||
|
||||
def estimate_num_samples(
|
||||
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||||
@@ -72,33 +237,296 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||
return images
|
||||
|
||||
|
||||
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
|
||||
return {
|
||||
"min": np.min(array, axis=axis, keepdims=keepdims),
|
||||
"max": np.max(array, axis=axis, keepdims=keepdims),
|
||||
"mean": np.mean(array, axis=axis, keepdims=keepdims),
|
||||
"std": np.std(array, axis=axis, keepdims=keepdims),
|
||||
"count": np.array([len(array)]),
|
||||
def _reshape_stats_by_axis(
|
||||
stats: dict[str, np.ndarray],
|
||||
axis: int | tuple[int, ...] | None,
|
||||
keepdims: bool,
|
||||
original_shape: tuple[int, ...],
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Reshape all statistics to match NumPy's output conventions.
|
||||
|
||||
Applies consistent reshaping to all statistics (except 'count') based on the
|
||||
axis and keepdims parameters. This ensures statistics have the correct shape
|
||||
for broadcasting with the original data.
|
||||
|
||||
Args:
|
||||
stats: Dictionary of computed statistics
|
||||
axis: Axis or axes along which statistics were computed
|
||||
keepdims: Whether to keep reduced dimensions as size-1 dimensions
|
||||
original_shape: Shape of the original array
|
||||
|
||||
Returns:
|
||||
Dictionary with reshaped statistics
|
||||
|
||||
Note:
|
||||
The 'count' statistic is never reshaped as it represents metadata
|
||||
rather than per-feature statistics.
|
||||
"""
|
||||
if axis == (1,) and not keepdims:
|
||||
return stats
|
||||
|
||||
result = {}
|
||||
for key, value in stats.items():
|
||||
if key == "count":
|
||||
result[key] = value
|
||||
else:
|
||||
result[key] = _reshape_single_stat(value, axis, keepdims, original_shape)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _reshape_for_image_stats(value: np.ndarray, keepdims: bool) -> np.ndarray:
|
||||
"""Reshape statistics for image data (axis=(0,2,3))."""
|
||||
if keepdims and value.ndim == 1:
|
||||
return value.reshape(1, -1, 1, 1)
|
||||
return value
|
||||
|
||||
|
||||
def _reshape_for_vector_stats(
|
||||
value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...]
|
||||
) -> np.ndarray:
|
||||
"""Reshape statistics for vector data (axis=0 or axis=(0,))."""
|
||||
if not keepdims:
|
||||
return value
|
||||
|
||||
if len(original_shape) == 1 and value.ndim > 0:
|
||||
return value.reshape(1)
|
||||
elif len(original_shape) >= 2 and value.ndim == 1:
|
||||
return value.reshape(1, -1)
|
||||
return value
|
||||
|
||||
|
||||
def _reshape_for_feature_stats(value: np.ndarray, keepdims: bool) -> np.ndarray:
|
||||
"""Reshape statistics for feature-wise computation (axis=(1,))."""
|
||||
if not keepdims:
|
||||
return value
|
||||
|
||||
if value.ndim == 0:
|
||||
return value.reshape(1, 1)
|
||||
elif value.ndim == 1:
|
||||
return value.reshape(-1, 1)
|
||||
return value
|
||||
|
||||
|
||||
def _reshape_for_global_stats(
|
||||
value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...]
|
||||
) -> np.ndarray | float:
|
||||
"""Reshape statistics for global reduction (axis=None)."""
|
||||
if keepdims:
|
||||
target_shape = tuple(1 for _ in original_shape)
|
||||
return value.reshape(target_shape)
|
||||
elif not keepdims and value.ndim > 0 and value.size == 1:
|
||||
return value.item()
|
||||
return value
|
||||
|
||||
|
||||
def _reshape_single_stat(
|
||||
value: np.ndarray, axis: int | tuple[int, ...] | None, keepdims: bool, original_shape: tuple[int, ...]
|
||||
) -> np.ndarray | float:
|
||||
"""Apply appropriate reshaping to a single statistic array.
|
||||
|
||||
This function transforms statistic arrays to match expected output shapes
|
||||
based on the axis configuration and keepdims parameter.
|
||||
|
||||
Args:
|
||||
value: The statistic array to reshape
|
||||
axis: Axis or axes that were reduced during computation
|
||||
keepdims: Whether to maintain reduced dimensions as size-1 dimensions
|
||||
original_shape: Shape of the original data before reduction
|
||||
|
||||
Returns:
|
||||
Reshaped array following NumPy broadcasting conventions
|
||||
|
||||
"""
|
||||
if axis == (0, 2, 3):
|
||||
return _reshape_for_image_stats(value, keepdims)
|
||||
|
||||
if axis in [0, (0,)]:
|
||||
return _reshape_for_vector_stats(value, keepdims, original_shape)
|
||||
|
||||
if axis == (1,):
|
||||
return _reshape_for_feature_stats(value, keepdims)
|
||||
|
||||
if axis is None:
|
||||
return _reshape_for_global_stats(value, keepdims, original_shape)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _prepare_array_for_stats(array: np.ndarray, axis: int | tuple[int, ...] | None) -> tuple[np.ndarray, int]:
|
||||
"""Prepare array for statistics computation by reshaping according to axis.
|
||||
|
||||
Args:
|
||||
array: Input data array
|
||||
axis: Axis or axes along which to compute statistics
|
||||
|
||||
Returns:
|
||||
Tuple of (reshaped_array, sample_count)
|
||||
"""
|
||||
if axis == (0, 2, 3): # Image data
|
||||
batch_size, channels, height, width = array.shape
|
||||
reshaped = array.transpose(0, 2, 3, 1).reshape(-1, channels)
|
||||
return reshaped, batch_size
|
||||
|
||||
if axis == 0 or axis == (0,): # Vector data
|
||||
if array.ndim == 1:
|
||||
reshaped = array.reshape(-1, 1)
|
||||
else:
|
||||
reshaped = array
|
||||
return reshaped, array.shape[0]
|
||||
|
||||
if axis == (1,): # Feature-wise statistics
|
||||
return array.T, array.shape[1]
|
||||
|
||||
if axis is None: # Global statistics
|
||||
reshaped = array.reshape(-1, 1)
|
||||
# For backward compatibility, count represents the first dimension size
|
||||
return reshaped, array.shape[0] if array.ndim > 0 else 1
|
||||
|
||||
raise ValueError(f"Unsupported axis configuration: {axis}")
|
||||
|
||||
|
||||
def _compute_basic_stats(
|
||||
array: np.ndarray, sample_count: int, quantile_list: list[float] | None = None
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Compute basic statistics for arrays with insufficient samples for quantiles.
|
||||
|
||||
Args:
|
||||
array: Reshaped array ready for statistics computation
|
||||
sample_count: Number of samples represented in the data
|
||||
|
||||
Returns:
|
||||
Dictionary with basic statistics and quantiles set to mean values
|
||||
"""
|
||||
if quantile_list is None:
|
||||
quantile_list = DEFAULT_QUANTILES
|
||||
quantile_list_keys = [f"q{int(q * 100):02d}" for q in quantile_list]
|
||||
|
||||
stats = {
|
||||
"min": np.min(array, axis=0),
|
||||
"max": np.max(array, axis=0),
|
||||
"mean": np.mean(array, axis=0),
|
||||
"std": np.std(array, axis=0),
|
||||
"count": np.array([sample_count]),
|
||||
}
|
||||
|
||||
# For single-element arrays with shape (1,1), convert to scalar arrays
|
||||
if array.shape == (1, 1):
|
||||
for key in stats:
|
||||
if key != "count" and stats[key].size == 1:
|
||||
stats[key] = np.array(stats[key].item())
|
||||
|
||||
for q in quantile_list_keys:
|
||||
stats[q] = stats["mean"].copy()
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def get_feature_stats(
|
||||
array: np.ndarray,
|
||||
axis: int | tuple[int, ...] | None,
|
||||
keepdims: bool,
|
||||
quantile_list: list[float] | None = None,
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Compute comprehensive statistics for array features along specified axes.
|
||||
|
||||
This function calculates min, max, mean, std, and quantiles (1%, 10%, 50%, 90%, 99%)
|
||||
for the input array along the specified axes. It handles different data layouts:
|
||||
- Image data: axis=(0,2,3) computes per-channel statistics
|
||||
- Vector data: axis=0 computes per-feature statistics
|
||||
- Feature-wise: axis=1 computes statistics across features
|
||||
- Global: axis=None computes statistics over entire array
|
||||
|
||||
Args:
|
||||
array: Input data array with shape appropriate for the specified axis
|
||||
axis: Axis or axes along which to compute statistics
|
||||
- (0, 2, 3): For image data (batch, channels, height, width)
|
||||
- 0 or (0,): For vector/tabular data (samples, features)
|
||||
- (1,): For computing across features
|
||||
- None: For global statistics over entire array
|
||||
keepdims: If True, reduced axes are kept as dimensions with size 1
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- 'min': Minimum values
|
||||
- 'max': Maximum values
|
||||
- 'mean': Mean values
|
||||
- 'std': Standard deviation
|
||||
- 'count': Number of samples (always shape (1,))
|
||||
- 'q01', 'q10', 'q50', 'q90', 'q99': Quantile values
|
||||
|
||||
"""
|
||||
if quantile_list is None:
|
||||
quantile_list = DEFAULT_QUANTILES
|
||||
|
||||
original_shape = array.shape
|
||||
reshaped, sample_count = _prepare_array_for_stats(array, axis)
|
||||
|
||||
if reshaped.shape[0] < 2:
|
||||
stats = _compute_basic_stats(reshaped, sample_count, quantile_list)
|
||||
else:
|
||||
running_stats = RunningQuantileStats()
|
||||
running_stats.update(reshaped)
|
||||
stats = running_stats.get_statistics()
|
||||
stats["count"] = np.array([sample_count])
|
||||
|
||||
# For axis=None, the stats are computed as 1D arrays but should be 0-dimensional arrays
|
||||
if axis is None and reshaped.shape[1] == 1:
|
||||
for key in stats:
|
||||
if key != "count" and stats[key].size == 1:
|
||||
stats[key] = np.array(stats[key].item())
|
||||
|
||||
stats = _reshape_stats_by_axis(stats, axis, keepdims, original_shape)
|
||||
return stats
|
||||
|
||||
|
||||
def compute_episode_stats(
|
||||
episode_data: dict[str, list[str] | np.ndarray],
|
||||
features: dict,
|
||||
quantile_list: list[float] | None = None,
|
||||
) -> dict:
|
||||
"""Compute comprehensive statistics for all features in an episode.
|
||||
|
||||
Processes different data types appropriately:
|
||||
- Images/videos: Samples from paths, computes per-channel stats, normalizes to [0,1]
|
||||
- Numerical arrays: Computes per-feature statistics
|
||||
- Strings: Skipped (no statistics computed)
|
||||
|
||||
Args:
|
||||
episode_data: Dictionary mapping feature names to data
|
||||
- For images/videos: list of file paths
|
||||
- For numerical data: numpy arrays
|
||||
features: Dictionary describing each feature's dtype and shape
|
||||
|
||||
Returns:
|
||||
Dictionary mapping feature names to their statistics dictionaries.
|
||||
Each statistics dictionary contains min, max, mean, std, count, and quantiles.
|
||||
|
||||
Note:
|
||||
Image statistics are normalized to [0,1] range and have shape (3,1,1) for
|
||||
per-channel values when dtype is 'image' or 'video'.
|
||||
"""
|
||||
if quantile_list is None:
|
||||
quantile_list = DEFAULT_QUANTILES
|
||||
|
||||
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
continue # HACK: we should receive np.arrays of strings
|
||||
elif features[key]["dtype"] in ["image", "video"]:
|
||||
ep_ft_array = sample_images(data) # data is a list of image paths
|
||||
axes_to_reduce = (0, 2, 3) # keep channel dim
|
||||
continue
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_ft_array = sample_images(data)
|
||||
axes_to_reduce = (0, 2, 3)
|
||||
keepdims = True
|
||||
else:
|
||||
ep_ft_array = data # data is already a np.ndarray
|
||||
axes_to_reduce = 0 # compute stats over the first axis
|
||||
keepdims = data.ndim == 1 # keep as np.array
|
||||
ep_ft_array = data
|
||||
axes_to_reduce = 0
|
||||
keepdims = data.ndim == 1
|
||||
|
||||
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
||||
ep_stats[key] = get_feature_stats(
|
||||
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims, quantile_list=quantile_list
|
||||
)
|
||||
|
||||
# finally, we normalize and remove batch dim for images
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
@@ -107,20 +535,37 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
|
||||
return ep_stats
|
||||
|
||||
|
||||
def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None:
|
||||
"""Validate a single statistic value."""
|
||||
if not isinstance(value, np.ndarray):
|
||||
raise ValueError(
|
||||
f"Stats must be composed of numpy array, but key '{key}' of feature '{feature_key}' "
|
||||
f"is of type '{type(value)}' instead."
|
||||
)
|
||||
|
||||
if value.ndim == 0:
|
||||
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||||
|
||||
if key == "count" and value.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.")
|
||||
|
||||
if "image" in feature_key and key != "count" and value.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.")
|
||||
|
||||
|
||||
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
for i in range(len(stats_list)):
|
||||
for fkey in stats_list[i]:
|
||||
for k, v in stats_list[i][fkey].items():
|
||||
if not isinstance(v, np.ndarray):
|
||||
raise ValueError(
|
||||
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
|
||||
)
|
||||
if v.ndim == 0:
|
||||
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||||
if k == "count" and v.shape != (1,):
|
||||
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
|
||||
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
|
||||
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
|
||||
"""Validate that all statistics have correct types and shapes.
|
||||
|
||||
Args:
|
||||
stats_list: List of statistics dictionaries to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If any statistic has incorrect type or shape
|
||||
"""
|
||||
for stats in stats_list:
|
||||
for feature_key, feature_stats in stats.items():
|
||||
for stat_key, stat_value in feature_stats.items():
|
||||
_validate_stat_value(stat_value, stat_key, feature_key)
|
||||
|
||||
|
||||
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
@@ -143,7 +588,7 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
|
||||
weighted_variances = (variances + delta_means**2) * counts
|
||||
total_variance = weighted_variances.sum(axis=0) / total_count
|
||||
|
||||
return {
|
||||
aggregated = {
|
||||
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
|
||||
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
|
||||
"mean": total_mean,
|
||||
@@ -151,6 +596,17 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
|
||||
"count": total_count,
|
||||
}
|
||||
|
||||
if stats_ft_list:
|
||||
quantile_keys = [k for k in stats_ft_list[0].keys() if k.startswith("q") and k[1:].isdigit()]
|
||||
|
||||
for q_key in quantile_keys:
|
||||
if all(q_key in s for s in stats_ft_list):
|
||||
quantile_values = np.stack([s[q_key] for s in stats_ft_list])
|
||||
weighted_quantiles = quantile_values * counts
|
||||
aggregated[q_key] = weighted_quantiles.sum(axis=0) / total_count
|
||||
|
||||
return aggregated
|
||||
|
||||
|
||||
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
||||
|
||||
205
src/lerobot/datasets/v30/augment_dataset_quantile_stats.py
Normal file
205
src/lerobot/datasets/v30/augment_dataset_quantile_stats.py
Normal file
@@ -0,0 +1,205 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script augments existing LeRobot datasets with quantile statistics.
|
||||
|
||||
Most datasets created before the quantile feature was added do not contain
|
||||
quantile statistics (q01, q10, q50, q90, q99) in their metadata. This script:
|
||||
|
||||
1. Loads an existing LeRobot dataset in v3.0 format
|
||||
2. Checks if it already contains quantile statistics
|
||||
3. If missing, computes quantile statistics for all features
|
||||
4. Updates the dataset metadata with the new quantile statistics
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
||||
--repo-id=lerobot/pusht \
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, compute_episode_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import write_stats
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
def has_quantile_stats(stats: dict[str, dict] | None, quantile_list_keys: list[str] | None = None) -> bool:
|
||||
"""Check if dataset statistics already contain quantile information.
|
||||
|
||||
Args:
|
||||
stats: Dataset statistics dictionary
|
||||
|
||||
Returns:
|
||||
True if quantile statistics are present, False otherwise
|
||||
"""
|
||||
if quantile_list_keys is None:
|
||||
quantile_list_keys = [f"q{int(q * 100):02d}" for q in DEFAULT_QUANTILES]
|
||||
|
||||
if stats is None:
|
||||
return False
|
||||
|
||||
for feature_stats in stats.values():
|
||||
if any(q_key in feature_stats for q_key in quantile_list_keys):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def load_episode_data(dataset: LeRobotDataset, episode_idx: int) -> dict:
|
||||
"""Load episode data by accessing the underlying HuggingFace dataset.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset
|
||||
episode_idx: Index of the episode to load
|
||||
|
||||
Returns:
|
||||
Dictionary containing episode data for each feature
|
||||
"""
|
||||
|
||||
episode_info = dataset.meta.episodes[episode_idx]
|
||||
episode_length = episode_info["length"]
|
||||
|
||||
start_idx = sum(dataset.meta.episodes[i]["length"] for i in range(episode_idx))
|
||||
end_idx = start_idx + episode_length
|
||||
|
||||
episode_data = {}
|
||||
|
||||
episode_slice = dataset.hf_dataset.select(range(start_idx, end_idx))
|
||||
|
||||
for key, feature_info in dataset.features.items():
|
||||
if feature_info["dtype"] == "string":
|
||||
continue
|
||||
|
||||
if feature_info["dtype"] in ["image", "video"]:
|
||||
image_paths = []
|
||||
for row in episode_slice:
|
||||
if key in row:
|
||||
relative_path = row[key]
|
||||
if isinstance(relative_path, str):
|
||||
absolute_path = str(dataset.meta.root / relative_path)
|
||||
image_paths.append(absolute_path)
|
||||
|
||||
if image_paths:
|
||||
episode_data[key] = image_paths
|
||||
else:
|
||||
arrays = []
|
||||
for row in episode_slice:
|
||||
if key in row:
|
||||
arrays.append(np.array(row[key]))
|
||||
|
||||
if arrays:
|
||||
episode_data[key] = np.stack(arrays)
|
||||
|
||||
return episode_data
|
||||
|
||||
|
||||
def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dict]:
|
||||
"""Compute quantile statistics for all episodes in the dataset.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset to compute statistics for
|
||||
|
||||
Returns:
|
||||
Dictionary containing aggregated statistics with quantiles
|
||||
"""
|
||||
logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes")
|
||||
|
||||
episode_stats_list = []
|
||||
|
||||
for episode_idx in range(dataset.num_episodes):
|
||||
episode_data = load_episode_data(dataset, episode_idx)
|
||||
ep_stats = compute_episode_stats(episode_data, dataset.features)
|
||||
episode_stats_list.append(ep_stats)
|
||||
|
||||
if not episode_stats_list:
|
||||
raise ValueError("No episode data found for computing statistics")
|
||||
|
||||
logging.info(f"Aggregating statistics from {len(episode_stats_list)} episodes")
|
||||
return aggregate_stats(episode_stats_list)
|
||||
|
||||
|
||||
def augment_dataset_with_quantile_stats(
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
) -> None:
|
||||
"""Augment a dataset with quantile statistics if they are missing.
|
||||
|
||||
Args:
|
||||
repo_id: Repository ID of the dataset
|
||||
root: Local root directory for the dataset
|
||||
"""
|
||||
logging.info(f"Loading dataset: {repo_id}")
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
)
|
||||
|
||||
if has_quantile_stats(dataset.meta.stats):
|
||||
logging.info("Dataset already contains quantile statistics. No action needed.")
|
||||
return
|
||||
|
||||
logging.info("Dataset does not contain quantile statistics. Computing them now...")
|
||||
|
||||
new_stats = compute_quantile_stats_for_dataset(dataset)
|
||||
|
||||
logging.info("Updating dataset metadata with new quantile statistics")
|
||||
dataset.meta.stats = new_stats
|
||||
|
||||
write_stats(new_stats, dataset.meta.root)
|
||||
|
||||
logging.info("Successfully updated dataset with quantile statistics")
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run the augmentation script."""
|
||||
parser = argparse.ArgumentParser(description="Augment LeRobot dataset with quantile statistics")
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository ID of the dataset (e.g., 'lerobot/pusht')",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
help="Local root directory for the dataset",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
root = Path(args.root) if args.root else None
|
||||
|
||||
init_logging()
|
||||
|
||||
augment_dataset_with_quantile_stats(
|
||||
repo_id=args.repo_id,
|
||||
root=root,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -281,8 +281,14 @@ class _NormalizationMixin:
|
||||
"""
|
||||
Core logic to apply a normalization or unnormalization transformation to a tensor.
|
||||
|
||||
This method selects the appropriate normalization mode (e.g., mean/std, min/max)
|
||||
based on the feature type and applies the corresponding mathematical operation.
|
||||
This method selects the appropriate normalization mode based on the feature type
|
||||
and applies the corresponding mathematical operation.
|
||||
|
||||
Normalization Modes:
|
||||
- MEAN_STD: Centers data around zero with unit variance.
|
||||
- MIN_MAX: Scales data to [-1, 1] range using actual min/max values.
|
||||
- QUANTILES: Scales data to [0, 1] range using 1st and 99th percentiles (q01/q99).
|
||||
- QUANTILE10: Scales data to [0, 1] range using 10th and 90th percentiles (q10/q90).
|
||||
|
||||
Args:
|
||||
tensor: The input tensor to transform.
|
||||
@@ -300,7 +306,12 @@ class _NormalizationMixin:
|
||||
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats:
|
||||
return tensor
|
||||
|
||||
if norm_mode not in (NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX):
|
||||
if norm_mode not in (
|
||||
NormalizationMode.MEAN_STD,
|
||||
NormalizationMode.MIN_MAX,
|
||||
NormalizationMode.QUANTILES,
|
||||
NormalizationMode.QUANTILE10,
|
||||
):
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
# For Accelerate compatibility: Ensure stats are on the same device and dtype as the input tensor
|
||||
@@ -334,6 +345,28 @@ class _NormalizationMixin:
|
||||
# Map from [min, max] to [-1, 1]
|
||||
return 2 * (tensor - min_val) / denom - 1
|
||||
|
||||
if norm_mode == NormalizationMode.QUANTILES and "q01" in stats and "q99" in stats:
|
||||
q01, q99 = stats["q01"], stats["q99"]
|
||||
denom = q99 - q01
|
||||
# Avoid division by zero by adding epsilon when quantiles are identical
|
||||
denom = torch.where(
|
||||
denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
|
||||
)
|
||||
if inverse:
|
||||
return tensor * denom + q01
|
||||
return (tensor - q01) / denom
|
||||
|
||||
if norm_mode == NormalizationMode.QUANTILE10 and "q10" in stats and "q90" in stats:
|
||||
q10, q90 = stats["q10"], stats["q90"]
|
||||
denom = q90 - q10
|
||||
# Avoid division by zero by adding epsilon when quantiles are identical
|
||||
denom = torch.where(
|
||||
denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
|
||||
)
|
||||
if inverse:
|
||||
return tensor * denom + q10
|
||||
return (tensor - q10) / denom
|
||||
|
||||
# If necessary stats are missing, return input unchanged.
|
||||
return tensor
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.datasets.compute_stats import (
|
||||
RunningQuantileStats,
|
||||
_assert_type_and_shape,
|
||||
aggregate_feature_stats,
|
||||
aggregate_stats,
|
||||
@@ -101,6 +102,9 @@ def test_get_feature_stats_axis_1(sample_array):
|
||||
"count": np.array([3]),
|
||||
}
|
||||
result = get_feature_stats(sample_array, axis=(1,), keepdims=False)
|
||||
|
||||
# Check that basic stats are correct (quantiles are also included now)
|
||||
assert set(expected.keys()).issubset(set(result.keys()))
|
||||
for key in expected:
|
||||
np.testing.assert_allclose(result[key], expected[key])
|
||||
|
||||
@@ -114,6 +118,9 @@ def test_get_feature_stats_no_axis(sample_array):
|
||||
"count": np.array([3]),
|
||||
}
|
||||
result = get_feature_stats(sample_array, axis=None, keepdims=False)
|
||||
|
||||
# Check that basic stats are correct (quantiles are also included now)
|
||||
assert set(expected.keys()).issubset(set(result.keys()))
|
||||
for key in expected:
|
||||
np.testing.assert_allclose(result[key], expected[key])
|
||||
|
||||
@@ -307,3 +314,520 @@ def test_aggregate_stats():
|
||||
results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04
|
||||
)
|
||||
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])
|
||||
|
||||
|
||||
def test_running_quantile_stats_initialization():
|
||||
"""Test proper initialization of RunningQuantileStats."""
|
||||
running_stats = RunningQuantileStats()
|
||||
assert running_stats._count == 0
|
||||
assert running_stats._mean is None
|
||||
assert running_stats._num_quantile_bins == 5000
|
||||
|
||||
# Test custom bin size
|
||||
running_stats_custom = RunningQuantileStats(num_quantile_bins=1000)
|
||||
assert running_stats_custom._num_quantile_bins == 1000
|
||||
|
||||
|
||||
def test_running_quantile_stats_single_batch_update():
|
||||
"""Test updating with a single batch."""
|
||||
np.random.seed(42)
|
||||
data = np.random.normal(0, 1, (100, 3))
|
||||
|
||||
running_stats = RunningQuantileStats()
|
||||
running_stats.update(data)
|
||||
|
||||
assert running_stats._count == 100
|
||||
assert running_stats._mean.shape == (3,)
|
||||
assert len(running_stats._histograms) == 3
|
||||
assert len(running_stats._bin_edges) == 3
|
||||
|
||||
# Verify basic statistics are reasonable
|
||||
np.testing.assert_allclose(running_stats._mean, np.mean(data, axis=0), atol=1e-10)
|
||||
|
||||
|
||||
def test_running_quantile_stats_multiple_batch_updates():
|
||||
"""Test updating with multiple batches."""
|
||||
np.random.seed(42)
|
||||
data1 = np.random.normal(0, 1, (100, 2))
|
||||
data2 = np.random.normal(1, 1, (150, 2))
|
||||
|
||||
running_stats = RunningQuantileStats()
|
||||
running_stats.update(data1)
|
||||
running_stats.update(data2)
|
||||
|
||||
assert running_stats._count == 250
|
||||
|
||||
# Verify running mean is correct
|
||||
combined_data = np.vstack([data1, data2])
|
||||
expected_mean = np.mean(combined_data, axis=0)
|
||||
np.testing.assert_allclose(running_stats._mean, expected_mean, atol=1e-10)
|
||||
|
||||
|
||||
def test_running_quantile_stats_get_statistics_basic():
|
||||
"""Test getting basic statistics without quantiles."""
|
||||
np.random.seed(42)
|
||||
data = np.random.normal(0, 1, (100, 2))
|
||||
|
||||
running_stats = RunningQuantileStats()
|
||||
running_stats.update(data)
|
||||
|
||||
stats = running_stats.get_statistics()
|
||||
|
||||
# Should have basic stats
|
||||
expected_keys = {"min", "max", "mean", "std", "count"}
|
||||
assert expected_keys.issubset(set(stats.keys()))
|
||||
|
||||
# Verify values
|
||||
np.testing.assert_allclose(stats["mean"], np.mean(data, axis=0), atol=1e-10)
|
||||
np.testing.assert_allclose(stats["std"], np.std(data, axis=0), atol=1e-6)
|
||||
np.testing.assert_equal(stats["count"], np.array([100]))
|
||||
|
||||
|
||||
def test_running_quantile_stats_get_statistics_with_quantiles():
|
||||
"""Test getting statistics with quantiles."""
|
||||
np.random.seed(42)
|
||||
data = np.random.normal(0, 1, (1000, 2))
|
||||
|
||||
running_stats = RunningQuantileStats()
|
||||
running_stats.update(data)
|
||||
|
||||
stats = running_stats.get_statistics()
|
||||
|
||||
# Should have basic stats plus quantiles
|
||||
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
||||
assert expected_keys.issubset(set(stats.keys()))
|
||||
|
||||
# Verify quantile values are reasonable
|
||||
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES
|
||||
|
||||
for i, q in enumerate(DEFAULT_QUANTILES):
|
||||
q_key = f"q{int(q * 100):02d}"
|
||||
assert q_key in stats
|
||||
assert stats[q_key].shape == (2,)
|
||||
|
||||
# Check that quantiles are in reasonable order
|
||||
if i > 0:
|
||||
prev_q_key = f"q{int(DEFAULT_QUANTILES[i - 1] * 100):02d}"
|
||||
assert np.all(stats[prev_q_key] <= stats[q_key])
|
||||
|
||||
|
||||
def test_running_quantile_stats_histogram_adjustment():
|
||||
"""Test that histograms adjust when min/max change."""
|
||||
running_stats = RunningQuantileStats()
|
||||
|
||||
# Initial data with small range
|
||||
data1 = np.array([[0.0, 1.0], [0.1, 1.1], [0.2, 1.2]])
|
||||
running_stats.update(data1)
|
||||
|
||||
initial_edges_0 = running_stats._bin_edges[0].copy()
|
||||
initial_edges_1 = running_stats._bin_edges[1].copy()
|
||||
|
||||
# Add data with much larger range
|
||||
data2 = np.array([[10.0, -10.0], [11.0, -11.0]])
|
||||
running_stats.update(data2)
|
||||
|
||||
# Bin edges should have changed
|
||||
assert not np.array_equal(initial_edges_0, running_stats._bin_edges[0])
|
||||
assert not np.array_equal(initial_edges_1, running_stats._bin_edges[1])
|
||||
|
||||
# New edges should cover the expanded range
|
||||
# First dimension: min should still be ~0.0, max should be ~11.0
|
||||
assert running_stats._bin_edges[0][0] <= 0.0
|
||||
assert running_stats._bin_edges[0][-1] >= 11.0
|
||||
|
||||
# Second dimension: min should be ~-11.0, max should be ~1.2
|
||||
assert running_stats._bin_edges[1][0] <= -11.0
|
||||
assert running_stats._bin_edges[1][-1] >= 1.2
|
||||
|
||||
|
||||
def test_running_quantile_stats_insufficient_data_error():
|
||||
"""Test error when trying to get stats with insufficient data."""
|
||||
running_stats = RunningQuantileStats()
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot compute statistics for less than 2 vectors"):
|
||||
running_stats.get_statistics()
|
||||
|
||||
# Single vector should also fail
|
||||
running_stats.update(np.array([[1.0]]))
|
||||
with pytest.raises(ValueError, match="Cannot compute statistics for less than 2 vectors"):
|
||||
running_stats.get_statistics()
|
||||
|
||||
|
||||
def test_running_quantile_stats_vector_length_consistency():
|
||||
"""Test error when vector lengths don't match."""
|
||||
running_stats = RunningQuantileStats()
|
||||
running_stats.update(np.array([[1.0, 2.0], [3.0, 4.0]]))
|
||||
|
||||
with pytest.raises(ValueError, match="The length of new vectors does not match"):
|
||||
running_stats.update(np.array([[1.0, 2.0, 3.0]])) # Different length
|
||||
|
||||
|
||||
def test_running_quantile_stats_reshape_handling():
|
||||
"""Test that various input shapes are handled correctly."""
|
||||
running_stats = RunningQuantileStats()
|
||||
|
||||
# Test 3D input (e.g., images)
|
||||
data_3d = np.random.normal(0, 1, (10, 32, 32))
|
||||
running_stats.update(data_3d)
|
||||
|
||||
assert running_stats._count == 10 * 32
|
||||
assert running_stats._mean.shape == (32,)
|
||||
|
||||
# Test 1D input
|
||||
running_stats_1d = RunningQuantileStats()
|
||||
data_1d = np.array([1, 2, 3, 4, 5]).reshape(-1, 1)
|
||||
running_stats_1d.update(data_1d)
|
||||
|
||||
assert running_stats_1d._count == 5
|
||||
assert running_stats_1d._mean.shape == (1,)
|
||||
|
||||
|
||||
def test_get_feature_stats_quantiles_enabled_by_default():
|
||||
"""Test that quantiles are computed by default."""
|
||||
data = np.random.normal(0, 1, (100, 5))
|
||||
stats = get_feature_stats(data, axis=0, keepdims=False)
|
||||
|
||||
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
||||
assert set(stats.keys()) == expected_keys
|
||||
|
||||
|
||||
def test_get_feature_stats_quantiles_with_vector_data():
|
||||
"""Test quantile computation with vector data."""
|
||||
np.random.seed(42)
|
||||
data = np.random.normal(0, 1, (100, 5))
|
||||
|
||||
stats = get_feature_stats(data, axis=0, keepdims=False)
|
||||
|
||||
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
||||
assert set(stats.keys()) == expected_keys
|
||||
|
||||
# Verify shapes
|
||||
assert stats["q01"].shape == (5,)
|
||||
assert stats["q99"].shape == (5,)
|
||||
|
||||
# Verify quantiles are reasonable
|
||||
assert np.all(stats["q01"] < stats["q99"])
|
||||
|
||||
|
||||
def test_get_feature_stats_quantiles_with_image_data():
|
||||
"""Test quantile computation with image data."""
|
||||
np.random.seed(42)
|
||||
data = np.random.normal(0, 1, (50, 3, 32, 32)) # batch, channels, height, width
|
||||
|
||||
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
||||
|
||||
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
||||
assert set(stats.keys()) == expected_keys
|
||||
|
||||
# Verify shapes for images (should be (1, channels, 1, 1))
|
||||
assert stats["q01"].shape == (1, 3, 1, 1)
|
||||
assert stats["q50"].shape == (1, 3, 1, 1)
|
||||
assert stats["q99"].shape == (1, 3, 1, 1)
|
||||
|
||||
|
||||
def test_get_feature_stats_fixed_quantiles():
|
||||
"""Test that fixed quantiles are always computed."""
|
||||
data = np.random.normal(0, 1, (200, 3))
|
||||
|
||||
stats = get_feature_stats(data, axis=0, keepdims=False)
|
||||
|
||||
expected_quantile_keys = {"q01", "q10", "q50", "q90", "q99"}
|
||||
assert expected_quantile_keys.issubset(set(stats.keys()))
|
||||
|
||||
|
||||
def test_get_feature_stats_unsupported_axis_error():
|
||||
"""Test error for unsupported axis configuration."""
|
||||
data = np.random.normal(0, 1, (10, 5))
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported axis configuration"):
|
||||
get_feature_stats(
|
||||
data,
|
||||
axis=(1, 2), # Unsupported axis
|
||||
keepdims=False,
|
||||
)
|
||||
|
||||
|
||||
def test_compute_episode_stats_backward_compatibility():
|
||||
"""Test that existing functionality is preserved."""
|
||||
episode_data = {
|
||||
"action": np.random.normal(0, 1, (100, 7)),
|
||||
"observation.state": np.random.normal(0, 1, (100, 10)),
|
||||
}
|
||||
features = {
|
||||
"action": {"dtype": "float32", "shape": (7,)},
|
||||
"observation.state": {"dtype": "float32", "shape": (10,)},
|
||||
}
|
||||
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
for key in ["action", "observation.state"]:
|
||||
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
||||
assert set(stats[key].keys()) == expected_keys
|
||||
|
||||
|
||||
def test_compute_episode_stats_with_custom_quantiles():
|
||||
"""Test quantile computation with custom quantile values."""
|
||||
np.random.seed(42)
|
||||
episode_data = {
|
||||
"action": np.random.normal(0, 1, (100, 7)),
|
||||
"observation.state": np.random.normal(2, 1, (100, 10)),
|
||||
}
|
||||
features = {
|
||||
"action": {"dtype": "float32", "shape": (7,)},
|
||||
"observation.state": {"dtype": "float32", "shape": (10,)},
|
||||
}
|
||||
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
# Should have quantiles
|
||||
for key in ["action", "observation.state"]:
|
||||
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
||||
assert set(stats[key].keys()) == expected_keys
|
||||
|
||||
# Verify shapes
|
||||
assert stats[key]["q01"].shape == (features[key]["shape"][0],)
|
||||
assert stats[key]["q99"].shape == (features[key]["shape"][0],)
|
||||
|
||||
|
||||
def test_compute_episode_stats_with_image_data():
|
||||
"""Test quantile computation with image features."""
|
||||
image_paths = [f"image_{i}.jpg" for i in range(50)]
|
||||
episode_data = {
|
||||
"observation.image": image_paths,
|
||||
"action": np.random.normal(0, 1, (50, 5)),
|
||||
}
|
||||
features = {
|
||||
"observation.image": {"dtype": "image"},
|
||||
"action": {"dtype": "float32", "shape": (5,)},
|
||||
}
|
||||
|
||||
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
# Image quantiles should be normalized and have correct shape
|
||||
assert "q01" in stats["observation.image"]
|
||||
assert "q50" in stats["observation.image"]
|
||||
assert "q99" in stats["observation.image"]
|
||||
assert stats["observation.image"]["q01"].shape == (3, 1, 1)
|
||||
assert stats["observation.image"]["q50"].shape == (3, 1, 1)
|
||||
assert stats["observation.image"]["q99"].shape == (3, 1, 1)
|
||||
|
||||
# Action quantiles should have correct shape
|
||||
assert stats["action"]["q01"].shape == (5,)
|
||||
assert stats["action"]["q50"].shape == (5,)
|
||||
assert stats["action"]["q99"].shape == (5,)
|
||||
|
||||
|
||||
def test_compute_episode_stats_string_features_skipped():
|
||||
"""Test that string features are properly skipped."""
|
||||
episode_data = {
|
||||
"task": ["pick_apple"] * 100, # String feature
|
||||
"action": np.random.normal(0, 1, (100, 5)),
|
||||
}
|
||||
features = {
|
||||
"task": {"dtype": "string"},
|
||||
"action": {"dtype": "float32", "shape": (5,)},
|
||||
}
|
||||
|
||||
stats = compute_episode_stats(
|
||||
episode_data,
|
||||
features,
|
||||
)
|
||||
|
||||
# String features should be skipped
|
||||
assert "task" not in stats
|
||||
assert "action" in stats
|
||||
assert "q01" in stats["action"]
|
||||
|
||||
|
||||
def test_aggregate_feature_stats_with_quantiles():
|
||||
"""Test aggregating feature stats that include quantiles."""
|
||||
stats_ft_list = [
|
||||
{
|
||||
"min": np.array([1.0]),
|
||||
"max": np.array([10.0]),
|
||||
"mean": np.array([5.0]),
|
||||
"std": np.array([2.0]),
|
||||
"count": np.array([100]),
|
||||
"q01": np.array([1.5]),
|
||||
"q99": np.array([9.5]),
|
||||
},
|
||||
{
|
||||
"min": np.array([2.0]),
|
||||
"max": np.array([12.0]),
|
||||
"mean": np.array([6.0]),
|
||||
"std": np.array([2.5]),
|
||||
"count": np.array([150]),
|
||||
"q01": np.array([2.5]),
|
||||
"q99": np.array([11.5]),
|
||||
},
|
||||
]
|
||||
|
||||
result = aggregate_feature_stats(stats_ft_list)
|
||||
|
||||
# Should preserve quantiles
|
||||
assert "q01" in result
|
||||
assert "q99" in result
|
||||
|
||||
# Verify quantile aggregation (weighted average)
|
||||
expected_q01 = (1.5 * 100 + 2.5 * 150) / 250 # ≈ 2.1
|
||||
expected_q99 = (9.5 * 100 + 11.5 * 150) / 250 # ≈ 10.7
|
||||
|
||||
np.testing.assert_allclose(result["q01"], np.array([expected_q01]), atol=1e-6)
|
||||
np.testing.assert_allclose(result["q99"], np.array([expected_q99]), atol=1e-6)
|
||||
|
||||
|
||||
def test_aggregate_stats_mixed_quantiles():
|
||||
"""Test aggregating stats where some have quantiles and some don't."""
|
||||
stats_with_quantiles = {
|
||||
"feature1": {
|
||||
"min": np.array([1.0]),
|
||||
"max": np.array([10.0]),
|
||||
"mean": np.array([5.0]),
|
||||
"std": np.array([2.0]),
|
||||
"count": np.array([100]),
|
||||
"q01": np.array([1.5]),
|
||||
"q99": np.array([9.5]),
|
||||
}
|
||||
}
|
||||
|
||||
stats_without_quantiles = {
|
||||
"feature2": {
|
||||
"min": np.array([0.0]),
|
||||
"max": np.array([5.0]),
|
||||
"mean": np.array([2.5]),
|
||||
"std": np.array([1.5]),
|
||||
"count": np.array([50]),
|
||||
}
|
||||
}
|
||||
|
||||
all_stats = [stats_with_quantiles, stats_without_quantiles]
|
||||
result = aggregate_stats(all_stats)
|
||||
|
||||
# Feature1 should keep its quantiles
|
||||
assert "q01" in result["feature1"]
|
||||
assert "q99" in result["feature1"]
|
||||
|
||||
# Feature2 should not have quantiles
|
||||
assert "q01" not in result["feature2"]
|
||||
assert "q99" not in result["feature2"]
|
||||
|
||||
|
||||
def test_assert_type_and_shape_with_quantiles():
|
||||
"""Test validation works correctly with quantile keys."""
|
||||
# Valid stats with quantiles
|
||||
valid_stats = [
|
||||
{
|
||||
"observation.image": {
|
||||
"min": np.array([0.0, 0.0, 0.0]).reshape(3, 1, 1),
|
||||
"max": np.array([1.0, 1.0, 1.0]).reshape(3, 1, 1),
|
||||
"mean": np.array([0.5, 0.5, 0.5]).reshape(3, 1, 1),
|
||||
"std": np.array([0.2, 0.2, 0.2]).reshape(3, 1, 1),
|
||||
"count": np.array([100]),
|
||||
"q01": np.array([0.1, 0.1, 0.1]).reshape(3, 1, 1),
|
||||
"q99": np.array([0.9, 0.9, 0.9]).reshape(3, 1, 1),
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# Should not raise error
|
||||
_assert_type_and_shape(valid_stats)
|
||||
|
||||
# Invalid shape for quantile
|
||||
invalid_stats = [
|
||||
{
|
||||
"observation.image": {
|
||||
"count": np.array([100]),
|
||||
"q01": np.array([0.1, 0.2]), # Wrong shape for image quantile
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="Shape of quantile 'q01' must be \\(3,1,1\\)"):
|
||||
_assert_type_and_shape(invalid_stats)
|
||||
|
||||
|
||||
def test_quantile_integration_single_value_quantiles():
|
||||
"""Test quantile computation with single repeated value."""
|
||||
data = np.ones((100, 3)) # All ones
|
||||
|
||||
running_stats = RunningQuantileStats()
|
||||
running_stats.update(data)
|
||||
|
||||
stats = running_stats.get_statistics()
|
||||
|
||||
# All quantiles should be approximately 1.0
|
||||
np.testing.assert_allclose(stats["q01"], np.array([1.0, 1.0, 1.0]), atol=1e-6)
|
||||
np.testing.assert_allclose(stats["q50"], np.array([1.0, 1.0, 1.0]), atol=1e-6)
|
||||
np.testing.assert_allclose(stats["q99"], np.array([1.0, 1.0, 1.0]), atol=1e-6)
|
||||
|
||||
|
||||
def test_quantile_integration_fixed_quantiles():
|
||||
"""Test that fixed quantiles are computed."""
|
||||
np.random.seed(42)
|
||||
data = np.random.normal(0, 1, (1000, 2))
|
||||
|
||||
stats = get_feature_stats(data, axis=0, keepdims=False)
|
||||
|
||||
# Check all fixed quantiles are present
|
||||
assert "q01" in stats
|
||||
assert "q10" in stats
|
||||
assert "q50" in stats
|
||||
assert "q90" in stats
|
||||
assert "q99" in stats
|
||||
|
||||
|
||||
def test_quantile_integration_large_dataset_quantiles():
|
||||
"""Test quantile computation efficiency with large datasets."""
|
||||
np.random.seed(42)
|
||||
large_data = np.random.normal(0, 1, (10000, 5))
|
||||
|
||||
running_stats = RunningQuantileStats(num_quantile_bins=1000) # Reduced bins for speed
|
||||
running_stats.update(large_data)
|
||||
|
||||
stats = running_stats.get_statistics()
|
||||
|
||||
# Should complete without issues and produce reasonable results
|
||||
assert stats["count"][0] == 10000
|
||||
assert len(stats["q01"]) == 5
|
||||
|
||||
|
||||
def test_fixed_quantiles_always_computed():
|
||||
"""Test that the fixed quantiles [0.01, 0.10, 0.50, 0.90, 0.99] are always computed."""
|
||||
np.random.seed(42)
|
||||
# Test with vector data
|
||||
vector_data = np.random.normal(0, 1, (100, 5))
|
||||
vector_stats = get_feature_stats(vector_data, axis=0, keepdims=False)
|
||||
|
||||
# Check all fixed quantiles are present
|
||||
expected_quantiles = ["q01", "q10", "q50", "q90", "q99"]
|
||||
for q_key in expected_quantiles:
|
||||
assert q_key in vector_stats
|
||||
assert vector_stats[q_key].shape == (5,)
|
||||
|
||||
# Test with image data
|
||||
image_data = np.random.randint(0, 256, (50, 3, 32, 32), dtype=np.uint8)
|
||||
image_stats = get_feature_stats(image_data, axis=(0, 2, 3), keepdims=True)
|
||||
|
||||
# Check all fixed quantiles are present for images
|
||||
for q_key in expected_quantiles:
|
||||
assert q_key in image_stats
|
||||
assert image_stats[q_key].shape == (1, 3, 1, 1)
|
||||
|
||||
# Test with episode data
|
||||
episode_data = {
|
||||
"action": np.random.normal(0, 1, (100, 7)),
|
||||
"observation.state": np.random.normal(0, 1, (100, 10)),
|
||||
}
|
||||
features = {
|
||||
"action": {"dtype": "float32", "shape": (7,)},
|
||||
"observation.state": {"dtype": "float32", "shape": (10,)},
|
||||
}
|
||||
|
||||
episode_stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
# Check all fixed quantiles are present in episode stats
|
||||
for key in ["action", "observation.state"]:
|
||||
for q_key in expected_quantiles:
|
||||
assert q_key in episode_stats[key]
|
||||
assert episode_stats[key][q_key].shape == (features[key]["shape"][0],)
|
||||
|
||||
212
tests/datasets/test_quantiles_dataset_integration.py
Normal file
212
tests/datasets/test_quantiles_dataset_integration.py
Normal file
@@ -0,0 +1,212 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Integration tests for quantile functionality in LeRobotDataset."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||
"""Mock image loading for consistent test results."""
|
||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_features():
|
||||
"""Simple feature configuration for testing."""
|
||||
return {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": ["arm_x", "arm_y", "arm_z", "gripper"],
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (10,),
|
||||
"names": [f"joint_{i}" for i in range(10)],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_create_dataset_with_fixed_quantiles(tmp_path, simple_features):
|
||||
"""Test creating dataset with fixed quantiles."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test_dataset_fixed_quantiles",
|
||||
fps=30,
|
||||
features=simple_features,
|
||||
root=tmp_path / "create_fixed_quantiles",
|
||||
)
|
||||
|
||||
# Dataset should be created successfully
|
||||
assert dataset is not None
|
||||
|
||||
|
||||
def test_save_episode_computes_all_quantiles(tmp_path, simple_features):
|
||||
"""Test that all fixed quantiles are computed when saving an episode."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test_dataset_save_episode",
|
||||
fps=30,
|
||||
features=simple_features,
|
||||
root=tmp_path / "save_episode_quantiles",
|
||||
)
|
||||
|
||||
# Add some frames
|
||||
for _ in range(10):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"action": np.random.randn(4).astype(np.float32), # Correct shape for action
|
||||
"observation.state": np.random.randn(10).astype(np.float32),
|
||||
"task": "test_task",
|
||||
}
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
# Check that all fixed quantiles were computed
|
||||
stats = dataset.meta.stats
|
||||
for key in ["action", "observation.state"]:
|
||||
assert "q01" in stats[key]
|
||||
assert "q10" in stats[key]
|
||||
assert "q50" in stats[key]
|
||||
assert "q90" in stats[key]
|
||||
assert "q99" in stats[key]
|
||||
|
||||
|
||||
def test_quantile_values_ordering(tmp_path, simple_features):
|
||||
"""Test that quantile values are properly ordered."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test_dataset_quantile_ordering",
|
||||
fps=30,
|
||||
features=simple_features,
|
||||
root=tmp_path / "quantile_ordering",
|
||||
)
|
||||
|
||||
# Add data with known distribution
|
||||
np.random.seed(42)
|
||||
for _ in range(100):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"action": np.random.randn(4).astype(np.float32), # Correct shape for action
|
||||
"observation.state": np.random.randn(10).astype(np.float32),
|
||||
"task": "test_task",
|
||||
}
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
stats = dataset.meta.stats
|
||||
|
||||
# Verify quantile ordering
|
||||
for key in ["action", "observation.state"]:
|
||||
assert np.all(stats[key]["q01"] <= stats[key]["q10"])
|
||||
assert np.all(stats[key]["q10"] <= stats[key]["q50"])
|
||||
assert np.all(stats[key]["q50"] <= stats[key]["q90"])
|
||||
assert np.all(stats[key]["q90"] <= stats[key]["q99"])
|
||||
|
||||
|
||||
def test_save_episode_with_fixed_quantiles(tmp_path, simple_features):
|
||||
"""Test saving episode always computes fixed quantiles."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test_dataset_save_fixed",
|
||||
fps=30,
|
||||
features=simple_features,
|
||||
root=tmp_path / "save_fixed_quantiles",
|
||||
)
|
||||
|
||||
# Add frames to episode
|
||||
np.random.seed(42)
|
||||
for _ in range(50):
|
||||
frame = {
|
||||
"action": np.random.normal(0, 1, (4,)).astype(np.float32),
|
||||
"observation.state": np.random.normal(0, 1, (10,)).astype(np.float32),
|
||||
"task": "test_task",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
# Check that all fixed quantiles are included
|
||||
stats = dataset.meta.stats
|
||||
for key in ["action", "observation.state"]:
|
||||
feature_stats = stats[key]
|
||||
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
||||
assert set(feature_stats.keys()) == expected_keys
|
||||
|
||||
|
||||
def test_quantile_aggregation_across_episodes(tmp_path, simple_features):
|
||||
"""Test quantile aggregation across multiple episodes."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test_dataset_aggregation",
|
||||
fps=30,
|
||||
features=simple_features,
|
||||
root=tmp_path / "quantile_aggregation",
|
||||
)
|
||||
|
||||
# Add frames to episode
|
||||
np.random.seed(42)
|
||||
for _ in range(100):
|
||||
frame = {
|
||||
"action": np.random.normal(0, 1, (4,)).astype(np.float32),
|
||||
"observation.state": np.random.normal(2, 1, (10,)).astype(np.float32),
|
||||
"task": "test_task",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
# Check stats include all fixed quantiles
|
||||
stats = dataset.meta.stats
|
||||
for key in ["action", "observation.state"]:
|
||||
feature_stats = stats[key]
|
||||
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
|
||||
assert set(feature_stats.keys()) == expected_keys
|
||||
assert feature_stats["q01"].shape == (simple_features[key]["shape"][0],)
|
||||
assert feature_stats["q50"].shape == (simple_features[key]["shape"][0],)
|
||||
assert feature_stats["q99"].shape == (simple_features[key]["shape"][0],)
|
||||
assert np.all(feature_stats["q01"] <= feature_stats["q50"])
|
||||
assert np.all(feature_stats["q50"] <= feature_stats["q99"])
|
||||
|
||||
|
||||
def test_save_multiple_episodes_with_quantiles(tmp_path, simple_features):
|
||||
"""Test quantile aggregation across multiple episodes."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test_dataset_multiple_episodes",
|
||||
fps=30,
|
||||
features=simple_features,
|
||||
root=tmp_path / "multiple_episodes",
|
||||
)
|
||||
|
||||
# Save multiple episodes
|
||||
np.random.seed(42)
|
||||
for episode_idx in range(3):
|
||||
for _ in range(50):
|
||||
frame = {
|
||||
"action": np.random.normal(episode_idx * 2.0, 1, (4,)).astype(np.float32),
|
||||
"observation.state": np.random.normal(-episode_idx * 1.5, 1, (10,)).astype(np.float32),
|
||||
"task": f"task_{episode_idx}",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
# Verify final stats include properly aggregated quantiles
|
||||
stats = dataset.meta.stats
|
||||
for key in ["action", "observation.state"]:
|
||||
feature_stats = stats[key]
|
||||
assert "q01" in feature_stats and "q99" in feature_stats
|
||||
assert feature_stats["count"][0] == 150 # 3 episodes * 50 frames
|
||||
@@ -165,6 +165,229 @@ def test_min_max_normalization(observation_normalizer):
|
||||
assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6)
|
||||
|
||||
|
||||
def test_quantile_normalization():
|
||||
"""Test QUANTILES mode using 1st-99th percentiles."""
|
||||
features = {
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.STATE: NormalizationMode.QUANTILES,
|
||||
}
|
||||
stats = {
|
||||
"observation.state": {
|
||||
"q01": np.array([0.1, -0.8]), # 1st percentile
|
||||
"q99": np.array([0.9, 0.8]), # 99th percentile
|
||||
},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
normalized_transition = normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check quantile normalization to [0, 1]
|
||||
# For state[0]: (0.5 - 0.1) / (0.9 - 0.1) = 0.4 / 0.8 = 0.5
|
||||
# For state[1]: (0.0 - (-0.8)) / (0.8 - (-0.8)) = 0.8 / 1.6 = 0.5
|
||||
expected_state = torch.tensor([0.5, 0.5])
|
||||
assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6)
|
||||
|
||||
|
||||
def test_quantile10_normalization():
|
||||
"""Test QUANTILE10 mode using 10th-90th percentiles."""
|
||||
features = {
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.STATE: NormalizationMode.QUANTILE10,
|
||||
}
|
||||
stats = {
|
||||
"observation.state": {
|
||||
"q10": np.array([0.2, -0.6]), # 10th percentile
|
||||
"q90": np.array([0.8, 0.6]), # 90th percentile
|
||||
},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
normalized_transition = normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check quantile normalization to [0, 1]
|
||||
# For state[0]: (0.5 - 0.2) / (0.8 - 0.2) = 0.3 / 0.6 = 0.5
|
||||
# For state[1]: (0.0 - (-0.6)) / (0.6 - (-0.6)) = 0.6 / 1.2 = 0.5
|
||||
expected_state = torch.tensor([0.5, 0.5])
|
||||
assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6)
|
||||
|
||||
|
||||
def test_quantile_unnormalization():
|
||||
"""Test that quantile normalization can be reversed properly."""
|
||||
features = {
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.ACTION: NormalizationMode.QUANTILES,
|
||||
}
|
||||
stats = {
|
||||
"action": {
|
||||
"q01": np.array([0.1, -0.8]),
|
||||
"q99": np.array([0.9, 0.8]),
|
||||
},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
# Test round-trip normalization
|
||||
original_action = torch.tensor([0.5, 0.0])
|
||||
transition = create_transition(action=original_action)
|
||||
|
||||
# Normalize then unnormalize
|
||||
normalized = normalizer(transition)
|
||||
unnormalized = unnormalizer(normalized)
|
||||
|
||||
# Should recover original values
|
||||
recovered_action = unnormalized[TransitionKey.ACTION]
|
||||
assert torch.allclose(recovered_action, original_action, atol=1e-6)
|
||||
|
||||
|
||||
def test_quantile_division_by_zero():
|
||||
"""Test quantile normalization handles edge case where q01 == q99."""
|
||||
features = {
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (1,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.STATE: NormalizationMode.QUANTILES,
|
||||
}
|
||||
stats = {
|
||||
"observation.state": {
|
||||
"q01": np.array([0.5]), # Same value
|
||||
"q99": np.array([0.5]), # Same value -> division by zero case
|
||||
},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.tensor([0.5]),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
# Should not crash and should handle gracefully
|
||||
normalized_transition = normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# When quantiles are identical, should normalize to 0 (due to epsilon handling)
|
||||
assert torch.isfinite(normalized_obs["observation.state"]).all()
|
||||
|
||||
|
||||
def test_quantile_partial_stats():
|
||||
"""Test that quantile normalization handles missing quantile stats gracefully."""
|
||||
features = {
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.STATE: NormalizationMode.QUANTILES,
|
||||
}
|
||||
|
||||
# Missing q99 - should pass through unchanged
|
||||
stats_partial = {
|
||||
"observation.state": {
|
||||
"q01": np.array([0.1, -0.8]), # Only q01, missing q99
|
||||
},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats_partial)
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
normalized_transition = normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Should pass through unchanged when stats are incomplete
|
||||
assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"])
|
||||
|
||||
|
||||
def test_quantile_mixed_with_other_modes():
|
||||
"""Test quantile normalization mixed with other normalization modes."""
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD, # Standard normalization
|
||||
FeatureType.STATE: NormalizationMode.QUANTILES, # Quantile normalization
|
||||
FeatureType.ACTION: NormalizationMode.QUANTILE10, # Different quantile mode
|
||||
}
|
||||
stats = {
|
||||
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
|
||||
"observation.state": {"q01": [0.1, -0.8], "q99": [0.9, 0.8]},
|
||||
"action": {"q10": [0.2, -0.6], "q90": [0.8, 0.6]},
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([0.5, 0.0]), # Should use QUANTILES
|
||||
}
|
||||
action = torch.tensor([0.5, 0.0]) # Should use QUANTILE10
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
|
||||
normalized_transition = normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
normalized_action = normalized_transition[TransitionKey.ACTION]
|
||||
|
||||
# Image should be mean/std normalized: (0.7 - 0.5) / 0.2 = 1.0, etc.
|
||||
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
|
||||
assert torch.allclose(normalized_obs["observation.image"], expected_image)
|
||||
|
||||
# State should be quantile normalized: (0.5 - 0.1) / (0.9 - 0.1) = 0.5, etc.
|
||||
expected_state = torch.tensor([0.5, 0.5])
|
||||
assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6)
|
||||
|
||||
# Action should be quantile10 normalized: (0.5 - 0.2) / (0.8 - 0.2) = 0.5, etc.
|
||||
expected_action = torch.tensor([0.5, 0.5])
|
||||
assert torch.allclose(normalized_action, expected_action, atol=1e-6)
|
||||
|
||||
|
||||
def test_quantile_with_missing_stats():
|
||||
"""Test that quantile normalization handles completely missing stats gracefully."""
|
||||
features = {
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.STATE: NormalizationMode.QUANTILES,
|
||||
}
|
||||
stats = {} # No stats provided
|
||||
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
normalized_transition = normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Should pass through unchanged when no stats available
|
||||
assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"])
|
||||
|
||||
|
||||
def test_selective_normalization(observation_stats):
|
||||
features = _create_observation_features()
|
||||
norm_map = _create_observation_norm_map()
|
||||
|
||||
Reference in New Issue
Block a user