#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import logging import numpy as np from lerobot.datasets.io_utils import load_image_as_numpy from lerobot.utils.constants import ACTION, OBS_STATE DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99] class RunningQuantileStats: """ Maintains running statistics for batches of vectors, including mean, standard deviation, min, max, and approximate quantiles. Statistics are computed per feature dimension and updated incrementally as new batches are observed. Quantiles are estimated using histograms, which adapt dynamically if the observed data range expands. """ def __init__(self, quantile_list: list[float] | None = None, num_quantile_bins: int = 5000): self._count = 0 self._mean = None self._mean_of_squares = None self._min = None self._max = None self._histograms = None self._bin_edges = None self._num_quantile_bins = num_quantile_bins self._quantile_list = quantile_list if self._quantile_list is None: self._quantile_list = DEFAULT_QUANTILES self._quantile_keys = [f"q{int(q * 100):02d}" for q in self._quantile_list] def update(self, batch: np.ndarray) -> None: """Update the running statistics with a batch of vectors. Args: batch: An array where all dimensions except the last are batch dimensions. """ batch = batch.reshape(-1, batch.shape[-1]) num_elements, vector_length = batch.shape if self._count == 0: self._mean = np.mean(batch, axis=0) self._mean_of_squares = np.mean(batch**2, axis=0) self._min = np.min(batch, axis=0) self._max = np.max(batch, axis=0) self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)] self._bin_edges = [ np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1) for i in range(vector_length) ] else: if vector_length != self._mean.size: raise ValueError("The length of new vectors does not match the initialized vector length.") new_max = np.max(batch, axis=0) new_min = np.min(batch, axis=0) max_changed = np.any(new_max > self._max) min_changed = np.any(new_min < self._min) self._max = np.maximum(self._max, new_max) self._min = np.minimum(self._min, new_min) if max_changed or min_changed: self._adjust_histograms() self._count += num_elements batch_mean = np.mean(batch, axis=0) batch_mean_of_squares = np.mean(batch**2, axis=0) # Update running mean and mean of squares self._mean += (batch_mean - self._mean) * (num_elements / self._count) self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * ( num_elements / self._count ) self._update_histograms(batch) def get_statistics(self) -> dict[str, np.ndarray]: """Compute and return the statistics of the vectors processed so far. Args: quantiles: List of quantiles to compute (e.g., [0.01, 0.10, 0.50, 0.90, 0.99]). If None, no quantiles computed. Returns: Dictionary containing the computed statistics. """ if self._count < 2: raise ValueError("Cannot compute statistics for less than 2 vectors.") variance = self._mean_of_squares - self._mean**2 stddev = np.sqrt(np.maximum(0, variance)) stats = { "min": self._min.copy(), "max": self._max.copy(), "mean": self._mean.copy(), "std": stddev, "count": np.array([self._count]), } quantile_results = self._compute_quantiles() for i, q in enumerate(self._quantile_keys): stats[q] = quantile_results[i] return stats def _adjust_histograms(self): """Adjust histograms when min or max changes.""" for i in range(len(self._histograms)): old_edges = self._bin_edges[i] old_hist = self._histograms[i] # Create new edges with small padding to ensure range coverage padding = (self._max[i] - self._min[i]) * 1e-10 new_edges = np.linspace( self._min[i] - padding, self._max[i] + padding, self._num_quantile_bins + 1 ) # Redistribute existing histogram counts to new bins # We need to map each old bin center to the new bins old_centers = (old_edges[:-1] + old_edges[1:]) / 2 new_hist = np.zeros(self._num_quantile_bins) for old_center, count in zip(old_centers, old_hist, strict=False): if count > 0: # Find which new bin this old center belongs to bin_idx = np.searchsorted(new_edges, old_center) - 1 bin_idx = max(0, min(bin_idx, self._num_quantile_bins - 1)) new_hist[bin_idx] += count self._histograms[i] = new_hist self._bin_edges[i] = new_edges def _update_histograms(self, batch: np.ndarray) -> None: """Update histograms with new vectors.""" for i in range(batch.shape[1]): hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i]) self._histograms[i] += hist def _compute_quantiles(self) -> list[np.ndarray]: """Compute quantiles based on histograms.""" results = [] for q in self._quantile_list: target_count = q * self._count q_values = [] for hist, edges in zip(self._histograms, self._bin_edges, strict=True): q_value = self._compute_single_quantile(hist, edges, target_count) q_values.append(q_value) results.append(np.array(q_values)) return results def _compute_single_quantile(self, hist: np.ndarray, edges: np.ndarray, target_count: float) -> float: """Compute a single quantile value from histogram and bin edges.""" cumsum = np.cumsum(hist) idx = np.searchsorted(cumsum, target_count) if idx == 0: return edges[0] if idx >= len(cumsum): return edges[-1] # If not edge case, interpolate within the bin count_before = cumsum[idx - 1] count_in_bin = cumsum[idx] - count_before # If no samples in this bin, use the bin edge if count_in_bin == 0: return edges[idx] # Linear interpolation within the bin fraction = (target_count - count_before) / count_in_bin return edges[idx] + fraction * (edges[idx + 1] - edges[idx]) def estimate_num_samples( dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75 ) -> int: """Heuristic to estimate the number of samples based on dataset size. The power controls the sample growth relative to dataset size. Lower the power for less number of samples. For default arguments, we have: - from 1 to ~500, num_samples=100 - at 1000, num_samples=177 - at 2000, num_samples=299 - at 5000, num_samples=594 - at 10000, num_samples=1000 - at 20000, num_samples=1681 """ if dataset_len < min_num_samples: min_num_samples = dataset_len return max(min_num_samples, min(int(dataset_len**power), max_num_samples)) def sample_indices(data_len: int) -> list[int]: num_samples = estimate_num_samples(data_len) return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist() def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300): _, height, width = img.shape if max(width, height) < max_size_threshold: # no downsampling needed return img downsample_factor = int(width / target_size) if width > height else int(height / target_size) return img[:, ::downsample_factor, ::downsample_factor] def sample_images(image_paths: list[str]) -> np.ndarray: sampled_indices = sample_indices(len(image_paths)) images = None for i, idx in enumerate(sampled_indices): path = image_paths[idx] # we load as uint8 to reduce memory usage img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True) img = auto_downsample_height_width(img) if images is None: images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8) images[i] = img return images def _reshape_stats_by_axis( stats: dict[str, np.ndarray], axis: int | tuple[int, ...] | None, keepdims: bool, original_shape: tuple[int, ...], ) -> dict[str, np.ndarray]: """Reshape all statistics to match NumPy's output conventions. Applies consistent reshaping to all statistics (except 'count') based on the axis and keepdims parameters. This ensures statistics have the correct shape for broadcasting with the original data. Args: stats: Dictionary of computed statistics axis: Axis or axes along which statistics were computed keepdims: Whether to keep reduced dimensions as size-1 dimensions original_shape: Shape of the original array Returns: Dictionary with reshaped statistics Note: The 'count' statistic is never reshaped as it represents metadata rather than per-feature statistics. """ if axis == (1,) and not keepdims: return stats result = {} for key, value in stats.items(): if key == "count": result[key] = value else: result[key] = _reshape_single_stat(value, axis, keepdims, original_shape) return result def _reshape_for_image_stats(value: np.ndarray, keepdims: bool) -> np.ndarray: """Reshape statistics for image data (axis=(0,2,3)).""" if keepdims and value.ndim == 1: return value.reshape(1, -1, 1, 1) return value def _reshape_for_vector_stats( value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...] ) -> np.ndarray: """Reshape statistics for vector data (axis=0 or axis=(0,)).""" if not keepdims: return value if len(original_shape) == 1 and value.ndim > 0: return value.reshape(1) elif len(original_shape) >= 2 and value.ndim == 1: return value.reshape(1, -1) return value def _reshape_for_feature_stats(value: np.ndarray, keepdims: bool) -> np.ndarray: """Reshape statistics for feature-wise computation (axis=(1,)).""" if not keepdims: return value if value.ndim == 0: return value.reshape(1, 1) elif value.ndim == 1: return value.reshape(-1, 1) return value def _reshape_for_global_stats( value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...] ) -> np.ndarray | float: """Reshape statistics for global reduction (axis=None).""" if keepdims: target_shape = tuple(1 for _ in original_shape) return value.reshape(target_shape) # Keep at least 1-D arrays to satisfy validator return np.atleast_1d(value) def _reshape_single_stat( value: np.ndarray, axis: int | tuple[int, ...] | None, keepdims: bool, original_shape: tuple[int, ...] ) -> np.ndarray | float: """Apply appropriate reshaping to a single statistic array. This function transforms statistic arrays to match expected output shapes based on the axis configuration and keepdims parameter. Args: value: The statistic array to reshape axis: Axis or axes that were reduced during computation keepdims: Whether to maintain reduced dimensions as size-1 dimensions original_shape: Shape of the original data before reduction Returns: Reshaped array following NumPy broadcasting conventions """ if axis == (0, 2, 3): return _reshape_for_image_stats(value, keepdims) if axis in [0, (0,)]: return _reshape_for_vector_stats(value, keepdims, original_shape) if axis == (1,): return _reshape_for_feature_stats(value, keepdims) if axis is None: return _reshape_for_global_stats(value, keepdims, original_shape) return value def _prepare_array_for_stats(array: np.ndarray, axis: int | tuple[int, ...] | None) -> tuple[np.ndarray, int]: """Prepare array for statistics computation by reshaping according to axis. Args: array: Input data array axis: Axis or axes along which to compute statistics Returns: Tuple of (reshaped_array, sample_count) """ if axis == (0, 2, 3): # Image data batch_size, channels, height, width = array.shape reshaped = array.transpose(0, 2, 3, 1).reshape(-1, channels) return reshaped, batch_size if axis == 0 or axis == (0,): # Vector data reshaped = array if array.ndim == 1: reshaped = array.reshape(-1, 1) return reshaped, array.shape[0] if axis == (1,): # Feature-wise statistics return array.T, array.shape[1] if axis is None: # Global statistics reshaped = array.reshape(-1, 1) # For backward compatibility, count represents the first dimension size return reshaped, array.shape[0] if array.ndim > 0 else 1 raise ValueError(f"Unsupported axis configuration: {axis}") def _compute_basic_stats( array: np.ndarray, sample_count: int, quantile_list: list[float] | None = None ) -> dict[str, np.ndarray]: """Compute basic statistics for arrays with insufficient samples for quantiles. Args: array: Reshaped array ready for statistics computation sample_count: Number of samples represented in the data Returns: Dictionary with basic statistics and quantiles set to mean values """ if quantile_list is None: quantile_list = DEFAULT_QUANTILES quantile_list_keys = [f"q{int(q * 100):02d}" for q in quantile_list] stats = { "min": np.min(array, axis=0), "max": np.max(array, axis=0), "mean": np.mean(array, axis=0), "std": np.std(array, axis=0), "count": np.array([sample_count]), } for q in quantile_list_keys: stats[q] = stats["mean"].copy() return stats def get_feature_stats( array: np.ndarray, axis: int | tuple[int, ...] | None, keepdims: bool, quantile_list: list[float] | None = None, ) -> dict[str, np.ndarray]: """Compute comprehensive statistics for array features along specified axes. This function calculates min, max, mean, std, and quantiles (1%, 10%, 50%, 90%, 99%) for the input array along the specified axes. It handles different data layouts: - Image data: axis=(0,2,3) computes per-channel statistics - Vector data: axis=0 computes per-feature statistics - Feature-wise: axis=1 computes statistics across features - Global: axis=None computes statistics over entire array Args: array: Input data array with shape appropriate for the specified axis axis: Axis or axes along which to compute statistics - (0, 2, 3): For image data (batch, channels, height, width) - 0 or (0,): For vector/tabular data (samples, features) - (1,): For computing across features - None: For global statistics over entire array keepdims: If True, reduced axes are kept as dimensions with size 1 Returns: Dictionary containing: - 'min': Minimum values - 'max': Maximum values - 'mean': Mean values - 'std': Standard deviation - 'count': Number of samples (always shape (1,)) - 'q01', 'q10', 'q50', 'q90', 'q99': Quantile values """ if quantile_list is None: quantile_list = DEFAULT_QUANTILES original_shape = array.shape reshaped, sample_count = _prepare_array_for_stats(array, axis) if reshaped.shape[0] < 2: stats = _compute_basic_stats(reshaped, sample_count, quantile_list) else: running_stats = RunningQuantileStats() running_stats.update(reshaped) stats = running_stats.get_statistics() stats["count"] = np.array([sample_count]) stats = _reshape_stats_by_axis(stats, axis, keepdims, original_shape) return stats def compute_episode_stats( episode_data: dict[str, list[str] | np.ndarray], features: dict, quantile_list: list[float] | None = None, ) -> dict: """Compute comprehensive statistics for all features in an episode. Processes different data types appropriately: - Images/videos: Samples from paths, computes per-channel stats, normalizes to [0,1] - Numerical arrays: Computes per-feature statistics - Strings: Skipped (no statistics computed) Args: episode_data: Dictionary mapping feature names to data - For images/videos: list of file paths - For numerical data: numpy arrays features: Dictionary describing each feature's dtype and shape Returns: Dictionary mapping feature names to their statistics dictionaries. Each statistics dictionary contains min, max, mean, std, count, and quantiles. Note: Image statistics are normalized to [0,1] range and have shape (3,1,1) for per-channel values when dtype is 'image' or 'video'. """ if quantile_list is None: quantile_list = DEFAULT_QUANTILES ep_stats = {} for key, data in episode_data.items(): if features[key]["dtype"] == "string": continue if features[key]["dtype"] in ["image", "video"]: ep_ft_array = sample_images(data) axes_to_reduce = (0, 2, 3) keepdims = True else: ep_ft_array = data axes_to_reduce = 0 keepdims = data.ndim == 1 ep_stats[key] = get_feature_stats( ep_ft_array, axis=axes_to_reduce, keepdims=keepdims, quantile_list=quantile_list ) if features[key]["dtype"] in ["image", "video"]: ep_stats[key] = { k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items() } return ep_stats def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None: """Validate a single statistic value.""" if not isinstance(value, np.ndarray): raise ValueError( f"Stats must be composed of numpy array, but key '{key}' of feature '{feature_key}' " f"is of type '{type(value)}' instead." ) if value.ndim == 0: raise ValueError("Number of dimensions must be at least 1, and is 0 instead.") if key == "count" and value.shape != (1,): raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.") if "image" in feature_key and key != "count" and value.shape != (3, 1, 1): raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.") def _assert_type_and_shape(stats_list: list[dict[str, dict]]): """Validate that all statistics have correct types and shapes. Args: stats_list: List of statistics dictionaries to validate Raises: ValueError: If any statistic has incorrect type or shape """ for stats in stats_list: for feature_key, feature_stats in stats.items(): for stat_key, stat_value in feature_stats.items(): _validate_stat_value(stat_value, stat_key, feature_key) def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]: """Aggregates stats for a single feature.""" means = np.stack([s["mean"] for s in stats_ft_list]) variances = np.stack([s["std"] ** 2 for s in stats_ft_list]) counts = np.stack([s["count"] for s in stats_ft_list]) total_count = counts.sum(axis=0) # Prepare weighted mean by matching number of dimensions while counts.ndim < means.ndim: counts = np.expand_dims(counts, axis=-1) # Compute the weighted mean weighted_means = means * counts total_mean = weighted_means.sum(axis=0) / total_count # Compute the variance using the parallel algorithm delta_means = means - total_mean weighted_variances = (variances + delta_means**2) * counts total_variance = weighted_variances.sum(axis=0) / total_count aggregated = { "min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0), "max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0), "mean": total_mean, "std": np.sqrt(total_variance), "count": total_count, } if stats_ft_list: quantile_keys = [k for k in stats_ft_list[0] if k.startswith("q") and k[1:].isdigit()] for q_key in quantile_keys: if all(q_key in s for s in stats_ft_list): quantile_values = np.stack([s[q_key] for s in stats_ft_list]) weighted_quantiles = quantile_values * counts aggregated[q_key] = weighted_quantiles.sum(axis=0) / total_count return aggregated def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]: """Aggregate stats from multiple compute_stats outputs into a single set of stats. The final stats will have the union of all data keys from each of the stats dicts. For instance: - new_min = min(min_dataset_0, min_dataset_1, ...) - new_max = max(max_dataset_0, max_dataset_1, ...) - new_mean = (mean of all data, weighted by counts) - new_std = (std of all data) """ _assert_type_and_shape(stats_list) data_keys = {key for stats in stats_list for key in stats} aggregated_stats = {key: {} for key in data_keys} for key in data_keys: stats_with_key = [stats[key] for stats in stats_list if key in stats] aggregated_stats[key] = aggregate_feature_stats(stats_with_key) return aggregated_stats def _get_valid_chunk_starts(episode_indices: np.ndarray, chunk_size: int) -> np.ndarray: """Return all start indices where a chunk of ``chunk_size`` stays within one episode.""" total = len(episode_indices) if total < chunk_size: return np.array([], dtype=np.int64) max_start = total - chunk_size starts = np.arange(max_start + 1) valid = episode_indices[starts] == episode_indices[starts + chunk_size - 1] return starts[valid] def _compute_relative_chunk_batch( start_indices: np.ndarray, all_actions: np.ndarray, all_states: np.ndarray, chunk_size: int, relative_mask: np.ndarray, ) -> np.ndarray: """Vectorised relative-action computation for a batch of start indices. Returns an ``(N * chunk_size, action_dim)`` float32 array. """ if len(start_indices) == 0: return np.empty((0, all_actions.shape[1]), dtype=np.float32) offsets = np.arange(chunk_size) frame_idx = start_indices[:, None] + offsets[None, :] chunks = all_actions[frame_idx].copy() states = all_states[start_indices] mask_dim = len(relative_mask) chunks[:, :, :mask_dim] -= states[:, None, :mask_dim] * relative_mask[None, None, :] return chunks.reshape(-1, all_actions.shape[1]) def compute_relative_action_stats( hf_dataset, features: dict, chunk_size: int, exclude_joints: list[str] | None = None, num_workers: int = 0, ) -> dict[str, np.ndarray]: """Compute normalization statistics for relative actions over the full dataset. Iterates *all* valid action chunks (within single episodes), converts them to relative actions (action − current_state), and computes per-dimension statistics suitable for normalization. Args: hf_dataset: The underlying HuggingFace dataset with "action", "observation.state", and "episode_index" columns. features: Dataset feature metadata (must contain "action" with "shape" and optionally "names"). chunk_size: Number of consecutive frames per action chunk. exclude_joints: Joint names whose dimensions should remain absolute (not converted to relative actions). num_workers: Number of parallel threads for computation. Values ≤1 mean single-threaded. Numpy releases the GIL so threads give real parallelism here. Returns: Statistics dict with keys "mean", "std", "min", "max", "q01", …, "q99". Raises: ValueError: If the dataset has fewer frames than ``chunk_size``. RuntimeError: If no valid (single-episode) chunks are found. """ from lerobot.processor.relative_action_processor import RelativeActionsProcessorStep if exclude_joints is None: exclude_joints = [] action_dim = features[ACTION]["shape"][0] action_names = features.get(ACTION, {}).get("names") mask_step = RelativeActionsProcessorStep( enabled=True, exclude_joints=exclude_joints, action_names=action_names, ) relative_mask = np.array(mask_step._build_mask(action_dim), dtype=np.float32) logging.info("Loading action/state data for relative action stats...") all_actions = np.array(hf_dataset[ACTION], dtype=np.float32) all_states = np.array(hf_dataset[OBS_STATE], dtype=np.float32) episode_indices = np.array(hf_dataset["episode_index"]) valid_starts = _get_valid_chunk_starts(episode_indices, chunk_size) if len(valid_starts) == 0: raise RuntimeError( f"No valid chunks found (total_frames={len(episode_indices)}, chunk_size={chunk_size})" ) effective_workers = max(num_workers, 1) logging.info( f"Computing relative action stats from {len(valid_starts)} chunks " f"(chunk_size={chunk_size}, workers={effective_workers})" ) batch_size = 50_000 batches = [valid_starts[i : i + batch_size] for i in range(0, len(valid_starts), batch_size)] running_stats = RunningQuantileStats() if num_workers > 1: from concurrent.futures import ThreadPoolExecutor, as_completed with ThreadPoolExecutor(max_workers=num_workers) as pool: futures = [ pool.submit( _compute_relative_chunk_batch, batch, all_actions, all_states, chunk_size, relative_mask, ) for batch in batches ] for future in as_completed(futures): running_stats.update(future.result()) else: for batch in batches: running_stats.update( _compute_relative_chunk_batch(batch, all_actions, all_states, chunk_size, relative_mask) ) stats = running_stats.get_statistics() excluded_dims = int(len(relative_mask) - relative_mask.sum()) total_frames = len(valid_starts) * chunk_size logging.info( f"Relative action stats ({len(valid_starts)} chunks, {total_frames} frames): " f"relative_dims={int(relative_mask.sum())}/{len(relative_mask)} (excluded={excluded_dims}), " f"mean={np.abs(stats['mean']).mean():.4f}, std={stats['std'].mean():.4f}, " f"q01={stats['q01'].mean():.4f}, q99={stats['q99'].mean():.4f}" ) return stats