mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
* Add option for pi family models to train with relative actions (relative to state) * formatting * add recomputation of stats and option to compute delta stats * normalzie after delta conversion * only recompute state for stats * calulate chunk based stats * sample 100k * load from parquet * sample 1m * stats per chunck * fix * use quantiles * stats for entire dataset * fix * max 1m frames * compute before dist * fix multi gpu processor bug * Fix RTC with delta actions and OpenArms motor_type wiring * feat: align pi0_fast delta actions with pi0/pi05 and add RTC integration tests - Add delta_exclude_joints and action_feature_names to PI0FastConfig - Move to_absolute_actions from modeling to processor pipeline for pi0_fast - Add delta action detection and logging to eval_with_real_robot.py - Add delta actions documentation to pi0 and pi05 READMEs - Fix ruff lint issues in test_delta_actions.py - Add test_rtc_delta_actions.py (24 tests) covering: - ActionQueue with delta vs absolute actions - RTC denoise step with delta leftovers - Full pipeline roundtrip (delta → RTC → absolute) - State rebasing approximation bounds - Non-delta policy compatibility - Multi-chunk consistency * chore: clean up test comments, add OpenPI attribution, remove debug logging - Replace decorative comment separators in test files with plain section headers - Add attribution comments for 1e-6 epsilon in normalize_processor.py (from OpenPI) - Remove debug logging blocks from lerobot_train.py * refactor: extract compute_delta_action_stats into compute_stats.py Move the ~70-line inline delta action stats block from lerobot_train.py into a dedicated function in compute_stats.py, where all other stats computation already lives. The training script now calls it in 6 lines. * refactor: remove unused get_processed_left_over from ActionQueue This method was never called outside of tests. Leftover actions for RTC guidance are always retrieved via get_left_over() (delta/original space). * revert: remove logging-only changes from eval_with_real_robot.py The delta actions detection helper and log message added no functional value — the script already handles delta policies correctly via the processor pipeline. * refactor: use ACTION/OBS_STATE constants instead of hardcoded strings Replace hardcoded "action" and "observation.state" with ACTION and OBS_STATE from utils.constants in compute_stats.py, dataset_tools.py, and lerobot_train.py. * style: remove stray blank lines in training loop * refactor: move delta action stats to preprocessing step, remove on-the-fly computation - Remove on-the-fly compute_delta_action_stats from lerobot_train.py - Rewrite recompute_stats to delegate action stats to compute_delta_action_stats (chunk-based sampling matching what the model sees during training) - Add chunk_size parameter to recompute_stats for delta action computation - Add delta actions documentation to pi0.mdx and pi05.mdx * feat: add recompute_stats CLI operation to lerobot-edit-dataset * fix(tests): relax quantile normalization test tolerance for 1e-6 epsilon * chore: remove agents_memory/pr_details.md from repo * refactor: rename delta actions to relative actions throughout What OpenPI calls "DeltaActions" is actually UMI's "relative trajectory" representation: each action in the chunk is an offset from the current state, not from the previous action. This avoids error accumulation. Renamed across all source, tests, docs, and CLI: - DeltaActionsProcessorStep → RelativeActionsProcessorStep - to_delta_actions → to_relative_actions - use_delta_actions → use_relative_actions - delta_exclude_joints → relative_exclude_joints - compute_delta_action_stats → compute_relative_action_stats - delta_action_processor.py → relative_action_processor.py - test_delta_actions.py → test_relative_actions.py Kept as-is: AbsoluteActionsProcessorStep (converts TO absolute), registry ID "delta_actions_processor" (backward compat), and unrelated delta references (IK pipeline, Robosuite, RA-BC metrics, gym envs). * docs: add Action Representations guide Dedicated page explaining absolute, relative, and delta actions with numerical examples, joint vs EE space, and how to use kinematics pipelines and the relative action processor. References UMI paper (Chi et al., 2024) for the terminology. * docs: remove redundant OpenPI naming note from action representations * docs: remove opinionated OpenPI reference from delta actions section * docs: replace ASCII diagram with UMI paper figure * docs: remove OpenPI reference from action representations * docs: use HF-hosted image instead of local asset * docs: clarify figure attribution * revert: restore original normalization epsilon behavior The 1e-6 unconditional epsilon change perturbed all normalized values, breaking backward compatibility tests. The original approach (1e-8 eps for MEAN_STD, conditional torch.where for QUANTILES) already handles division by zero correctly without affecting non-degenerate cases. * fix: restore delta_action_processor.py used by phone/RL teleop The rename commit incorrectly deleted delta_action_processor.py and duplicated its classes into relative_action_processor.py. Restore the original file and import from it instead. * fix(processor): address PR #2970 review comments - Remove shebang from relative_action_processor.py (library module, not script) - Add device alignment in to_relative_actions/to_absolute_actions so _last_state on CPU doesn't cause cross-device errors when actions are on CUDA - Rename delta_step → relative_step in AbsoluteActionsProcessorStep for naming consistency; update factory.py, all processor files, and tests - Expand _reconnect_relative_absolute_steps docstring to explain why post-hoc rewiring is needed after deserialization - Fix off-by-one in compute_stats.py: sample_upper_bound = total_frames - chunk_size + 1 so last valid start index is included and total_frames == chunk_size is not rejected - Remove redundant NOTE comment in processor_pi05.py (duplicated two lines below) - Fix pi0_fast processor ordering: move relative_step before NormalizerProcessorStep so normalizer sees delta actions (matching pi0/pi05); flip postprocessor to unnormalize → absolute accordingly. Relative stats are now required for all pi models - Revert use_relative_joint_actions_aloha → use_delta_joint_actions_aloha in configuration_smolvla.py (preserve existing public API) - Update action_representations.mdx: add missing joint to 6-DOF example, fix 'based on a figure', clarify pi family ordering, add RTC compatibility section * update rtc link * feat: compute relative action stats over full dataset with optional parallelism Remove the 100k sample cap from compute_relative_action_stats and process all valid chunks. Vectorize with numpy (pre-load actions/states, fancy indexing + broadcasting) for a large speedup over the per-index HF dataset loop. Add num_workers param for thread-based parallelism (numpy releases the GIL). Update docs to show --push_to_hub for recompute_stats. * style: apply ruff formatting to compute_stats.py * testing on real robot * style: fix ruff format and remove redundant .keys() calls
770 lines
28 KiB
Python
770 lines
28 KiB
Python
#!/usr/bin/env python
|
||
|
||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
|
||
import numpy as np
|
||
|
||
from lerobot.datasets.io_utils import load_image_as_numpy
|
||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||
|
||
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
|
||
|
||
|
||
class RunningQuantileStats:
|
||
"""
|
||
Maintains running statistics for batches of vectors, including mean,
|
||
standard deviation, min, max, and approximate quantiles.
|
||
|
||
Statistics are computed per feature dimension and updated incrementally
|
||
as new batches are observed. Quantiles are estimated using histograms,
|
||
which adapt dynamically if the observed data range expands.
|
||
"""
|
||
|
||
def __init__(self, quantile_list: list[float] | None = None, num_quantile_bins: int = 5000):
|
||
self._count = 0
|
||
self._mean = None
|
||
self._mean_of_squares = None
|
||
self._min = None
|
||
self._max = None
|
||
self._histograms = None
|
||
self._bin_edges = None
|
||
self._num_quantile_bins = num_quantile_bins
|
||
|
||
self._quantile_list = quantile_list
|
||
if self._quantile_list is None:
|
||
self._quantile_list = DEFAULT_QUANTILES
|
||
self._quantile_keys = [f"q{int(q * 100):02d}" for q in self._quantile_list]
|
||
|
||
def update(self, batch: np.ndarray) -> None:
|
||
"""Update the running statistics with a batch of vectors.
|
||
|
||
Args:
|
||
batch: An array where all dimensions except the last are batch dimensions.
|
||
"""
|
||
batch = batch.reshape(-1, batch.shape[-1])
|
||
num_elements, vector_length = batch.shape
|
||
|
||
if self._count == 0:
|
||
self._mean = np.mean(batch, axis=0)
|
||
self._mean_of_squares = np.mean(batch**2, axis=0)
|
||
self._min = np.min(batch, axis=0)
|
||
self._max = np.max(batch, axis=0)
|
||
self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)]
|
||
self._bin_edges = [
|
||
np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1)
|
||
for i in range(vector_length)
|
||
]
|
||
else:
|
||
if vector_length != self._mean.size:
|
||
raise ValueError("The length of new vectors does not match the initialized vector length.")
|
||
|
||
new_max = np.max(batch, axis=0)
|
||
new_min = np.min(batch, axis=0)
|
||
max_changed = np.any(new_max > self._max)
|
||
min_changed = np.any(new_min < self._min)
|
||
self._max = np.maximum(self._max, new_max)
|
||
self._min = np.minimum(self._min, new_min)
|
||
|
||
if max_changed or min_changed:
|
||
self._adjust_histograms()
|
||
|
||
self._count += num_elements
|
||
|
||
batch_mean = np.mean(batch, axis=0)
|
||
batch_mean_of_squares = np.mean(batch**2, axis=0)
|
||
|
||
# Update running mean and mean of squares
|
||
self._mean += (batch_mean - self._mean) * (num_elements / self._count)
|
||
self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * (
|
||
num_elements / self._count
|
||
)
|
||
|
||
self._update_histograms(batch)
|
||
|
||
def get_statistics(self) -> dict[str, np.ndarray]:
|
||
"""Compute and return the statistics of the vectors processed so far.
|
||
|
||
Args:
|
||
quantiles: List of quantiles to compute (e.g., [0.01, 0.10, 0.50, 0.90, 0.99]). If None, no quantiles computed.
|
||
|
||
Returns:
|
||
Dictionary containing the computed statistics.
|
||
"""
|
||
if self._count < 2:
|
||
raise ValueError("Cannot compute statistics for less than 2 vectors.")
|
||
|
||
variance = self._mean_of_squares - self._mean**2
|
||
|
||
stddev = np.sqrt(np.maximum(0, variance))
|
||
|
||
stats = {
|
||
"min": self._min.copy(),
|
||
"max": self._max.copy(),
|
||
"mean": self._mean.copy(),
|
||
"std": stddev,
|
||
"count": np.array([self._count]),
|
||
}
|
||
|
||
quantile_results = self._compute_quantiles()
|
||
for i, q in enumerate(self._quantile_keys):
|
||
stats[q] = quantile_results[i]
|
||
|
||
return stats
|
||
|
||
def _adjust_histograms(self):
|
||
"""Adjust histograms when min or max changes."""
|
||
for i in range(len(self._histograms)):
|
||
old_edges = self._bin_edges[i]
|
||
old_hist = self._histograms[i]
|
||
|
||
# Create new edges with small padding to ensure range coverage
|
||
padding = (self._max[i] - self._min[i]) * 1e-10
|
||
new_edges = np.linspace(
|
||
self._min[i] - padding, self._max[i] + padding, self._num_quantile_bins + 1
|
||
)
|
||
|
||
# Redistribute existing histogram counts to new bins
|
||
# We need to map each old bin center to the new bins
|
||
old_centers = (old_edges[:-1] + old_edges[1:]) / 2
|
||
new_hist = np.zeros(self._num_quantile_bins)
|
||
|
||
for old_center, count in zip(old_centers, old_hist, strict=False):
|
||
if count > 0:
|
||
# Find which new bin this old center belongs to
|
||
bin_idx = np.searchsorted(new_edges, old_center) - 1
|
||
bin_idx = max(0, min(bin_idx, self._num_quantile_bins - 1))
|
||
new_hist[bin_idx] += count
|
||
|
||
self._histograms[i] = new_hist
|
||
self._bin_edges[i] = new_edges
|
||
|
||
def _update_histograms(self, batch: np.ndarray) -> None:
|
||
"""Update histograms with new vectors."""
|
||
for i in range(batch.shape[1]):
|
||
hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i])
|
||
self._histograms[i] += hist
|
||
|
||
def _compute_quantiles(self) -> list[np.ndarray]:
|
||
"""Compute quantiles based on histograms."""
|
||
results = []
|
||
for q in self._quantile_list:
|
||
target_count = q * self._count
|
||
q_values = []
|
||
|
||
for hist, edges in zip(self._histograms, self._bin_edges, strict=True):
|
||
q_value = self._compute_single_quantile(hist, edges, target_count)
|
||
q_values.append(q_value)
|
||
|
||
results.append(np.array(q_values))
|
||
return results
|
||
|
||
def _compute_single_quantile(self, hist: np.ndarray, edges: np.ndarray, target_count: float) -> float:
|
||
"""Compute a single quantile value from histogram and bin edges."""
|
||
cumsum = np.cumsum(hist)
|
||
idx = np.searchsorted(cumsum, target_count)
|
||
|
||
if idx == 0:
|
||
return edges[0]
|
||
if idx >= len(cumsum):
|
||
return edges[-1]
|
||
|
||
# If not edge case, interpolate within the bin
|
||
count_before = cumsum[idx - 1]
|
||
count_in_bin = cumsum[idx] - count_before
|
||
|
||
# If no samples in this bin, use the bin edge
|
||
if count_in_bin == 0:
|
||
return edges[idx]
|
||
|
||
# Linear interpolation within the bin
|
||
fraction = (target_count - count_before) / count_in_bin
|
||
return edges[idx] + fraction * (edges[idx + 1] - edges[idx])
|
||
|
||
|
||
def estimate_num_samples(
|
||
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||
) -> int:
|
||
"""Heuristic to estimate the number of samples based on dataset size.
|
||
The power controls the sample growth relative to dataset size.
|
||
Lower the power for less number of samples.
|
||
|
||
For default arguments, we have:
|
||
- from 1 to ~500, num_samples=100
|
||
- at 1000, num_samples=177
|
||
- at 2000, num_samples=299
|
||
- at 5000, num_samples=594
|
||
- at 10000, num_samples=1000
|
||
- at 20000, num_samples=1681
|
||
"""
|
||
if dataset_len < min_num_samples:
|
||
min_num_samples = dataset_len
|
||
return max(min_num_samples, min(int(dataset_len**power), max_num_samples))
|
||
|
||
|
||
def sample_indices(data_len: int) -> list[int]:
|
||
num_samples = estimate_num_samples(data_len)
|
||
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
|
||
|
||
|
||
def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
|
||
_, height, width = img.shape
|
||
|
||
if max(width, height) < max_size_threshold:
|
||
# no downsampling needed
|
||
return img
|
||
|
||
downsample_factor = int(width / target_size) if width > height else int(height / target_size)
|
||
return img[:, ::downsample_factor, ::downsample_factor]
|
||
|
||
|
||
def sample_images(image_paths: list[str]) -> np.ndarray:
|
||
sampled_indices = sample_indices(len(image_paths))
|
||
|
||
images = None
|
||
for i, idx in enumerate(sampled_indices):
|
||
path = image_paths[idx]
|
||
# we load as uint8 to reduce memory usage
|
||
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
|
||
img = auto_downsample_height_width(img)
|
||
|
||
if images is None:
|
||
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
||
|
||
images[i] = img
|
||
|
||
return images
|
||
|
||
|
||
def _reshape_stats_by_axis(
|
||
stats: dict[str, np.ndarray],
|
||
axis: int | tuple[int, ...] | None,
|
||
keepdims: bool,
|
||
original_shape: tuple[int, ...],
|
||
) -> dict[str, np.ndarray]:
|
||
"""Reshape all statistics to match NumPy's output conventions.
|
||
|
||
Applies consistent reshaping to all statistics (except 'count') based on the
|
||
axis and keepdims parameters. This ensures statistics have the correct shape
|
||
for broadcasting with the original data.
|
||
|
||
Args:
|
||
stats: Dictionary of computed statistics
|
||
axis: Axis or axes along which statistics were computed
|
||
keepdims: Whether to keep reduced dimensions as size-1 dimensions
|
||
original_shape: Shape of the original array
|
||
|
||
Returns:
|
||
Dictionary with reshaped statistics
|
||
|
||
Note:
|
||
The 'count' statistic is never reshaped as it represents metadata
|
||
rather than per-feature statistics.
|
||
"""
|
||
if axis == (1,) and not keepdims:
|
||
return stats
|
||
|
||
result = {}
|
||
for key, value in stats.items():
|
||
if key == "count":
|
||
result[key] = value
|
||
else:
|
||
result[key] = _reshape_single_stat(value, axis, keepdims, original_shape)
|
||
|
||
return result
|
||
|
||
|
||
def _reshape_for_image_stats(value: np.ndarray, keepdims: bool) -> np.ndarray:
|
||
"""Reshape statistics for image data (axis=(0,2,3))."""
|
||
if keepdims and value.ndim == 1:
|
||
return value.reshape(1, -1, 1, 1)
|
||
return value
|
||
|
||
|
||
def _reshape_for_vector_stats(
|
||
value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...]
|
||
) -> np.ndarray:
|
||
"""Reshape statistics for vector data (axis=0 or axis=(0,))."""
|
||
if not keepdims:
|
||
return value
|
||
|
||
if len(original_shape) == 1 and value.ndim > 0:
|
||
return value.reshape(1)
|
||
elif len(original_shape) >= 2 and value.ndim == 1:
|
||
return value.reshape(1, -1)
|
||
return value
|
||
|
||
|
||
def _reshape_for_feature_stats(value: np.ndarray, keepdims: bool) -> np.ndarray:
|
||
"""Reshape statistics for feature-wise computation (axis=(1,))."""
|
||
if not keepdims:
|
||
return value
|
||
|
||
if value.ndim == 0:
|
||
return value.reshape(1, 1)
|
||
elif value.ndim == 1:
|
||
return value.reshape(-1, 1)
|
||
return value
|
||
|
||
|
||
def _reshape_for_global_stats(
|
||
value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...]
|
||
) -> np.ndarray | float:
|
||
"""Reshape statistics for global reduction (axis=None)."""
|
||
if keepdims:
|
||
target_shape = tuple(1 for _ in original_shape)
|
||
return value.reshape(target_shape)
|
||
# Keep at least 1-D arrays to satisfy validator
|
||
return np.atleast_1d(value)
|
||
|
||
|
||
def _reshape_single_stat(
|
||
value: np.ndarray, axis: int | tuple[int, ...] | None, keepdims: bool, original_shape: tuple[int, ...]
|
||
) -> np.ndarray | float:
|
||
"""Apply appropriate reshaping to a single statistic array.
|
||
|
||
This function transforms statistic arrays to match expected output shapes
|
||
based on the axis configuration and keepdims parameter.
|
||
|
||
Args:
|
||
value: The statistic array to reshape
|
||
axis: Axis or axes that were reduced during computation
|
||
keepdims: Whether to maintain reduced dimensions as size-1 dimensions
|
||
original_shape: Shape of the original data before reduction
|
||
|
||
Returns:
|
||
Reshaped array following NumPy broadcasting conventions
|
||
|
||
"""
|
||
if axis == (0, 2, 3):
|
||
return _reshape_for_image_stats(value, keepdims)
|
||
|
||
if axis in [0, (0,)]:
|
||
return _reshape_for_vector_stats(value, keepdims, original_shape)
|
||
|
||
if axis == (1,):
|
||
return _reshape_for_feature_stats(value, keepdims)
|
||
|
||
if axis is None:
|
||
return _reshape_for_global_stats(value, keepdims, original_shape)
|
||
|
||
return value
|
||
|
||
|
||
def _prepare_array_for_stats(array: np.ndarray, axis: int | tuple[int, ...] | None) -> tuple[np.ndarray, int]:
|
||
"""Prepare array for statistics computation by reshaping according to axis.
|
||
|
||
Args:
|
||
array: Input data array
|
||
axis: Axis or axes along which to compute statistics
|
||
|
||
Returns:
|
||
Tuple of (reshaped_array, sample_count)
|
||
"""
|
||
if axis == (0, 2, 3): # Image data
|
||
batch_size, channels, height, width = array.shape
|
||
reshaped = array.transpose(0, 2, 3, 1).reshape(-1, channels)
|
||
return reshaped, batch_size
|
||
|
||
if axis == 0 or axis == (0,): # Vector data
|
||
reshaped = array
|
||
if array.ndim == 1:
|
||
reshaped = array.reshape(-1, 1)
|
||
return reshaped, array.shape[0]
|
||
|
||
if axis == (1,): # Feature-wise statistics
|
||
return array.T, array.shape[1]
|
||
|
||
if axis is None: # Global statistics
|
||
reshaped = array.reshape(-1, 1)
|
||
# For backward compatibility, count represents the first dimension size
|
||
return reshaped, array.shape[0] if array.ndim > 0 else 1
|
||
|
||
raise ValueError(f"Unsupported axis configuration: {axis}")
|
||
|
||
|
||
def _compute_basic_stats(
|
||
array: np.ndarray, sample_count: int, quantile_list: list[float] | None = None
|
||
) -> dict[str, np.ndarray]:
|
||
"""Compute basic statistics for arrays with insufficient samples for quantiles.
|
||
|
||
Args:
|
||
array: Reshaped array ready for statistics computation
|
||
sample_count: Number of samples represented in the data
|
||
|
||
Returns:
|
||
Dictionary with basic statistics and quantiles set to mean values
|
||
"""
|
||
if quantile_list is None:
|
||
quantile_list = DEFAULT_QUANTILES
|
||
quantile_list_keys = [f"q{int(q * 100):02d}" for q in quantile_list]
|
||
|
||
stats = {
|
||
"min": np.min(array, axis=0),
|
||
"max": np.max(array, axis=0),
|
||
"mean": np.mean(array, axis=0),
|
||
"std": np.std(array, axis=0),
|
||
"count": np.array([sample_count]),
|
||
}
|
||
|
||
for q in quantile_list_keys:
|
||
stats[q] = stats["mean"].copy()
|
||
|
||
return stats
|
||
|
||
|
||
def get_feature_stats(
|
||
array: np.ndarray,
|
||
axis: int | tuple[int, ...] | None,
|
||
keepdims: bool,
|
||
quantile_list: list[float] | None = None,
|
||
) -> dict[str, np.ndarray]:
|
||
"""Compute comprehensive statistics for array features along specified axes.
|
||
|
||
This function calculates min, max, mean, std, and quantiles (1%, 10%, 50%, 90%, 99%)
|
||
for the input array along the specified axes. It handles different data layouts:
|
||
- Image data: axis=(0,2,3) computes per-channel statistics
|
||
- Vector data: axis=0 computes per-feature statistics
|
||
- Feature-wise: axis=1 computes statistics across features
|
||
- Global: axis=None computes statistics over entire array
|
||
|
||
Args:
|
||
array: Input data array with shape appropriate for the specified axis
|
||
axis: Axis or axes along which to compute statistics
|
||
- (0, 2, 3): For image data (batch, channels, height, width)
|
||
- 0 or (0,): For vector/tabular data (samples, features)
|
||
- (1,): For computing across features
|
||
- None: For global statistics over entire array
|
||
keepdims: If True, reduced axes are kept as dimensions with size 1
|
||
|
||
Returns:
|
||
Dictionary containing:
|
||
- 'min': Minimum values
|
||
- 'max': Maximum values
|
||
- 'mean': Mean values
|
||
- 'std': Standard deviation
|
||
- 'count': Number of samples (always shape (1,))
|
||
- 'q01', 'q10', 'q50', 'q90', 'q99': Quantile values
|
||
|
||
"""
|
||
if quantile_list is None:
|
||
quantile_list = DEFAULT_QUANTILES
|
||
|
||
original_shape = array.shape
|
||
reshaped, sample_count = _prepare_array_for_stats(array, axis)
|
||
|
||
if reshaped.shape[0] < 2:
|
||
stats = _compute_basic_stats(reshaped, sample_count, quantile_list)
|
||
else:
|
||
running_stats = RunningQuantileStats()
|
||
running_stats.update(reshaped)
|
||
stats = running_stats.get_statistics()
|
||
stats["count"] = np.array([sample_count])
|
||
|
||
stats = _reshape_stats_by_axis(stats, axis, keepdims, original_shape)
|
||
return stats
|
||
|
||
|
||
def compute_episode_stats(
|
||
episode_data: dict[str, list[str] | np.ndarray],
|
||
features: dict,
|
||
quantile_list: list[float] | None = None,
|
||
) -> dict:
|
||
"""Compute comprehensive statistics for all features in an episode.
|
||
|
||
Processes different data types appropriately:
|
||
- Images/videos: Samples from paths, computes per-channel stats, normalizes to [0,1]
|
||
- Numerical arrays: Computes per-feature statistics
|
||
- Strings: Skipped (no statistics computed)
|
||
|
||
Args:
|
||
episode_data: Dictionary mapping feature names to data
|
||
- For images/videos: list of file paths
|
||
- For numerical data: numpy arrays
|
||
features: Dictionary describing each feature's dtype and shape
|
||
|
||
Returns:
|
||
Dictionary mapping feature names to their statistics dictionaries.
|
||
Each statistics dictionary contains min, max, mean, std, count, and quantiles.
|
||
|
||
Note:
|
||
Image statistics are normalized to [0,1] range and have shape (3,1,1) for
|
||
per-channel values when dtype is 'image' or 'video'.
|
||
"""
|
||
if quantile_list is None:
|
||
quantile_list = DEFAULT_QUANTILES
|
||
|
||
ep_stats = {}
|
||
for key, data in episode_data.items():
|
||
if features[key]["dtype"] == "string":
|
||
continue
|
||
|
||
if features[key]["dtype"] in ["image", "video"]:
|
||
ep_ft_array = sample_images(data)
|
||
axes_to_reduce = (0, 2, 3)
|
||
keepdims = True
|
||
else:
|
||
ep_ft_array = data
|
||
axes_to_reduce = 0
|
||
keepdims = data.ndim == 1
|
||
|
||
ep_stats[key] = get_feature_stats(
|
||
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims, quantile_list=quantile_list
|
||
)
|
||
|
||
if features[key]["dtype"] in ["image", "video"]:
|
||
ep_stats[key] = {
|
||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||
}
|
||
|
||
return ep_stats
|
||
|
||
|
||
def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None:
|
||
"""Validate a single statistic value."""
|
||
if not isinstance(value, np.ndarray):
|
||
raise ValueError(
|
||
f"Stats must be composed of numpy array, but key '{key}' of feature '{feature_key}' "
|
||
f"is of type '{type(value)}' instead."
|
||
)
|
||
|
||
if value.ndim == 0:
|
||
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||
|
||
if key == "count" and value.shape != (1,):
|
||
raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.")
|
||
|
||
if "image" in feature_key and key != "count" and value.shape != (3, 1, 1):
|
||
raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.")
|
||
|
||
|
||
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||
"""Validate that all statistics have correct types and shapes.
|
||
|
||
Args:
|
||
stats_list: List of statistics dictionaries to validate
|
||
|
||
Raises:
|
||
ValueError: If any statistic has incorrect type or shape
|
||
"""
|
||
for stats in stats_list:
|
||
for feature_key, feature_stats in stats.items():
|
||
for stat_key, stat_value in feature_stats.items():
|
||
_validate_stat_value(stat_value, stat_key, feature_key)
|
||
|
||
|
||
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||
"""Aggregates stats for a single feature."""
|
||
means = np.stack([s["mean"] for s in stats_ft_list])
|
||
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
|
||
counts = np.stack([s["count"] for s in stats_ft_list])
|
||
total_count = counts.sum(axis=0)
|
||
|
||
# Prepare weighted mean by matching number of dimensions
|
||
while counts.ndim < means.ndim:
|
||
counts = np.expand_dims(counts, axis=-1)
|
||
|
||
# Compute the weighted mean
|
||
weighted_means = means * counts
|
||
total_mean = weighted_means.sum(axis=0) / total_count
|
||
|
||
# Compute the variance using the parallel algorithm
|
||
delta_means = means - total_mean
|
||
weighted_variances = (variances + delta_means**2) * counts
|
||
total_variance = weighted_variances.sum(axis=0) / total_count
|
||
|
||
aggregated = {
|
||
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
|
||
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
|
||
"mean": total_mean,
|
||
"std": np.sqrt(total_variance),
|
||
"count": total_count,
|
||
}
|
||
|
||
if stats_ft_list:
|
||
quantile_keys = [k for k in stats_ft_list[0] if k.startswith("q") and k[1:].isdigit()]
|
||
|
||
for q_key in quantile_keys:
|
||
if all(q_key in s for s in stats_ft_list):
|
||
quantile_values = np.stack([s[q_key] for s in stats_ft_list])
|
||
weighted_quantiles = quantile_values * counts
|
||
aggregated[q_key] = weighted_quantiles.sum(axis=0) / total_count
|
||
|
||
return aggregated
|
||
|
||
|
||
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
||
|
||
The final stats will have the union of all data keys from each of the stats dicts.
|
||
|
||
For instance:
|
||
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
||
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||
- new_mean = (mean of all data, weighted by counts)
|
||
- new_std = (std of all data)
|
||
"""
|
||
|
||
_assert_type_and_shape(stats_list)
|
||
|
||
data_keys = {key for stats in stats_list for key in stats}
|
||
aggregated_stats = {key: {} for key in data_keys}
|
||
|
||
for key in data_keys:
|
||
stats_with_key = [stats[key] for stats in stats_list if key in stats]
|
||
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
|
||
|
||
return aggregated_stats
|
||
|
||
|
||
def _get_valid_chunk_starts(episode_indices: np.ndarray, chunk_size: int) -> np.ndarray:
|
||
"""Return all start indices where a chunk of ``chunk_size`` stays within one episode."""
|
||
total = len(episode_indices)
|
||
if total < chunk_size:
|
||
return np.array([], dtype=np.int64)
|
||
max_start = total - chunk_size
|
||
starts = np.arange(max_start + 1)
|
||
valid = episode_indices[starts] == episode_indices[starts + chunk_size - 1]
|
||
return starts[valid]
|
||
|
||
|
||
def _compute_relative_chunk_batch(
|
||
start_indices: np.ndarray,
|
||
all_actions: np.ndarray,
|
||
all_states: np.ndarray,
|
||
chunk_size: int,
|
||
relative_mask: np.ndarray,
|
||
) -> np.ndarray:
|
||
"""Vectorised relative-action computation for a batch of start indices.
|
||
|
||
Returns an ``(N * chunk_size, action_dim)`` float32 array.
|
||
"""
|
||
if len(start_indices) == 0:
|
||
return np.empty((0, all_actions.shape[1]), dtype=np.float32)
|
||
offsets = np.arange(chunk_size)
|
||
frame_idx = start_indices[:, None] + offsets[None, :]
|
||
chunks = all_actions[frame_idx].copy()
|
||
states = all_states[start_indices]
|
||
mask_dim = len(relative_mask)
|
||
chunks[:, :, :mask_dim] -= states[:, None, :mask_dim] * relative_mask[None, None, :]
|
||
return chunks.reshape(-1, all_actions.shape[1])
|
||
|
||
|
||
def compute_relative_action_stats(
|
||
hf_dataset,
|
||
features: dict,
|
||
chunk_size: int,
|
||
exclude_joints: list[str] | None = None,
|
||
num_workers: int = 0,
|
||
) -> dict[str, np.ndarray]:
|
||
"""Compute normalization statistics for relative actions over the full dataset.
|
||
|
||
Iterates *all* valid action chunks (within single episodes), converts them to
|
||
relative actions (action − current_state), and computes per-dimension
|
||
statistics suitable for normalization.
|
||
|
||
Args:
|
||
hf_dataset: The underlying HuggingFace dataset with "action",
|
||
"observation.state", and "episode_index" columns.
|
||
features: Dataset feature metadata (must contain "action" with "shape"
|
||
and optionally "names").
|
||
chunk_size: Number of consecutive frames per action chunk.
|
||
exclude_joints: Joint names whose dimensions should remain absolute
|
||
(not converted to relative actions).
|
||
num_workers: Number of parallel threads for computation. Values ≤1
|
||
mean single-threaded. Numpy releases the GIL so threads give
|
||
real parallelism here.
|
||
|
||
Returns:
|
||
Statistics dict with keys "mean", "std", "min", "max", "q01", …, "q99".
|
||
|
||
Raises:
|
||
ValueError: If the dataset has fewer frames than ``chunk_size``.
|
||
RuntimeError: If no valid (single-episode) chunks are found.
|
||
"""
|
||
from lerobot.processor.relative_action_processor import RelativeActionsProcessorStep
|
||
|
||
if exclude_joints is None:
|
||
exclude_joints = []
|
||
|
||
action_dim = features[ACTION]["shape"][0]
|
||
action_names = features.get(ACTION, {}).get("names")
|
||
mask_step = RelativeActionsProcessorStep(
|
||
enabled=True,
|
||
exclude_joints=exclude_joints,
|
||
action_names=action_names,
|
||
)
|
||
relative_mask = np.array(mask_step._build_mask(action_dim), dtype=np.float32)
|
||
|
||
logging.info("Loading action/state data for relative action stats...")
|
||
all_actions = np.array(hf_dataset[ACTION], dtype=np.float32)
|
||
all_states = np.array(hf_dataset[OBS_STATE], dtype=np.float32)
|
||
episode_indices = np.array(hf_dataset["episode_index"])
|
||
|
||
valid_starts = _get_valid_chunk_starts(episode_indices, chunk_size)
|
||
if len(valid_starts) == 0:
|
||
raise RuntimeError(
|
||
f"No valid chunks found (total_frames={len(episode_indices)}, chunk_size={chunk_size})"
|
||
)
|
||
|
||
effective_workers = max(num_workers, 1)
|
||
logging.info(
|
||
f"Computing relative action stats from {len(valid_starts)} chunks "
|
||
f"(chunk_size={chunk_size}, workers={effective_workers})"
|
||
)
|
||
|
||
batch_size = 50_000
|
||
batches = [valid_starts[i : i + batch_size] for i in range(0, len(valid_starts), batch_size)]
|
||
|
||
running_stats = RunningQuantileStats()
|
||
|
||
if num_workers > 1:
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
|
||
with ThreadPoolExecutor(max_workers=num_workers) as pool:
|
||
futures = [
|
||
pool.submit(
|
||
_compute_relative_chunk_batch,
|
||
batch,
|
||
all_actions,
|
||
all_states,
|
||
chunk_size,
|
||
relative_mask,
|
||
)
|
||
for batch in batches
|
||
]
|
||
for future in as_completed(futures):
|
||
running_stats.update(future.result())
|
||
else:
|
||
for batch in batches:
|
||
running_stats.update(
|
||
_compute_relative_chunk_batch(batch, all_actions, all_states, chunk_size, relative_mask)
|
||
)
|
||
|
||
stats = running_stats.get_statistics()
|
||
|
||
excluded_dims = int(len(relative_mask) - relative_mask.sum())
|
||
total_frames = len(valid_starts) * chunk_size
|
||
logging.info(
|
||
f"Relative action stats ({len(valid_starts)} chunks, {total_frames} frames): "
|
||
f"relative_dims={int(relative_mask.sum())}/{len(relative_mask)} (excluded={excluded_dims}), "
|
||
f"mean={np.abs(stats['mean']).mean():.4f}, std={stats['std'].mean():.4f}, "
|
||
f"q01={stats['q01'].mean():.4f}, q99={stats['q99'].mean():.4f}"
|
||
)
|
||
|
||
return stats
|