mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 04:41:24 +00:00
* Add extensive language support * Address review: split persistent/event schemas, drop event timestamps - recipe.py: derive _VALID_ROLES/_VALID_STREAMS from MessageRole/MessageStream Literals - dataset_metadata.py: keep CODEBASE_VERSION at v3.0 - language.py: remove RESERVED_STYLES; split arrow/feature schemas into persistent (with timestamp) and event (without timestamp); add docstrings - language_render.py: events use frame-row timestamp implicitly; no per-event timestamp filtering or sorting - converters.py: drop unused subtask_key passthrough - add docstrings to new public APIs (recipe, render_messages_processor, collate) - update tests for split schemas; revert uv.lock Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Add docstrings to all new helpers; revert uv.lock Covers private helpers in recipe.py, language.py, language_render.py, and render_messages_processor.py. Also reverts uv.lock to main (it was re-generated by `uv run` during local checks). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat(language): add motion (persistent) and trace (event-only) styles Promote the previously-reserved motion/trace styles to first-class core styles. motion routes to language_persistent (it tracks robot state over time); trace routes to language_events (single-moment annotations). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat(language): per-camera tagging on view-dependent styles Adds a nullable `camera` field to the language row struct (both persistent and event variants) so view-dependent styles like `vqa` can carry which `observation.images.*` view they were grounded against. Without this, multi-camera datasets ended up with multiple `(vqa, role)` rows at the same timestamp that the resolver could not disambiguate. - `language.py`: add `camera` to PERSISTENT_ROW_FIELDS / EVENT_ROW_FIELDS, to both Arrow struct types and the HF datasets feature mappings; introduce VIEW_DEPENDENT_STYLES = {vqa, motion, trace} plus `is_view_dependent_style` and `validate_camera_field` helpers (camera required iff style is view-dependent). - `language_render.py`: thread an optional `camera=` kwarg through every resolver (`active_at`, `emitted_at`, `nth_prev`, `nth_next`) and through `_matching_rows` / `_select_*`, so recipes can disambiguate per-camera VQA with `emitted_at(t, style=vqa, role=assistant, camera=...)`. Without a `camera` filter, multi-row matches keep raising the existing ambiguity error — which is the desired behaviour on multi-camera data. - `recipes/pi05_hirobot.yaml`: replace the single `ask_vqa` branch with `ask_vqa_top` and `ask_vqa_wrist` per-camera sub-recipes (each carrying the matching image block), keeping the original 0.20 budget and documenting the customization point for datasets with different cameras. - Tests: schema test asserts the new field order; new tests cover `is_view_dependent_style`, `validate_camera_field` (both required and forbidden directions), per-camera `emitted_at` filtering, and the ambiguity error when two cameras emit `(vqa, assistant)` at the same timestamp without a `camera=` filter. RenderMessagesStep + dataset passthrough fixtures updated to include the new field. - `docs/source/language_and_recipes.mdx`: document the `camera` field, the per-camera resolver pattern, and the canonical recipe convention. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(language): drop motion from VIEW_DEPENDENT_STYLES Motion primitives are described in robot-frame (joint / Cartesian) terms, not pixel space, so they are camera-agnostic. Only `vqa` (event) and `trace` (event, pixel-trajectory) are view-dependent. The `camera` field stays on PERSISTENT_ROW_FIELDS for schema symmetry — the validator, resolver, and HF feature mapping behave identically across the two columns regardless of which styles populate `camera` today — but persistent rows now always have `camera=None` in practice. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat(language): task_aug style + automatic ${task} rephrasing rotation Adds task-prompt diversity (Xiao 2022 / CAST) without touching ``meta/tasks.parquet`` or forcing recipes to opt in. The plan reserved ``task_aug`` as a future style; this lands it now. - ``language.py``: add ``task_aug`` to ``CORE_STYLES`` and ``PERSISTENT_STYLES``. ``column_for_style("task_aug")`` returns ``language_persistent`` so PR 2 writers route it correctly. - ``language_render.py``: ``_resolve_task`` now consults the persistent slice for rows of ``style="task_aug", role="user"``. When any exist it picks one deterministically by ``sample_idx`` (blake2b-keyed, not Python's randomized hash) so an epoch sees every rephrasing of every episode while the same sample still resolves identically across reruns. Falls back to the canonical ``meta/tasks.parquet`` task when no rephrasings are present, so existing datasets and unannotated runs keep their behaviour. Explicit ``task=`` overrides still win. - Tests: rephrasing coverage across samples, determinism on repeat ``sample_idx``, fallback when persistent has no ``task_aug`` rows, and explicit override priority. Recipes get this for free: any ``${task}`` placeholder rotates through the available rephrasings. Recipes that want the literal canonical task can override the binding. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat(language): tool catalog in meta/info.json + LeRobotDatasetMetadata.tools Stores OpenAI-style function schemas at ``meta/info.json["tools"]`` so datasets can declare which tools are available (today: just ``say``; tomorrow: per-dataset extensions). The ``DEFAULT_TOOLS`` constant fills in for unannotated datasets so chat-template consumers don't have to special-case anything. Three pieces: - ``language.py``: ``SAY_TOOL_SCHEMA`` and ``DEFAULT_TOOLS`` constants. Single source of truth — PR 2's writer and PR 3's runtime tool registry will both import from here instead of duplicating the dict. - ``dataset_metadata.py``: ``LeRobotDatasetMetadata.tools`` property reads ``info.json["tools"]`` and falls back to ``DEFAULT_TOOLS``. Returns deep-copied dicts so callers can mutate the result safely. - ``docs/source/tools.mdx``: spec page covering the catalog, per-row invocations, and the three-step "how to add a new tool" workflow (declare schema, implement, register). Linked from the docs toctree under the Datasets section. This lays the groundwork for PR 2's pipeline writing the catalog out during annotation, and PR 3's ``src/lerobot/tools/`` package shipping runnable implementations (one file per tool — first up: ``say.py`` wrapping Kyutai's pocket-tts). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Apply ruff and prettier formatting after merge Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * refactor(language): unify resolver dispatch and prune redundant test scaffolding * Drop the unused `events` kwarg from `active_at`/`nth_prev`/`nth_next`; only `emitted_at` actually consults events. The dispatcher in `_resolve_spec` now passes events conditionally. * Replace the dual `_persistent_sort_key`/`_event_sort_key` pair with a single `_row_sort_key` and drop the `sort_key` parameter from `_select_one`. Event rows lack `timestamp` (it is implicit in the frame) and now default to `0.0` for sort purposes — the `(style, role)` tiebreaker is unchanged. * Inline `_select_latest` into `active_at` (its only caller). * Collapse `emitted_at`'s dual-branch into one `_select_one` call. * Tighten `_validate_persistent_resolver` to a single `column_for_style(style) != LANGUAGE_PERSISTENT` check. * Parameterize `test_per_camera_blend_renders_both_views` over the two cameras and factor the sub-recipe builder into `_vqa_subrecipe` so the test no longer hand-rolls two near-identical recipe blocks. Net -98 LOC; behavior, public resolver names, and test expectations unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(language): always raise on ambiguous resolver matches `_select_one` previously skipped its ambiguity check whenever any of `role`/`tool_name`/`camera` was set, on the assumption that the caller had already pinned down a unique row. That left a real ambiguity hole for VQA: with two cameras emitting `(vqa, assistant)` at the same frame, `emitted_at(..., role="assistant")` silently picked the first sorted row instead of telling the recipe to add `camera=...`. The existing `test_emitted_at_raises_on_ambiguous_per_camera_vqa` test already encoded the desired behavior. Tighten the check: any time `len(rows) > 1` we now raise with the selectors echoed back, so users see exactly which fields they passed and that more is needed to disambiguate. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * chore: fix CI — collapse short ValueError to one line, refresh uv.lock * `ruff format` on CI (newer version) wants the short `camera=None` ValueError on a single line. * `uv.lock` was stale relative to `pyproject.toml`'s `datasets>=4.7.0` pin (and picked up upstream `s390x` marker fixes for cuda packages). CI runs `uv sync --locked` which rejected the divergence. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(language): keep base install green — drop processor re-export, gate dataset-extra tests `lerobot.processor` re-exported `RenderMessagesStep` at the package level, so importing anything from `lerobot.processor` pulled in `lerobot.datasets.language` → `lerobot.datasets/__init__.py` → `require_package("datasets")`, which fails in the Tier 1 base install that intentionally omits the `[dataset]` extra. The chain bricked collection for unrelated suites (`tests/policies/pi0_pi05/...`, `tests/envs/...`, etc.). * Stop re-exporting `RenderMessagesStep` from `lerobot.processor`. The only consumer (the test) already imports from the submodule. Document the deliberate omission in the module docstring. * Add `pytest.importorskip("datasets", ...)` (and `pandas` where needed) at the top of the four PR-added tests that exercise the language stack: - tests/datasets/test_language.py - tests/datasets/test_language_render.py - tests/processor/test_render_messages_processor.py - tests/utils/test_collate.py Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(language): address review — tools accessor, motion docs, conditional collate * **`meta.tools` actually reads `info.json["tools"]`.** `DatasetInfo` had no `tools` field, so `from_dict` silently dropped the key (it warned about unknown fields then discarded them) and the property always returned `DEFAULT_TOOLS`. Added `tools: list[dict] | None` to the dataclass; `to_dict()` drops it when unset so existing datasets keep a clean `info.json`. Fixed the accessor to read `self.info.tools` (the previous `.get(...)` would have raised AttributeError on the dataclass anyway). Added regression tests: fallback when absent, round-trip from disk, and round-trip through `DatasetInfo.from_dict` / `to_dict`. * **`motion` is not view-dependent — fix the docs.** The mdx claimed rows of style `motion` must carry `camera`, but `VIEW_DEPENDENT_STYLES = {"vqa", "trace"}` and the validator agrees: motion primitives are joint/Cartesian-frame, not pixel-space. Updated both call-out paragraphs in `language_and_recipes.mdx`. * **Conditional `collate_fn` swap.** Added `meta.has_language_columns` and gate the `lerobot_collate_fn` swap in `lerobot_train.py` on it, so non-language datasets keep PyTorch's `default_collate`. Also added a pass-through test in `test_collate.py` that asserts on a plain tensor batch the custom collate matches `default_collate` key-for-key, plus a test for the `None`-sample drop path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * review: dedupe regex, centralize column names, harden collate, more tests * **#2 — dedupe `_PLACEHOLDER_RE`.** The same regex was compiled in `recipe.py` and `language_render.py`. Promote to module-level `PLACEHOLDER_RE` in `recipe.py` (its primary owner — declares template syntax) and import from `language_render.py`. * **#3 — centralize language column names.** `io_utils.py` had hardcoded `{"language_persistent", "language_events"}` literals at two sites. Replace with `LANGUAGE_COLUMNS` import so a future column rename can't silently desync. * **#4 — defensive collate preserved-keys.** `lerobot_collate_fn` silently filtered language fields from samples that didn't have them, which would hand downstream consumers a preserved list shorter than the tensor batch. Now: if any sample carries a key, every sample in the batch must carry it; otherwise raise a `ValueError` so the upstream rendering bug surfaces at the boundary. * **#5 — `_scalar` rejects non-singleton lists.** Previously a zero- or multi-element list fell through and triggered confusing `float([])` errors downstream. Now raises `ValueError` with the actual length. * **#6 — refactor `_extract_complementary_data`.** Replace 11 lines of `key = {... if ... else {}}` plus an 11-line splat dict with a single `_COMPLEMENTARY_KEYS` tuple iterated once. * **#7 — document `EXTENDED_STYLES`.** Was an empty `set()` with no comment. Add a docstring explaining it's an intentional extension point: downstream modules append project-local styles before `column_for_style` is called. * **#9 — `tools.mdx` notes the runtime layer is future work.** The page referenced `src/lerobot/tools/`, `registry.py`, and `get_tools(meta)` — none exist in this PR. Added a callout at the start of "How to add your own tool" plus a note on the implementations paragraph. * **#10 — tests for YAML round-trip, malformed rows, blend validation.** `test_recipe.py` grew from 1 case to 12 covering: blend-or-messages exclusivity, target-turn requirement, blend emptiness, weight presence/positivity, nested-blend rejection, `from_dict` with nested blends, `from_yaml` / `load_recipe` agreement, top-level non-mapping rejection. Added a malformed-row test for `_normalize_rows` that asserts non-dict entries raise `TypeError`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * review: emitted_at uses 0.1s tolerance; MessageTurn requires stream at construction * **Float tolerance in `emitted_at` for persistent styles.** The ``_timestamp(row) == t`` exact-equality check silently missed any caller that derived ``t`` arithmetically (e.g. ``frame_idx / fps``) even though the parquet timestamp would only differ by ULPs. Added ``EMITTED_AT_TOLERANCE_S = 0.1`` and check ``abs(...) <= tolerance`` instead, with a docstring explaining why exact equality wasn't enough and why 0.1 s is safe at typical 30–100 Hz control rates. Test asserts the new behavior at half-window (matches) and double-window (no match) using the constant so it stays in sync. * **`MessageTurn.stream` is required at construction.** It was typed ``MessageStream | None = None`` so YAML could omit ``stream:`` and pass the dataclass invariant — but ``_validate_rendered`` rejected ``None`` streams later, surfacing the error at the first sample instead of at recipe load. Now ``__post_init__`` raises ``ValueError`` if ``stream`` is ``None``, with the list of valid streams in the message. The redundant late-stage check in ``_validate_rendered`` is replaced with a one-line comment that cites the upstream invariant. Test pins the new construction-time rejection. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * docs(tools): drop follow-up-PR references Reword the two callouts in `tools.mdx` to describe the runtime layer in present tense ("not part of the catalog layer shipped today", "those modules don't yet exist in the tree") instead of pointing at a specific follow-up PR. Keeps the doc honest about what works now without coupling it to a particular release order. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * review: address CarolinePascal feedback - language timestamps: float64 -> float32 to match LeRobotDataset frame timestamps (Arrow struct + HF feature) - dataset_metadata: hoist `.language` imports to module top — language.py has no lerobot imports, so there is no circular-import risk - dataset_metadata: add a `meta.tools` setter that persists the catalog to info.json and reloads `meta.info` - feature_utils: validate the `language` dtype instead of returning "" — warn (non-fatal) when a non-empty value is written at record time - centralize the scalar-unwrap helper as `lerobot.utils.utils.unwrap_scalar`, shared by render_messages_processor and language_render - docs: move `## Layer 2 — recipe anatomy` ahead of the resolver sections, which describe recipe bindings rather than dataset layout - language_render: note in EMITTED_AT_TOLERANCE_S that persistent rows change on a human-action timescale, not the camera frame rate Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
770 lines
28 KiB
Python
770 lines
28 KiB
Python
#!/usr/bin/env python
|
||
|
||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
|
||
import numpy as np
|
||
|
||
from lerobot.processor import RelativeActionsProcessorStep
|
||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||
|
||
from .io_utils import load_image_as_numpy
|
||
|
||
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
|
||
|
||
|
||
class RunningQuantileStats:
|
||
"""
|
||
Maintains running statistics for batches of vectors, including mean,
|
||
standard deviation, min, max, and approximate quantiles.
|
||
|
||
Statistics are computed per feature dimension and updated incrementally
|
||
as new batches are observed. Quantiles are estimated using histograms,
|
||
which adapt dynamically if the observed data range expands.
|
||
"""
|
||
|
||
def __init__(self, quantile_list: list[float] | None = None, num_quantile_bins: int = 5000):
|
||
self._count = 0
|
||
self._mean = None
|
||
self._mean_of_squares = None
|
||
self._min = None
|
||
self._max = None
|
||
self._histograms = None
|
||
self._bin_edges = None
|
||
self._num_quantile_bins = num_quantile_bins
|
||
|
||
self._quantile_list = quantile_list
|
||
if self._quantile_list is None:
|
||
self._quantile_list = DEFAULT_QUANTILES
|
||
self._quantile_keys = [f"q{int(q * 100):02d}" for q in self._quantile_list]
|
||
|
||
def update(self, batch: np.ndarray) -> None:
|
||
"""Update the running statistics with a batch of vectors.
|
||
|
||
Args:
|
||
batch: An array where all dimensions except the last are batch dimensions.
|
||
"""
|
||
batch = batch.reshape(-1, batch.shape[-1])
|
||
num_elements, vector_length = batch.shape
|
||
|
||
if self._count == 0:
|
||
self._mean = np.mean(batch, axis=0)
|
||
self._mean_of_squares = np.mean(batch**2, axis=0)
|
||
self._min = np.min(batch, axis=0)
|
||
self._max = np.max(batch, axis=0)
|
||
self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)]
|
||
self._bin_edges = [
|
||
np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1)
|
||
for i in range(vector_length)
|
||
]
|
||
else:
|
||
if vector_length != self._mean.size:
|
||
raise ValueError("The length of new vectors does not match the initialized vector length.")
|
||
|
||
new_max = np.max(batch, axis=0)
|
||
new_min = np.min(batch, axis=0)
|
||
max_changed = np.any(new_max > self._max)
|
||
min_changed = np.any(new_min < self._min)
|
||
self._max = np.maximum(self._max, new_max)
|
||
self._min = np.minimum(self._min, new_min)
|
||
|
||
if max_changed or min_changed:
|
||
self._adjust_histograms()
|
||
|
||
self._count += num_elements
|
||
|
||
batch_mean = np.mean(batch, axis=0)
|
||
batch_mean_of_squares = np.mean(batch**2, axis=0)
|
||
|
||
# Update running mean and mean of squares
|
||
self._mean += (batch_mean - self._mean) * (num_elements / self._count)
|
||
self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * (
|
||
num_elements / self._count
|
||
)
|
||
|
||
self._update_histograms(batch)
|
||
|
||
def get_statistics(self) -> dict[str, np.ndarray]:
|
||
"""Compute and return the statistics of the vectors processed so far.
|
||
|
||
Args:
|
||
quantiles: List of quantiles to compute (e.g., [0.01, 0.10, 0.50, 0.90, 0.99]). If None, no quantiles computed.
|
||
|
||
Returns:
|
||
Dictionary containing the computed statistics.
|
||
"""
|
||
if self._count < 2:
|
||
raise ValueError("Cannot compute statistics for less than 2 vectors.")
|
||
|
||
variance = self._mean_of_squares - self._mean**2
|
||
|
||
stddev = np.sqrt(np.maximum(0, variance))
|
||
|
||
stats = {
|
||
"min": self._min.copy(),
|
||
"max": self._max.copy(),
|
||
"mean": self._mean.copy(),
|
||
"std": stddev,
|
||
"count": np.array([self._count]),
|
||
}
|
||
|
||
quantile_results = self._compute_quantiles()
|
||
for i, q in enumerate(self._quantile_keys):
|
||
stats[q] = quantile_results[i]
|
||
|
||
return stats
|
||
|
||
def _adjust_histograms(self):
|
||
"""Adjust histograms when min or max changes."""
|
||
for i in range(len(self._histograms)):
|
||
old_edges = self._bin_edges[i]
|
||
old_hist = self._histograms[i]
|
||
|
||
# Create new edges with small padding to ensure range coverage
|
||
padding = (self._max[i] - self._min[i]) * 1e-10
|
||
new_edges = np.linspace(
|
||
self._min[i] - padding, self._max[i] + padding, self._num_quantile_bins + 1
|
||
)
|
||
|
||
# Redistribute existing histogram counts to new bins
|
||
# We need to map each old bin center to the new bins
|
||
old_centers = (old_edges[:-1] + old_edges[1:]) / 2
|
||
new_hist = np.zeros(self._num_quantile_bins)
|
||
|
||
for old_center, count in zip(old_centers, old_hist, strict=False):
|
||
if count > 0:
|
||
# Find which new bin this old center belongs to
|
||
bin_idx = np.searchsorted(new_edges, old_center) - 1
|
||
bin_idx = max(0, min(bin_idx, self._num_quantile_bins - 1))
|
||
new_hist[bin_idx] += count
|
||
|
||
self._histograms[i] = new_hist
|
||
self._bin_edges[i] = new_edges
|
||
|
||
def _update_histograms(self, batch: np.ndarray) -> None:
|
||
"""Update histograms with new vectors."""
|
||
for i in range(batch.shape[1]):
|
||
hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i])
|
||
self._histograms[i] += hist
|
||
|
||
def _compute_quantiles(self) -> list[np.ndarray]:
|
||
"""Compute quantiles based on histograms."""
|
||
results = []
|
||
for q in self._quantile_list:
|
||
target_count = q * self._count
|
||
q_values = []
|
||
|
||
for hist, edges in zip(self._histograms, self._bin_edges, strict=True):
|
||
q_value = self._compute_single_quantile(hist, edges, target_count)
|
||
q_values.append(q_value)
|
||
|
||
results.append(np.array(q_values))
|
||
return results
|
||
|
||
def _compute_single_quantile(self, hist: np.ndarray, edges: np.ndarray, target_count: float) -> float:
|
||
"""Compute a single quantile value from histogram and bin edges."""
|
||
cumsum = np.cumsum(hist)
|
||
idx = np.searchsorted(cumsum, target_count)
|
||
|
||
if idx == 0:
|
||
return edges[0]
|
||
if idx >= len(cumsum):
|
||
return edges[-1]
|
||
|
||
# If not edge case, interpolate within the bin
|
||
count_before = cumsum[idx - 1]
|
||
count_in_bin = cumsum[idx] - count_before
|
||
|
||
# If no samples in this bin, use the bin edge
|
||
if count_in_bin == 0:
|
||
return edges[idx]
|
||
|
||
# Linear interpolation within the bin
|
||
fraction = (target_count - count_before) / count_in_bin
|
||
return edges[idx] + fraction * (edges[idx + 1] - edges[idx])
|
||
|
||
|
||
def estimate_num_samples(
|
||
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||
) -> int:
|
||
"""Heuristic to estimate the number of samples based on dataset size.
|
||
The power controls the sample growth relative to dataset size.
|
||
Lower the power for less number of samples.
|
||
|
||
For default arguments, we have:
|
||
- from 1 to ~500, num_samples=100
|
||
- at 1000, num_samples=177
|
||
- at 2000, num_samples=299
|
||
- at 5000, num_samples=594
|
||
- at 10000, num_samples=1000
|
||
- at 20000, num_samples=1681
|
||
"""
|
||
if dataset_len < min_num_samples:
|
||
min_num_samples = dataset_len
|
||
return max(min_num_samples, min(int(dataset_len**power), max_num_samples))
|
||
|
||
|
||
def sample_indices(data_len: int) -> list[int]:
|
||
num_samples = estimate_num_samples(data_len)
|
||
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
|
||
|
||
|
||
def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
|
||
_, height, width = img.shape
|
||
|
||
if max(width, height) < max_size_threshold:
|
||
# no downsampling needed
|
||
return img
|
||
|
||
downsample_factor = int(width / target_size) if width > height else int(height / target_size)
|
||
return img[:, ::downsample_factor, ::downsample_factor]
|
||
|
||
|
||
def sample_images(image_paths: list[str]) -> np.ndarray:
|
||
sampled_indices = sample_indices(len(image_paths))
|
||
|
||
images = None
|
||
for i, idx in enumerate(sampled_indices):
|
||
path = image_paths[idx]
|
||
# we load as uint8 to reduce memory usage
|
||
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
|
||
img = auto_downsample_height_width(img)
|
||
|
||
if images is None:
|
||
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
||
|
||
images[i] = img
|
||
|
||
return images
|
||
|
||
|
||
def _reshape_stats_by_axis(
|
||
stats: dict[str, np.ndarray],
|
||
axis: int | tuple[int, ...] | None,
|
||
keepdims: bool,
|
||
original_shape: tuple[int, ...],
|
||
) -> dict[str, np.ndarray]:
|
||
"""Reshape all statistics to match NumPy's output conventions.
|
||
|
||
Applies consistent reshaping to all statistics (except 'count') based on the
|
||
axis and keepdims parameters. This ensures statistics have the correct shape
|
||
for broadcasting with the original data.
|
||
|
||
Args:
|
||
stats: Dictionary of computed statistics
|
||
axis: Axis or axes along which statistics were computed
|
||
keepdims: Whether to keep reduced dimensions as size-1 dimensions
|
||
original_shape: Shape of the original array
|
||
|
||
Returns:
|
||
Dictionary with reshaped statistics
|
||
|
||
Note:
|
||
The 'count' statistic is never reshaped as it represents metadata
|
||
rather than per-feature statistics.
|
||
"""
|
||
if axis == (1,) and not keepdims:
|
||
return stats
|
||
|
||
result = {}
|
||
for key, value in stats.items():
|
||
if key == "count":
|
||
result[key] = value
|
||
else:
|
||
result[key] = _reshape_single_stat(value, axis, keepdims, original_shape)
|
||
|
||
return result
|
||
|
||
|
||
def _reshape_for_image_stats(value: np.ndarray, keepdims: bool) -> np.ndarray:
|
||
"""Reshape statistics for image data (axis=(0,2,3))."""
|
||
if keepdims and value.ndim == 1:
|
||
return value.reshape(1, -1, 1, 1)
|
||
return value
|
||
|
||
|
||
def _reshape_for_vector_stats(
|
||
value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...]
|
||
) -> np.ndarray:
|
||
"""Reshape statistics for vector data (axis=0 or axis=(0,))."""
|
||
if not keepdims:
|
||
return value
|
||
|
||
if len(original_shape) == 1 and value.ndim > 0:
|
||
return value.reshape(1)
|
||
elif len(original_shape) >= 2 and value.ndim == 1:
|
||
return value.reshape(1, -1)
|
||
return value
|
||
|
||
|
||
def _reshape_for_feature_stats(value: np.ndarray, keepdims: bool) -> np.ndarray:
|
||
"""Reshape statistics for feature-wise computation (axis=(1,))."""
|
||
if not keepdims:
|
||
return value
|
||
|
||
if value.ndim == 0:
|
||
return value.reshape(1, 1)
|
||
elif value.ndim == 1:
|
||
return value.reshape(-1, 1)
|
||
return value
|
||
|
||
|
||
def _reshape_for_global_stats(
|
||
value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...]
|
||
) -> np.ndarray | float:
|
||
"""Reshape statistics for global reduction (axis=None)."""
|
||
if keepdims:
|
||
target_shape = tuple(1 for _ in original_shape)
|
||
return value.reshape(target_shape)
|
||
# Keep at least 1-D arrays to satisfy validator
|
||
return np.atleast_1d(value)
|
||
|
||
|
||
def _reshape_single_stat(
|
||
value: np.ndarray, axis: int | tuple[int, ...] | None, keepdims: bool, original_shape: tuple[int, ...]
|
||
) -> np.ndarray | float:
|
||
"""Apply appropriate reshaping to a single statistic array.
|
||
|
||
This function transforms statistic arrays to match expected output shapes
|
||
based on the axis configuration and keepdims parameter.
|
||
|
||
Args:
|
||
value: The statistic array to reshape
|
||
axis: Axis or axes that were reduced during computation
|
||
keepdims: Whether to maintain reduced dimensions as size-1 dimensions
|
||
original_shape: Shape of the original data before reduction
|
||
|
||
Returns:
|
||
Reshaped array following NumPy broadcasting conventions
|
||
|
||
"""
|
||
if axis == (0, 2, 3):
|
||
return _reshape_for_image_stats(value, keepdims)
|
||
|
||
if axis in [0, (0,)]:
|
||
return _reshape_for_vector_stats(value, keepdims, original_shape)
|
||
|
||
if axis == (1,):
|
||
return _reshape_for_feature_stats(value, keepdims)
|
||
|
||
if axis is None:
|
||
return _reshape_for_global_stats(value, keepdims, original_shape)
|
||
|
||
return value
|
||
|
||
|
||
def _prepare_array_for_stats(array: np.ndarray, axis: int | tuple[int, ...] | None) -> tuple[np.ndarray, int]:
|
||
"""Prepare array for statistics computation by reshaping according to axis.
|
||
|
||
Args:
|
||
array: Input data array
|
||
axis: Axis or axes along which to compute statistics
|
||
|
||
Returns:
|
||
Tuple of (reshaped_array, sample_count)
|
||
"""
|
||
if axis == (0, 2, 3): # Image data
|
||
batch_size, channels, height, width = array.shape
|
||
reshaped = array.transpose(0, 2, 3, 1).reshape(-1, channels)
|
||
return reshaped, batch_size
|
||
|
||
if axis == 0 or axis == (0,): # Vector data
|
||
reshaped = array
|
||
if array.ndim == 1:
|
||
reshaped = array.reshape(-1, 1)
|
||
return reshaped, array.shape[0]
|
||
|
||
if axis == (1,): # Feature-wise statistics
|
||
return array.T, array.shape[1]
|
||
|
||
if axis is None: # Global statistics
|
||
reshaped = array.reshape(-1, 1)
|
||
# For backward compatibility, count represents the first dimension size
|
||
return reshaped, array.shape[0] if array.ndim > 0 else 1
|
||
|
||
raise ValueError(f"Unsupported axis configuration: {axis}")
|
||
|
||
|
||
def _compute_basic_stats(
|
||
array: np.ndarray, sample_count: int, quantile_list: list[float] | None = None
|
||
) -> dict[str, np.ndarray]:
|
||
"""Compute basic statistics for arrays with insufficient samples for quantiles.
|
||
|
||
Args:
|
||
array: Reshaped array ready for statistics computation
|
||
sample_count: Number of samples represented in the data
|
||
|
||
Returns:
|
||
Dictionary with basic statistics and quantiles set to mean values
|
||
"""
|
||
if quantile_list is None:
|
||
quantile_list = DEFAULT_QUANTILES
|
||
quantile_list_keys = [f"q{int(q * 100):02d}" for q in quantile_list]
|
||
|
||
stats = {
|
||
"min": np.min(array, axis=0),
|
||
"max": np.max(array, axis=0),
|
||
"mean": np.mean(array, axis=0),
|
||
"std": np.std(array, axis=0),
|
||
"count": np.array([sample_count]),
|
||
}
|
||
|
||
for q in quantile_list_keys:
|
||
stats[q] = stats["mean"].copy()
|
||
|
||
return stats
|
||
|
||
|
||
def get_feature_stats(
|
||
array: np.ndarray,
|
||
axis: int | tuple[int, ...] | None,
|
||
keepdims: bool,
|
||
quantile_list: list[float] | None = None,
|
||
) -> dict[str, np.ndarray]:
|
||
"""Compute comprehensive statistics for array features along specified axes.
|
||
|
||
This function calculates min, max, mean, std, and quantiles (1%, 10%, 50%, 90%, 99%)
|
||
for the input array along the specified axes. It handles different data layouts:
|
||
- Image data: axis=(0,2,3) computes per-channel statistics
|
||
- Vector data: axis=0 computes per-feature statistics
|
||
- Feature-wise: axis=1 computes statistics across features
|
||
- Global: axis=None computes statistics over entire array
|
||
|
||
Args:
|
||
array: Input data array with shape appropriate for the specified axis
|
||
axis: Axis or axes along which to compute statistics
|
||
- (0, 2, 3): For image data (batch, channels, height, width)
|
||
- 0 or (0,): For vector/tabular data (samples, features)
|
||
- (1,): For computing across features
|
||
- None: For global statistics over entire array
|
||
keepdims: If True, reduced axes are kept as dimensions with size 1
|
||
|
||
Returns:
|
||
Dictionary containing:
|
||
- 'min': Minimum values
|
||
- 'max': Maximum values
|
||
- 'mean': Mean values
|
||
- 'std': Standard deviation
|
||
- 'count': Number of samples (always shape (1,))
|
||
- 'q01', 'q10', 'q50', 'q90', 'q99': Quantile values
|
||
|
||
"""
|
||
if quantile_list is None:
|
||
quantile_list = DEFAULT_QUANTILES
|
||
|
||
original_shape = array.shape
|
||
reshaped, sample_count = _prepare_array_for_stats(array, axis)
|
||
|
||
if reshaped.shape[0] < 2:
|
||
stats = _compute_basic_stats(reshaped, sample_count, quantile_list)
|
||
else:
|
||
running_stats = RunningQuantileStats()
|
||
running_stats.update(reshaped)
|
||
stats = running_stats.get_statistics()
|
||
stats["count"] = np.array([sample_count])
|
||
|
||
stats = _reshape_stats_by_axis(stats, axis, keepdims, original_shape)
|
||
return stats
|
||
|
||
|
||
def compute_episode_stats(
|
||
episode_data: dict[str, list[str] | np.ndarray],
|
||
features: dict,
|
||
quantile_list: list[float] | None = None,
|
||
) -> dict:
|
||
"""Compute comprehensive statistics for all features in an episode.
|
||
|
||
Processes different data types appropriately:
|
||
- Images/videos: Samples from paths, computes per-channel stats, normalizes to [0,1]
|
||
- Numerical arrays: Computes per-feature statistics
|
||
- Strings: Skipped (no statistics computed)
|
||
|
||
Args:
|
||
episode_data: Dictionary mapping feature names to data
|
||
- For images/videos: list of file paths
|
||
- For numerical data: numpy arrays
|
||
features: Dictionary describing each feature's dtype and shape
|
||
|
||
Returns:
|
||
Dictionary mapping feature names to their statistics dictionaries.
|
||
Each statistics dictionary contains min, max, mean, std, count, and quantiles.
|
||
|
||
Note:
|
||
Image statistics are normalized to [0,1] range and have shape (3,1,1) for
|
||
per-channel values when dtype is 'image' or 'video'.
|
||
"""
|
||
if quantile_list is None:
|
||
quantile_list = DEFAULT_QUANTILES
|
||
|
||
ep_stats = {}
|
||
for key, data in episode_data.items():
|
||
if features[key]["dtype"] in {"string", "language"}:
|
||
continue
|
||
|
||
if features[key]["dtype"] in ["image", "video"]:
|
||
ep_ft_array = sample_images(data)
|
||
axes_to_reduce = (0, 2, 3)
|
||
keepdims = True
|
||
else:
|
||
ep_ft_array = data
|
||
axes_to_reduce = 0
|
||
keepdims = data.ndim == 1
|
||
|
||
ep_stats[key] = get_feature_stats(
|
||
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims, quantile_list=quantile_list
|
||
)
|
||
|
||
if features[key]["dtype"] in ["image", "video"]:
|
||
ep_stats[key] = {
|
||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||
}
|
||
|
||
return ep_stats
|
||
|
||
|
||
def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None:
|
||
"""Validate a single statistic value."""
|
||
if not isinstance(value, np.ndarray):
|
||
raise ValueError(
|
||
f"Stats must be composed of numpy array, but key '{key}' of feature '{feature_key}' "
|
||
f"is of type '{type(value)}' instead."
|
||
)
|
||
|
||
if value.ndim == 0:
|
||
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
|
||
|
||
if key == "count" and value.shape != (1,):
|
||
raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.")
|
||
|
||
if "image" in feature_key and key != "count" and value.shape != (3, 1, 1):
|
||
raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.")
|
||
|
||
|
||
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||
"""Validate that all statistics have correct types and shapes.
|
||
|
||
Args:
|
||
stats_list: List of statistics dictionaries to validate
|
||
|
||
Raises:
|
||
ValueError: If any statistic has incorrect type or shape
|
||
"""
|
||
for stats in stats_list:
|
||
for feature_key, feature_stats in stats.items():
|
||
for stat_key, stat_value in feature_stats.items():
|
||
_validate_stat_value(stat_value, stat_key, feature_key)
|
||
|
||
|
||
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||
"""Aggregates stats for a single feature."""
|
||
means = np.stack([s["mean"] for s in stats_ft_list])
|
||
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
|
||
counts = np.stack([s["count"] for s in stats_ft_list])
|
||
total_count = counts.sum(axis=0)
|
||
|
||
# Prepare weighted mean by matching number of dimensions
|
||
while counts.ndim < means.ndim:
|
||
counts = np.expand_dims(counts, axis=-1)
|
||
|
||
# Compute the weighted mean
|
||
weighted_means = means * counts
|
||
total_mean = weighted_means.sum(axis=0) / total_count
|
||
|
||
# Compute the variance using the parallel algorithm
|
||
delta_means = means - total_mean
|
||
weighted_variances = (variances + delta_means**2) * counts
|
||
total_variance = weighted_variances.sum(axis=0) / total_count
|
||
|
||
aggregated = {
|
||
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
|
||
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
|
||
"mean": total_mean,
|
||
"std": np.sqrt(total_variance),
|
||
"count": total_count,
|
||
}
|
||
|
||
if stats_ft_list:
|
||
quantile_keys = [k for k in stats_ft_list[0] if k.startswith("q") and k[1:].isdigit()]
|
||
|
||
for q_key in quantile_keys:
|
||
if all(q_key in s for s in stats_ft_list):
|
||
quantile_values = np.stack([s[q_key] for s in stats_ft_list])
|
||
weighted_quantiles = quantile_values * counts
|
||
aggregated[q_key] = weighted_quantiles.sum(axis=0) / total_count
|
||
|
||
return aggregated
|
||
|
||
|
||
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
||
|
||
The final stats will have the union of all data keys from each of the stats dicts.
|
||
|
||
For instance:
|
||
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
||
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||
- new_mean = (mean of all data, weighted by counts)
|
||
- new_std = (std of all data)
|
||
"""
|
||
|
||
_assert_type_and_shape(stats_list)
|
||
|
||
data_keys = {key for stats in stats_list for key in stats}
|
||
aggregated_stats = {key: {} for key in data_keys}
|
||
|
||
for key in data_keys:
|
||
stats_with_key = [stats[key] for stats in stats_list if key in stats]
|
||
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
|
||
|
||
return aggregated_stats
|
||
|
||
|
||
def _get_valid_chunk_starts(episode_indices: np.ndarray, chunk_size: int) -> np.ndarray:
|
||
"""Return all start indices where a chunk of ``chunk_size`` stays within one episode."""
|
||
total = len(episode_indices)
|
||
if total < chunk_size:
|
||
return np.array([], dtype=np.int64)
|
||
max_start = total - chunk_size
|
||
starts = np.arange(max_start + 1)
|
||
valid = episode_indices[starts] == episode_indices[starts + chunk_size - 1]
|
||
return starts[valid]
|
||
|
||
|
||
def _compute_relative_chunk_batch(
|
||
start_indices: np.ndarray,
|
||
all_actions: np.ndarray,
|
||
all_states: np.ndarray,
|
||
chunk_size: int,
|
||
relative_mask: np.ndarray,
|
||
) -> np.ndarray:
|
||
"""Vectorised relative-action computation for a batch of start indices.
|
||
|
||
Returns an ``(N * chunk_size, action_dim)`` float32 array.
|
||
"""
|
||
if len(start_indices) == 0:
|
||
return np.empty((0, all_actions.shape[1]), dtype=np.float32)
|
||
offsets = np.arange(chunk_size)
|
||
frame_idx = start_indices[:, None] + offsets[None, :]
|
||
chunks = all_actions[frame_idx].copy()
|
||
states = all_states[start_indices]
|
||
mask_dim = len(relative_mask)
|
||
chunks[:, :, :mask_dim] -= states[:, None, :mask_dim] * relative_mask[None, None, :]
|
||
return chunks.reshape(-1, all_actions.shape[1])
|
||
|
||
|
||
def compute_relative_action_stats(
|
||
hf_dataset,
|
||
features: dict,
|
||
chunk_size: int,
|
||
exclude_joints: list[str] | None = None,
|
||
num_workers: int = 0,
|
||
) -> dict[str, np.ndarray]:
|
||
"""Compute normalization statistics for relative actions over the full dataset.
|
||
|
||
Iterates *all* valid action chunks (within single episodes), converts them to
|
||
relative actions (action − current_state), and computes per-dimension
|
||
statistics suitable for normalization.
|
||
|
||
Args:
|
||
hf_dataset: The underlying HuggingFace dataset with "action",
|
||
"observation.state", and "episode_index" columns.
|
||
features: Dataset feature metadata (must contain "action" with "shape"
|
||
and optionally "names").
|
||
chunk_size: Number of consecutive frames per action chunk.
|
||
exclude_joints: Joint names whose dimensions should remain absolute
|
||
(not converted to relative actions).
|
||
num_workers: Number of parallel threads for computation. Values ≤1
|
||
mean single-threaded. Numpy releases the GIL so threads give
|
||
real parallelism here.
|
||
|
||
Returns:
|
||
Statistics dict with keys "mean", "std", "min", "max", "q01", …, "q99".
|
||
|
||
Raises:
|
||
ValueError: If the dataset has fewer frames than ``chunk_size``.
|
||
RuntimeError: If no valid (single-episode) chunks are found.
|
||
"""
|
||
if exclude_joints is None:
|
||
exclude_joints = []
|
||
|
||
action_dim = features[ACTION]["shape"][0]
|
||
action_names = features.get(ACTION, {}).get("names")
|
||
mask_step = RelativeActionsProcessorStep(
|
||
enabled=True,
|
||
exclude_joints=exclude_joints,
|
||
action_names=action_names,
|
||
)
|
||
relative_mask = np.array(mask_step._build_mask(action_dim), dtype=np.float32)
|
||
|
||
logging.info("Loading action/state data for relative action stats...")
|
||
all_actions = np.array(hf_dataset[ACTION], dtype=np.float32)
|
||
all_states = np.array(hf_dataset[OBS_STATE], dtype=np.float32)
|
||
episode_indices = np.array(hf_dataset["episode_index"])
|
||
|
||
valid_starts = _get_valid_chunk_starts(episode_indices, chunk_size)
|
||
if len(valid_starts) == 0:
|
||
raise RuntimeError(
|
||
f"No valid chunks found (total_frames={len(episode_indices)}, chunk_size={chunk_size})"
|
||
)
|
||
|
||
effective_workers = max(num_workers, 1)
|
||
logging.info(
|
||
f"Computing relative action stats from {len(valid_starts)} chunks "
|
||
f"(chunk_size={chunk_size}, workers={effective_workers})"
|
||
)
|
||
|
||
batch_size = 50_000
|
||
batches = [valid_starts[i : i + batch_size] for i in range(0, len(valid_starts), batch_size)]
|
||
|
||
running_stats = RunningQuantileStats()
|
||
|
||
if num_workers > 1:
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
|
||
with ThreadPoolExecutor(max_workers=num_workers) as pool:
|
||
futures = [
|
||
pool.submit(
|
||
_compute_relative_chunk_batch,
|
||
batch,
|
||
all_actions,
|
||
all_states,
|
||
chunk_size,
|
||
relative_mask,
|
||
)
|
||
for batch in batches
|
||
]
|
||
for future in as_completed(futures):
|
||
running_stats.update(future.result())
|
||
else:
|
||
for batch in batches:
|
||
running_stats.update(
|
||
_compute_relative_chunk_batch(batch, all_actions, all_states, chunk_size, relative_mask)
|
||
)
|
||
|
||
stats = running_stats.get_statistics()
|
||
|
||
excluded_dims = int(len(relative_mask) - relative_mask.sum())
|
||
total_frames = len(valid_starts) * chunk_size
|
||
logging.info(
|
||
f"Relative action stats ({len(valid_starts)} chunks, {total_frames} frames): "
|
||
f"relative_dims={int(relative_mask.sum())}/{len(relative_mask)} (excluded={excluded_dims}), "
|
||
f"mean={np.abs(stats['mean']).mean():.4f}, std={stats['std'].mean():.4f}, "
|
||
f"q01={stats['q01'].mean():.4f}, q99={stats['q99'].mean():.4f}"
|
||
)
|
||
|
||
return stats
|