Compare commits

..

2 Commits

Author SHA1 Message Date
pepijn
1e7c0d6aa1 annotate(plan): force composite-action subtasks; ban ultra-fine splits
Tighten ``module_1_subtasks.txt`` so the VLM emits one composite
atomic action per subtask instead of decomposing every pick into
``move to X`` / ``grasp X`` / ``lift X``:

- Lock the verb vocabulary to the composite set the low-level
  policy actually learns end-to-end: ``pick up`` (approach + grasp +
  lift), ``put``/``place`` (transport + release), ``push``, ``pull``,
  ``turn``, ``press``, ``open``, ``close``, ``pour``, ``insert``.
  ``go to`` is allowed only as a pure relocation between phases.
- Add an explicit ``Forbidden ultra-fine splits`` block enumerating
  the patterns the VLM was tempted to emit (``move to X``,
  ``reach for X``, ``grasp X``, ``lift X``, ``release X``) and
  instructing it to fold each into its parent composite.
- Rewrite the Good/Bad examples to match the composite contract;
  the previous ``"move to blue cube" / "grasp blue cube" / "lift
  blue cube"`` Good list was actively encouraging the over-
  segmentation pattern this prompt is supposed to prevent.
- Tighten the duration rule: candidates shorter than
  ``min_subtask_seconds`` must be merged into a neighbour rather
  than emitted. Pairs with bumping the runtime floor to 3 s so
  composites have room to land.

Pure prompt change — no code or schema change. Existing canonical-
vocabulary retry path is unaffected (the new verb whitelist lives
in prose, not in the validator).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-26 05:14:30 +00:00
pepijn
920c6ef5a2 docs(annotate): disable phase-0 vocabulary discovery by default in run_hf_job
Heterogeneous datasets (different tasks/scenes across episodes) don't
share a single small subtask + memory vocabulary, so the canonical
vocabulary phase narrowed every episode to the wrong target distribution.
Flip the example to free-form generation by default and document the
``--vocabulary.enabled=true`` switch for homogeneous datasets where the
canonical vocabulary still helps the downstream policy.

No pipeline-code changes: ``VocabularyConfig.enabled`` already gates
phase 0 (see ``executor.py:_run_vocabulary_phase`` and
``VocabularyConfig`` docstring) and falls back to free-form generation.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-26 04:42:10 +00:00
79 changed files with 164 additions and 11643 deletions

View File

@@ -106,8 +106,7 @@ inside the script — every flag in there maps directly onto a CLI flag of
## Style-to-recipe consumer mapping
The pipeline's outputs are designed to be consumed by recipes (see
[Language Columns and Recipes](./language_and_recipes)) — for the
canonical PI052 blend `src/lerobot/configs/recipes/subtask_mem_vqa_speech.yaml`:
[Language Columns and Recipes](./language_and_recipes)) — typically:
- low-level / high-level / memory-update branches consume
`subtask`/`plan`/`memory` from `language_persistent`.

View File

@@ -141,11 +141,6 @@ sample["target_message_indices"]
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone, which keeps the same dataset usable across SmolVLA, Pi0.5, and any future VLM that expects OpenAI-style chat messages.
## Blends
Blend recipes select one weighted sub-recipe deterministically from the sample index.
`recipes/subtasks_vqa.yaml` trains the core blend — high-level subtask prediction, low-level execution, and VQA. `recipes/subtask_mem_vqa_speech.yaml` is the fuller variant that also adds memory updates and spoken interjection responses.
## Graceful absence
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.

View File

@@ -1,18 +1,20 @@
#!/usr/bin/env python
"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6 MoE).
Spawns one ``h200x4`` job that:
Spawns one ``h200x2`` job that:
1. installs this branch of ``lerobot`` plus the annotation extras,
2. boots four vllm servers (one per H200) with Qwen3.6-35B-A3B-FP8,
3. runs the plan + vqa modules across the dataset in free-form
mode phase 0 (canonical vocabulary discovery) is disabled so
every episode's subtasks + memory are generated independently;
interjections is also disabled, which short-circuits the
plan_update phase that depends on it,
2. boots two vllm servers (one per GPU) with Qwen3.6-35B-A3B-FP8,
3. runs the plan / interjections / vqa modules across the dataset
in free-form mode (phase 0 canonical-vocabulary discovery is
disabled — each episode generates its own subtasks + memory),
4. uploads the annotated dataset to ``--dest_repo_id`` (when set)
or back to ``--repo_id``.
Re-enable phase 0 with ``--vocabulary.enabled=true`` (optionally
``--vocabulary.sample_episodes=N``) when the dataset is homogeneous
enough to share one subtask + memory vocabulary across all episodes.
Usage:
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
@@ -38,84 +40,50 @@ CMD = (
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
"export VLLM_VIDEO_BACKEND=pyav && "
"lerobot-annotate "
"--repo_id=pepijn223/robocasa_smoke_2atomic_v3 "
"--dest_repo_id=pepijn223/robocasa_smoke_2atomic_v3_annotated "
"--repo_id=imstevenpmwork/super_poulain_draft "
"--dest_repo_id=pepijn223/super_poulain_vocab "
"--push_to_hub=true "
"--vlm.backend=openai "
"--vlm.model_id=Qwen/Qwen3.6-35B-A3B-FP8 "
"--vlm.parallel_servers=4 "
"--vlm.num_gpus=4 "
"--vlm.parallel_servers=2 "
"--vlm.num_gpus=2 "
'--vlm.serve_command="vllm serve Qwen/Qwen3.6-35B-A3B-FP8 '
# 4× the context (32768 → 131072) so long episodes at 1 Hz fit even
# at full Qwen vision resolution: 90 frames @ ~700 vision tokens/frame
# ≈ 63 k tokens, comfortably under 131 k. On 1× H200 (144 GB) the
# 35B-FP8 model leaves plenty of room for the bigger KV cache.
"--tensor-parallel-size 1 --max-model-len 131072 "
'--gpu-memory-utilization 0.85 --uvicorn-log-level warning --port {port}" '
"--tensor-parallel-size 1 --max-model-len 32768 "
'--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" '
"--vlm.serve_ready_timeout_s=1800 "
"--vlm.client_concurrency=256 "
"--vlm.client_concurrency=128 "
"--vlm.max_new_tokens=512 "
# Low temperature for VQA: bbox + keypoint are coordinate-regression
# tasks where sampling noise directly degrades localization
# (overlapping boxes, drifted points). 0.2 keeps the model decisive
# while still letting question/label phrasing vary across frames.
"--vlm.temperature=0.2 "
"--executor.episode_parallelism=64 "
"--vlm.temperature=0.7 "
"--executor.episode_parallelism=16 "
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}' "
# Whole-scene agentview is the right choice for subtask reasoning +
# VQA on robocasa: the wrist (``robot0_eye_in_hand``) usually only
# sees the gripper + nearby object, which hurts "what is happening
# in this episode" decomposition. Override per-dataset if your
# cameras are named differently (inspect ``meta/info.json``).
"--vlm.camera_key=observation.images.robot0_agentview_left "
# Phase 0 — canonical vocabulary discovery DISABLED. This dataset's
# episodes span heterogeneous tasks/scenes, so a single shared
# subtask + memory vocabulary would be too narrow — each episode
# generates its subtasks + memory free-form instead.
"--vlm.camera_key=observation.images.wrist "
# Phase 0 — canonical vocabulary discovery DISABLED by default.
# Heterogeneous datasets (different tasks/scenes across episodes)
# don't share a single small subtask + memory vocabulary, so each
# episode generates its subtasks + memory free-form. Flip to
# ``--vocabulary.enabled=true`` (optionally ``--vocabulary.sample_episodes=N``)
# for homogeneous datasets where a shared canonical vocabulary
# helps the downstream policy.
"--vocabulary.enabled=false "
# Phase 1 — plan module (subtasks + plan + memory + task_aug).
"--plan.enabled=true "
"--plan.frames_per_second=1.0 "
"--plan.use_video_url=true "
"--plan.use_video_url_fps=1.0 "
# Force coarse, composite subtasks (``pick up X`` = approach + grasp
# + lift in one span, not three). 3 s is large enough to host a
# full grasp-or-place composite at typical 20 fps robocasa speeds;
# any candidate span shorter than this gets merged into a neighbour
# by the prompt's authoring rules (see module_1_subtasks.txt).
"--plan.min_subtask_seconds=3.0 "
# Cap so the VLM can't drift into micro-segmentation. Combined with
# the composite-action rules in the prompt, this targets ~3-6
# meaningful spans per episode for typical pick-and-place demos.
"--plan.plan_max_steps=9 "
# ``off`` keeps the dataset's canonical ``record.episode_task`` as-is
# — no per-episode VLM "what is this video about" call. Switch to
# ``if_short`` (default) only if some episodes have placeholder /
# missing canonical tasks; ``always`` overrides every episode's task.
"--plan.derive_task_from_video=off "
# 0 disables the task_aug pass entirely (see PlanConfig.n_task_rephrasings
# docstring) — no per-episode paraphrase generation, no task_aug rows.
"--plan.n_task_rephrasings=0 "
# Phase 2 — interjections OFF (also skips phase 3 plan_update,
# see executor.py:_run_plan_update_phase guard).
"--interjections.enabled=false "
# Phase 4 — general VQA. K=1 keeps each VQA answer on its own
# emission frame (no temporal smear); see VqaConfig.K docstring.
# 3 Hz cadence: at 20 fps source, that's a VQA tick every ~7 frames.
# NOTE: VQA emits per-camera, so for robocasa (3 cameras) each tick
# produces 3 (user, assistant) row pairs — total call volume ~= 3 *
# 3 Hz * mean_episode_seconds * n_episodes.
"--vqa.enabled=true "
"--vqa.K=1 "
"--vqa.vqa_emission_hz=3.0"
"--plan.derive_task_from_video=always "
"--plan.n_task_rephrasings=30 "
# Phase 2 — interjections + speech.
"--interjections.max_interjections_per_episode=6 "
# Phase 4 — general VQA.
"--vqa.K=3 "
"--vqa.vqa_emission_hz=1.0"
)
job = run_job(
image="vllm/vllm-openai:latest",
command=["bash", "-c", CMD],
flavor="h200x4",
flavor="h200x2",
secrets={"HF_TOKEN": token},
timeout="24h",
timeout="2h",
)
print(f"Job URL: {job.url}")
print(f"Job ID: {job.id}")

View File

@@ -1,74 +0,0 @@
#!/bin/bash
#SBATCH --job-name=bench-pi052-kernels
#SBATCH --partition=hopper-prod
#SBATCH --qos=high
#SBATCH --time=01:30:00
#SBATCH --ntasks=1
#SBATCH --gpus-per-task=1
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_kernels_%j.out
# HF kernels exploration via Liger's apply_liger_kernel_to_paligemma.
# Baseline (SDPA, no kernels) vs. per-subkernel ablations vs. all-on.
# Same harness as bench_pi052_step.py — only the --kernels flag varies
# across runs so any delta is attributable to the patched op(s).
#
# Subkernels exercised: rope, rms_norm, geglu, layer_norm.
# Skipped: cross_entropy / fused_linear_cross_entropy — pi052 calls
# F.cross_entropy directly and bypasses PaliGemma's forward, so those
# patches wouldn't fire without model-code changes (separate PR).
set -euo pipefail
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
# /fsx triton cache is shared across nodes with different glibc versions
# — kernels built on one node trip GLIBC_2.34-not-found on another. Use
# a node-local cache per job to side-step that.
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
echo "=== Node: $(hostname) ==="
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
ldd --version | head -1
# Liger isn't in our standard env yet — install on the compute node so
# the slurm log captures the exact version that produced the numbers.
python -m pip install -q --upgrade 'liger-kernel'
python - <<'PY' || true
from importlib.metadata import version, PackageNotFoundError
try:
print("liger-kernel", version("liger-kernel"))
except PackageNotFoundError:
print("liger-kernel: not importable")
import liger_kernel.transformers as t
print("apply_liger_kernel_to_paligemma:", hasattr(t, "apply_liger_kernel_to_paligemma"))
PY
run() {
echo
echo "--- $* ---"
python examples/benchmark/bench_pi052_step.py "$@" || true
}
# -- Baseline (no kernels) at the BS we actually train at. --
run --attn sdpa --batch-size 8 --kernels none
run --attn sdpa --batch-size 16 --kernels none
# -- Per-subkernel ablations at BS=16 to isolate each contributor. --
run --attn sdpa --batch-size 16 --kernels rms_norm
run --attn sdpa --batch-size 16 --kernels geglu
run --attn sdpa --batch-size 16 --kernels layer_norm
run --attn sdpa --batch-size 16 --kernels rope
# -- All-on, both BS to compare against the matched baselines above. --
run --attn sdpa --batch-size 8 --kernels all
run --attn sdpa --batch-size 16 --kernels all
# -- Headroom check: does kernels-all let BS=24 fit (baseline OOMs near here)? --
run --attn sdpa --batch-size 24 --kernels none
run --attn sdpa --batch-size 24 --kernels all

View File

@@ -1,338 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Benchmark ``PI052Policy.forward + backward`` on a single GPU.
Compares the new SDPA attention path against the eager baseline by
monkeypatching ``sdpa_attention_forward`` before the first model
forward — so both runs share identical Q/K/V plumbing and only the
attention kernel differs. Reports steps/sec and peak GPU memory.
SLURM-only:
sbatch examples/benchmark/bench_pi052_step.slurm
Or one-off:
srun --partition=hopper-prod --qos=high --gpus=1 --time=15 \\
python examples/benchmark/bench_pi052_step.py --attn sdpa --batch-size 8
"""
from __future__ import annotations
import argparse
import gc
import math
import os
import time
import torch
def _maybe_patch_eager() -> None:
"""Swap ``sdpa_attention_forward`` for the original eager forward.
Must be called BEFORE PI052Policy is instantiated — the layer
compute functions resolve the symbol at call time (module-level
lookup), so this patch covers both pi05 and pi052 KI paths."""
from transformers.models.gemma import modeling_gemma
from lerobot.policies.pi05 import modeling_pi05
modeling_pi05.sdpa_attention_forward = modeling_gemma.eager_attention_forward
_LIGER_SUBKERNELS = ("rope", "rms_norm", "geglu", "layer_norm")
def _maybe_patch_liger(spec: str) -> dict:
"""Globally patch PaliGemma/Gemma/Siglip modules with Liger Triton kernels.
Must be called BEFORE PI052Policy is instantiated — Liger replaces
classes inside ``transformers.models.{gemma,gemma2,siglip,paligemma}``,
so any model built after the call picks up the fused forwards.
``spec`` is a comma-separated subset of {rope, rms_norm, geglu,
layer_norm} (also ``all`` and ``none``). ``cross_entropy`` and
``fused_linear_cross_entropy`` are intentionally skipped — pi052's
losses use ``F.cross_entropy`` directly (not ``nn.CrossEntropyLoss``)
and never traverse ``PaliGemmaForConditionalGeneration.forward``,
so neither patch would fire without invasive model-code changes.
"""
enabled = dict.fromkeys(_LIGER_SUBKERNELS, False)
if spec in ("", "none"):
return enabled
tokens = [t.strip() for t in spec.split(",") if t.strip()]
if tokens == ["all"]:
enabled = dict.fromkeys(_LIGER_SUBKERNELS, True)
else:
for t in tokens:
if t not in enabled:
raise SystemExit(f"Unknown liger subkernel: {t!r}. Choose from {_LIGER_SUBKERNELS} or 'all'.")
enabled[t] = True
from liger_kernel.transformers import apply_liger_kernel_to_paligemma
apply_liger_kernel_to_paligemma(
rope=enabled["rope"],
rms_norm=enabled["rms_norm"],
geglu=enabled["geglu"],
layer_norm=enabled["layer_norm"],
cross_entropy=False,
fused_linear_cross_entropy=False,
)
return enabled
def _maybe_patch_flex() -> None:
"""Swap ``sdpa_attention_forward`` for a FlexAttention-backed forward.
Experimental: builds a per-call ``score_mod`` from the additive
mask and dispatches to a compiled ``flex_attention`` kernel.
Known issue on torch 2.7.1: dynamo errors out with
``FlexAttentionHigherOrderVariable() has no type`` when the
``score_mod`` closure captures a per-call bias tensor. A proper
port needs ``create_block_mask(mask_mod, ...)`` plumbed at the
PI05Pytorch.forward level so a BlockMask object can be passed
down to the layer compute, not a per-call closure. Left as
future work; keep this stub for benchmark experimentation."""
import torch
from torch.nn.attention.flex_attention import flex_attention
from lerobot.policies.pi05 import modeling_pi05
compiled_flex = torch.compile(flex_attention, dynamic=True)
def flex_forward(module, query, key, value, attention_mask, scaling, dropout=0.0):
n_rep = module.num_key_value_groups
if n_rep > 1:
key = key.repeat_interleave(n_rep, dim=1)
value = value.repeat_interleave(n_rep, dim=1)
bias = attention_mask # (B, 1, Lq, Lk) additive
def score_mod(score, b, h, q_idx, kv_idx):
return score + bias[b, 0, q_idx, kv_idx]
attn_output = compiled_flex(query, key, value, score_mod=score_mod, scale=scaling)
return attn_output.transpose(1, 2).contiguous(), None
modeling_pi05.sdpa_attention_forward = flex_forward
def _build_policy(args, device: torch.device):
"""Random-init PI052Policy at production-relevant shapes."""
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.pi052.configuration_pi052 import PI052Config
from lerobot.policies.pi052.modeling_pi052 import PI052Policy
# Production has ``unfreeze_lm_head=True`` + ``text_loss_weight>0``,
# which flips ``train_expert_only=False`` in __post_init__ and
# makes the whole PaliGemma + Gemma-expert stack trainable. We
# mirror that here so the optimizer-state count reflects reality;
# the loss path still goes through ``PI05Policy.forward`` because
# ``text_labels`` / FAST tokens are absent from the synthetic batch
# (see ``PI052Policy.forward`` early-return).
config = PI052Config(
max_action_dim=args.action_dim,
max_state_dim=args.state_dim,
dtype=args.dtype,
knowledge_insulation=args.knowledge_insulation,
text_loss_weight=1e-3 if args.train_full else 0.0,
flow_loss_weight=1.0,
enable_fast_action_loss=False,
unfreeze_lm_head=args.train_full,
tokenizer_max_length=args.lang_tokens,
device="cuda",
compile_model=args.compile_model,
compile_mode=args.compile_mode,
)
config.input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(args.state_dim,)),
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(args.action_dim,)),
}
policy = PI052Policy(config)
policy.to(device)
if args.gradient_checkpointing:
policy.model.gradient_checkpointing_enable()
policy.train()
return policy, config
def _build_batch(args, config, device: torch.device) -> dict:
"""Synthetic batch matching the training-loop input contract."""
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
)
B = args.batch_size
L = args.lang_tokens
return {
OBS_LANGUAGE_TOKENS: torch.randint(0, 250000, (B, L), device=device),
OBS_LANGUAGE_ATTENTION_MASK: torch.ones(B, L, dtype=torch.bool, device=device),
"observation.images.base_0_rgb": torch.rand(B, 3, 224, 224, device=device),
"observation.images.base_0_rgb_padding_mask": torch.ones(B, dtype=torch.bool, device=device),
"observation.state": torch.randn(B, args.state_dim, device=device),
ACTION: torch.randn(B, config.chunk_size, args.action_dim, device=device),
"action_is_pad": torch.zeros(B, config.chunk_size, dtype=torch.bool, device=device),
"task": ["bench task"] * B,
}
def _step(policy, batch, optimizer=None) -> torch.Tensor:
loss, _ = policy.forward(batch)
loss.backward()
if optimizer is not None:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
else:
for p in policy.parameters():
if p.grad is not None:
p.grad = None
return loss.detach()
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--attn", choices=["sdpa", "eager", "flex"], default="sdpa")
parser.add_argument(
"--kernels",
default="none",
help=(
"Liger sub-kernels to enable, comma-separated. Choose from "
f"{_LIGER_SUBKERNELS} or use 'all' / 'none' (default). Applied "
"via apply_liger_kernel_to_paligemma() BEFORE model build."
),
)
parser.add_argument(
"--compile",
dest="compile_model",
action="store_true",
help="Set policy.config.compile_model=True (torch.compile the forward).",
)
parser.add_argument(
"--compile-mode",
default="default",
help="torch.compile mode (default | reduce-overhead | max-autotune).",
)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--warmup", type=int, default=8)
parser.add_argument("--steps", type=int, default=40)
parser.add_argument("--lang-tokens", type=int, default=512)
parser.add_argument("--dtype", choices=["bfloat16", "float32"], default="bfloat16")
parser.add_argument("--action-dim", type=int, default=14)
parser.add_argument("--state-dim", type=int, default=14)
parser.add_argument("--knowledge-insulation", action="store_true", default=True)
parser.add_argument(
"--gradient-checkpointing",
dest="gradient_checkpointing",
action=argparse.BooleanOptionalAction,
default=True,
)
parser.add_argument(
"--optimizer",
choices=["none", "adamw", "adamw_fused"],
default="adamw_fused",
help=(
"Whether to include an AdamW step in the timed iteration. "
"'none' mirrors the fwd+bwd-only original bench; 'adamw' / "
"'adamw_fused' add the realistic ~2x param-bytes optimizer "
"state and ``optimizer.step()`` cost."
),
)
parser.add_argument(
"--train-full",
action=argparse.BooleanOptionalAction,
default=True,
help=(
"Mirror production: unfreeze the PaliGemma backbone (full "
"~3B trainable params) instead of training only the 300M "
"action expert."
),
)
args = parser.parse_args()
if not torch.cuda.is_available():
raise SystemExit("Benchmark requires CUDA; submit via slurm (srun/sbatch).")
if args.attn == "eager":
_maybe_patch_eager()
elif args.attn == "flex":
_maybe_patch_flex()
liger_flags = _maybe_patch_liger(args.kernels)
device = torch.device("cuda")
torch.cuda.reset_peak_memory_stats()
policy, config = _build_policy(args, device)
batch = _build_batch(args, config, device)
optimizer = None
trainable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
if args.optimizer != "none":
trainable = [p for p in policy.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(
trainable, lr=5e-5, fused=(args.optimizer == "adamw_fused")
)
for _ in range(args.warmup):
_step(policy, batch, optimizer)
torch.cuda.synchronize()
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)
starter.record()
for _ in range(args.steps):
_step(policy, batch, optimizer)
ender.record()
torch.cuda.synchronize()
total_ms = starter.elapsed_time(ender)
step_ms = total_ms / args.steps
peak_gb = torch.cuda.max_memory_allocated() / (1024**3)
optim_gb = 0.0
if optimizer is not None:
for st in optimizer.state.values():
for v in st.values():
if torch.is_tensor(v):
optim_gb += v.numel() * v.element_size() / (1024**3)
liger_on = ",".join(k for k, v in liger_flags.items() if v) or "none"
name = (
f"{args.attn:>5} | BS={args.batch_size} | L={args.lang_tokens} | "
f"KI={args.knowledge_insulation} | GC={args.gradient_checkpointing} | "
f"compile={args.compile_model} | liger={liger_on} | opt={args.optimizer} | dtype={args.dtype}"
)
print(
f"{name}\n step_ms={step_ms:.1f} steps/sec={1000.0 / step_ms:.3f} "
f"peak_mem={peak_gb:.2f} GiB optim_state={optim_gb:.2f} GiB "
f"trainable_params={trainable_params / 1e9:.2f}B"
)
del policy, batch
gc.collect()
torch.cuda.empty_cache()
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -1,36 +0,0 @@
#!/bin/bash
#SBATCH --job-name=bench-pi052-attn
#SBATCH --partition=hopper-prod
#SBATCH --qos=high
#SBATCH --time=00:30:00
#SBATCH --ntasks=1
#SBATCH --gpus-per-task=1
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_%j.out
set -euo pipefail
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
echo "=== Node: $(hostname) ==="
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
python -c "import torch; print('torch', torch.__version__, 'cuda', torch.version.cuda)"
run() {
echo
echo "--- $* ---"
python examples/benchmark/bench_pi052_step.py "$@" || true
}
# Attention parity benchmark — same shapes, different attention kernel.
run --attn eager --batch-size 8
run --attn sdpa --batch-size 8
# Headroom benchmark — does SDPA's memory cut allow a bigger micro-batch?
run --attn sdpa --batch-size 12
run --attn sdpa --batch-size 16
run --attn sdpa --batch-size 24

View File

@@ -1,39 +0,0 @@
#!/bin/bash
#SBATCH --job-name=bench-pi052-v2
#SBATCH --partition=hopper-prod
#SBATCH --qos=high
#SBATCH --time=00:45:00
#SBATCH --ntasks=1
#SBATCH --gpus-per-task=1
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v2_%j.out
set -euo pipefail
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
echo "=== Node: $(hostname) ==="
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
run() {
echo
echo "--- $* ---"
python examples/benchmark/bench_pi052_step.py "$@" || true
}
# A: GC ON — see if the selective-AC change (one less recompute level)
# narrows the eager vs SDPA gap at BS=8.
run --attn eager --batch-size 8
run --attn sdpa --batch-size 8
# B: GC OFF — isolate the raw attention-kernel cost & memory delta.
run --attn eager --batch-size 4 --no-gradient-checkpointing
run --attn sdpa --batch-size 4 --no-gradient-checkpointing
# C: SDPA + GC headroom sweep — where does it OOM?
run --attn sdpa --batch-size 16
run --attn sdpa --batch-size 24
run --attn sdpa --batch-size 32

View File

@@ -1,36 +0,0 @@
#!/bin/bash
#SBATCH --job-name=bench-pi052-v3
#SBATCH --partition=hopper-prod
#SBATCH --qos=high
#SBATCH --time=00:45:00
#SBATCH --ntasks=1
#SBATCH --gpus-per-task=1
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v3_%j.out
set -euo pipefail
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
echo "=== Node: $(hostname) ==="
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
run() {
echo
echo "--- $* ---"
python examples/benchmark/bench_pi052_step.py "$@" || true
}
# Compile sweep: does torch.compile + SDPA give a non-trivial boost on
# top of the bare SDPA path?
run --attn sdpa --batch-size 8 --compile
run --attn sdpa --batch-size 16 --compile
# FlexAttention sweep (experimental): score_mod adds the additive bias
# in-kernel; expect a long first-step compile, then SDPA-or-better steady
# state.
run --attn flex --batch-size 8
run --attn flex --batch-size 16

View File

@@ -1,41 +0,0 @@
#!/bin/bash
#SBATCH --job-name=bench-pi052-v4
#SBATCH --partition=hopper-prod
#SBATCH --qos=high
#SBATCH --time=01:00:00
#SBATCH --ntasks=1
#SBATCH --gpus-per-task=1
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v4_%j.out
set -euo pipefail
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
# /fsx triton cache is shared across nodes with different glibc versions
# — kernels built on one node trip GLIBC_2.34-not-found on another. Use
# a node-local cache per job to side-step that.
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
echo "=== Node: $(hostname) ==="
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
ldd --version | head -1
run() {
echo
echo "--- $* ---"
python examples/benchmark/bench_pi052_step.py "$@" || true
}
# compile path on top of SDPA + selective AC
run --attn sdpa --batch-size 8 --compile
run --attn sdpa --batch-size 16 --compile
# FlexAttention experimental
run --attn flex --batch-size 8
run --attn flex --batch-size 16

View File

@@ -1,33 +0,0 @@
#!/bin/bash
#SBATCH --job-name=bench-pi052-v5
#SBATCH --partition=hopper-prod
#SBATCH --qos=high
#SBATCH --time=00:45:00
#SBATCH --ntasks=1
#SBATCH --gpus-per-task=1
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v5_%j.out
set -euo pipefail
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
echo "=== Node: $(hostname) ==="
run() {
echo
echo "--- $* ---"
python examples/benchmark/bench_pi052_step.py "$@" || true
}
# compile_mode=default (graph-only, no autotune) is the right knob with
# gradient checkpointing — max-autotune in v4 was 2x slower than no-compile.
run --attn sdpa --batch-size 8 --compile --compile-mode default
run --attn sdpa --batch-size 16 --compile --compile-mode default
run --attn sdpa --batch-size 8 --compile --compile-mode reduce-overhead

View File

@@ -1,31 +0,0 @@
#!/bin/bash
#SBATCH --job-name=bench-pi052-v6-bs32
#SBATCH --partition=hopper-prod
#SBATCH --qos=high
#SBATCH --time=00:30:00
#SBATCH --ntasks=1
#SBATCH --gpus-per-task=1
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v6_%j.out
set -euo pipefail
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
echo "=== Node: $(hostname) ==="
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
run() {
echo
echo "--- $* ---"
python examples/benchmark/bench_pi052_step.py "$@" || true
}
# BS=32 with the production settings (SDPA + compile=default).
run --attn sdpa --batch-size 32 --compile --compile-mode default

View File

@@ -1,39 +0,0 @@
#!/bin/bash
#SBATCH --job-name=bench-pi052-v7-opt
#SBATCH --partition=hopper-prod
#SBATCH --qos=high
#SBATCH --time=00:45:00
#SBATCH --ntasks=1
#SBATCH --gpus-per-task=1
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v7_%j.out
set -euo pipefail
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
echo "=== Node: $(hostname) ==="
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
run() {
echo
echo "--- $* ---"
python examples/benchmark/bench_pi052_step.py "$@" || true
}
# Realistic full-step memory: fwd + bwd + AdamW step. The original
# sweep was fwd+bwd-only and undercounted memory by the optimizer-
# state size (~2x param bytes for AdamW). This run confirms BS=16
# and BS=32 still fit with the optimizer in residency.
run --attn sdpa --batch-size 16 --compile --compile-mode default --optimizer adamw_fused
run --attn sdpa --batch-size 32 --compile --compile-mode default --optimizer adamw_fused
# Without compile, in case the production cluster has compile issues.
run --attn sdpa --batch-size 16 --optimizer adamw_fused
run --attn sdpa --batch-size 32 --optimizer adamw_fused

View File

@@ -1,36 +0,0 @@
#!/bin/bash
#SBATCH --job-name=bench-pi052-v8-bs40-dtype
#SBATCH --partition=hopper-prod
#SBATCH --qos=high
#SBATCH --time=00:45:00
#SBATCH --ntasks=1
#SBATCH --gpus-per-task=1
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v8_%j.out
set -euo pipefail
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
echo "=== Node: $(hostname) ==="
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
run() {
echo
echo "--- $* ---"
python examples/benchmark/bench_pi052_step.py "$@" || true
}
# Confirm BS=40 fits on a single H100 with the optimizer in residency.
run --attn sdpa --batch-size 40 --compile --compile-mode default --optimizer adamw_fused
# Dtype A/B at modest batch — fp32 needs ~2x the memory of bf16, so we
# drop to BS=4 to keep both runs comparable instead of OOMing fp32.
run --attn sdpa --batch-size 4 --optimizer adamw_fused --dtype bfloat16
run --attn sdpa --batch-size 4 --optimizer adamw_fused --dtype float32

View File

@@ -1,29 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: GemmaDecoderLayer,SiglipEncoderLayer
fsdp_use_orig_params: true
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

File diff suppressed because it is too large Load Diff

View File

@@ -85,11 +85,6 @@ dependencies = [
"termcolor>=2.4.0,<4.0.0",
"tqdm>=4.66.0,<5.0.0",
# Training utilities
# EMA of policy parameters (Diffusion Policy / pi05 style). Tiny
# pure-python dependency — preferred over a hand-rolled implementation.
"ema-pytorch>=0.7.7,<1.0.0",
# Build tools (required by opencv-python-headless on some platforms)
"cmake>=3.29.0.1,<4.2.0",
"setuptools>=71.0.0,<81.0.0",
@@ -147,7 +142,6 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
placo-dep = ["placo>=0.9.6,<0.9.16"]
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
sentencepiece-dep = ["sentencepiece>=0.2.0,<0.3.0"] # FAST action tokenizer backend (pi052, pi0_fast)
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
peft-dep = ["peft>=0.18.0,<1.0.0"]
@@ -203,7 +197,7 @@ wallx = [
"torchdiffeq>=0.2.4,<0.3.0",
"lerobot[qwen-vl-utils-dep]",
]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]", "lerobot[sentencepiece-dep]"]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
groot = [
@@ -237,14 +231,6 @@ annotations = [
"vllm>=0.6.0,<1.0.0; sys_platform == 'linux'",
]
# Tool implementations under src/lerobot/tools/. Each tool's dependencies
# are isolated so adding a new tool doesn't bloat the base install.
# Currently only `say` (Kyutai pocket-tts; CPU-only, ~100M params).
tools = [
"pocket-tts>=1.0.0,<3.0.0",
"scipy>=1.11.0,<2.0.0", # SayTool.output_dir uses scipy.io.wavfile
]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
@@ -337,8 +323,6 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
# Interactive hierarchical-VLA runtime for PI052 (PaliGemma backbone).
lerobot-pi052-runtime="lerobot.scripts.lerobot_pi052_runtime:main"
# ---------------- Tool Configurations ----------------

View File

@@ -1,47 +0,0 @@
#!/bin/bash
# Build a tiny RoboCasa smoke dataset (2 short atomic tasks, all episodes) for
# fast end-to-end training validation before the real run.
#
# Defaults: target/human, OpenStandMixerHead + NavigateKitchen (~1k episodes,
# ~131k frames, ~109 min @ 20 fps), 2 SLURM workers on hopper-cpu.
#
# Override via env: TASKS, REPO_ID, WORK_DIR, WORKERS, CPUS, PARTITION, LOCAL=1.
set -euo pipefail
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
source ~/miniconda3/etc/profile.d/conda.sh
conda activate lerobot
REPO_ID="${REPO_ID:-${HF_USER:?HF_USER is unset}/robocasa_smoke_2atomic_v3}"
WORK_DIR="${WORK_DIR:-/fsx/${USER}/robocasa/datasets/v1.0}"
ROBOCASA_ROOT="${ROBOCASA_ROOT:-/fsx/${USER}/robocasa}"
LOGS_DIR="${LOGS_DIR:-/fsx/${USER}/logs/robocasa}"
TASKS="${TASKS:-OpenStandMixerHead NavigateKitchen}"
WORKERS="${WORKERS:-2}"
CPUS="${CPUS:-8}"
PARTITION="${PARTITION:-hopper-cpu}"
LOCAL="${LOCAL:-0}"
ARGS=(
examples/port_datasets/slurm_build_robocasa_composite_seen.py
--repo-id="$REPO_ID"
--work-dir="$WORK_DIR"
--robocasa-root="$ROBOCASA_ROOT"
--split=target --source=human
--tasks $TASKS
--workers="$WORKERS"
--cpus-per-task="$CPUS"
--partition="$PARTITION"
--mem-per-cpu=4G
--time=04:00:00
--logs-dir="$LOGS_DIR"
--job-name=port_robocasa_smoke
)
if [[ "$LOCAL" == "1" ]]; then
ARGS+=(--slurm=0)
fi
echo "Smoke dataset: $REPO_ID"
echo "Tasks: $TASKS"
python "${ARGS[@]}"

View File

@@ -134,11 +134,8 @@ class Executor:
written = self.writer.write_all(records, staging_dir, root)
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
# Keep meta/info.json aligned with the parquet schema we just wrote
# (language columns advertised; canonical ``say`` tool registered for
# PI052 / Pi0.5 / dataset-visualizer consumers via
# ``LeRobotDatasetMetadata.tools``). Idempotent and additive: existing
# user metadata is preserved.
# Keep meta/info.json aligned with the parquet schema we just wrote.
# Idempotent and additive: existing user metadata is preserved.
self._ensure_annotation_metadata_in_info(root)
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)

View File

@@ -23,7 +23,7 @@ one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
generated against that camera's frame and stamped with the matching
``camera`` field on the emitted rows. The resolver disambiguates via
``camera=...``; recipes that consume VQA do so through one sub-recipe
per camera (see ``recipes/subtasks_vqa.yaml``).
per camera (see ``recipes/pi05_hirobot.yaml``).
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.

View File

@@ -5,40 +5,15 @@ pixel coordinates, keypoints, counts, attributes, and spatial relations.
The frame shows a robot working on: "{episode_task}".
QUALITY BAR — read before answering:
- Only label objects you are highly confident about. If you are not
sure what an object is, do NOT include it. A short, certain answer
beats a long, speculative one.
- For coordinate-grounded answers (bbox, keypoint) only emit a label
when you can localize the object *tightly and precisely*. If the
object is occluded, ambiguous, off-frame, or you can't pin its
extent, return an empty detections list / pick a different object
rather than guessing.
- Prefer task-relevant objects (the thing the robot is manipulating
or interacting with) over background clutter.
Question types and the EXACT answer JSON shape required for each:
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
"bbox": [x1, y1, x2, y2]}}, ...]}}
Pixel coordinates (x_min, y_min, x_max, y_max). Emit
AT MOST 3 detections, and *only* the highest-confidence
ones — 1 tight, certain detection is preferred over 3
loose ones. Each box must be tight (no >10% padding
around the object) and the label must be specific
("red mug" not "object"). Return an empty list if no
object meets the bar.
bbox is in pixel coordinates (x_min, y_min, x_max, y_max).
ECoT example: "a white cup [124, 25, 176, 113]".
keypoint => {{"label": "<point>", "point_format": "xy",
"point": [x, y]}}
Pick ONE high-confidence, precisely-localizable point
(e.g. a graspable handle, a button center, the gripper
tip). The point must land within a few pixels of the
feature. Do not emit a coarse "somewhere on the object"
point — pick a different question type if no such
point exists in this frame.
count => {{"label": "<obj>", "count": <int>,
"note": "<optional short note>"}}

View File

@@ -205,149 +205,3 @@ class WandBLogger:
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
def log_training_examples(
self,
batch: dict,
step: int,
*,
camera_keys: list[str],
n_samples: int = 4,
policy=None,
predict_actions: bool = False,
mode: str = "train",
) -> None:
"""Push a ``wandb.Table`` of training-example rows for the current batch.
Each row is one batch element with:
* one ``wandb.Image`` column per camera in ``camera_keys`` (CHW or
HWC, uint8 or float in [0,1] — auto-detected),
* any text fields present in the batch (``task`` / ``subtask`` /
``memory`` / ``instruction``),
* ground-truth action first/last frame (the action chunk's
endpoints — gives a quick sense of trajectory direction),
* if ``predict_actions=True`` and ``policy`` is supplied, the model's
``predict_action_chunk`` first/last frame alongside.
This is opt-in via ``--wandb.log_examples_freq=N`` on the CLI; the
training loop calls it once every N steps. Cheap to keep on: with
N=4 samples and 3 cameras you upload 12 small PNGs per dump and (if
enabled) run one extra inference forward pass.
"""
import logging # noqa: PLC0415
import numpy as np # noqa: PLC0415
import torch # noqa: PLC0415
if mode not in {"train", "eval"}:
raise ValueError(mode)
# Batch size — first tensor-like value wins.
bsz = next(
(int(v.shape[0]) for v in batch.values() if hasattr(v, "shape") and v.ndim > 0),
None,
)
if not bsz:
return
n = min(int(n_samples), bsz)
# Optional predicted-action forward pass on the first n samples.
pred_actions: np.ndarray | None = None
if predict_actions and policy is not None:
was_training = policy.training
try:
policy.eval()
sub_batch = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
sub_batch[k] = v[:n]
elif isinstance(v, (list, tuple)):
sub_batch[k] = list(v[:n])
else:
sub_batch[k] = v
with torch.no_grad():
pred = policy.predict_action_chunk(sub_batch)
pred_actions = pred.detach().cpu().float().numpy()
except Exception as exc: # noqa: BLE001
logging.warning(
"log_training_examples: predict_action_chunk failed (%s) — "
"skipping predicted-action columns",
exc,
)
pred_actions = None
finally:
if was_training:
policy.train()
present_cameras = [c for c in camera_keys if c in batch]
text_keys = [k for k in ("task", "subtask", "memory", "instruction") if k in batch]
columns = ["sample"]
columns.extend(c.removeprefix("observation.images.") or c for c in present_cameras)
columns.extend(text_keys)
columns.append("gt_action_first")
columns.append("gt_action_last")
if pred_actions is not None:
columns.append("pred_action_first")
columns.append("pred_action_last")
table = self._wandb.Table(columns=columns)
def _to_uint8_hwc(t: torch.Tensor) -> np.ndarray:
# Strip an outer time dim if present: (T, C, H, W) -> first frame.
if t.ndim == 4:
t = t[0]
# CHW -> HWC.
if t.ndim == 3 and t.shape[0] in (1, 3, 4) and t.shape[-1] not in (1, 3, 4):
t = t.permute(1, 2, 0)
arr = t.detach().cpu().float().numpy()
if arr.size and float(arr.max()) <= 1.5:
arr = arr * 255.0
return np.clip(arr, 0, 255).astype(np.uint8)
def _action_endpoints(a: torch.Tensor) -> tuple[str, str]:
arr = a.detach().cpu().float().numpy()
if arr.ndim == 2: # (T, D)
return (
str(np.round(arr[0], 3).tolist()),
str(np.round(arr[-1], 3).tolist()),
)
if arr.ndim == 1:
rounded = np.round(arr, 3).tolist()
return (str(rounded), str(rounded))
return (str(arr.tolist()), str(arr.tolist()))
for i in range(n):
row: list = [i]
for cam in present_cameras:
try:
row.append(self._wandb.Image(_to_uint8_hwc(batch[cam][i])))
except Exception as exc: # noqa: BLE001
logging.warning(
"log_training_examples: camera %s sample %d failed (%s)",
cam,
i,
exc,
)
row.append(None)
for tk in text_keys:
v = batch[tk]
if isinstance(v, (list, tuple)):
row.append(str(v[i]) if i < len(v) else "")
else:
row.append(str(v))
action = batch.get("action")
if isinstance(action, torch.Tensor) and action.ndim >= 1:
first, last = _action_endpoints(action[i])
row.append(first)
row.append(last)
else:
row.append("")
row.append("")
if pred_actions is not None:
p = torch.from_numpy(pred_actions[i])
pfirst, plast = _action_endpoints(p)
row.append(pfirst)
row.append(plast)
table.add_data(*row)
self._wandb.log({f"{mode}/examples": table}, step=step)

View File

@@ -62,72 +62,6 @@ class WandBConfig:
run_id: str | None = None
mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online'
add_tags: bool = True # If True, save configuration as tags in the WandB run.
# Periodic training-example dump (independent of ``log_freq``). When > 0,
# every ``log_examples_freq`` steps the trainer pushes a ``wandb.Table``
# with one row per sampled batch element containing each camera view
# (rendered as ``wandb.Image``), any text fields present in the batch
# (``task`` / ``subtask`` / ``memory`` / ``instruction``), and the
# ground-truth action chunk's first + last frames. Defaults to 5000 — set
# to 0 to disable. Only fires when ``enable=True``, so runs without wandb
# are unaffected.
log_examples_freq: int = 5000
# Number of batch elements to include in each example dump.
log_examples_n: int = 4
# If True (default), also run ``policy.predict_action_chunk`` on the logged
# samples (in eval mode, no_grad) and add predicted vs ground-truth action
# columns to the table. Costs one extra forward pass per dump — negligible
# at the 5k-step default cadence. Set to ``False`` if your policy doesn't
# implement ``predict_action_chunk`` or you want to skip the extra forward.
log_examples_predict_actions: bool = True
@dataclass
class EMAConfig:
"""Exponential Moving Average of trainable policy parameters.
Diffusion / flow-matching policies (Diffusion Policy, π0/π0.5,
pi052) benefit substantially from averaging late-training
parameter oscillations — see Chi et al. 2023 §V.D. The official
JAX openpi trainer ships EMA with ``ema_decay=0.99`` (default) and
``0.999`` for its pi05_libero config; the openpi PyTorch port
explicitly lists EMA as unsupported, and LeRobot main inherited
that gap. Enabling this flag plugs ema-pytorch
(https://github.com/lucidrains/ema-pytorch) into the LeRobot
training loop with a shadow ``nn.Module`` clone of the policy.
Cost: 1× model params in fp32 shadow (~13 GB for pi052's 3.3B
params) + one elementwise update per training step (~1% step time).
On by default — matches openpi (JAX) which ships EMA on for every
config, and closes the gap with the openpi PyTorch port which
explicitly lists EMA as unsupported. Set ``--ema.enable=false`` to
disable for short runs / memory-constrained training where the
extra fp32 shadow copy is the bottleneck.
"""
enable: bool = True
# Target EMA decay β in θ_ema ← β·θ_ema + (1-β)·θ_live (passed to
# ema-pytorch as ``beta``).
# 0.999 — last ~1000 steps; pi05_libero default in openpi
# 0.99 — last ~100 steps; openpi top-level default
# 0.75 — very fast EMA (Diffusion Policy original setting)
# 0.9999 — very slow EMA (long classification runs)
decay: float = 0.999
# Skip the first N calls to ``ema.update()``; during this window
# the shadow is just a hard copy of the live weights (no averaging).
# Lets early-training rapid changes settle before averaging begins.
# Maps to ema-pytorch's ``update_after_step`` (NOT a smooth decay
# ramp like older lerobot EMA implementations).
warmup_steps: int = 0
# When True, the periodic eval block uses the EMA shadow model
# directly (``ema.ema_model``) instead of the live policy. Standard
# practice for diffusion-style policies — eval scores are usually
# 13% higher than the live policy at the same step.
use_for_eval: bool = True
# When True, the periodic wandb training-example dump uses the EMA
# shadow for the optional predicted-action columns (so what you see
# in W&B matches eval behavior).
use_for_wandb_examples: bool = True
@dataclass

View File

@@ -147,16 +147,7 @@ class TrainingRecipe:
return cls.from_dict(data)
def _validate_message_recipe(self) -> None:
"""Ensure every templated binding is known and the recipe supervises something.
A recipe is valid if it has at least one of:
* a ``target: true`` assistant turn (drives text-CE supervision), or
* a ``stream: low_level`` turn (drives flow / action supervision via
``predict_actions=True``, even when no assistant turn is targeted —
e.g. π0.5-style ``low_level_execution`` where the action expert
conditions on a user-only ``${subtask}`` prompt).
"""
"""Ensure every templated binding is known and at least one turn is a target."""
assert self.messages is not None
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
@@ -165,14 +156,8 @@ class TrainingRecipe:
if missing:
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
has_target = any(turn.target for turn in self.messages)
has_low_level = any(turn.stream == "low_level" for turn in self.messages)
if not (has_target or has_low_level):
raise ValueError(
"Message recipes must contain at least one supervised turn — "
"either ``target: true`` (text CE) or ``stream: low_level`` "
"(flow/action loss)."
)
if not any(turn.target for turn in self.messages):
raise ValueError("Message recipes must contain at least one target turn.")
def _validate_blend_recipe(self) -> None:
"""Ensure each blend component is a non-empty, weighted message recipe."""

View File

@@ -1,68 +0,0 @@
# subtask_mem_vqa_speech — Hi-Robot blend + memory + spoken responses.
#
# Superset of subtasks_vqa.yaml. Keeps the core subtask + action + VQA
# training, and adds two text-supervised tasks:
#
# high_level_subtask — predict the subtask from the task.
# low_level_execution — flow loss with [images, subtask, state].
# memory_update — compress progress into a memory note.
# user_interjection_response — reply to a user interjection with a
# spoken `say` tool call (no plan, no
# subtask text — just the spoken reply).
# ask_vqa_{top,wrist} — camera-grounded VQA.
#
# Plan is intentionally left out — memory is the only persistent
# high-level state here, keeping the prompt short.
#
# Requires the dataset to carry `memory`, `interjection` and `say`-tool
# annotations (the annotation pipeline's memory + interjection modules)
# in addition to `subtask` and `vqa`. Sub-recipes whose `if_present`
# bindings are missing simply don't render for that sample, so a
# dataset without interjections still trains the rest of the blend.
#
# Tool-call note: the `say` tool call on the interjection-response turn
# is flattened to a `<say>...</say>` text marker by the tokenizer step
# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the
# marker the runtime parses back (`_split_plan_and_say`).
blend:
high_level_subtask:
weight: 0.30
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
low_level_execution:
weight: 0.55
messages:
# The action expert is conditioned on the SUBTASK — at inference
# `HighLevelSubtaskFwd` generates it via the LM head and feeds it
# here. `stream: low_level` flips `predict_actions=True` so the
# flow loss fires; no text-CE target (subtask prediction is owned
# by `high_level_subtask`).
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
memory_update:
# At inference, `MemoryUpdateFwd` is triggered only on
# `subtask_change` events (sparse). Training densely with
# `active_at` — i.e. on every frame inside a subtask interval,
# not just the boundary frame — supervises the same
# (prior_memory, completed_subtask) → current_memory mapping
# against varied observations within the interval. The model
# learns a stateless transformation; the *when* to emit lives in
# the inference trigger, not the model. Annotations only exist
# for ~1% of frames as boundary events, so `emitted_at` would
# waste 99% of the blend draws (and silently leak them into a
# task-conditioned fallback); `active_at` lifts the renderable
# rate to ~87% on this dataset.
weight: 0.15
bindings:
prior_memory: "nth_prev(style=memory, offset=1)"
current_memory: "active_at(t, style=memory)"
completed_subtask: "nth_prev(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}

View File

@@ -1,99 +0,0 @@
# subtask_mem_vqa_robocasa — Hi-Robot blend tuned for RoboCasa cameras.
#
# Same supervision as ``subtask_mem.yaml`` (subtask + memory) plus
# camera-grounded VQA across the three RoboCasa camera keys produced
# by ``slurm_build_robocasa_composite_seen.py``:
#
# observation.images.robot0_agentview_left (left scene view)
# observation.images.robot0_agentview_right (right scene view)
# observation.images.robot0_eye_in_hand (wrist)
#
# The annotation pipeline (``examples/annotations/run_hf_job.py``) emits
# VQA per camera, so each anchor frame produces three (user, assistant)
# rows tagged with their source camera. Each VQA sub-recipe consumes
# the rows for one camera via ``camera=...`` resolver bindings.
#
# Spatial VQA targets (bbox / point) are rewritten from JSON to
# PaliGemma ``<locDDDD>`` tokens by ``_messages_vqa_to_loc`` —
# ``register_paligemma_loc_tokens`` already collapses them to single
# detection-vocab ids so the LM head learns the pretrained pointing /
# detection prior, not a 7-piece BPE salad.
#
# Interjections / spoken responses are intentionally absent — the
# annotation job runs with ``--interjections.enabled=false``.
blend:
high_level_subtask:
weight: 0.25
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
low_level_execution:
weight: 0.45
messages:
# Action expert is conditioned on the SUBTASK; at inference the
# high-level loop generates it via the LM head and feeds it here.
# ``stream: low_level`` flips ``predict_actions=True`` so the flow
# loss fires; subtask CE is owned by ``high_level_subtask``.
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
memory_update:
# Trained densely with ``active_at`` — every frame inside a subtask
# interval — so the (prior_memory, completed_subtask) → current_memory
# mapping is supervised against varied observations. The *when* to
# emit lives in the inference trigger (subtask_change), not the
# model. See ``subtask_mem.yaml`` for the long version of this note.
weight: 0.15
bindings:
prior_memory: "nth_prev(style=memory, offset=1)"
current_memory: "active_at(t, style=memory)"
completed_subtask: "nth_prev(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
ask_vqa_agentview_left:
weight: 0.05
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_left)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_left)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.robot0_agentview_left}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_agentview_right:
weight: 0.05
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_right)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_right)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.robot0_agentview_right}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_wrist:
weight: 0.05
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_eye_in_hand)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_eye_in_hand)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.robot0_eye_in_hand}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}

View File

@@ -1,114 +0,0 @@
# subtask_mem_vqa_speech — Hi-Robot blend + memory + spoken responses.
#
# Superset of subtasks_vqa.yaml. Keeps the core subtask + action + VQA
# training, and adds two text-supervised tasks:
#
# high_level_subtask — predict the subtask from the task.
# low_level_execution — flow loss with [images, subtask, state].
# memory_update — compress progress into a memory note.
# user_interjection_response — reply to a user interjection with a
# spoken `say` tool call (no plan, no
# subtask text — just the spoken reply).
# ask_vqa_{top,wrist} — camera-grounded VQA.
#
# Plan is intentionally left out — memory is the only persistent
# high-level state here, keeping the prompt short.
#
# Requires the dataset to carry `memory`, `interjection` and `say`-tool
# annotations (the annotation pipeline's memory + interjection modules)
# in addition to `subtask` and `vqa`. Sub-recipes whose `if_present`
# bindings are missing simply don't render for that sample, so a
# dataset without interjections still trains the rest of the blend.
#
# Tool-call note: the `say` tool call on the interjection-response turn
# is flattened to a `<say>...</say>` text marker by the tokenizer step
# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the
# marker the runtime parses back (`_split_plan_and_say`).
blend:
high_level_subtask:
weight: 0.25
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
low_level_execution:
weight: 0.40
messages:
# The action expert is conditioned on the SUBTASK — at inference
# `HighLevelSubtaskFwd` generates it via the LM head and feeds it
# here. `stream: low_level` flips `predict_actions=True` so the
# flow loss fires; no text-CE target (subtask prediction is owned
# by `high_level_subtask`).
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
memory_update:
# At inference, `MemoryUpdateFwd` is triggered only on
# `subtask_change` events (sparse). Training densely with
# `active_at` — i.e. on every frame inside a subtask interval,
# not just the boundary frame — supervises the same
# (prior_memory, completed_subtask) → current_memory mapping
# against varied observations within the interval. The model
# learns a stateless transformation; the *when* to emit lives in
# the inference trigger, not the model. Annotations only exist
# for ~1% of frames as boundary events, so `emitted_at` would
# waste 99% of the blend draws (and silently leak them into the
# task-conditioned fallback); `active_at` lifts the renderable
# rate to ~87% on Hi-Robot-style datasets.
weight: 0.10
bindings:
prior_memory: "nth_prev(style=memory, offset=1)"
current_memory: "active_at(t, style=memory)"
completed_subtask: "nth_prev(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
user_interjection_response:
weight: 0.10
bindings:
interjection: "emitted_at(t, style=interjection)"
speech: "emitted_at(t, role=assistant, tool_name=say)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
# Spoken reply only: the assistant turn carries no text content,
# just a `say` tool call (`tool_calls_from: speech`). The chat
# tokenizer flattens it to a `<say>...</say>` marker, so the
# supervised target trains the model to respond to an
# interjection with a spoken acknowledgement.
- {role: assistant, stream: high_level, target: true, if_present: speech, tool_calls_from: speech}
# VQA is view-dependent — each camera gets its own sub-recipe so the
# resolver disambiguates via `camera=...`. Camera keys match
# subtasks_vqa.yaml (`front` + `wrist`); adjust to your dataset.
ask_vqa_top:
weight: 0.075
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.front}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_wrist:
weight: 0.075
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.wrist}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}

View File

@@ -1,61 +0,0 @@
# subtasks_vqa — Hi-Robot blend for PI052 (PaliGemma backbone).
#
# Trains two things only: subtasks and VQA. Plan and memory are
# intentionally left out — keeps the prompt short and the training
# surface small. The fuller blend with memory + spoken replies is
# ``subtask_mem_vqa_speech.yaml``.
#
# high_level_subtask — predict the subtask from the task.
# low_level_execution — flow loss with [images, subtask, state].
# ask_vqa_{top,wrist} — camera-grounded VQA.
#
# PI052's text tokenizer renders these messages as plain
# ``Role: content`` text (PaliGemma is not chat-pretrained).
blend:
high_level_subtask:
weight: 0.40
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
low_level_execution:
weight: 0.40
messages:
# The action expert is conditioned on the SUBTASK — at inference
# the high-level loop (``HighLevelSubtaskFwd``) generates the
# subtask via the LM head and feeds it here. The action expert's
# prefix is [images, subtask, state]. ``stream: low_level`` flips
# ``predict_actions=True`` so the flow loss fires; no text-CE
# target here (subtask prediction is owned by
# ``high_level_subtask``).
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
ask_vqa_top:
weight: 0.10
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.front}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_wrist:
weight: 0.10
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.wrist}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}

View File

@@ -30,7 +30,7 @@ from lerobot.utils.hub import HubMixin
from lerobot.utils.sample_weighting import SampleWeightingConfig
from . import parser
from .default import DatasetConfig, EMAConfig, EvalConfig, PeftConfig, WandBConfig
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig
from .rewards import RewardModelConfig
@@ -111,20 +111,9 @@ class TrainPipelineConfig(HubMixin):
scheduler: LRSchedulerConfig | None = None
eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig)
ema: EMAConfig = field(default_factory=EMAConfig)
peft: PeftConfig | None = None
# VQA oversampling. When set (a fraction in (0, 1)), the training
# dataloader uses a WeightedEpisodeAwareSampler that draws frames
# carrying a `vqa` language annotation often enough that they make
# up roughly this fraction of the training stream. VQA annotations
# are typically sparse, so without this they are underrepresented.
# `None` (default) keeps uniform episode-aware sampling.
vqa_target_fraction: float | None = None
# Sample weighting configuration (e.g., for RA-BC training). Old
# inline ``use_rabc`` / ``rabc_*`` params are migrated to this
# field by ``_migrate_legacy_rabc_keys`` above.
# Sample weighting configuration (e.g., for RA-BC training)
sample_weighting: SampleWeightingConfig | None = None
# Rename map for the observation to override the image and state keys

View File

@@ -35,6 +35,7 @@ from .dataset_tools import (
remove_feature,
split_dataset,
)
from .factory import make_dataset, resolve_delta_timestamps
from .image_writer import safe_stop_image_writer
from .io_utils import load_episodes, write_stats
from .language import (
@@ -49,24 +50,11 @@ from .lerobot_dataset import LeRobotDataset
from .multi_dataset import MultiLeRobotDataset
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
from .sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
from .sampler import EpisodeAwareSampler
from .streaming_dataset import StreamingLeRobotDataset
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
from .video_utils import VideoEncodingManager
def make_dataset(*args, **kwargs):
from .factory import make_dataset as _make_dataset
return _make_dataset(*args, **kwargs)
def resolve_delta_timestamps(*args, **kwargs):
from .factory import resolve_delta_timestamps as _resolve_delta_timestamps
return _resolve_delta_timestamps(*args, **kwargs)
# NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.)
# and legacy migration constants are intentionally NOT re-exported here.
# Import directly: ``from lerobot.datasets.io_utils import ...``
@@ -77,7 +65,6 @@ __all__ = [
"DEFAULT_QUANTILES",
"EVENT_ONLY_STYLES",
"EpisodeAwareSampler",
"WeightedEpisodeAwareSampler",
"LANGUAGE_EVENTS",
"LANGUAGE_PERSISTENT",
"LeRobotDataset",

View File

@@ -126,53 +126,10 @@ class DatasetReader:
def _load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
features = get_hf_features_from_features(self._meta.features)
# Datasets annotated with the PR1 language columns may have been
# written without registering those columns in ``meta/info.json``
# (e.g. they predate ``CODEBASE_VERSION="v3.1"`` and were
# back-filled by ``lerobot-annotate``). Probe a single parquet
# shard and graft the column features on so the strict
# ``Dataset.from_parquet`` cast doesn't fail with
# ``column names don't match``.
features = self._extend_features_with_language_columns(features)
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def _extend_features_with_language_columns(
self, features: datasets.Features
) -> datasets.Features:
"""Add ``language_persistent`` / ``language_events`` to ``features``
when the underlying parquet shards declare them but the metadata
doesn't. No-op when neither column is present or both are
already registered.
"""
# Find any one parquet to peek at; bail if there are none yet
# (the dataset will fail later for an unrelated reason and we
# want that error to surface as-is).
try:
sample = next((self.root / "data").glob("*/*.parquet"))
except StopIteration:
return features
from pyarrow import parquet as _pq # noqa: PLC0415
schema_names = set(_pq.read_schema(sample).names)
from .language import ( # noqa: PLC0415
LANGUAGE_EVENTS,
LANGUAGE_PERSISTENT,
language_events_column_feature,
language_persistent_column_feature,
)
extra: dict[str, object] = {}
if LANGUAGE_PERSISTENT in schema_names and LANGUAGE_PERSISTENT not in features:
extra[LANGUAGE_PERSISTENT] = language_persistent_column_feature()
if LANGUAGE_EVENTS in schema_names and LANGUAGE_EVENTS not in features:
extra[LANGUAGE_EVENTS] = language_events_column_feature()
if not extra:
return features
return datasets.Features({**features, **extra})
def _check_cached_episodes_sufficient(self) -> bool:
"""Check if the cached dataset contains all requested episodes and their video files."""
if self.hf_dataset is None or len(self.hf_dataset) == 0:

View File

@@ -70,22 +70,8 @@ def _json_arrow_type() -> pa.DataType:
def _json_feature() -> object:
"""Return the HF feature used for tool-call payloads.
Older ``datasets`` versions do not expose ``datasets.Json``. The
annotation pipeline currently emits the canonical ``say`` tool call
shape, so use that explicit struct instead of falling back to a string
that cannot cast structured parquet values.
"""
if hasattr(datasets, "Json"):
return datasets.Json()
return {
"type": datasets.Value("string"),
"function": {
"name": datasets.Value("string"),
"arguments": {"text": datasets.Value("string")},
},
}
"""Return the HF ``datasets`` JSON feature, falling back to a string value."""
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
def language_persistent_row_arrow_type() -> pa.StructType:

View File

@@ -170,29 +170,6 @@ def render_sample(
"""
persistent_rows = _normalize_rows(persistent or [])
event_rows = _normalize_rows(events or [])
# VQA-priority routing. A ``vqa`` annotation is sparse and
# view-dependent; the plain weighted blend would (a) waste a draw
# whenever it picks an ``ask_vqa*`` sub-recipe for a frame that has
# no VQA, and (b) silently drop a VQA-annotated frame whenever it
# picks a non-VQA sub-recipe. So: if the blend has ``ask_vqa*``
# sub-recipes and *this* frame carries one of their VQA bindings,
# render VQA here regardless of the weighted draw. That makes VQA's
# recipe-side training share equal the VQA-annotation density (the
# maximum reachable without a dataset-level oversampling sampler).
if recipe.blend is not None:
vqa_rendered = _render_vqa_if_present(
recipe,
persistent=persistent_rows,
events=event_rows,
t=t,
sample_idx=sample_idx,
task=task,
dataset_ctx=dataset_ctx,
)
if vqa_rendered is not None:
return vqa_rendered
selected_recipe = _select_recipe(recipe, sample_idx)
bindings = _resolve_bindings(
selected_recipe,
@@ -206,59 +183,6 @@ def render_sample(
return _render_message_recipe(selected_recipe, bindings)
def _render_vqa_if_present(
recipe: TrainingRecipe,
*,
persistent: Sequence[LanguageRow],
events: Sequence[LanguageRow],
t: float,
sample_idx: int,
task: str | None,
dataset_ctx: Any | None,
) -> RenderedMessages | None:
"""Render an ``ask_vqa*`` sub-recipe iff this frame carries a VQA
annotation; otherwise return ``None`` so the caller falls back to the
normal weighted blend.
When several VQA sub-recipes resolve (e.g. a frame annotated for more
than one camera), one is chosen deterministically by relative weight.
"""
assert recipe.blend is not None
renderable: list[tuple[float, RenderedMessages]] = []
for name, component in recipe.blend.items():
if not name.startswith("ask_vqa"):
continue
bindings = _resolve_bindings(
component,
persistent=persistent,
events=events,
t=t,
sample_idx=sample_idx,
task=task,
dataset_ctx=dataset_ctx,
)
rendered = _render_message_recipe(component, bindings)
if rendered is not None:
renderable.append((float(component.weight or 0.0), rendered))
if not renderable:
return None
if len(renderable) == 1:
return renderable[0][1]
# Multiple cameras have a VQA for this frame — deterministic pick by
# relative weight (fall back to a uniform draw if all weights are 0).
total = sum(w for w, _ in renderable) or float(len(renderable))
digest = hashlib.blake2b(f"vqa:{sample_idx}".encode(), digest_size=8).digest()
draw = int.from_bytes(digest, "big") / 2**64 * total
cumulative = 0.0
for w, rendered in renderable:
cumulative += w or (total / len(renderable))
if draw < cumulative:
return rendered
return renderable[-1][1]
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
if recipe.blend is None:
@@ -422,15 +346,7 @@ def _render_message_recipe(
if turn.target:
target_indices.append(message_idx)
# A render is meaningful if it supervises *something*: either a
# text-CE target turn, or a ``low_level`` stream turn (flow / action
# supervision — e.g. the flow-only ``low_level_execution`` recipe,
# ``user(${subtask})`` with ``stream: low_level`` and no target).
# Without this, a flow-only recipe renders to ``None`` every time
# the blend draws it → ``predict_actions`` is never True → the
# action expert never receives a flow loss.
has_low_level = any(stream == "low_level" for stream in streams)
if not target_indices and not has_low_level:
if not target_indices:
return None
rendered = {
@@ -487,10 +403,8 @@ def _validate_rendered(rendered: RenderedMessages) -> None:
if len(streams) != len(messages):
raise ValueError("message_streams must be aligned with messages.")
# Valid iff it supervises something: a text-CE target turn OR a
# ``low_level`` stream turn (flow / action supervision).
if not target_indices and not any(s == "low_level" for s in streams):
raise ValueError("Rendered samples must contain a target message or a low_level-stream message.")
if not target_indices:
raise ValueError("Rendered samples must contain at least one target message.")
for idx in target_indices:
if idx < 0 or idx >= len(messages):
raise ValueError(f"Target message index {idx} is out of bounds.")

View File

@@ -84,66 +84,3 @@ class EpisodeAwareSampler:
def __len__(self) -> int:
return len(self.indices)
class WeightedEpisodeAwareSampler(EpisodeAwareSampler):
"""``EpisodeAwareSampler`` that draws frames *with replacement* in
proportion to per-frame weights.
Used to oversample frames carrying a sparse annotation (e.g. a VQA
question) so the policy sees them more often than their natural
dataset density. One epoch still yields ``len(self.indices)``
samples — the weights only change the *composition* of the stream,
not its length. Each epoch re-draws, so the oversampled subset
varies run to run.
"""
def __init__(
self,
dataset_from_indices: list[int],
dataset_to_indices: list[int],
frame_weights,
*,
episode_indices_to_use: list | None = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
):
"""
Args:
dataset_from_indices: Episode start indices (see ``EpisodeAwareSampler``).
dataset_to_indices: Episode end indices.
frame_weights: 1-D sequence/tensor of non-negative weights, one per
dataset frame (length == total dataset frames). Higher weight ⇒
that frame is sampled more often.
episode_indices_to_use / drop_n_first_frames / drop_n_last_frames:
Same meaning as ``EpisodeAwareSampler`` — the episode-boundary
frame filtering is applied first, then weighting is restricted
to the surviving frames.
"""
super().__init__(
dataset_from_indices,
dataset_to_indices,
episode_indices_to_use=episode_indices_to_use,
drop_n_first_frames=drop_n_first_frames,
drop_n_last_frames=drop_n_last_frames,
shuffle=False,
)
weights = torch.as_tensor(frame_weights, dtype=torch.double).flatten()
idx = torch.tensor(self.indices, dtype=torch.long)
if weights.numel() <= int(idx.max()):
raise ValueError(
f"frame_weights has {weights.numel()} entries but the sampler "
f"references frame index {int(idx.max())}."
)
selected = weights[idx]
if not torch.isfinite(selected).all() or bool((selected < 0).any()):
raise ValueError("frame_weights must be finite and non-negative.")
if float(selected.sum()) <= 0.0:
# All surviving frames have zero weight — fall back to uniform.
selected = torch.ones_like(selected)
self._weights = selected
def __iter__(self) -> Iterator[int]:
picks = torch.multinomial(self._weights, num_samples=len(self.indices), replacement=True)
for i in picks.tolist():
yield self.indices[i]

View File

@@ -366,24 +366,17 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
hub_versions = get_repo_versions(repo_id)
if not hub_versions:
msg = (
f"Repo {repo_id!r} has no codebase-version tags. The dataset "
f"either doesn't exist on the Hub yet, or it was uploaded "
f"without a ``v3.x``-style tag. To tag an existing dataset run:\n"
f" from huggingface_hub import HfApi\n"
f" HfApi().create_tag({repo_id!r}, tag='v3.0', repo_type='dataset', exist_ok=True)"
raise RevisionNotFoundError(
f"""Your dataset must be tagged with a codebase version.
Assuming _version_ is the codebase_version value in the info.json, you can run this:
```python
from huggingface_hub import HfApi
hub_api = HfApi()
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
```
"""
)
# ``RevisionNotFoundError`` extends ``HfHubHTTPError`` whose
# ``__init__`` indexes ``response.headers`` unconditionally on
# current ``huggingface_hub`` versions. Constructing it without
# a real ``Response`` object crashes with either
# ``TypeError: missing 1 required keyword-only argument`` (old
# builds) or ``AttributeError: 'NoneType' object has no attribute
# 'headers'`` (new builds). Skip that path entirely — this isn't
# really an HTTP error, it's a configuration issue — and raise a
# plain ``RuntimeError`` so the message actually reaches the
# caller.
raise RuntimeError(msg)
if target_version in hub_versions:
return f"v{target_version}"

View File

@@ -104,8 +104,6 @@ class AdamWConfig(OptimizerConfig):
eps: float = 1e-8
weight_decay: float = 1e-2
grad_clip_norm: float = 10.0
foreach: bool | None = None
fused: bool | None = None
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
kwargs = asdict(self)

View File

@@ -24,7 +24,6 @@ from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as M
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .pi052.configuration_pi052 import PI052Config as PI052Config
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
@@ -48,7 +47,6 @@ __all__ = [
"PI0Config",
"PI0FastConfig",
"PI05Config",
"PI052Config",
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",

View File

@@ -61,79 +61,6 @@ from .wall_x.configuration_wall_x import WallXConfig
from .xvla.configuration_xvla import XVLAConfig
def _restore_pi052_pretrained_state(
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
pretrained_path: str,
) -> None:
"""Transplant saved stateful blobs from a pi052 checkpoint into fresh pipelines.
pi052's preprocessor includes steps whose constructor args don't
JSON-roundtrip (``RenderMessagesStep.recipe`` is a Python object,
``ActionTokenizerProcessorStep.action_tokenizer_name`` is a
fitted-tokenizer path that may not exist at eval time). We rebuild
those pipelines fresh from ``config.recipe_path`` and then walk
over the saved ``policy_{pre,post}processor.json`` files to find
each step's ``state_file`` reference and load the bytes back into
the corresponding fresh step. Today that's only the
NormalizerProcessorStep / UnnormalizerProcessorStep (the action /
state quantile stats), but the loop is generic so any future
stateful step picks up its blob automatically.
Pairing is by ``registry_name`` AND position so a benign reorder
on the saved side surfaces a warning rather than silently feeding
the wrong tensors into the wrong step.
"""
import json # noqa: PLC0415
import logging # noqa: PLC0415
from pathlib import Path # noqa: PLC0415
from safetensors.torch import load_file # noqa: PLC0415
base = Path(pretrained_path)
if not base.exists():
return
log = logging.getLogger(__name__)
for pipeline, config_filename in [
(preprocessor, f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"),
(postprocessor, f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"),
]:
config_path = base / config_filename
if not config_path.exists():
continue
saved = json.loads(config_path.read_text())
for idx, (saved_step, fresh_step) in enumerate(
zip(saved.get("steps", []), pipeline.steps, strict=False)
):
state_file = saved_step.get("state_file")
if not state_file:
continue
saved_name = saved_step.get("registry_name")
fresh_name = getattr(type(fresh_step), "_registry_name", None)
if saved_name and fresh_name and saved_name != fresh_name:
log.warning(
"PI052 state restore: %s step %d registry name mismatch "
"(saved=%s, fresh=%s); skipping %s",
config_filename, idx, saved_name, fresh_name, state_file,
)
continue
state_path = base / state_file
if not state_path.exists():
log.warning(
"PI052 state restore: %s missing at %s; %s left at fresh init",
state_file, base, fresh_name,
)
continue
fresh_step.load_state_dict(load_file(str(state_path)))
log.info(
"PI052 state restore: loaded %s into %s (step %d)",
state_file, fresh_name, idx,
)
def _reconnect_relative_absolute_steps(
preprocessor: PolicyProcessorPipeline, postprocessor: PolicyProcessorPipeline
) -> None:
@@ -200,10 +127,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .pi05.modeling_pi05 import PI05Policy
return PI05Policy
elif name == "pi052":
from .pi052.modeling_pi052 import PI052Policy
return PI052Policy
elif name == "gaussian_actor":
from .gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
@@ -244,8 +167,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05",
"pi052", "gaussian_actor", "smolvla", "wall_x".
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
"smolvla", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -268,10 +191,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0Config(**kwargs)
elif policy_type == "pi05":
return PI05Config(**kwargs)
elif policy_type == "pi052":
from .pi052.configuration_pi052 import PI052Config
return PI052Config(**kwargs)
elif policy_type == "gaussian_actor":
return GaussianActorConfig(**kwargs)
elif policy_type == "smolvla":
@@ -312,12 +231,6 @@ class ProcessorConfigKwargs(TypedDict, total=False):
preprocessor_overrides: dict[str, Any] | None
postprocessor_overrides: dict[str, Any] | None
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
# Optional: HF Hub repo id of the dataset the policy is being
# trained on. Used by policies that auto-fit pieces of their
# preprocessing (e.g. pi052's FAST action tokenizer per
# Pertsch et al. 2025 [64], π0.5 §III.C). When omitted, those
# policies fall back to their universal pre-fitted tokenizers.
dataset_repo_id: str | None
def make_pre_post_processors(
@@ -350,29 +263,6 @@ def make_pre_post_processors(
NotImplementedError: If a processor factory is not implemented for the given
policy configuration type.
"""
if pretrained_path and getattr(policy_cfg, "type", None) == "pi052":
# pi052 pipelines don't roundtrip through the saved
# ``policy_preprocessor.json``: ``RenderMessagesStep`` holds a
# Python ``TrainingRecipe`` (not JSON-serializable; saved as
# ``{}``) and ``ActionTokenizerProcessorStep`` saves a host-only
# FAST tokenizer path. Generic ``from_pretrained`` then dies
# with ``RenderMessagesStep.__init__() missing 1 required
# positional argument: 'recipe'`` (job 22164494).
#
# Mirror ``lerobot_pi052_runtime``'s bootstrap: build pipelines
# fresh from ``config.recipe_path`` and transplant the saved
# stateful blobs (normalizer stats) from the checkpoint dir.
from .pi052.processor_pi052 import make_pi052_pre_post_processors
preprocessor, postprocessor = make_pi052_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_repo_id=kwargs.get("dataset_repo_id"),
)
_restore_pi052_pretrained_state(preprocessor, postprocessor, pretrained_path)
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
return preprocessor, postprocessor
if pretrained_path:
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
if isinstance(policy_cfg, GrootConfig):
@@ -467,22 +357,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif policy_cfg.type == "pi052":
# NOTE: PI052Config subclasses PI05Config, so this branch MUST
# come before the PI05Config isinstance check below (otherwise
# pi052 would silently pick up π0.5's processor).
from .pi052.processor_pi052 import make_pi052_pre_post_processors
processors = make_pi052_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
# ``dataset_repo_id`` flows in via kwargs when FAST CE is
# enabled — the train loop sets it from ``--dataset.repo_id``.
# When ``None``, ``make_pi052_pre_post_processors`` skips
# the auto-fit and uses the universal tokenizer.
dataset_repo_id=kwargs.get("dataset_repo_id"),
)
elif isinstance(policy_cfg, PI05Config):
from .pi05.processor_pi05 import make_pi05_pre_post_processors

View File

@@ -178,6 +178,7 @@ N_COLOR_CHANNELS = 3
# config
@strict
class GR00TN15Config(PretrainedConfig):
model_type = "gr00t_n1_5"

View File

@@ -93,21 +93,6 @@ class PI05Config(PreTrainedConfig):
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.01
optimizer_grad_clip_norm: float = 1.0
optimizer_foreach: bool | None = False
optimizer_fused: bool | None = True
# LM-head LR multiplier. The PaliGemma `lm_head` projection (and its
# tied `embed_tokens`) is the surface the LM head's first-token
# distribution depends on. With ``knowledge_insulation`` blocking
# action→VLM gradients, the LM head only sees gradients on text-CE
# samples — which can be a small fraction of the mix (e.g. ~45% in
# ``subtask_mem.yaml``). Under aggressive cosine LR decay the head's
# first-token distribution can drift back toward PaliGemma's
# pretrained ``<loc>`` detection prior, despite teacher-forced CE
# staying near zero. Boosting just the LM-head LR (e.g. 5x) keeps
# the head pinned to fine-tuning targets without perturbing the
# backbone / vision tower / action expert. Default 1.0 = no change.
lm_head_lr_scale: float = 1.0
# Scheduler settings: see openpi `CosineDecaySchedule`
# Note: These will auto-scale if --steps < scheduler_decay_steps
@@ -167,8 +152,6 @@ class PI05Config(PreTrainedConfig):
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
foreach=self.optimizer_foreach,
fused=self.optimizer_fused,
)
def get_scheduler_preset(self):

View File

@@ -15,7 +15,6 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers.cache_utils import DynamicCache
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
)
else:
CONFIG_MAPPING = None
DynamicCache = None
modeling_gemma = None
PiGemmaForCausalLM = None
_gated_residual = None
@@ -138,6 +139,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
return att_2d_masks & pad_2d_masks
def clone_past_key_values(past_key_values):
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
return DynamicCache(
tuple(
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
)
)
def pad_vector(vector, new_dim):
"""Pad the last dimension of a vector to new_dim with zeros.
@@ -223,53 +233,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
return padded_images
def sdpa_attention_forward(
module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0,
):
"""Drop-in for ``modeling_gemma.eager_attention_forward`` using
``torch.nn.functional.scaled_dot_product_attention``.
PyTorch SDPA picks the memory-efficient kernel for arbitrary additive
bias masks (the FA backend only accepts causal/sliding-window). On
H100 that is ~1.3-1.7x faster and uses ~30-40% less attention memory
than the eager softmax(QK^T)+matmul path. Mirrors eager's signature
and output shape (``(B, Lq, H, D)``) so call sites are unchanged.
"""
n_rep = module.num_key_value_groups
if n_rep > 1:
key = key.repeat_interleave(n_rep, dim=1)
value = value.repeat_interleave(n_rep, dim=1)
if attention_mask is not None and attention_mask.dtype != query.dtype:
attention_mask = attention_mask.to(dtype=query.dtype)
attn_output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=dropout if module.training else 0.0,
is_causal=False,
scale=scaling,
)
return attn_output.transpose(1, 2).contiguous(), None
# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.model.language_model, gemma_expert.model]
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
layer = layers[i]
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
gates.append(gate)
input_shape = hidden_states.shape[:-1]
@@ -291,14 +262,16 @@ def compute_layer_complete(
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
cos, sin = rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
att_output, _ = sdpa_attention_forward(
paligemma.model.language_model.layers[layer_idx].self_attn,
paligemma_layer = layers[0]
scaling = paligemma_layer.self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma_layer.self_attn,
query_states,
key_states,
value_states,
@@ -306,13 +279,13 @@ def compute_layer_complete(
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
head_dim = paligemma_layer.self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
layer = layers[i]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
@@ -444,7 +417,6 @@ class PaliGemmaWithExpertModel(
params_to_keep_float32 = [
"vision_tower",
"multi_modal_projector",
"lm_head",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
@@ -477,13 +449,13 @@ class PaliGemmaWithExpertModel(
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
features = image_outputs.pooler_output
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.embed_tokens(tokens)
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
def forward(
self,
@@ -521,8 +493,9 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.model.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
paligemma_layers = self.paligemma.model.language_model.layers
gemma_expert_layers = self.gemma_expert.model.layers
rotary_emb = self.paligemma.model.language_model.rotary_emb
# Check if gradient checkpointing is enabled for any of the models
use_gradient_checkpointing = (
@@ -532,36 +505,39 @@ class PaliGemmaWithExpertModel(
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Process all layers with gradient checkpointing if enabled
for layer_idx in range(num_layers):
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
use_reentrant=False,
preserve_rng_state=False,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
layers=layers,
rotary_emb=rotary_emb,
)
else:
inputs_embeds = compute_layer_complete(
layer_idx,
inputs_embeds,
attention_mask,
position_ids,
adarms_cond,
paligemma=self.paligemma,
gemma_expert=self.gemma_expert,
layers=layers,
rotary_emb=rotary_emb,
)
# final norm
final_norms = (
self.paligemma.model.language_model.norm,
self.gemma_expert.model.norm,
)
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -653,13 +629,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
)
return func(*args, **kwargs)
def _prepare_attention_masks_4d(self, att_2d_masks, dtype=None):
def _prepare_attention_masks_4d(self, att_2d_masks):
"""Helper method to prepare 4D attention masks for transformer."""
att_2d_masks_4d = att_2d_masks[:, None, :, :]
result = torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
if dtype is not None:
result = result.to(dtype=dtype)
return result
return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
def sample_noise(self, shape, device):
return torch.normal(
@@ -701,8 +674,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Process language tokens
def lang_embed_func(tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
lang_emb_dim = lang_emb.shape[-1]
return lang_emb * math.sqrt(lang_emb_dim)
return lang_emb
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
embs.append(lang_emb)
@@ -789,22 +761,21 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype)
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
# Selective AC: rely on the per-layer checkpoint inside
# ``PaliGemmaWithExpertModel.forward`` (which wraps each
# transformer block individually). The previous outer
# ``_apply_checkpoint(forward_func, ...)`` doubled up — it
# re-ran the full backbone forward during backward *and* each
# block's own checkpoint re-ran during that recompute. Pure
# waste with SDPA, which already streams attention activations.
(_, suffix_out), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_masks_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
adarms_cond=[None, adarms_cond],
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
(_, suffix_out), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_masks_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
adarms_cond=[None, adarms_cond],
)
return suffix_out
suffix_out = self._apply_checkpoint(
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
)
suffix_out = suffix_out[:, -self.config.chunk_size :]
@@ -848,9 +819,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(
prefix_att_2d_masks, dtype=prefix_embs.dtype
)
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
@@ -920,12 +889,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
full_att_2d_masks_4d = self._prepare_attention_masks_4d(
full_att_2d_masks, dtype=suffix_embs.dtype
)
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
past_key_values = clone_past_key_values(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,
@@ -1060,16 +1027,6 @@ class PI05Policy(PreTrainedPolicy):
if remap_count > 0:
print(f"Remapped {remap_count} state dict keys")
lm_head_key = "model.paligemma_with_expert.paligemma.lm_head.weight"
embed_tokens_key = (
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
)
if lm_head_key not in remapped_state_dict and embed_tokens_key in remapped_state_dict:
remapped_state_dict[lm_head_key] = remapped_state_dict[embed_tokens_key].clone().float()
print("Initialized PaliGemma lm_head from language token embeddings")
elif lm_head_key in remapped_state_dict:
remapped_state_dict[lm_head_key] = remapped_state_dict[lm_head_key].float()
# Load the remapped state dict into the model
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
@@ -1163,62 +1120,8 @@ class PI05Policy(PreTrainedPolicy):
return fixed_state_dict
def get_optim_params(self):
"""Return policy parameters, optionally split into LR-scaled groups.
When ``config.lm_head_lr_scale != 1.0``, the PaliGemma ``lm_head``
and its tied ``embed_tokens`` are placed in their own param
group with ``lr = base_lr * lm_head_lr_scale``. The cosine
scheduler multiplies both groups by the same lambda each step,
so the ratio is preserved across decay. Default ``1.0`` =
return ``self.parameters()`` (back-compat with existing checkpoints
and configs).
"""
scale = float(getattr(self.config, "lm_head_lr_scale", 1.0))
if scale == 1.0:
return self.parameters()
head_params: list[torch.nn.Parameter] = []
other_params: list[torch.nn.Parameter] = []
# Both ``lm_head.weight`` and the tied ``embed_tokens.weight`` —
# boosting only the projection without the embedding pulls them
# apart and breaks the tie that PaliGemma was pre-trained with.
head_substrings = (
"paligemma_with_expert.paligemma.lm_head.",
"paligemma_with_expert.paligemma.model.language_model.embed_tokens.",
)
for name, p in self.named_parameters():
if not p.requires_grad:
continue
if any(s in name for s in head_substrings):
head_params.append(p)
else:
other_params.append(p)
base_lr = float(self.config.optimizer_lr)
groups: list[dict[str, object]] = []
if other_params:
groups.append({"params": other_params, "lr": base_lr, "name": "policy"})
if head_params:
groups.append(
{"params": head_params, "lr": base_lr * scale, "name": "lm_head"}
)
# Sanity: head_substrings must match at least one parameter, otherwise
# the scale silently does nothing — surface that fast.
if not head_params:
raise RuntimeError(
"lm_head_lr_scale != 1.0 but no parameters matched the LM-head "
"name patterns: "
f"{head_substrings!r}. Did the underlying PaliGemma module rename?"
)
logging.info(
"PI05Policy: LM-head LR scale = %.3g (base=%.3g, head=%.3g) over "
"%d head params + %d other params",
scale,
base_lr,
base_lr * scale,
len(head_params),
len(other_params),
)
return groups
def get_optim_params(self) -> dict:
return self.parameters()
def reset(self):
"""Reset internal state - called when environment resets."""

View File

@@ -1,42 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""π0.5 v2 — full reproduction of the π0.5 paper's hierarchical
inference recipe on lerobot.
Extends :class:`lerobot.policies.pi05.PI05Policy` with:
* recipe-driven training (PR 1's :class:`RenderMessagesStep`),
* PaliGemma ``lm_head`` cross-entropy on supervised subtask spans
(the "high-level subtask prediction" of the paper, §IV.D),
* AR text generation at inference (:meth:`PI052Policy.select_message`),
* per-component prompt dropout (Pi 0.7 §V.E) for regularising the
text head against missing context at inference.
See ``src/lerobot/configs/recipes/subtasks_vqa.yaml`` for the
canonical training recipe and
``examples/training/pi052_hirobot.slurm`` for the launcher.
"""
from .configuration_pi052 import PI052Config
from .modeling_pi052 import PI052Policy
from .processor_pi052 import make_pi052_pre_post_processors
from .text_processor_pi052 import PI052TextTokenizerStep
__all__ = [
"PI052Config",
"PI052Policy",
"PI052TextTokenizerStep",
"make_pi052_pre_post_processors",
]

View File

@@ -1,216 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""π0.5 v2 (with text head) — reproduction of the π0.5 paper's
hierarchical inference recipe.
Same architecture as the existing ``PI05Policy`` (PaliGemma 2B VLM +
~300M Gemma action expert, joint training with FAST tokens during
pre-train and flow matching during post-train), but with the
PaliGemma ``lm_head`` re-enabled so the same model can be supervised
to predict both:
* **subtask strings** at the high level (cross-entropy on the LM
head), and
* **action chunks** at the low level (flow matching on the
action-expert tokens).
This is the dual-head co-training pattern from the paper:
L = H(x, f_θ_text) + α * ‖ω - a - f_θ_action(a_τ, o, )‖²
with α = 10.0 per § IV.D of arxiv:2504.16054. The π0.5 model splits
inference into a text-prediction step followed by an action-prediction
step, which the multi-rate ``PI052Runtime`` (in
``lerobot.policies.pi052.inference``) drives at separate rates.
"""
from dataclasses import dataclass
from lerobot.configs import PreTrainedConfig
from ..pi05.configuration_pi05 import PI05Config
@PreTrainedConfig.register_subclass("pi052")
@dataclass
class PI052Config(PI05Config):
"""π0.5 with the PaliGemma LM head re-enabled for subtask prediction.
Recipe-driven dual-head training: the flow head supervises actions,
the LM head supervises subtask / plan / memory / VQA text. The
flow:text loss split is the milder 5:1 (see ``flow_loss_weight``).
"""
# Recipe / language stack ---------------------------------------------
recipe_path: str | None = "recipes/subtasks_vqa.yaml"
"""Path (absolute or relative to ``src/lerobot/configs/``) to a
``TrainingRecipe`` YAML. Defaults to the canonical Hi-Robot blend
shipped alongside this policy. Set to ``None`` to disable recipe
rendering and fall back to π0.5's single-task ``Task: ... Action:``
prompt path (unannotated datasets keep working that way)."""
apply_chat_template: bool = False
"""PaliGemma is *not* chat-pretrained — its tokenizer doesn't ship a
chat template, so we don't apply one. The recipe renderer's output
is concatenated as a plain prefix + assistant suffix instead,
mirroring how the π0.5 paper's high-level inference samples text
auto-regressively after the prefix."""
# Loss weights --------------------------------------------------------
# Paper §IV.D uses α=10 between the flow and text terms, assuming
# text is a rare auxiliary task. With the recipe stack the flow-only
# `low_level` branch fires on a large share of samples, so α=10
# swamps the LM head and collapses generation into degenerate
# repetition. We use the milder 5:1 split here.
text_loss_weight: float = 1.0
"""Weight on the LM-head cross-entropy term. Set to ``0`` to disable
text training entirely (reverts to flow-only / π0.5 behaviour)."""
flow_loss_weight: float = 5.0
"""Weight on the action-expert flow-matching term. ``5.0`` — a milder
flow:text split than the paper's α=10, since the flow-only
``low_level`` recipe already gives the action expert frequent
gradient. Lower it further if the LM head still underfits."""
# Backbone training ---------------------------------------------------
unfreeze_lm_head: bool = True
"""Whether to keep the PaliGemma ``lm_head`` unfrozen for fine-tuning.
The existing ``PI05Policy`` zeroes / freezes the head on load
because it never reads from it. Must be ``True`` for π0.5-style
hierarchical inference."""
# Per-component prompt dropout (Pi0.7 §V.E) ---------------------------
# Randomly drop non-target context messages so the LM head learns
# to handle missing /
# stale plan / memory at inference. Defaults to 0.0 so behaviour
# is identical until explicitly enabled.
plan_dropout_prob: float = 0.0
memory_dropout_prob: float = 0.0
subtask_dropout_prob: float = 0.0
# FAST discrete-action supervision — paper §III.B-C ------------------
# When enabled, actions are *also* tokenised via the FAST tokenizer
# ("physical-intelligence/fast") and supervised with cross-entropy
# on the PaliGemma LM head — exactly as in the paper's pre-training
# objective (Eq. 1 mixes FAST CE + flow MSE + subtask CE). The
# ActionTokenizerProcessorStep is wired into the preprocessor
# pipeline when this flag is set; the loss is computed in
# PI052Policy.forward.
enable_fast_action_loss: bool = True
"""If True, tokenise actions with the FAST tokenizer and add a
cross-entropy loss on the LM head. On by default to match the
π0.5 paper's three-loss objective (text CE + FAST CE + flow MSE,
§III.B-C Eq. 1). Set to False if you only want the
post-training-style flow + text recipe."""
action_tokenizer_name: str = "physical-intelligence/fast"
"""HF identifier for the FAST action tokenizer."""
max_action_tokens: int = 256
"""Maximum number of FAST tokens per action chunk."""
fast_skip_tokens: int = 128
"""Number of low-vocab tokens the FAST tokenizer skips to avoid
collisions with PaliGemma's text vocabulary."""
fast_action_loss_weight: float = 1.0
"""Weight on the FAST-action-token CE loss. Paper §III.C uses 1.0."""
auto_fit_fast_tokenizer: bool = False
"""If True, the processor factory checks ``fast_tokenizer_cache_dir``
for a previously-fitted tokenizer keyed on ``(dataset_repo_id,
base_tokenizer_name, fit_samples)``. On cache miss, it loads
``action_tokenizer_name`` as a base, samples
``fast_tokenizer_fit_samples`` action chunks from the dataset, runs
``.fit()``, saves the result, and uses *that* fitted path as the
actual tokenizer. Pertsch et al. 2025 (FAST paper [64], π0.5 §III.C)
explicitly recommend per-dataset fitting for best compression.
Off by default because the fit requires a separate pre-training
pass over the dataset (~1-2 min on a medium dataset) and depends
on the FAST tokenizer snapshot having a ``.fit()`` method. Opt in
when you want paper-faithful compression; leave off to fall back
on the universal ``physical-intelligence/fast`` codebook."""
fast_tokenizer_cache_dir: str = "~/.cache/lerobot/fast_tokenizers"
"""Where fitted FAST tokenizers are stored. ``~`` expands."""
fast_tokenizer_fit_samples: int = 1024
"""Number of action chunks to sample for the fit. The FAST paper uses
a few thousand; 1024 is a reasonable default for medium datasets."""
# Knowledge insulation — paper §III.B --------------------------------
# When enabled, gradients from the action expert's flow loss are
# blocked from flowing back into the VLM's K/V projections. This
# prevents the action loss from over-fitting the language backbone
# to robot-specific features. Implemented in ``modeling_pi052`` as
# a per-instance monkey-patch on ``paligemma_with_expert.forward``
# that splits queries into VLM and action halves and ``.detach()``-s
# the VLM K/V tensors used in the action-half's attention.
knowledge_insulation: bool = False
"""If True, route every transformer layer through the KI
attention path that blocks action→VLM gradient flow on K/V."""
# Learning-rate defaults --------------------------------------------
# pi052 inherits π0.5's openpi-validated optimizer config (peak LR
# 2.5e-5, cosine→2.5e-6, 1k warmup, AdamW (0.9, 0.95), wd=0.01,
# grad_clip=1.0). The only place pi052 needs to diverge from pi05
# is the LM-head LR multiplier: pi05 has no text supervision so the
# head doesn't get gradients; pi052 always has text supervision
# (subtask / memory / VQA) via the recipe, and under KI the LM head
# only sees gradients on ~3045% of the batch (the text-CE mask
# share of the recipe). Under aggressive cosine decay this is too
# weak to keep the head pinned, so it drifts back toward PaliGemma's
# pretrained ``<loc>`` first-token bias. 5x is the documented fix
# (see ``PI05Config.lm_head_lr_scale`` docstring); the wiring is
# already in ``PI05Policy.get_optim_params`` — it splits the LM head
# + tied ``embed_tokens`` into their own param group while sharing
# the same cosine lambda, so the 5x ratio is preserved across decay.
lm_head_lr_scale: float = 5.0
# PaLM-style z-loss on text CE. Penalises the log-partition function
# ``z = log Σ exp(logits)`` drifting away from zero — without it, large-
# vocab models (PaliGemma is 257k) can let ``logsumexp`` grow unbounded
# while CE stays low, because a uniform additive logit bias cancels in
# softmax. PaLM appendix B / Chinchilla report z-loss is essential for
# stable large-vocab CE; it especially helps under ``lm_head_lr_scale=
# 5.0`` which amplifies drift risk on the LM head. ``1e-4`` is the
# commonly cited weight; set 0 to disable entirely.
text_ce_z_loss_weight: float = 1e-4
# Liger Triton kernels (rope + geglu + layer_norm) are now patched
# unconditionally at model build time — see ``_enable_hf_kernels``
# in ``modeling_pi052``. The patch is process-global, idempotent
# and degrades gracefully if ``liger-kernel`` is missing. Measured
# at -4.5% step time on H100 (bench job 22161421); peak memory
# unchanged. ``fused_linear_cross_entropy`` ships separately via
# ``_shifted_lin_ce`` / ``_fast_lin_ce``.
use_hf_kernels: bool = True
"""Deprecated. Liger HF kernels are patched unconditionally by
``_enable_hf_kernels`` — this field is retained as a no-op for
backward compatibility with checkpoints saved before commit
d70c8104 (which still serialize ``use_hf_kernels: true`` into
``config.json``). Loading those configs would otherwise raise
``DecodingError: The fields use_hf_kernels are not valid for
PI052Config`` (job 22164492). Remove in a future major bump."""
def __post_init__(self) -> None:
super().__post_init__()
# Backbone needs gradients flowing through the text head when
# we're training it. Override the π0.5 default
# (``train_expert_only=True``) unless the user explicitly opts
# out of text training via ``text_loss_weight=0``.
if self.text_loss_weight > 0 and self.unfreeze_lm_head:
self.train_expert_only = False

View File

@@ -1,263 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset-specific FAST action tokenizer fitting.
The published ``physical-intelligence/fast`` tokenizer is a *universal*
codebook fitted on a heterogeneous mix of robot datasets. Per Pertsch
et al. 2025 (the FAST paper, [64] in the π0.5 paper) and §III.C of
π0.5 itself, the recommended practice is to **finetune the tokenizer on
your specific dataset's action distribution** before training the
policy — same way one would adapt a language tokenizer to a domain
corpus. Without this finetune step, action sequences from your robot
may require more tokens per chunk than necessary, lowering effective
compression and slowing convergence of the action-CE loss.
This module provides a single utility, :func:`fit_fast_tokenizer`,
that does the finetune. The training entry point invokes it
automatically when the policy's ``enable_fast_action_loss`` and
``auto_fit_fast_tokenizer`` flags are both ``True`` and no cached
fitted tokenizer is found at ``fast_tokenizer_cache_dir``.
The fitted tokenizer is saved to
``{cache_dir}/{dataset_hash}_{base_hash}/`` so successive training
runs over the same dataset re-use it.
"""
from __future__ import annotations
import hashlib
import logging
import os
import time
from pathlib import Path
import numpy as np
logger = logging.getLogger(__name__)
# Marker file the cache-hit check looks for. ``ProcessorMixin.save_pretrained``
# writes ``processor_config.json`` (NOT ``preprocessor_config.json`` —
# that's the image / feature-extractor convention). Centralised here so
# the cache-hit check and the rank-N readiness wait agree on the same
# sentinel.
_CACHE_SENTINEL = "processor_config.json"
def _dataset_signature(
dataset_repo_id: str,
base_tokenizer_name: str,
n_samples: int,
chunk_size: int,
) -> str:
"""Deterministic short hash for naming the cache directory.
Keys on (dataset, base tokenizer, sample count, chunk size) so any
of those changing re-runs the fit. ``chunk_size`` matters because
the tokenizer is fit on chunks of that length.
"""
h = hashlib.sha256()
h.update(dataset_repo_id.encode("utf-8"))
h.update(b"\0")
h.update(base_tokenizer_name.encode("utf-8"))
h.update(b"\0")
h.update(str(n_samples).encode("utf-8"))
h.update(b"\0")
h.update(str(chunk_size).encode("utf-8"))
return h.hexdigest()[:16]
def fit_fast_tokenizer(
*,
dataset_repo_id: str,
cache_dir: str | Path,
base_tokenizer_name: str = "physical-intelligence/fast",
n_samples: int = 1024,
chunk_size: int = 50,
seed: int = 42,
) -> str:
"""Fit a FAST tokenizer on a LeRobot dataset's action distribution.
Args:
dataset_repo_id: HF Hub repo id of the LeRobotDataset to fit on.
cache_dir: Directory under which to save (and look up) fitted
tokenizers. The actual save path is
``{cache_dir}/{signature}``.
base_tokenizer_name: HF identifier for the base FAST tokenizer
to finetune from. ``physical-intelligence/fast`` is the
universal one.
n_samples: Number of action chunks to sample for the fit. The
FAST paper uses a few thousand; ``1024`` is a good default
for medium datasets.
chunk_size: Length of each action chunk (matches
``policy.chunk_size``). The FAST tokenizer is fit on
sequences of this length.
seed: RNG seed for sample selection.
Returns:
The local path to the fitted tokenizer. Passed directly to
``--policy.action_tokenizer_name`` for the training run.
Raises:
ImportError: If the ``transformers`` library doesn't expose
``AutoProcessor`` or the FAST tokenizer doesn't have a
``.fit()`` method (then you're on an older FAST snapshot —
update to the current published model).
FileNotFoundError: If the dataset can't be loaded.
"""
cache_dir = Path(cache_dir)
sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples, chunk_size)
out_dir = cache_dir / sig
if out_dir.exists() and (out_dir / _CACHE_SENTINEL).exists():
logger.info(
"FAST tokenizer cache hit: %s — re-using fitted tokenizer for "
"dataset=%s base=%s n_samples=%d",
out_dir, dataset_repo_id, base_tokenizer_name, n_samples,
)
return str(out_dir)
# DDP-safe fit: only the (local) main process actually fits + saves;
# other ranks poll the cache sentinel until the leader is done.
# Without this guard, all N ranks fit concurrently and race on
# ``save_pretrained`` + ``AutoProcessor.from_pretrained`` (the latter
# copies ``processing_action_tokenizer.py`` into ``HF_MODULES_CACHE``
# and compiles a ``.pyc`` — concurrent writers occasionally produce
# a stale / partial ``.pyc`` and the subsequent ``from .. import
# UniversalActionProcessor`` raises ``AttributeError``.
is_leader = (
int(os.environ.get("RANK", "0")) == 0
and int(os.environ.get("LOCAL_RANK", "0")) == 0
)
if not is_leader:
timeout_s = 1800.0 # 30 min — covers ~1024-sample fits on cold caches
start = time.monotonic()
while not (out_dir / _CACHE_SENTINEL).exists():
if time.monotonic() - start > timeout_s:
raise RuntimeError(
f"FAST tokenizer fit: non-leader rank timed out after "
f"{timeout_s:.0f}s waiting for {out_dir / _CACHE_SENTINEL}. "
"Leader rank likely crashed during the fit."
)
time.sleep(2.0)
logger.info("FAST tokenizer ready (leader populated cache): %s", out_dir)
return str(out_dir)
logger.info(
"FAST tokenizer cache miss — fitting on dataset=%s "
"base=%s n_samples=%d chunk_size=%d%s",
dataset_repo_id, base_tokenizer_name, n_samples, chunk_size, out_dir,
)
from transformers import AutoProcessor # noqa: PLC0415
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415
# Stream a single episode's worth of action chunks at a time so
# we don't blow memory on huge datasets. Random episode +
# random start offset gives a reasonable spread.
#
# Actions are read straight from the underlying HF dataset's
# ``action`` *column* — never via ``ds[i]``. ``ds[i]`` builds a full
# training item (delta-timestamp expansion + video decode + image
# transforms); a single bad video frame would then throw and, since
# the failure was swallowed at debug level, silently starve the fit
# of every chunk. The action column carries no video, so reading it
# directly is both faster and immune to decode errors.
rng = np.random.default_rng(seed)
actions_buf: list[np.ndarray] = []
# Load just the metadata first to know episode boundaries.
ds_meta_only = LeRobotDataset(dataset_repo_id, episodes=[0])
num_episodes = ds_meta_only.meta.total_episodes
if "action" not in ds_meta_only.features:
available = ", ".join(sorted(ds_meta_only.features)) or "<none>"
raise RuntimeError(
f"FAST fit: dataset {dataset_repo_id!r} has no ``action`` feature. "
f"Available features: {available}."
)
del ds_meta_only
samples_per_episode = max(1, n_samples // max(num_episodes, 1))
collected = 0
eps_visited = 0
short_episodes = 0
for ep_idx in rng.permutation(num_episodes):
if collected >= n_samples:
break
ep_idx = int(ep_idx)
try:
ds = LeRobotDataset(dataset_repo_id, episodes=[ep_idx])
ep_actions = np.asarray(ds.hf_dataset["action"], dtype=np.float32)
except Exception as exc: # noqa: BLE001
logger.warning("FAST fit: skipping episode %d: %s", ep_idx, exc)
continue
if ep_actions.ndim != 2 or ep_actions.shape[0] < chunk_size:
short_episodes += 1
continue
# Sample ``samples_per_episode`` contiguous chunks uniformly.
starts = rng.integers(0, ep_actions.shape[0] - chunk_size + 1, size=samples_per_episode)
for s in starts:
actions_buf.append(ep_actions[int(s) : int(s) + chunk_size])
collected += 1
if collected >= n_samples:
break
eps_visited += 1
if not actions_buf:
raise RuntimeError(
f"FAST fit collected zero action chunks from {dataset_repo_id!r}: "
f"all {num_episodes} episodes were shorter than chunk_size="
f"{chunk_size} ({short_episodes} too short) or had an unreadable "
"``action`` column. Lower ``chunk_size`` to match your episode "
"lengths."
)
actions = np.stack(actions_buf, axis=0).astype(np.float32) # (N, H, D)
logger.info(
"FAST fit: collected %d chunks of shape %s from %d episodes",
actions.shape[0], actions.shape[1:], eps_visited,
)
# Quantile-normalise per dimension before fitting.
#
# The FAST tokenizer DCT-transforms actions, scales by ``scale`` and
# rounds to integer tokens; the integer *range* must fit the
# codebook (vocab_size, default 1024). Raw motor units (e.g. encoder
# ticks) blow that range up — hence "Vocab size 1024 is too small".
# More importantly, at training time ``ActionTokenizerProcessorStep``
# runs *after* the QUANTILES ``NormalizerProcessorStep``, so it
# encodes normalised actions. Fitting on raw actions would mismatch
# that space. We replicate QUANTILES normalisation here (per-dim
# [q01, q99] → [-1, 1], clipped) so the fit and the training-time
# encode see the same distribution.
flat = actions.reshape(-1, actions.shape[-1])
q01 = np.quantile(flat, 0.01, axis=0)
q99 = np.quantile(flat, 0.99, axis=0)
span = np.where((q99 - q01) > 1e-6, q99 - q01, 1.0)
actions = np.clip((actions - q01) / span * 2.0 - 1.0, -1.0, 1.0).astype(np.float32)
base = AutoProcessor.from_pretrained(base_tokenizer_name, trust_remote_code=True)
if not hasattr(base, "fit"):
raise ImportError(
f"Base FAST tokenizer {base_tokenizer_name!r} has no ``.fit()`` "
"method — your transformers / model snapshot is too old. Update "
"to the current ``physical-intelligence/fast`` revision."
)
fitted = base.fit(actions)
out_dir.mkdir(parents=True, exist_ok=True)
fitted.save_pretrained(str(out_dir))
logger.info("FAST fit: saved fitted tokenizer to %s", out_dir)
return str(out_dir)

View File

@@ -1,73 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PI052 inference / runtime orchestration.
Multi-rate runtime that mirrors the recipe-time training shape:
low_level_execution → LowLevelForward + DispatchAction (high Hz)
high_level_subtask → HighLevelSubtaskFwd (~1 Hz)
memory_update → MemoryUpdateFwd (event: subtask_change)
user_interjection_response → UserInterjectionFwd (event: stdin)
ask_vqa_* → AskVQAFwd (event: stdin question)
speech tool calls → DispatchToolCalls (event: tool_call_pending)
The CLI ``lerobot-pi052-runtime`` builds a ``PI052Runtime`` and calls
``run()``.
"""
from .repl import StdinReader
from .runtime import PI052Runtime
from .runtime_state import initial_runtime_state, push_log, set_if_changed, take_event
from .steps import (
AskVQAFwd,
DispatchAction,
DispatchToolCalls,
HighLevelSubtaskFwd,
InferenceStep,
LowLevelForward,
MemoryUpdateFwd,
UserInterjectionFwd,
)
from .triggers import EventTrigger, HzTrigger, Tick, TickClock, Trigger
from .ui import make_state_panel, print_robot_lines, print_user_line
__all__ = [
# runtime
"PI052Runtime",
"StdinReader",
# state helpers
"initial_runtime_state",
"push_log",
"set_if_changed",
"take_event",
# triggers
"Trigger",
"Tick",
"TickClock",
"HzTrigger",
"EventTrigger",
# steps
"InferenceStep",
"LowLevelForward",
"DispatchAction",
"HighLevelSubtaskFwd",
"MemoryUpdateFwd",
"UserInterjectionFwd",
"AskVQAFwd",
"DispatchToolCalls",
# UI
"make_state_panel",
"print_robot_lines",
"print_user_line",
]

View File

@@ -1,105 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stdin REPL event collector for the PI052 runtime.
Reads non-blocking stdin lines, classifies each one heuristically:
"stop" / "quit" / "exit" → state["stop"] = True
"/action" / "/pause" → set state["mode"]
ends with "?" → user_vqa_query event
starts with "task:" or first line → set runtime task
anything else → user_interjection event
Plugged into the runtime via ``event_collector=StdinReader().poll``.
Note: the shipped CLI (``lerobot-pi052-runtime``) drives stdin
directly in its REPL / autonomous loops and does *not* wire this
collector; it's kept as the documented embedding hook and for tests.
"""
from __future__ import annotations
import select
import sys
from dataclasses import dataclass, field
from typing import Any
@dataclass
class StdinReader:
"""Non-blocking stdin line collector for the runtime loop."""
prompt: str = "> "
_seen_first_line: bool = field(default=False, init=False)
_prompted: bool = field(default=False, init=False)
def poll(self, state: dict[str, Any]) -> None:
"""Drain pending stdin lines into runtime events."""
# Print the input prompt once on every fresh tick if we don't
# already have a pending line; matches the expected REPL feel.
if not self._prompted:
print(self.prompt, end="", flush=True)
self._prompted = True
# ``select`` with timeout=0 makes this non-blocking. Only works
# for actual TTY / pipe stdins; CI / scripted runs hit EOF.
try:
ready, _, _ = select.select([sys.stdin], [], [], 0)
except (ValueError, OSError):
return
if not ready:
return
line = sys.stdin.readline()
if not line: # EOF
state["stop"] = True
return
line = line.strip()
self._prompted = False # we'll re-prompt next tick
if not line:
return
lower = line.lower()
if lower in {"stop", "quit", "exit"}:
state["stop"] = True
return
# Slash commands flip the run mode. ``/pause`` stops the action
# loop (the action steps gate on ``state["mode"]``); ``/action``
# resumes it.
if lower.split(" ", 1)[0] in {"/action", "/act", "/run"}:
state["mode"] = "action"
return
if lower in {"/pause", "/p"}:
state["mode"] = "paused"
queue = state.get("action_queue")
if hasattr(queue, "clear"):
queue.clear()
return
# First non-control line sets the task if no task is active.
if not state.get("task"):
task = line[5:].strip() if lower.startswith("task:") else line
state["task"] = task
print(f"[pi052] Task: {task}", flush=True)
self._seen_first_line = True
return
# Question → VQA; statement → interjection.
if lower.endswith("?"):
state["recent_vqa_query"] = line
state.setdefault("events_this_tick", []).append("user_vqa_query")
else:
state["recent_interjection"] = line
state.setdefault("events_this_tick", []).append("user_interjection")

View File

@@ -1,205 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PI052 runtime loop.
Threads the multi-rate inference pipeline together with a stdin REPL
event collector, drives ticks through :class:`TickClock`, and prints
state-change updates to the user.
"""
from __future__ import annotations
import logging
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Callable
from .runtime_state import initial_runtime_state, push_log
from .steps import (
AskVQAFwd,
DispatchAction,
DispatchToolCalls,
HighLevelSubtaskFwd,
InferenceStep,
LowLevelForward,
MemoryUpdateFwd,
)
from .triggers import EventTrigger, HzTrigger, TickClock
logger = logging.getLogger(__name__)
@dataclass
class PI052Runtime:
"""Compose the inference pipeline and drive it tick-by-tick."""
policy: Any
tools: dict[str, Any] = field(default_factory=dict)
"""Name → tool-instance dict, e.g. ``{"say": SayTool(...)}``. Read
from :func:`lerobot.tools.get_tools(meta)` when wiring the
runtime."""
observation_provider: Callable[[], dict | None] | None = None
"""Closure returning the current preprocessed observation batch.
``None`` for dry-run / language-only sessions."""
robot_executor: Callable[[Any], None] | None = None
"""Closure that takes one action chunk and forwards it to the
robot. ``None`` for dry-run."""
event_collector: Callable[[dict], None] | None = None
"""Per-tick hook that polls external sources (stdin, network) and
appends event names to ``state["events_this_tick"]``."""
chunk_hz: float = 4.0
ctrl_hz: float = 50.0
high_level_hz: float = 1.0
max_rate_hz: float = 50.0
pipeline: list[InferenceStep] = field(init=False)
state: dict[str, Any] = field(init=False)
_stop: bool = field(default=False, init=False)
def __post_init__(self) -> None:
# Subtask + memory + VQA configuration. Pipeline:
#
# HighLevelSubtaskFwd → generate the next subtask via the LM
# head at ~``high_level_hz``; writes
# ``current_subtask`` and emits
# ``subtask_change`` on a transition.
# MemoryUpdateFwd → on ``subtask_change``, refresh
# ``current_memory`` from the
# ``memory_update`` head.
# AskVQAFwd → answer camera-grounded stdin questions.
# LowLevelForward → action chunk conditioned on the
# generated ``current_subtask``.
# DispatchAction → drain the chunk to the robot.
# DispatchToolCalls → fire any pending tool calls.
#
# Order matters: ``HighLevelSubtaskFwd`` must run before
# ``MemoryUpdateFwd`` so the event is visible the same tick, and
# both must run before ``LowLevelForward`` (which is gated on
# "action queue empty") so the chunk consumes the freshest
# subtask. ``UserInterjectionFwd`` is still importable but
# disabled until plan generation is wired in.
self.pipeline = [
HighLevelSubtaskFwd(
trigger=HzTrigger(self.high_level_hz),
policy=self.policy,
observation_provider=self.observation_provider,
),
# Listens for the ``subtask_change`` event raised by
# ``HighLevelSubtaskFwd`` and refreshes ``current_memory``.
MemoryUpdateFwd(
trigger=EventTrigger("subtask_change"),
policy=self.policy,
observation_provider=self.observation_provider,
),
AskVQAFwd(
policy=self.policy,
observation_provider=self.observation_provider,
),
LowLevelForward(
trigger=HzTrigger(self.chunk_hz),
policy=self.policy,
observation_provider=self.observation_provider,
),
DispatchAction(
trigger=HzTrigger(self.ctrl_hz),
robot_executor=self.robot_executor,
),
DispatchToolCalls(tools=self.tools),
]
self.state = initial_runtime_state()
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def set_task(self, task: str) -> None:
"""Set or replace the active task. Logged for the REPL."""
self.state["task"] = task
push_log(self.state, f"Task: {task}")
def stop(self) -> None:
self._stop = True
def run(self, *, max_ticks: int | None = None) -> None:
"""Main loop. Returns when ``stop()`` is called or after
``max_ticks`` ticks (useful for tests / dry-run)."""
clock = TickClock(max_rate_hz=self.max_rate_hz)
while not self._stop:
tick = clock.advance()
self.state["_tick"] = tick
self.state["events_this_tick"] = []
self.state["log_lines"] = []
if self.event_collector is not None:
self.event_collector(self.state)
if self.state.get("stop"):
self._stop = True
break
for step in self.pipeline:
self.state = step(self.state)
self._flush_logs()
if max_ticks is not None and tick.index >= max_ticks:
break
self._on_shutdown()
# ------------------------------------------------------------------
# REPL helper: drive one full pipeline pass and return its logs
# ------------------------------------------------------------------
def step_once(self) -> list[str]:
"""Run one tick of the pipeline and return the log lines.
Used by the interactive REPL: instead of a background thread,
the CLI drives ticks synchronously after each user input. Logs
are returned (not printed) so the caller can route them into
the rich-Live chat scrollback.
"""
from .triggers import Tick # noqa: PLC0415
# Synthesize a tick. We don't need the real wall-clock pacing
# here — the REPL drives the runtime, not vice versa — but
# ``HzTrigger`` uses ``tick.monotonic_seconds`` to gate, so we
# bump it generously so every Hz-triggered step considers
# itself due.
import time as _time # noqa: PLC0415
prev_index = self.state.get("_tick").index if isinstance(self.state.get("_tick"), Tick) else 0
self.state["_tick"] = Tick(index=prev_index + 1, monotonic_seconds=_time.monotonic())
self.state["log_lines"] = []
# ``events_this_tick`` is set up by the caller before
# ``step_once`` (the REPL pushes user-driven events first).
self.state.setdefault("events_this_tick", [])
for step in self.pipeline:
self.state = step(self.state)
return list(self.state.get("log_lines") or [])
# ------------------------------------------------------------------
# I/O
# ------------------------------------------------------------------
def _flush_logs(self) -> None:
for line in self.state.get("log_lines") or []:
print(f"[pi052] {line}", flush=True)
def _on_shutdown(self) -> None:
# Drain any queued action chunks safely.
queue = self.state.get("action_queue")
if isinstance(queue, deque):
queue.clear()
print("[pi052] runtime stopped", flush=True)

View File

@@ -1,95 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runtime state passed between inference steps each tick.
The runtime threads a single dict through the pipeline; this module
documents the shape and provides factories. We use a plain ``dict``
rather than a frozen dataclass because steps freely add and remove
keys (``events_this_tick``, ``messages_pending``, ``tool_calls_pending``,
…) and dataclass field churn would just get in the way.
Stable keys (read by multiple steps):
task str the current top-level task
current_plan str | None latest plan emitted by the planner
current_subtask str | None latest subtask the policy is executing
current_memory str | None latest compressed memory
recent_interjection str | None most recent user interjection text (consumed)
action_queue collections.deque[Tensor] pending action chunks
tool_calls_pending list[dict] parsed but not-yet-dispatched tool calls
events_this_tick list[str] triggers consumed this tick
_tick Tick current tick (set by the loop)
mode str "action" (run the robot) | "paused"
(action loop stopped — robot holds)
log_lines list[str] human-readable status lines printed each tick
"""
from __future__ import annotations
from collections import deque
from typing import Any
def initial_runtime_state(task: str | None = None) -> dict[str, Any]:
"""Build a fresh runtime state dict with sensible defaults."""
return {
"task": task,
"current_plan": None,
"current_subtask": None,
"current_memory": None,
"recent_interjection": None,
"action_queue": deque(),
"tool_calls_pending": [],
"events_this_tick": [],
"log_lines": [],
"mode": "action",
"stop": False,
}
def take_event(state: dict[str, Any], event_name: str) -> bool:
"""Pop ``event_name`` from ``events_this_tick`` if present.
Steps that consume an event call this so the same event doesn't
re-fire on a sibling step within the same tick.
"""
events: list[str] = state.get("events_this_tick") or []
if event_name in events:
events.remove(event_name)
return True
return False
def push_log(state: dict[str, Any], line: str) -> None:
"""Append ``line`` to the per-tick log buffer; the runtime prints
it at the end of the tick."""
state.setdefault("log_lines", []).append(line)
def set_if_changed(state: dict[str, Any], key: str, value: Any, label: str | None = None) -> bool:
"""Update ``state[key]`` and log a diff line if the value changed.
Returns ``True`` if the value actually changed.
"""
prev = state.get(key)
if prev == value:
return False
state[key] = value
if label is not None:
push_log(state, f" {label}: {value}")
return True

View File

@@ -1,936 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference steps for the PI052 multi-rate runtime.
Each step is a tiny class with a ``trigger`` and an ``__call__(state)``;
the runtime applies them in order each tick. When a step's trigger
doesn't fire, the step is a no-op and the runtime moves on.
Stream-to-step mapping mirrors the ``subtasks_vqa.yaml`` recipe:
* ``LowLevelForward`` — calls ``policy.select_action`` for the
action chunk; trained by
``low_level_execution``
* ``EnqueueChunk`` — pushes the chunk to ``action_queue``
* ``DispatchAction`` — pops one action per control tick and
forwards to the robot
* ``HighLevelSubtaskFwd`` — calls ``policy.select_message`` for the
next subtask; trained by
``high_level_subtask``
* ``MemoryUpdateFwd`` — fires on subtask boundary; trained by
``memory_update``
* ``UserInterjectionFwd`` — fires on stdin interjection; trained by
``user_interjection_response``
* ``AskVQAFwd`` — fires on stdin question; trained by
``ask_vqa_*``
* ``DispatchToolCalls`` — pops ``tool_calls_pending`` and calls
the matching ``Tool`` instance
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from typing import Any
from .runtime_state import push_log, set_if_changed, take_event
from .triggers import EventTrigger, HzTrigger, Trigger
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Step base + runner
# ---------------------------------------------------------------------------
@dataclass
class InferenceStep:
"""A trigger-gated callable. Subclasses override :meth:`run`."""
trigger: Trigger
def __call__(self, state: dict[str, Any]) -> dict[str, Any]:
if not self.trigger.should_fire(state["_tick"], state):
return state
return self.run(state) or state
def run(self, state: dict[str, Any]) -> dict[str, Any] | None: # pragma: no cover
raise NotImplementedError
# ---------------------------------------------------------------------------
# Low-level (action) path
# ---------------------------------------------------------------------------
@dataclass
class LowLevelForward(InferenceStep):
"""Run the policy's action head and produce one action chunk."""
policy: Any = None
observation_provider: Any = None
"""Callable ``() -> dict``: returns the current observation batch
(already preprocessed). Typically wraps the robot's camera /
proprio reads. ``None`` in dry-run mode → step skips."""
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=4.0))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or self.observation_provider is None:
return None
# ``/vlm`` mode pauses the whole action loop so the robot holds
# position while the operator probes the VLM with VQA.
if state.get("mode", "action") != "action":
return None
if not state.get("task"):
return None
# PI052 produces *action chunks* (typically 50 steps via
# flow-matching). Every step gets dispatched to the robot;
# popping one per dispatch tick is essentially free. Only
# generate a new chunk once the previous one has fully
# drained — this is the canonical "sense → think → act"
# loop. Refreshing while a chunk is still queued causes the
# new chunk to "telescope" past the old one (planned from an
# observation that's already 25+ steps stale by the time it
# starts dispatching).
queue = state.setdefault("action_queue", [])
if len(queue) > 0:
return None
observation = self.observation_provider()
if observation is None:
return None
# The action expert is conditioned on the SUBTASK generated by
# the high-level loop (``HighLevelSubtaskFwd`` runs earlier in
# the pipeline and writes ``current_subtask``). Matches the
# training-time ``low_level_execution`` recipe — ``user(${subtask})``.
# Falls back to the task string only on the very first frame,
# before the high-level loop has produced a subtask.
subtask = state.get("current_subtask") or state.get("task") or ""
ctx = [{"role": "user", "content": subtask}]
# ``add_generation_prompt=False`` to match the training-time
# prefix shape: at training the action expert sees the rendered
# user turn ending at ``<|im_end|>`` (no trailing
# ``<|im_start|>assistant\n``). Passing True here would append
# extra role-marker tokens the action expert never saw during
# training.
text_batch = _build_text_batch(self.policy, ctx, add_generation_prompt=False)
from lerobot.utils.constants import ( # noqa: PLC0415
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
)
observation = dict(observation)
observation[OBS_LANGUAGE_TOKENS] = text_batch["lang_tokens"]
observation[OBS_LANGUAGE_ATTENTION_MASK] = text_batch["lang_masks"]
try:
# ``predict_action_chunk`` returns the *full* chunk shape
# ``(batch, n_action_steps, action_dim)``. Enqueue every
# step so DispatchAction at ctrl_hz can drain them
# smoothly until the next refresh.
chunk = self.policy.predict_action_chunk(observation)
except Exception as exc: # noqa: BLE001
logger.warning(
"predict_action_chunk failed: %s",
exc,
exc_info=logger.isEnabledFor(logging.DEBUG),
)
push_log(
state,
f" [warn] predict_action_chunk failed: "
f"{type(exc).__name__}: {exc}",
)
return None
# ``chunk`` shape: ``(batch, n_action_steps, action_dim)``. Push
# each step as a ``(1, action_dim)`` tensor so the existing
# action executor's batch-squeeze logic works unchanged.
if chunk.ndim == 3:
chunk_iter = chunk[0] # ``(n_action_steps, action_dim)``
elif chunk.ndim == 2:
chunk_iter = chunk
else:
chunk_iter = chunk.unsqueeze(0)
for step in chunk_iter:
queue.append(step.unsqueeze(0))
state["last_chunk_size"] = int(chunk_iter.shape[0])
return None
@dataclass
class DispatchAction(InferenceStep):
"""Pop one action per tick and hand it to the robot.
In dry-run mode (``robot_executor=None``) the step still pops the
queue so it doesn't grow unbounded — the popped tensor is logged
instead of executed.
Wall-clock catch-up: the action queue represents an open-loop
trajectory at a fixed step rate (``trigger.hz`` ≈ ``ctrl_hz``).
When the main loop stalls — e.g. an LLM call for the high-level
subtask blocks for ~2 s on MPS — the dispatch trigger fires only
once over that whole interval. Naively popping a single entry per
fire makes the robot lag further and further behind the planned
timeline, and a 50-step chunk would take ~125 s to drain instead
of ~1.7 s. Track real elapsed time between dispatches and pop
``round(elapsed * hz)`` entries, sending the most recent one. The
skipped intermediate joint targets are stale anyway — the dynamixel
will smooth toward the latest goal position.
"""
robot_executor: Any = None
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=50.0))
_last_dispatch_t: float | None = field(default=None, init=False)
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
import time as _time # noqa: PLC0415
# ``/vlm`` mode pauses dispatch — the robot holds its last
# commanded position while the operator runs VQA.
if state.get("mode", "action") != "action":
self._last_dispatch_t = None
return None
queue = state.get("action_queue")
if not queue:
# Reset wall-clock anchor when the queue is empty so the
# next chunk doesn't see a huge fake "elapsed" window.
self._last_dispatch_t = None
return None
now = _time.monotonic()
hz = getattr(self.trigger, "hz", 30.0)
if self._last_dispatch_t is None or hz <= 0:
n_to_pop = 1
else:
elapsed = now - self._last_dispatch_t
# ``max(1, ...)`` so we always pop at least one when the
# trigger fires; ``min(len(queue), ...)`` so we don't run
# off the end of the chunk.
n_to_pop = max(1, min(len(queue), int(round(elapsed * hz))))
self._last_dispatch_t = now
# Drain ``n_to_pop`` stale entries, keep only the latest as the
# action actually sent. The intermediate joint targets would
# all be ~1030 ms apart in chunk time — the robot can't track
# them individually anyway when the host loop is slow.
latest = None
for _ in range(n_to_pop):
if not queue:
break
latest = queue.popleft() if hasattr(queue, "popleft") else queue.pop(0)
state["actions_dispatched"] = state.get("actions_dispatched", 0) + 1
if latest is not None and self.robot_executor is not None:
self.robot_executor(latest)
return None
# ---------------------------------------------------------------------------
# High-level (text) paths — all use policy.select_message
# ---------------------------------------------------------------------------
def _build_text_batch(
policy: Any,
prompt_messages: list[dict[str, Any]],
*,
add_generation_prompt: bool = True,
) -> dict[str, Any]:
"""Tokenize chat messages into the batch ``select_message`` expects.
PI052's backbone (PaliGemma) ships no chat template, so we train on
a plain role-prefixed concatenation built by
``PI052TextTokenizerStep``. We reuse that exact formatter so the
inference prefix matches training; ``add_generation_prompt`` appends
the bare ``Assistant: `` header the LM head continues from.
"""
import torch # noqa: PLC0415
from transformers import AutoTokenizer # noqa: PLC0415
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: PLC0415
_flatten_say_tool_calls,
_format_messages,
_strip_blocks,
register_paligemma_loc_tokens,
)
tok_name = (
getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224"
)
# Register PaliGemma's <locDDDD> tokens so inference encoding /
# decoding sees them as single vocab ids — must match training.
tokenizer = register_paligemma_loc_tokens(AutoTokenizer.from_pretrained(tok_name))
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in prompt_messages]
prompt, _spans = _format_messages(messages)
if add_generation_prompt:
prompt = prompt + "Assistant: "
encoded = tokenizer(prompt, return_tensors="pt")
ids = encoded["input_ids"]
attn = encoded.get("attention_mask")
if attn is None and tokenizer.pad_token_id is not None:
attn = ids != tokenizer.pad_token_id
if attn is not None and hasattr(attn, "dtype") and attn.dtype != torch.bool:
attn = attn.bool()
# Move tokens onto the policy's device — otherwise prefix embedding
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
# model), which the caller's broad except would swallow silently.
device = getattr(getattr(policy, "config", None), "device", None)
if device is not None:
try:
ids = ids.to(device)
if attn is not None and hasattr(attn, "to"):
attn = attn.to(device)
except Exception as exc: # noqa: BLE001
logger.debug("could not move pi052 lang tokens to %s: %s", device, exc)
return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer}
def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]:
new = dict(m)
new.pop("stream", None)
new.pop("target", None)
return new
@dataclass
class HighLevelSubtaskFwd(InferenceStep):
"""At ~1 Hz, ask the policy for the next subtask.
Mirrors the ``high_level_subtask`` recipe layout exactly:
user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}"
user: "Current subtask: ${subtask}" (if subtask present)
↓ generate ↓
assistant: <next subtask>
"""
policy: Any = None
observation_provider: Any = None
"""Same shape as ``LowLevelForward.observation_provider``. When
set, the resulting observation is merged into ``select_message``'s
batch so text generation runs against real video + state."""
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or not state.get("task"):
return None
# ``/vlm`` mode pauses subtask generation along with the rest of
# the action loop.
if state.get("mode", "action") != "action":
return None
# Gate to chunk boundaries: only generate a fresh subtask when
# the action queue is empty (i.e. right before LowLevelForward
# refreshes the chunk). ``select_message`` takes ~2 s on MPS,
# and running it every loop iteration starves DispatchAction
# at ctrl_hz=30 — the queue drains at ~0.4 actions/sec instead
# of 30/sec and the robot barely moves. Tying it to the same
# "queue empty" condition as the chunk refresh produces a
# clean sense → think → act cycle.
#
# Rearm the trigger when skipping so a low-hz schedule
# (e.g. ``--high_level_hz=0.2`` = once per 5 s) doesn't lose
# the slot: the trigger fires once on the timer but the brief
# queue-empty window almost never coincides, so without rearm
# HL would effectively never run.
queue = state.get("action_queue") or []
if len(queue) > 0:
if hasattr(self.trigger, "rearm"):
self.trigger.rearm()
return None
# Per-chunk-boundary throttle: at each "queue empty" moment we
# increment a counter; subtask gen only fires once the counter
# reaches ``subtask_chunks_per_gen``. Lets the operator run e.g.
# 5 action chunks per subtask-gen so the LM head doesn't churn
# every 1.7 s (a fresh subtask while the previous one is still
# being executed is wasted compute *and* causes the action
# expert's flow trajectory to be re-planned mid-grasp).
chunks_per_gen = max(1, int(state.get("subtask_chunks_per_gen", 1) or 1))
# Initialise so the first chunk boundary fires immediately
# (counter starts at chunks_per_gen, decrements per skip,
# generates and resets when it hits 0).
if "_hl_chunks_until_gen" not in state:
state["_hl_chunks_until_gen"] = 0
if state["_hl_chunks_until_gen"] > 0:
state["_hl_chunks_until_gen"] -= 1
if hasattr(self.trigger, "rearm"):
self.trigger.rearm()
return None
state["_hl_chunks_until_gen"] = chunks_per_gen - 1
ctx = _msgs_for_subtask(state)
observation = _maybe_observation(self.observation_provider)
# Default: greedy argmax, no min_new_tokens, no special-token
# suppression — matches training. Operator can override via
# ``--text_min_new_tokens=N --text_temperature=T --text_top_p=P``
# on the CLI; useful for under-trained checkpoints whose LM
# head still favours EOS at position 0 (pre-trained chat
# backbone's short-turn prior hasn't been fully overridden
# by the fine-tuning supervision yet).
msg = _generate_with_policy(
self.policy,
ctx,
observation=observation,
state=state,
label="subtask gen",
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
temperature=float(state.get("text_gen_temperature") or 0.0),
top_p=float(state.get("text_gen_top_p") or 1.0),
# Subtasks never legitimately contain PaliGemma ``<loc>``
# tokens — suppress them so a checkpoint whose LM head
# has drifted toward the pretrained loc-prior falls back
# to its (still-correct) text mass.
suppress_loc_tokens=True,
)
# Diagnostics: surface what the model is *actually* producing
# at chunk boundaries, even when the output gets rejected or
# repeats. Memorisation collapse looks like "same accepted
# subtask N times in a row" or "gibberish_count rising while
# current_subtask is stuck". The state panel renders these.
state["last_subtask_raw"] = msg or ""
# Persistent empty completion is its own failure mode (model
# immediately EOS-es from the chat-template generation
# prompt) — surface it once every N occurrences so the
# operator can distinguish "generation failing silently"
# from "generating fine but filter rejecting".
if not msg:
empties = state.get("subtask_empty_count", 0) + 1
state["subtask_empty_count"] = empties
if empties == 1 or empties % 5 == 0:
debug = getattr(self.policy, "_last_select_message_debug", "") or ""
if debug:
push_log(
state,
f" [info] subtask gen empty (×{empties}); {debug}",
)
else:
push_log(
state,
f" [info] subtask gen returned empty (×{empties}) — "
"no tokens generated (head EOS-ing before any "
"non-special token).",
)
if msg and _looks_like_gibberish(msg):
# Bump a counter so the operator can see the model is
# struggling without spamming the log every tick. A first
# rejection still logs once so the failure is visible.
count = state.get("subtask_gibberish_count", 0) + 1
state["subtask_gibberish_count"] = count
if count == 1 or count % 30 == 0:
push_log(
state,
f" [info] subtask gen rejected (gibberish ×{count}): {msg[:60]!r}",
)
return None
if msg:
prev_subtask = state.get("current_subtask")
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
if changed:
# Stash the just-completed subtask so ``MemoryUpdateFwd``
# can drop it into its prompt as ``Completed subtask:``
# — the recipe binds ``completed_subtask`` to
# ``nth_prev(style=subtask, offset=1)``, i.e. the subtask
# that was active *before* the change.
if prev_subtask:
state["prior_subtask"] = prev_subtask
# Subtask change is a downstream trigger.
state.setdefault("events_this_tick", []).append("subtask_change")
state["subtask_repeat_count"] = 0
else:
# Same accepted string regenerated — memorisation tell.
# Once this counter climbs past a few, you're seeing
# the model unable to move past the current subtask
# despite the chunk having drained (visual scene may
# have changed but the LM is replaying training
# tokens).
state["subtask_repeat_count"] = (
state.get("subtask_repeat_count", 0) + 1
)
# Silently skip empty completions — common when the model
# warms up or generates only EOS; logging it every tick at
# ctrl_hz is just noise.
return None
@dataclass
class MemoryUpdateFwd(InferenceStep):
"""On subtask boundary, refresh the compressed memory.
Mirrors the ``memory_update`` recipe layout exactly:
user: "${task}"
assistant: "Previous memory: ${prior_memory}" (if prior memory)
user: "Completed subtask: ${completed_subtask}" (if subtask)
↓ generate ↓
assistant: <new memory>
"""
policy: Any = None
observation_provider: Any = None
trigger: Trigger = field(default_factory=lambda: EventTrigger("subtask_change"))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
# Don't consume the event — multiple steps may want to react.
if self.policy is None:
return None
ctx = _msgs_for_memory(state)
observation = _maybe_observation(self.observation_provider)
new_memory = _generate_with_policy(
self.policy,
ctx,
observation=observation,
state=state,
label="memory gen",
suppress_loc_tokens=True,
)
state["last_memory_raw"] = new_memory or ""
if new_memory and _looks_like_gibberish(new_memory):
count = state.get("memory_gibberish_count", 0) + 1
state["memory_gibberish_count"] = count
push_log(
state,
f" [info] memory gen rejected (gibberish ×{count}): {new_memory[:60]!r}",
)
return None
if new_memory:
set_if_changed(state, "current_memory", new_memory, label="memory")
return None
@dataclass
class UserInterjectionFwd(InferenceStep):
"""On stdin interjection, refresh the plan + emit a paired ``say``.
Mirrors the ``user_interjection_response`` recipe layout exactly:
user: "${task}"
assistant: "Previous plan:\\n${prior_plan}" (if prior plan)
user: "${interjection}" (the new utterance)
↓ generate ↓
assistant: <plan + <say>...</say>>
"""
policy: Any = None
observation_provider: Any = None
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_interjection"))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or not take_event(state, "user_interjection"):
return None
ctx = _msgs_for_interjection(state)
observation = _maybe_observation(self.observation_provider)
out = _generate_with_policy(
self.policy,
ctx,
observation=observation,
state=state,
label="plan/say gen",
suppress_loc_tokens=True,
)
if not out:
# Don't log every empty completion — happens repeatedly on
# MPS during warm-up and floods the panel. The user can
# re-trigger by typing again.
return None
if _looks_like_gibberish(out):
count = state.get("plan_gibberish_count", 0) + 1
state["plan_gibberish_count"] = count
push_log(
state,
f" [info] plan/say gen rejected (gibberish ×{count}): {out[:60]!r}",
)
return None
# Heuristic split: model is trained to emit one assistant turn
# carrying both plan text AND a `say` tool call. Look for a
# "<say>...</say>" or "say(...)" marker; fall back to whole
# text → plan, no speech.
plan_text, speech_text = _split_plan_and_say(out)
if plan_text and _looks_like_gibberish(plan_text):
plan_text = ""
if plan_text:
set_if_changed(state, "current_plan", plan_text, label="plan")
if speech_text:
push_log(state, f" speech: {speech_text}")
state.setdefault("tool_calls_pending", []).append(
{
"type": "function",
"function": {"name": "say", "arguments": {"text": speech_text}},
}
)
state.setdefault("events_this_tick", []).append("tool_call_pending")
# Mark interjection consumed.
state["recent_interjection"] = None
return None
@dataclass
class AskVQAFwd(InferenceStep):
"""On stdin question, answer a frame-grounded VQA.
Mirrors the ``ask_vqa_*`` recipe layout exactly: a single user
turn carrying just the VQA question, plus the camera image block
in training (we drop the image at inference because the dataset's
image preprocessing doesn't match SmolVLM's vision tower input).
user: <question>
↓ generate ↓
assistant: <vqa answer>
"""
policy: Any = None
observation_provider: Any = None
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_vqa_query"))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
if self.policy is None or not take_event(state, "user_vqa_query"):
return None
question = state.get("recent_vqa_query")
if not question:
return None
ctx = _msgs_for_vqa(question)
observation = _maybe_observation(self.observation_provider)
answer = _generate_with_policy(
self.policy,
ctx,
observation=observation,
state=state,
label="vqa gen",
)
# VQA answers are intentionally JSON-like during training, so
# ``_looks_like_gibberish`` would false-positive on them. Keep
# the answer as-is — the VQA panel line lets the user judge.
if answer:
push_log(state, f" vqa: {answer}")
state["recent_vqa_query"] = None
return None
# ---------------------------------------------------------------------------
# Tool dispatch
# ---------------------------------------------------------------------------
@dataclass
class DispatchToolCalls(InferenceStep):
"""Pop ``tool_calls_pending`` and execute them via :data:`TOOL_REGISTRY`."""
tools: dict[str, Any] = field(default_factory=dict)
trigger: Trigger = field(default_factory=lambda: EventTrigger("tool_call_pending"))
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
take_event(state, "tool_call_pending")
pending = state.get("tool_calls_pending") or []
for call in pending:
try:
fn = (call or {}).get("function") or {}
name = fn.get("name")
args = fn.get("arguments") or {}
tool = self.tools.get(name)
if tool is None:
push_log(state, f" [warn] tool {name!r} not registered — skipping call")
continue
tool.call(args)
except Exception as exc: # noqa: BLE001
push_log(state, f" [error] tool dispatch failed: {exc}")
state["tool_calls_pending"] = []
return None
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _looks_like_gibberish(text: str) -> bool:
"""Heuristically detect generation that's clearly off the rails.
Memorised models can collapse to dominant-mode outputs when the
prompt drifts even slightly from training distribution. Reject:
* empty / whitespace-only
* too few alphabetic characters (mostly punctuation)
* a single character repeated past the threshold
* starts with ``":"`` and contains no letters
* too few unique tokens — e.g. ``"the"``, ``"the the the"``,
``"Ass\\n::\\nthe"`` (the collapse seen on real-robot frames
where the model emits one or two memorised tokens repeatedly)
* chat-template fragment leakage (``Assistant:``, ``User:``,
``Ass\\n``)
Real subtasks look like ``"close the gripper to grasp the blue
cube"`` — multiple unique alphabetic tokens, no role-marker
fragments. Anything materially shorter than that is rejected.
"""
if not text or not text.strip():
return True
stripped = text.strip()
alpha = sum(1 for c in stripped if c.isalpha())
if alpha < max(3, len(stripped) // 8):
return True
if stripped.startswith('":') and stripped.count('"') > stripped.count(" "):
return True
# Single repeating char: e.g. ``""""""``.
if len(set(stripped)) <= 2 and len(stripped) > 4:
return True
# Chat-template fragment leakage — the model emits ``Ass``,
# ``Assistant:``, ``User:``, often with extra newlines/colons.
# Reject if the cleaned text is mostly role-marker shards.
cleaned = stripped.replace("\n", " ").replace(":", " ")
for marker in ("Assistant", "User", "Ass "):
if marker in cleaned and len(cleaned.split()) < 4:
return True
tokens = [t for t in cleaned.split() if any(c.isalpha() for c in t)]
unique_alpha = {t.lower() for t in tokens}
# Short degenerate output — model stuck on ``the`` or a couple of
# memorised single-token continuations.
if len(unique_alpha) < 3 and len(stripped) < 80:
return True
# Long repetition collapse — the LM head loops an n-gram for the
# whole generation budget ("the arm the arm … the the the the").
# Length-independent: many tokens but a tiny unique ratio. The
# earlier ``< 80`` check missed these because the looped string
# blows well past 80 chars.
if len(tokens) >= 8 and len(unique_alpha) <= max(3, len(tokens) // 10):
return True
return False
def _control_context_messages(
state: dict[str, Any],
*,
include_completed: bool = False,
extra_user: str | None = None,
) -> list[dict[str, Any]]:
"""Build a chat-template-ready prompt from current runtime state.
Mirrors what ``subtasks_vqa.yaml`` renders into ``${task}\nPlan:
${plan}\nMemory: ${memory}`` for the high-level branches.
"""
# Always emit ``Plan: `` / ``Memory: `` labels — even with empty
# values — to mirror the training-time recipe substitution.
task = state.get("task") or ""
plan = state.get("current_plan") or ""
memory = state.get("current_memory") or ""
parts = [task, f"Plan: {plan}", f"Memory: {memory}"]
if include_completed and state.get("current_subtask"):
parts.append(f"Completed subtask: {state['current_subtask']}")
head = "\n".join(parts)
msgs: list[dict[str, Any]] = [{"role": "user", "content": head}]
if extra_user:
msgs.append({"role": "user", "content": extra_user})
return msgs
# ---------------------------------------------------------------------------
# Per-recipe prompt builders. Each one mirrors a single sub-recipe's
# message layout in ``subtasks_vqa.yaml`` so the chat-templated
# prompt at inference matches what the model saw during training.
# Generic ``_control_context_messages`` is kept around as a fallback
# for ad-hoc callers but the four high-level steps now use these.
# ---------------------------------------------------------------------------
def _hirobot_user_head(state: dict[str, Any]) -> str:
"""Build the ``task\\nPlan: …\\nMemory: …`` user content string.
Mirrors what the recipe renders at training time, where
``language_render._substitute`` substitutes empty strings for
missing ``${plan}`` / ``${memory}`` bindings — i.e. the
``Plan: `` / ``Memory: `` prefix labels are *always* in the
user turn, even when their values aren't set yet. Skipping them
here (the previous behaviour) produced a different prompt shape
on early frames before plan / memory are populated and on
samples where the dataset has no plan / memory annotation.
"""
task = state.get("task") or ""
plan = state.get("current_plan") or ""
memory = state.get("current_memory") or ""
return f"{task}\nPlan: {plan}\nMemory: {memory}"
def _msgs_for_subtask(state: dict[str, Any]) -> list[dict[str, Any]]:
"""``high_level_subtask`` recipe layout — predict the subtask from the
task. The v-current recipe's user turn is just ``${task}`` (plan and
memory are not trained), so the inference prompt is the bare task —
no ``Plan: `` / ``Memory: `` lines.
"""
return [{"role": "user", "content": state.get("task") or ""}]
def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]:
"""Memory-update prompt — mirrors ``memory_update`` recipe layout.
Recipe layout (``subtask_mem.yaml``):
user: "${task}"
assistant: "Previous memory: ${prior_memory}" (if_present prior)
user: "Completed subtask: ${completed}" (if_present completed)
assistant: → predicts new memory
Fired by ``MemoryUpdateFwd`` on a ``subtask_change`` event:
``state['current_memory']`` is the memory the policy last emitted
(= the ``prior_memory`` binding at training), and
``state['prior_subtask']`` is the subtask that just got replaced
(= the ``completed_subtask`` binding at training).
"""
msgs: list[dict[str, Any]] = [
{"role": "user", "content": state.get("task") or ""},
]
prior_memory = state.get("current_memory")
if prior_memory:
msgs.append(
{"role": "assistant", "content": f"Previous memory: {prior_memory}"}
)
completed_subtask = state.get("prior_subtask")
if completed_subtask:
msgs.append(
{"role": "user", "content": f"Completed subtask: {completed_subtask}"}
)
return msgs
def _msgs_for_interjection(state: dict[str, Any]) -> list[dict[str, Any]]:
"""``user_interjection_response`` recipe layout."""
msgs: list[dict[str, Any]] = [
{"role": "user", "content": state.get("task") or ""}
]
if state.get("current_plan"):
msgs.append(
{"role": "assistant", "content": f"Previous plan:\n{state['current_plan']}"}
)
interjection = state.get("recent_interjection")
if interjection:
msgs.append({"role": "user", "content": interjection})
return msgs
def _msgs_for_plan(state: dict[str, Any]) -> list[dict[str, Any]]:
"""``plan_generation`` recipe layout — bare task → plan.
The assistant turn is the generation target, so we only render
the user turn at inference; the runtime appends the predicted
plan after sampling.
"""
return [{"role": "user", "content": state.get("task") or ""}]
def _msgs_for_vqa(question: str) -> list[dict[str, Any]]:
"""``ask_vqa_*`` recipe layout (text-only at inference)."""
return [{"role": "user", "content": question}]
def _maybe_observation(provider: Any) -> dict | None:
"""Pull one observation from ``provider`` if it's set, else ``None``.
Errors from the provider are logged at debug level and swallowed —
text generation still runs (in text-only mode) so a flaky frame
source doesn't kill the REPL.
"""
if provider is None:
return None
try:
return provider()
except Exception as exc: # noqa: BLE001
logger.debug("observation_provider raised %s — falling back to text-only", exc)
return None
def _generate_with_policy(
policy: Any,
messages: list[dict[str, Any]],
*,
observation: dict | None = None,
state: dict[str, Any] | None = None,
label: str = "select_message",
min_new_tokens: int = 0,
temperature: float = 0.0,
top_p: float = 1.0,
suppress_loc_tokens: bool = False,
) -> str:
"""Drive ``policy.select_message`` with a chat batch (and optional obs).
When ``observation`` carries ``observation.images.*`` and
``observation.state``, those are merged into the batch so
``select_message`` runs the same VLM prefix the policy was trained
on. Without an observation the runtime falls back to a text-only
prompt — the text head still runs, but generations may drift from
the training distribution.
Failures are surfaced both to the module logger (``warning``) and,
when ``state`` is given, to the runtime's user-visible log via
:func:`push_log`, so the REPL no longer "looks dead" when
something goes wrong inside generation.
"""
if not hasattr(policy, "select_message"):
if state is not None:
push_log(state, f" [warn] policy has no select_message — skipping {label}")
return ""
text_batch = _build_text_batch(policy, messages)
try:
from lerobot.utils.constants import ( # noqa: PLC0415
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
)
batch: dict[str, Any] = {
OBS_LANGUAGE_TOKENS: text_batch["lang_tokens"],
OBS_LANGUAGE_ATTENTION_MASK: text_batch["lang_masks"],
}
if observation:
for k, v in observation.items():
if isinstance(k, str) and k.startswith("observation.") and k not in batch:
batch[k] = v
kwargs: dict[str, Any] = {
"tokenizer": text_batch["tokenizer"],
"min_new_tokens": min_new_tokens,
"temperature": temperature,
"top_p": top_p,
}
kwargs["suppress_loc_tokens"] = suppress_loc_tokens
return policy.select_message(batch, **kwargs)
except Exception as exc: # noqa: BLE001
logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG))
if state is not None:
push_log(state, f" [warn] {label} failed: {type(exc).__name__}: {exc}")
return ""
_SAY_RE = re.compile(r"<\s*say\s*>(.*?)<\s*/\s*say\s*>", re.IGNORECASE | re.DOTALL)
def _split_plan_and_say(text: str) -> tuple[str, str]:
"""Pull a ``<say>...</say>`` snippet out of ``text``; remainder is plan.
The training-time tool-call serializer wraps ``say(text="")`` in a
deterministic textual marker so prefix-LM-style training learns to
emit it. The runtime parses it back here. If no marker is present,
the entire text is treated as plan with no speech.
"""
if not text:
return "", ""
match = _SAY_RE.search(text)
if not match:
return text.strip(), ""
speech = match.group(1).strip().strip('"').strip("'")
plan = (text[: match.start()] + text[match.end() :]).strip()
return plan, speech

View File

@@ -1,134 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trigger primitives for PI052's multi-rate inference runtime.
Mirrors the plan's Section "Runtime orchestration": each
``InferenceStep`` is gated by a :class:`Trigger` that decides per tick
whether the step fires. Two trigger flavours cover all the cadences
the canonical recipe needs:
* :class:`HzTrigger` for periodic beats (action chunks at ~3-5 Hz,
high-level subtask generation at ~1 Hz, action dispatch at ~50 Hz)
* :class:`EventTrigger` for one-shot reactions (subtask boundary →
memory update; user interjection → plan refresh; user VQA query →
vqa answer; pending tool call → dispatcher)
Triggers are stateless except for ``HzTrigger``'s last-fire timestamp.
The runtime stores the :class:`Tick` clock as ``state["_tick"]`` so
every step shares a single time source.
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from typing import Any, Protocol
@dataclass
class Tick:
"""Single tick from :class:`TickClock`. Carries time references the
runtime steps consume to gate themselves."""
index: int
"""Monotonic counter — increments by one per tick."""
monotonic_seconds: float
"""``time.monotonic()`` at the start of this tick."""
@dataclass
class TickClock:
"""Drives the runtime loop at up to ``max_rate_hz``.
Sleeps just enough between :meth:`advance` calls to enforce the
rate. With ``max_rate_hz=50`` the loop wakes ~every 20ms; the
higher-level ``HzTrigger`` slices that timeline into sub-cadences.
"""
max_rate_hz: float = 50.0
_index: int = field(default=0, init=False)
_last_seconds: float | None = field(default=None, init=False)
def advance(self) -> Tick:
period = 1.0 / max(self.max_rate_hz, 0.1)
now = time.monotonic()
if self._last_seconds is not None:
sleep_for = (self._last_seconds + period) - now
if sleep_for > 0:
time.sleep(sleep_for)
now = time.monotonic()
self._last_seconds = now
self._index += 1
return Tick(index=self._index, monotonic_seconds=now)
class Trigger(Protocol):
"""Decide whether the next ``InferenceStep`` should fire."""
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool: ...
@dataclass
class HzTrigger:
"""Fire at most ``hz`` times per second.
A step that gates further (e.g. ``HighLevelSubtaskFwd`` skipping
when the action queue is non-empty) and wants the trigger to
retry next tick instead of waiting a full period can call
:meth:`rearm` from inside ``run``. Without this, a low-hz trigger
(e.g. ``hz=0.2`` = once per 5 s) almost never coincides with the
brief queue-empty window and the step never fires at all.
"""
hz: float
_last_seconds: float | None = field(default=None, init=False)
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
period = 1.0 / max(self.hz, 1e-6)
if self._last_seconds is None or (tick.monotonic_seconds - self._last_seconds) >= period:
self._last_seconds = tick.monotonic_seconds
return True
return False
def rearm(self) -> None:
"""Mark the trigger as not having fired, so the next tick re-evaluates.
Used by a step that decided to skip after ``should_fire`` already
committed the firing — keeps the cadence honest without losing
the slot.
"""
self._last_seconds = None
@dataclass
class EventTrigger:
"""Fire when ``event_name`` is in ``state["events_this_tick"]``.
The runtime fills ``events_this_tick`` once per tick from:
* stdin / network input (``user_interjection``, ``user_vqa_query``,
``stop``)
* internal state transitions (``subtask_change``,
``tool_call_pending``)
The list is consumed (cleared at the end of the tick) so events
fire at most once.
"""
event_name: str
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
events: list[str] = state.get("events_this_tick") or []
return self.event_name in events

View File

@@ -1,127 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rich-based REPL layout for the PI052 runtime.
Two-zone terminal layout:
[chat scrollback — user messages / robot responses, scrolls naturally]
┌── State ──────────────────────────────────────────┐
│ task please clean up the kitchen │
│ subtask grasp the handle of the sponge │
│ plan 1. grasp sponge 2. wipe 3. tidy │
│ memory sponge picked up; counter still dirty │
└───────────────────────────────────────────────────┘
> _
The state panel re-renders on every state change. Chat lines are
``console.print``'d above the live region so they accumulate naturally
in scrollback. Implemented with :class:`rich.live.Live` plus
:func:`rich.console.Console.input` for the prompt — when an input is
pending, ``rich.Live`` auto-suspends so the input doesn't fight the
panel for cursor position.
"""
from __future__ import annotations
from typing import Any
try: # rich is optional; only required for the interactive REPL.
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
_HAS_RICH = True
except ImportError: # pragma: no cover
_HAS_RICH = False
Console = Any # type: ignore[assignment]
Panel = Any # type: ignore[assignment]
Table = Any # type: ignore[assignment]
Text = Any # type: ignore[assignment]
_STATE_KEYS = (
("task", "task"),
("current_subtask", "subtask"),
("current_plan", "plan"),
("current_memory", "memory"),
)
def make_state_panel(state: dict[str, Any]) -> Any:
"""Render the persistent state panel for the live region.
Returns a :class:`rich.panel.Panel`. Caller passes it to
``Live.update(panel)`` whenever the state changes.
"""
if not _HAS_RICH:
raise RuntimeError(
"rich is required for the interactive REPL. "
"`pip install rich` (it's a transitive dep of lerobot)."
)
table = Table.grid(padding=(0, 2), expand=True)
table.add_column(justify="right", style="dim", no_wrap=True, width=10)
table.add_column(justify="left")
for key, label in _STATE_KEYS:
value = state.get(key)
if value is None:
rendered = Text("(not set)", style="dim italic")
else:
rendered = Text(str(value), style="bold")
table.add_row(label, rendered)
queue = state.get("action_queue")
queue_len = len(queue) if hasattr(queue, "__len__") else 0
pending = state.get("tool_calls_pending") or []
footer = Text.assemble(
("queued actions: ", "dim"),
(str(queue_len), "bold cyan"),
(" pending tool calls: ", "dim"),
(str(len(pending)), "bold magenta"),
)
table.add_row("", footer)
run_mode = state.get("mode", "action")
mode_tag = (
"[green]action[/]" if run_mode == "action" else "[yellow]paused[/]"
)
return Panel(
table,
title=f"[bold]PI052 state[/] · mode: {mode_tag}",
border_style="cyan",
)
def print_user_line(console: Any, line: str) -> None:
"""Append a user-typed line to the chat scrollback."""
if not _HAS_RICH:
print(f"you: {line}", flush=True)
return
console.print(f"[bold cyan]you:[/] {line}")
def print_robot_lines(console: Any, lines: list[str]) -> None:
"""Append robot/runtime log lines to the chat scrollback."""
if not _HAS_RICH:
for line in lines:
print(f"robot: {line.lstrip()}", flush=True)
return
for line in lines:
# The runtime uses leading whitespace + "label: text"; render
# the label in green and the value in default for readability.
stripped = line.lstrip()
if ":" in stripped:
label, _, value = stripped.partition(":")
console.print(f"[bold green]robot[/] [dim]({label.strip()})[/] {value.strip()}")
else:
console.print(f"[bold green]robot:[/] {stripped}")

View File

@@ -1,423 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Interactive VQA for the PI052 runtime.
In ``/vlm`` mode a typed line is treated as a VQA question. This module
runs the full interactive flow:
1. pull the current observation and list available cameras,
2. ask the operator which camera to ground the question on,
3. generate the answer with the VLM conditioned on that one camera,
4. parse the JSON answer; if it carries a bounding box (``bbox``) or a
point (``keypoint``), draw the overlay on the camera frame, save a
PNG to ``./vqa_overlays/`` and auto-open it.
VQA answer schemas mirror the annotation pipeline's ``VQA_ANSWER_SHAPES``
(see ``lerobot.annotations.steerable_pipeline.validator``):
* ``bbox`` — ``{"detections": [{"label", "bbox_format": "xyxy",
"bbox": [x1, y1, x2, y2]}, ...]}``
* ``keypoint`` — ``{"label", "point_format": "xy", "point": [x, y]}``
* ``count`` / ``attribute`` / ``spatial`` — text-only, no overlay.
"""
from __future__ import annotations
import json
import logging
import os
import re
import subprocess
import sys
import time
import webbrowser
from pathlib import Path
from typing import Any
from .runtime_state import push_log
logger = logging.getLogger(__name__)
_IMAGE_PREFIX = "observation.images."
# PaliGemma detection / pointing vocabulary. PI052 trains spatial VQA
# answers in this native ``<locNNNN>`` format (index in [0, 1023],
# normalized to the image axis) instead of pixel-coordinate JSON, so the
# answer string the runtime parses can be e.g.
# ``<loc0512><loc0301> blue cube`` (point) or
# ``<loc0100><loc0080><loc0400><loc0360> blue cube`` (box).
_LOC_RE = re.compile(r"<loc(\d{1,4})>")
# Iteration order for shape matching — most specific keys first so an
# answer is classified deterministically.
_SHAPE_ORDER = ("bbox", "keypoint", "count", "attribute", "spatial")
_BBOX_COLOR = (255, 64, 64)
_POINT_COLOR = (64, 220, 64)
# ---------------------------------------------------------------------------
# Camera selection
# ---------------------------------------------------------------------------
def available_cameras(observation: dict | None) -> list[str]:
"""Return the sorted ``observation.images.*`` keys present in ``observation``."""
if not observation:
return []
return sorted(k for k in observation if isinstance(k, str) and k.startswith(_IMAGE_PREFIX))
def camera_short_name(camera_key: str) -> str:
"""Strip the ``observation.images.`` prefix for display."""
return camera_key[len(_IMAGE_PREFIX) :] if camera_key.startswith(_IMAGE_PREFIX) else camera_key
def prompt_camera_choice(
cameras: list[str],
*,
input_fn: Any = input,
print_fn: Any = print,
) -> str | None:
"""Ask the operator which camera frame to draw a VQA overlay on.
Accepts either the menu number or the (short or full) camera name.
A single-camera setup auto-selects without prompting. Returns the
chosen ``observation.images.*`` key, or ``None`` if the operator
cancels / gives an invalid answer.
"""
if not cameras:
return None
if len(cameras) == 1:
return cameras[0]
print_fn("Draw the result on which camera?")
for i, cam in enumerate(cameras, 1):
print_fn(f" [{i}] {camera_short_name(cam)}")
try:
raw = str(input_fn("camera> ")).strip()
except (EOFError, KeyboardInterrupt):
return None
if not raw:
return cameras[0]
if raw.isdigit():
idx = int(raw) - 1
return cameras[idx] if 0 <= idx < len(cameras) else None
for cam in cameras:
if raw == cam or raw == camera_short_name(cam):
return cam
return None
# ---------------------------------------------------------------------------
# Answer parsing
# ---------------------------------------------------------------------------
def _loc_to_norm(idx: int) -> float:
"""PaliGemma ``<locNNNN>`` index → normalized [0, 1] axis coordinate."""
return max(0.0, min(1023.0, float(idx))) / 1023.0
def parse_loc_answer(answer: str) -> dict | None:
"""Parse a PaliGemma ``<loc>``-format spatial VQA answer.
PI052 trains spatial answers in PaliGemma's native detection
vocabulary, label-first: a point is ``<label> <locY><locX>``, a box
is ``<label> <locY0><locX0><locY1><locX1>``, and multiple boxes are
joined by `` ; `` (e.g. ``cube <loc..><loc..><loc..><loc..> ; box
<loc..><loc..><loc..><loc..>``). Loc-first formats are also accepted
— this parser strips loc tokens and treats the remainder as the
label, so order is irrelevant. Coordinates come back *normalized*
([0, 1]); the overlay denormalizes them against the chosen camera
frame's pixel size.
Returns ``{"kind", "payload", "normalized": True}`` on success
(``payload`` mirrors the JSON shapes so the overlay code is shared),
or ``None`` when the answer carries no ``<loc>`` tokens.
"""
if not answer or "<loc" not in answer:
return None
segments = [seg for seg in answer.split(";") if "<loc" in seg]
points: list[tuple[float, float, str]] = []
boxes: list[tuple[float, float, float, float, str]] = []
for seg in segments:
locs = [int(m) for m in _LOC_RE.findall(seg)]
label = _LOC_RE.sub("", seg).strip()
if len(locs) == 2:
y, x = (_loc_to_norm(v) for v in locs[:2])
points.append((x, y, label))
elif len(locs) >= 4:
y1, x1, y2, x2 = (_loc_to_norm(v) for v in locs[:4])
boxes.append((x1, y1, x2, y2, label))
if boxes:
detections = [
{"label": lbl, "bbox_format": "xyxy", "bbox": [x1, y1, x2, y2]}
for (x1, y1, x2, y2, lbl) in boxes
]
return {"kind": "bbox", "payload": {"detections": detections}, "normalized": True}
if len(points) == 1:
x, y, lbl = points[0]
return {
"kind": "keypoint",
"payload": {"label": lbl, "point_format": "xy", "point": [x, y]},
"normalized": True,
}
if points: # several bare points → treat as detections-as-points
detections = [
{"label": lbl, "bbox_format": "xyxy", "bbox": [x, y, x, y]} for (x, y, lbl) in points
]
return {"kind": "bbox", "payload": {"detections": detections}, "normalized": True}
return None
def parse_vqa_answer(answer: str) -> dict | None:
"""Parse a VQA answer string into ``{"kind", "payload"}``.
``kind`` is one of the ``VQA_ANSWER_SHAPES`` names (``bbox``,
``keypoint``, ``count``, ``attribute``, ``spatial``) or ``"unknown"``
when the JSON doesn't match any known shape. PaliGemma ``<loc>``
spatial answers are detected first (PI052 trains them in that native
format). Returns ``None`` when the answer is neither ``<loc>`` text
nor a parseable JSON object.
"""
if not answer or not answer.strip():
return None
loc_parsed = parse_loc_answer(answer)
if loc_parsed is not None:
return loc_parsed
try:
payload = json.loads(answer)
except (ValueError, TypeError):
return None
if not isinstance(payload, dict):
return None
try:
from lerobot.annotations.steerable_pipeline.validator import ( # noqa: PLC0415
VQA_ANSWER_SHAPES,
)
shapes = VQA_ANSWER_SHAPES
except ImportError: # pragma: no cover - annotation extra not installed
shapes = {
"bbox": {"detections"},
"keypoint": {"label", "point_format", "point"},
"count": {"label", "count"},
"attribute": {"label", "attribute", "value"},
"spatial": {"subject", "relation", "object"},
}
keys = set(payload)
for kind in _SHAPE_ORDER:
required = shapes.get(kind)
if required and required <= keys:
return {"kind": kind, "payload": payload}
return {"kind": "unknown", "payload": payload}
def answer_has_overlay(parsed: dict | None) -> bool:
"""True iff ``parsed`` carries drawable spatial coordinates."""
return bool(parsed) and parsed.get("kind") in ("bbox", "keypoint")
# ---------------------------------------------------------------------------
# Overlay drawing
# ---------------------------------------------------------------------------
def observation_image_to_pil(image_tensor: Any) -> Any:
"""Convert an ``observation.images.*`` tensor to a PIL RGB image.
The runtime observation stores images as ``(1, C, H, W)`` (or
``(C, H, W)``) float tensors in ``[0, 1]``. Reuses
``image_array_to_pil_image`` which handles the CHW→HWC transpose and
the float→uint8 scaling.
"""
from lerobot.datasets.image_writer import image_array_to_pil_image # noqa: PLC0415
arr = image_tensor
if hasattr(arr, "detach"):
arr = arr.detach().cpu()
if hasattr(arr, "numpy"):
arr = arr.numpy()
while arr.ndim > 3: # drop leading batch dim(s)
arr = arr[0]
return image_array_to_pil_image(arr).convert("RGB")
def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
"""Draw ``bbox`` / ``keypoint`` answers onto a copy of ``image``.
Non-spatial answers (``count`` / ``attribute`` / ``spatial`` /
``unknown``) are returned as an unmodified copy. When ``parsed`` has
``normalized=True`` (PaliGemma ``<loc>`` answers) the [0, 1]
coordinates are scaled to the image's pixel size.
"""
from PIL import ImageDraw # noqa: PLC0415
img = image.convert("RGB").copy()
kind = parsed.get("kind")
payload = parsed.get("payload") or {}
draw = ImageDraw.Draw(img)
w, h = img.size
sx, sy = (w, h) if parsed.get("normalized") else (1, 1)
if kind == "bbox":
for det in payload.get("detections") or []:
if not isinstance(det, dict):
continue
box = det.get("bbox")
if not (isinstance(box, list | tuple) and len(box) == 4):
continue
try:
x1, y1, x2, y2 = (float(v) for v in box)
except (TypeError, ValueError):
continue
x1, x2 = x1 * sx, x2 * sx
y1, y2 = y1 * sy, y2 * sy
draw.rectangle([x1, y1, x2, y2], outline=_BBOX_COLOR, width=3)
label = str(det.get("label", "")).strip()
if label:
draw.text((x1 + 3, max(0.0, y1 - 12)), label, fill=_BBOX_COLOR)
elif kind == "keypoint":
point = payload.get("point")
if isinstance(point, list | tuple) and len(point) == 2:
try:
x, y = float(point[0]) * sx, float(point[1]) * sy
except (TypeError, ValueError):
return img
r = 6
draw.ellipse([x - r, y - r, x + r, y + r], outline=_POINT_COLOR, width=3)
draw.line([x - 2 * r, y, x + 2 * r, y], fill=_POINT_COLOR, width=2)
draw.line([x, y - 2 * r, x, y + 2 * r], fill=_POINT_COLOR, width=2)
label = str(payload.get("label", "")).strip()
if label:
draw.text((x + r + 3, y - r), label, fill=_POINT_COLOR)
return img
def _open_file(path: Path) -> None:
"""Best-effort open ``path`` in the OS default viewer."""
try:
if sys.platform == "darwin":
subprocess.run(["open", str(path)], check=False)
elif sys.platform.startswith("linux"):
subprocess.run(["xdg-open", str(path)], check=False)
elif os.name == "nt":
os.startfile(str(path)) # type: ignore[attr-defined] # noqa: S606
else: # pragma: no cover - exotic platform
webbrowser.open(path.resolve().as_uri())
except Exception as exc: # noqa: BLE001
logger.debug("could not auto-open %s: %s", path, exc)
def save_and_open_overlay(image: Any, out_dir: str | Path = "./vqa_overlays") -> Path:
"""Save ``image`` as a timestamped PNG under ``out_dir`` and auto-open it."""
out = Path(out_dir)
out.mkdir(parents=True, exist_ok=True)
path = out / f"vqa_{int(time.time() * 1000)}.png"
image.save(path)
_open_file(path)
return path
# ---------------------------------------------------------------------------
# Orchestrator
# ---------------------------------------------------------------------------
def handle_vqa_query(
*,
policy: Any,
observation_provider: Any,
question: str,
state: dict[str, Any],
input_fn: Any = input,
print_fn: Any = print,
) -> None:
"""Run one interactive VQA question end to end.
Called synchronously from the input layer while the runtime is in
``/question`` mode (the action loop is gated off, so the policy is
not in concurrent use). Progress is reported via both
:func:`push_log` (REPL panel scrollback) and ``print_fn`` (direct
stdout) — in autonomous question mode the panel redraw is suspended,
so the direct print is what the operator actually sees.
"""
from .steps import _generate_with_policy, _msgs_for_vqa # noqa: PLC0415
def report(line: str) -> None:
"""Surface a line both to the panel scrollback and to stdout."""
push_log(state, line)
try:
print_fn(line)
except Exception: # noqa: BLE001
pass
if policy is None or not hasattr(policy, "select_message"):
report(" [warn] vqa: policy has no select_message — skipping")
return
observation: dict | None = None
if observation_provider is not None:
try:
observation = observation_provider()
except Exception as exc: # noqa: BLE001
logger.debug("observation_provider raised %s", exc)
# Feed the FULL observation (every camera + state) to the VLM. The
# ``ask_vqa_*`` recipes look single-camera, but the image *block* is
# stripped before tokenization — the actual frames reach the model
# via PI052's ``OBS_IMAGES_*`` channels, and ``embed_prefix``
# consumes *all* ``config.image_features`` regardless of which
# camera the sub-recipe was tagged for. So the model always sees
# every camera; the operator never has to name one to ask.
answer = _generate_with_policy(
policy,
_msgs_for_vqa(question),
observation=observation,
state=state,
label="vqa gen",
)
if not answer:
report(" [info] vqa gen returned empty")
return
report(f" vqa: {answer}")
parsed = parse_vqa_answer(answer)
if not answer_has_overlay(parsed):
if parsed is None:
report(" [info] vqa answer is not JSON — no overlay")
return
# The answer carries a bounding box / point. Its pixel coordinates
# are camera-specific and the text answer doesn't say which camera,
# so ask the operator *now* — only when there is actually something
# to draw — which camera frame to render the overlay on.
cameras = available_cameras(observation)
if observation is None or not cameras:
report(" [info] no camera image — cannot draw overlay")
return
chosen = prompt_camera_choice(cameras, input_fn=input_fn, print_fn=print_fn)
if chosen is None:
report(" [info] overlay skipped — no camera selected")
return
try:
pil = observation_image_to_pil(observation[chosen])
overlay = draw_vqa_overlay(pil, parsed)
path = save_and_open_overlay(overlay)
report(f" vqa overlay ({camera_short_name(chosen)}) saved: {path}")
except Exception as exc: # noqa: BLE001
logger.warning("vqa overlay failed: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG))
report(f" [warn] vqa overlay failed: {type(exc).__name__}: {exc}")

File diff suppressed because it is too large Load Diff

View File

@@ -1,198 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""π0.5 v2 pre/post-processor factory.
When ``config.recipe_path`` is set, the pre-processor pipeline becomes:
rename observations
add batch dim
relative-action prep (inherited from π0.5)
NormalizerProcessorStep
RenderMessagesStep — recipe → messages, target_message_indices,
message_streams (PR 1 of the steerable
stack)
PI052TextTokenizerStep — messages → input_ids + label mask +
predict_actions
DeviceProcessorStep
When ``recipe_path`` is ``None`` we delegate to the plain π0.5 pipeline
so unannotated datasets keep working.
Post-processor is unchanged from π0.5.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
from lerobot.configs.recipe import TrainingRecipe
from lerobot.processor import (
AbsoluteActionsProcessorStep,
ActionTokenizerProcessorStep,
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RelativeActionsProcessorStep,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
policy_action_to_transition,
transition_to_policy_action,
)
# RenderMessagesStep is intentionally not re-exported from
# ``lerobot.processor`` because it pulls in optional language-stack deps;
# import it directly.
from lerobot.processor.render_messages_processor import RenderMessagesStep
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from ..pi05.processor_pi05 import make_pi05_pre_post_processors
from .configuration_pi052 import PI052Config
from .text_processor_pi052 import PI052TextTokenizerStep
def make_pi052_pre_post_processors(
config: PI052Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
dataset_repo_id: str | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Build PI0.5-v2's pre/post-processor pipelines.
Falls through to π0.5's stock pipeline when ``recipe_path`` is unset.
"""
if not config.recipe_path:
return make_pi05_pre_post_processors(config, dataset_stats=dataset_stats)
recipe = _load_recipe(config.recipe_path)
relative_step = RelativeActionsProcessorStep(
enabled=config.use_relative_actions,
exclude_joints=getattr(config, "relative_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None),
)
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
relative_step,
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
RenderMessagesStep(recipe=recipe),
PI052TextTokenizerStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
plan_dropout_prob=getattr(config, "plan_dropout_prob", 0.0),
memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0),
subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0),
),
]
# FAST tokenizer for discrete-action CE supervision (paper §III.C).
# Only inserted when explicitly enabled — keeps the post-training-
# style recipe (flow + text) as the default. When on, the step
# writes ACTION_TOKENS / ACTION_TOKEN_MASK into
# ``COMPLEMENTARY_DATA`` and the modeling forward picks them up.
if getattr(config, "enable_fast_action_loss", False):
# Per Pertsch et al. 2025 (FAST [64], π0.5 §III.C): fit the
# tokenizer on this dataset's action distribution rather than
# using the universal codebook off the shelf. We do this once
# and cache to disk, keyed on (dataset, base, n_samples).
action_tokenizer_path = config.action_tokenizer_name
if (
getattr(config, "auto_fit_fast_tokenizer", False)
and dataset_repo_id is not None
):
from .fit_fast_tokenizer import fit_fast_tokenizer # noqa: PLC0415
cache_dir = Path(config.fast_tokenizer_cache_dir).expanduser()
try:
action_tokenizer_path = fit_fast_tokenizer(
dataset_repo_id=dataset_repo_id,
cache_dir=cache_dir,
base_tokenizer_name=config.action_tokenizer_name,
n_samples=config.fast_tokenizer_fit_samples,
chunk_size=config.chunk_size,
)
except Exception as exc: # noqa: BLE001
import logging # noqa: PLC0415
logging.getLogger(__name__).warning(
"FAST tokenizer fit failed (%s) — falling back to "
"the universal base tokenizer %r. Train will still "
"work but compression will be suboptimal.",
exc, config.action_tokenizer_name,
)
input_steps.append(
ActionTokenizerProcessorStep(
action_tokenizer_name=action_tokenizer_path,
max_action_tokens=config.max_action_tokens,
fast_skip_tokens=config.fast_skip_tokens,
paligemma_tokenizer_name="google/paligemma-3b-pt-224",
)
)
input_steps.append(DeviceProcessorStep(device=config.device))
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
AbsoluteActionsProcessorStep(
enabled=config.use_relative_actions,
relative_step=relative_step,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
def _load_recipe(path_str: str) -> TrainingRecipe:
"""Resolve ``path_str`` to a ``TrainingRecipe``.
Accepts an absolute path or a path relative to
``src/lerobot/configs/``.
"""
p = Path(path_str)
if not p.is_absolute() and not p.exists():
from lerobot.configs import recipe as _recipe_module # noqa: PLC0415
configs_dir = Path(_recipe_module.__file__).resolve().parent
candidate = configs_dir / path_str
if candidate.exists():
p = candidate
return TrainingRecipe.from_yaml(p)

View File

@@ -1,598 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""π0.5 v2 text-tokenisation step.
PaliGemma is *not* chat-pretrained, so we can't lean on
``tokenizer.apply_chat_template``. Instead we concatenate the rendered
messages as plain text with simple ``User: ... Assistant: ...`` role
delimiters — matching the prompt format π0.5 uses in the paper
(``Task: ... State: ... Action: ...``).
Outputs:
* ``OBS_LANGUAGE_TOKENS`` / ``OBS_LANGUAGE_ATTENTION_MASK`` — the
concatenated prompt tokenised by the PaliGemma tokenizer (the same
one ``processor_pi05`` already uses).
* ``text_labels`` — same shape as token ids, ``-100`` everywhere except
positions belonging to messages whose index is in
``target_message_indices``. ``modeling_pi052`` runs cross-entropy on
those positions via the PaliGemma ``lm_head``.
* ``predict_actions`` — bool tensor, ``True`` iff any of the rendered
target messages has ``message_streams[i] == "low_level"``.
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from typing import Any
import torch
from torch import Tensor
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
logger = logging.getLogger(__name__)
def _content_to_text(content: Any) -> str:
"""Collapse a message's ``content`` (string or multimodal blocks) to text."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [
b["text"]
for b in content
if isinstance(b, dict) and b.get("type") == "text" and isinstance(b.get("text"), str)
]
return "\n".join(parts)
return ""
def _flatten_say_tool_calls(message: dict[str, Any]) -> dict[str, Any]:
"""Serialize assistant ``say`` tool calls into a ``<say>...</say>`` marker.
PaliGemma's flat text prompt has no notion of structured tool calls,
and ``_format_messages`` only reads ``role`` / ``content`` — so
without this a ``say`` tool call is dropped entirely and never
supervised. Rewriting it into the content text as a ``<say>...</say>``
marker lets the LM head learn to emit it; the runtime parses it back
via ``_split_plan_and_say``. Messages without ``say`` tool calls are
returned unchanged (the structured calls, if any, are still dropped).
"""
tool_calls = message.get("tool_calls")
if not tool_calls:
return message
say_texts: list[str] = []
for call in tool_calls:
if not isinstance(call, dict):
continue
fn = call.get("function") or {}
if fn.get("name") != "say":
continue
args = fn.get("arguments")
if isinstance(args, str):
try:
import json # noqa: PLC0415
args = json.loads(args)
except (ValueError, TypeError):
args = {}
text = args.get("text", "") if isinstance(args, dict) else ""
if text:
say_texts.append(str(text))
new = dict(message)
new.pop("tool_calls", None)
if not say_texts:
return new
base = _content_to_text(new.get("content")).strip()
marker = "".join(f"<say>{t}</say>" for t in say_texts)
new["content"] = f"{base}\n{marker}" if base else marker
return new
def _strip_blocks(message: dict[str, Any]) -> dict[str, Any]:
"""Normalise a message's content to a plain string.
The recipe renderer can emit ``content`` as a string OR as a list
of HF-style multimodal blocks (``{type: text, text: ...}``,
``{type: image, feature: ...}``). PaliGemma's text tokenizer can
only consume strings, so we flatten: drop image blocks (cameras
flow through ``observation.images.*`` separately) and join text
block texts.
"""
new = dict(message)
new.pop("stream", None)
new.pop("target", None)
content = new.get("content")
if content is None:
new["content"] = ""
elif isinstance(content, str):
pass
elif isinstance(content, list):
parts: list[str] = []
for block in content:
if not isinstance(block, dict):
continue
if block.get("type") == "text":
t = block.get("text", "")
if isinstance(t, str):
parts.append(t)
new["content"] = "\n".join(parts)
else:
new["content"] = str(content)
return new
def _is_batched_messages(messages: Any) -> bool:
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
if value is None:
return [None] * batch_size
if isinstance(value, torch.Tensor):
if value.numel() == 1:
return [int(value.item())] * batch_size
values = value.reshape(-1).tolist()
return [int(v) for v in values[:batch_size]]
if isinstance(value, (list, tuple)):
if len(value) == 1:
return _sample_indices(value[0], batch_size)
return [int(v.item() if hasattr(v, "item") else v) for v in value[:batch_size]]
return [int(value)] * batch_size
# ---------------------------------------------------------------------------
# VQA spatial answers → PaliGemma <loc> format (PI052 only)
#
# PaliGemma is pre-trained on detection / pointing with a ``<locNNNN>``
# vocabulary (normalized [0, 1023]). The recipe's bbox / keypoint VQA
# answers are stored as JSON in Qwen2.5-VL's grounding convention:
# **01000 normalized coordinates**, NOT pixels. (Verified empirically
# on the published datasets: x and y both span 0..1000 with ~30% of
# values exceeding the camera's pixel dimensions — they're not pixels.)
# Converting to ``<loc>`` is therefore camera-resolution-independent:
# ``loc_idx = round(coord / 1000 * 1023)``. We do the conversion here —
# not in the dataset — so the dataset keeps the raw JSON and stays
# backbone-agnostic.
# ---------------------------------------------------------------------------
# The 01000 scale Qwen2.5-VL emits for grounding coordinates.
_VQA_COORD_SCALE = 1000.0
def register_paligemma_loc_tokens(tokenizer: Any) -> Any:
"""Make PaliGemma's ``<locDDDD>`` ids match on raw text — single tokens.
PaliGemma reserves vocab ids [256000, 257023] for ``<locDDDD>``
(detection / pointing) tokens, but the *stock* tokenizer does NOT
match them when encoding raw text — it BPE-splits ``<loc0162>`` into
7 pieces (``<``, ``loc``, ``0``, ``1``, ``6``, ``2``, ``>``). Training
the LM head on a ``<loc>`` target then supervises those 7 generic
BPE pieces instead of one detection-vocab id, the LM head learns to
emit the *character sequence*, and those pieces' logits dominate
other turns (the ``<loc>``-salad on subtasks). Registering the loc
tokens once makes them tokenize as their single ids (256000+idx),
leveraging PaliGemma's detection prior properly. Idempotent.
"""
if "<loc0000>" in getattr(tokenizer, "added_tokens_encoder", {}):
return tokenizer
tokenizer.add_tokens([f"<loc{i:04d}>" for i in range(1024)])
return tokenizer
def _loc_token(coord: float, scale: float = _VQA_COORD_SCALE) -> str:
"""PaliGemma ``<locNNNN>`` for a coord on a ``[0, scale]`` axis."""
idx = round(float(coord) / scale * 1023) if scale > 0 else 0
return f"<loc{max(0, min(1023, idx)):04d}>"
def _vqa_answer_to_loc(answer: dict[str, Any]) -> str | None:
"""Convert a bbox / keypoint VQA answer dict to PaliGemma ``<loc>`` text.
Input coordinates are in Qwen2.5-VL's 01000 normalized space (see
module-level note). y is emitted before x for each coordinate pair
(PaliGemma convention), with the integer indices in [0, 1023].
**Format: label first, locs after.** PaliGemma's pretraining puts
locs first (``<loc><loc> label``), but for our small-dataset VQA
blend that turns the LM head into a loc-emission attractor at every
``Assistant:`` position — VQA targets share their first supervised
token with ~25% of all text samples, and the head collapses to
emitting ``<loc>`` regardless of the prompt. Putting the label
first (``label <locY><locX>``) means every text sample (subtask,
memory, VQA, …) starts the supervised target with a real word,
breaking the attractor. The model still learns the loc vocabulary
for the *spatial* portion of the answer; it just can't fire it as
the first generation step from a clean prompt.
Returns ``None`` for non-spatial answers (count / attribute /
spatial-relation) — those keep their JSON form.
"""
point = answer.get("point")
if isinstance(point, list | tuple) and len(point) == 2 and "point_format" in answer:
try:
x, y = float(point[0]), float(point[1])
except (TypeError, ValueError):
return None
label = str(answer.get("label", "")).strip()
if not label:
return None
return f"{label} {_loc_token(y)}{_loc_token(x)}"
detections = answer.get("detections")
if isinstance(detections, list) and detections:
parts: list[str] = []
for det in detections:
if not isinstance(det, dict):
continue
box = det.get("bbox")
if not (isinstance(box, list | tuple) and len(box) == 4):
continue
try:
x1, y1, x2, y2 = (float(v) for v in box)
except (TypeError, ValueError):
continue
label = str(det.get("label", "")).strip()
if not label:
continue
toks = (
f"{_loc_token(y1)}{_loc_token(x1)}"
f"{_loc_token(y2)}{_loc_token(x2)}"
)
parts.append(f"{label} {toks}")
return " ; ".join(parts) if parts else None
return None
def _messages_vqa_to_loc(
messages: list[dict[str, Any]],
target_indices: list[int],
) -> list[dict[str, Any]]:
"""Rewrite bbox / keypoint VQA *target* answers from JSON to ``<loc>`` text.
Each target turn whose content parses as a spatial VQA answer is
converted. Non-spatial answers and subtask / memory targets (plain
text → not JSON) are left untouched. Camera-independent: VQA coords
are 01000 normalized, so no observation lookup is needed.
"""
if not target_indices:
return messages
out = list(messages)
for idx in target_indices:
if not (0 <= idx < len(out)):
continue
content = out[idx].get("content")
if not isinstance(content, str) or not content.strip():
continue
try:
answer = json.loads(content)
except (ValueError, TypeError):
continue # subtask / memory targets are plain text — skip
if not isinstance(answer, dict):
continue
loc_text = _vqa_answer_to_loc(answer)
if loc_text is not None:
out[idx] = {**out[idx], "content": loc_text}
return out
def _format_messages(
messages: list[dict[str, Any]],
target_indices: list[int] | None = None,
eos_token: str | None = None,
) -> tuple[str, list[tuple[int, int]]]:
"""Concatenate messages into the π0.5-style flat prompt.
When both ``target_indices`` and ``eos_token`` are given, the EOS
string is appended to each supervised target turn's content and the
returned span covers it — so the label builder marks the EOS token
as a supervised label. That teaches the LM head where the answer
*ends*: without an EOS in the target span the model is never given a
stop signal and rambles to ``max_length`` at inference. Inference
callers omit both args (no EOS baked into the prompt — the model
generates it and ``select_message`` stops on it).
Returns:
prompt: the full text the tokenizer will consume.
msg_spans: list of ``(char_start, char_end)`` covering each
message's supervised payload (content, plus the
appended EOS for target turns) within ``prompt``.
"""
targets = set(target_indices or [])
parts: list[str] = []
spans: list[tuple[int, int]] = []
cursor = 0
for i, m in enumerate(messages):
role = m.get("role", "user")
content = m.get("content", "") or ""
# Role tag + newline. The model has to learn to emit the same
# role tokens at generation time, which is fine for greedy
# decoding because the chat template is implicit in the
# supervised target span.
header = f"{role.capitalize()}: "
# A supervised target turn ends with EOS so the model learns to
# terminate; the span below covers content + EOS. Non-target
# turns (and inference) carry no EOS.
body = content + eos_token if (eos_token and i in targets) else content
# span covers the content (+ EOS) portion only — never the role
# tag — so labels are computed over the supervised payload.
full = header + body + "\n"
start = cursor + len(header)
end = start + len(body)
parts.append(full)
spans.append((start, end))
cursor += len(full)
return "".join(parts), spans
@dataclass
@ProcessorStepRegistry.register(name="pi052_text_tokenizer")
class PI052TextTokenizerStep(ProcessorStep):
"""Render messages → token ids + label mask + predict_actions flag.
No chat template; concatenates messages as
``User: ... \\nAssistant: ...`` text.
"""
tokenizer_name: str = "google/paligemma-3b-pt-224"
max_length: int = 200
padding: str = "max_length"
padding_side: str = "right"
plan_dropout_prob: float = 0.0
memory_dropout_prob: float = 0.0
subtask_dropout_prob: float = 0.0
interjection_dropout_prob: float = 0.0
dropout_seed: int | None = None
def __post_init__(self) -> None:
self._tokenizer: Any = None
def _ensure_tokenizer(self) -> Any:
if self._tokenizer is not None:
return self._tokenizer
from transformers import AutoTokenizer # noqa: PLC0415
self._tokenizer = register_paligemma_loc_tokens(
AutoTokenizer.from_pretrained(self.tokenizer_name)
)
return self._tokenizer
# ------------------------------------------------------------------
# Pipeline step
# ------------------------------------------------------------------
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
transition = transition.copy()
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
messages = complementary.get("messages") or []
if not messages:
# No recipe was rendered — caller will fall back to the
# plain Pi0.5 prompt path. We pass the transition through
# unmodified.
return transition
tokenizer = self._ensure_tokenizer()
# VQA coords are 01000 normalized (Qwen2.5-VL convention) — the
# <loc> conversion is camera-resolution-independent and needs no
# observation lookup here.
if _is_batched_messages(messages):
indices_iter = _sample_indices(complementary.get("index"), len(messages))
encoded = [
self._encode_messages(
tokenizer,
msg,
list(streams),
list(tgt_indices),
complementary,
sample_idx=int(s_idx) if s_idx is not None else None,
)
for msg, streams, tgt_indices, s_idx in zip(
messages,
complementary.get("message_streams") or [[] for _ in messages],
complementary.get("target_message_indices") or [[] for _ in messages],
indices_iter,
strict=False,
)
]
else:
sample_idx = _sample_indices(complementary.get("index"), 1)[0]
encoded = [
self._encode_messages(
tokenizer,
messages,
list(complementary.get("message_streams") or []),
list(complementary.get("target_message_indices") or []),
complementary,
sample_idx=sample_idx,
)
]
obs = dict(transition.get(TransitionKey.OBSERVATION) or {})
obs[OBS_LANGUAGE_TOKENS] = torch.stack([ids for ids, _, _, _, _ in encoded])
obs[OBS_LANGUAGE_ATTENTION_MASK] = torch.stack([attn for _, attn, _, _, _ in encoded])
transition[TransitionKey.OBSERVATION] = obs
transition[TransitionKey.COMPLEMENTARY_DATA] = {
**complementary,
"text_labels": torch.stack([labels for _, _, labels, _, _ in encoded]),
"predict_actions": torch.stack([pred for _, _, _, pred, _ in encoded]),
}
return transition
def _encode_messages(
self,
tokenizer: Any,
messages: list[dict[str, Any]],
message_streams: list[str | None],
target_indices: list[int],
complementary: dict[str, Any],
sample_idx: int | None = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor, str]:
# Optional: drop non-target messages per the dropout config.
# Keeps the supervised-target indices stable by re-mapping
# after removal.
if (
self.plan_dropout_prob
or self.memory_dropout_prob
or self.subtask_dropout_prob
or self.interjection_dropout_prob
):
messages, target_indices = self._apply_prompt_dropout(
messages,
target_indices,
complementary,
sample_idx=sample_idx,
)
# Rewrite bbox / keypoint VQA target answers from JSON to
# PaliGemma <loc> text. Coords are 01000 normalized so this is
# camera-independent.
messages = _messages_vqa_to_loc(messages, target_indices)
# Flatten ``say`` tool calls into ``<say>...</say>`` text before
# stripping, so the spoken reply is actually tokenized and
# supervised (PaliGemma's flat prompt has no structured calls).
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages]
# Append EOS to supervised target turns so the LM head learns to
# stop (the span covers it → it becomes a supervised label).
prompt, spans = _format_messages(
messages, target_indices, getattr(tokenizer, "eos_token", None)
)
encoded = tokenizer(
prompt,
max_length=self.max_length,
padding=self.padding,
truncation=True,
return_tensors="pt",
return_offsets_mapping=True,
padding_side=self.padding_side,
)
input_ids = encoded["input_ids"][0]
attention_mask = encoded["attention_mask"][0].bool()
offsets = encoded["offset_mapping"][0] # (seq, 2), char (start,end)
# Build label mask: -100 everywhere except over supervised
# target message char ranges.
labels = torch.full_like(input_ids, fill_value=-100)
for idx in target_indices:
if idx >= len(spans):
continue
char_start, char_end = spans[idx]
for token_pos in range(input_ids.shape[0]):
if not attention_mask[token_pos]:
continue
tok_start, tok_end = int(offsets[token_pos, 0]), int(offsets[token_pos, 1])
if tok_end <= char_start or tok_start >= char_end:
continue
labels[token_pos] = input_ids[token_pos]
# Scan ALL message streams (not just targets): the
# ``low_level_execution`` recipe drops ``target: true`` on
# the assistant to avoid trivial copy-from-user text-CE; the
# flow loss still needs to fire, gated by ``stream: low_level``.
predict_actions = torch.tensor(
bool(any(s == "low_level" for s in message_streams)),
dtype=torch.bool,
)
return input_ids, attention_mask, labels, predict_actions, prompt
# ------------------------------------------------------------------
# Per-component prompt dropout (Pi0.7 §V.E)
# ------------------------------------------------------------------
def _apply_prompt_dropout(
self,
messages: list[dict[str, Any]],
target_indices: list[int],
complementary: dict[str, Any],
sample_idx: int | None = None,
) -> tuple[list[dict[str, Any]], list[int]]:
"""Drop messages classified as plan/memory/subtask context.
Targets are *never* dropped (they're the supervised payload).
Re-maps target_indices to the new positions after drops.
"""
import random # noqa: PLC0415
seed = self.dropout_seed
if seed is None:
# Canonical row-index key set by ``BatchProcessor`` /
# ``render_messages_processor``. Falling back to other
# keys silently gave every sample seed=0 → identical
# dropout pattern across the whole epoch.
seed_src = sample_idx if sample_idx is not None else complementary.get("index", 0)
try:
if hasattr(seed_src, "item"):
seed_src = seed_src.item()
seed = int(seed_src)
except (TypeError, ValueError):
seed = 0
rng = random.Random(seed)
keep_indices: list[int] = []
for idx, msg in enumerate(messages):
if idx in target_indices:
keep_indices.append(idx)
continue
kind = _classify_for_dropout(msg)
prob = {
"plan": self.plan_dropout_prob,
"memory": self.memory_dropout_prob,
"subtask": self.subtask_dropout_prob,
"interjection": self.interjection_dropout_prob,
}.get(kind, 0.0)
if prob > 0.0 and rng.random() < prob:
continue
keep_indices.append(idx)
# Build remap and apply
new_messages = [messages[i] for i in keep_indices]
old_to_new = {old: new for new, old in enumerate(keep_indices)}
new_targets = [old_to_new[t] for t in target_indices if t in old_to_new]
return new_messages, new_targets
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def _classify_for_dropout(message: dict[str, Any]) -> str | None:
"""Heuristic content-prefix classifier (plan / memory / subtask)."""
content = message.get("content")
if isinstance(content, list):
text_parts = [b.get("text", "") for b in content if isinstance(b, dict) and b.get("type") == "text"]
content = " ".join(text_parts)
elif content is None:
return None
elif not isinstance(content, str):
return None
s = content.strip()
if s.startswith("Plan:") or s.startswith("Previous plan"):
return "plan"
if s.startswith("Memory:") or s.startswith("Previous memory"):
return "memory"
if s.startswith("Current subtask") or s.startswith("Completed subtask"):
return "subtask"
return None

View File

@@ -275,8 +275,6 @@ class PiGemmaModel(GemmaModel): # type: ignore[misc]
# Convert to bfloat16 if the first layer uses bfloat16
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.bfloat16)
if causal_mask is not None and torch.is_floating_point(causal_mask):
causal_mask = causal_mask.to(dtype=hidden_states.dtype)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)

View File

@@ -175,6 +175,9 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
complementary_data["task_index"] = task_index_value.unsqueeze(0)
complementary_data.pop("language_persistent", None)
complementary_data.pop("language_events", None)
if "messages" in complementary_data:
messages = complementary_data["messages"]
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):

View File

@@ -52,9 +52,6 @@ class RenderMessagesStep(ProcessorStep):
if not persistent and not events:
return transition
if _is_batched_language(persistent) or _is_batched_language(events):
return self._call_batch(transition, complementary_data, persistent, events)
timestamp = complementary_data.get("timestamp")
if timestamp is None:
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
@@ -70,131 +67,18 @@ class RenderMessagesStep(ProcessorStep):
dataset_ctx=self.dataset_ctx,
)
if rendered is None:
rendered = _fallback_low_level_render(complementary_data.get("task"))
if rendered is None:
return None
return None
new_transition = transition.copy()
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
new_complementary_data = dict(complementary_data)
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
new_complementary_data.pop(LANGUAGE_EVENTS, None)
new_complementary_data.update(rendered)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def _call_batch(
self,
transition: EnvTransition,
complementary_data: dict[str, Any],
persistent_batch: list,
events_batch: list,
) -> EnvTransition | None:
timestamp = complementary_data.get("timestamp")
if timestamp is None:
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
batch_size = max(len(persistent_batch), len(events_batch))
messages: list[list[dict[str, Any]]] = []
message_streams: list[list[str | None]] = []
target_message_indices: list[list[int]] = []
keep_indices: list[int] = []
for i in range(batch_size):
rendered = render_sample(
recipe=self.recipe,
persistent=persistent_batch[i] if i < len(persistent_batch) else [],
events=events_batch[i] if i < len(events_batch) else [],
t=_batch_value(timestamp, i),
sample_idx=int(_batch_value(complementary_data.get("index", 0), i)),
task=_batch_value(complementary_data.get("task"), i),
dataset_ctx=self.dataset_ctx,
)
if rendered is None:
rendered = _fallback_low_level_render(_batch_value(complementary_data.get("task"), i))
if rendered is None:
continue
keep_indices.append(i)
messages.append(rendered["messages"])
message_streams.append(rendered["message_streams"])
target_message_indices.append(rendered["target_message_indices"])
if not messages:
return None
new_transition = (
_select_batch_indices(transition, keep_indices)
if len(keep_indices) != batch_size
else transition.copy()
)
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
new_complementary_data.pop(LANGUAGE_EVENTS, None)
new_complementary_data["messages"] = messages
new_complementary_data["message_streams"] = message_streams
new_complementary_data["target_message_indices"] = target_message_indices
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Pass features through unchanged; rendering only touches complementary data."""
return features
def _scalar(value: Any) -> float | int:
"""Unwrap a tensor/array/single-element list into a Python scalar."""
if hasattr(value, "item"):
return value.item()
if isinstance(value, list):
if len(value) != 1:
raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}")
return _scalar(value[0])
return value
def _is_batched_language(value: Any) -> bool:
return isinstance(value, list) and bool(value) and isinstance(value[0], list)
def _batch_value(value: Any, index: int) -> Any:
if value is None:
return None
if isinstance(value, list):
return value[index]
if hasattr(value, "ndim") and getattr(value, "ndim") > 0:
return _scalar(value[index])
return _scalar(value)
def _select_batch_indices(transition: EnvTransition, indices: list[int]) -> EnvTransition:
selected = transition.copy()
for key in (TransitionKey.OBSERVATION, TransitionKey.COMPLEMENTARY_DATA):
data = selected.get(key)
if isinstance(data, dict):
selected[key] = {k: _select_value(v, indices) for k, v in data.items()}
action = selected.get(TransitionKey.ACTION)
if action is not None:
selected[TransitionKey.ACTION] = _select_value(action, indices)
return selected
def _select_value(value: Any, indices: list[int]) -> Any:
if isinstance(value, list) and len(value) >= len(indices):
return [value[i] for i in indices]
if hasattr(value, "index_select") and hasattr(value, "new_tensor") and getattr(value, "ndim", 0) > 0:
return value.index_select(0, value.new_tensor(indices).long())
return value
def _fallback_low_level_render(task: Any) -> dict[str, Any] | None:
"""Keep action-only samples trainable when no recipe branch matches."""
if hasattr(task, "item"):
task = task.item()
if not isinstance(task, str) or not task:
return None
return {
"messages": [{"role": "user", "content": task}],
"message_streams": ["low_level"],
"target_message_indices": [],
}

View File

@@ -32,7 +32,6 @@ import torch
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.types import EnvTransition, RobotObservation, TransitionKey
from lerobot.utils.constants import (
ACTION_CODE_TOKEN_MASK,
ACTION_TOKEN_MASK,
ACTION_TOKENS,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -413,15 +412,14 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
# During inference, no action is available, skip tokenization
return new_transition
# Tokenize and get masks for the full formatted sequence and the discrete action codes.
tokens, mask, code_mask = self._tokenize_action(action)
# Tokenize and get both tokens and mask
tokens, mask = self._tokenize_action(action)
# Store mask in complementary data
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if complementary_data is None:
complementary_data = {}
complementary_data[ACTION_TOKEN_MASK] = mask
complementary_data[ACTION_CODE_TOKEN_MASK] = code_mask
complementary_data[ACTION_TOKENS] = tokens
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition
@@ -432,7 +430,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
"""
return self._paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Tokenizes the action tensor and creates a mask.
@@ -461,7 +459,6 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
# The fast tokenizer expects action data and returns token IDs
tokens_list = []
masks_list = []
code_masks_list = []
for i in range(batch_size):
# Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy)
@@ -479,26 +476,19 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
if tokens.dim() > 1:
tokens = tokens.flatten()
action_code_tokens = self._act_tokens_to_paligemma_tokens(tokens)
bos_id = self._paligemma_tokenizer.bos_token_id
prompt_tokens = torch.tensor(
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
device=action.device,
)
end_tokens = torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device)
code_start = 1 + len(prompt_tokens)
code_end = code_start + len(action_code_tokens)
# add bos
tokens = torch.cat(
[
torch.tensor([bos_id], device=action.device),
prompt_tokens,
action_code_tokens,
end_tokens,
torch.tensor(
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
device=action.device,
),
self._act_tokens_to_paligemma_tokens(tokens),
torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),
]
)
code_mask = torch.zeros(len(tokens), dtype=torch.bool, device=action.device)
code_mask[code_start:code_end] = True
# Truncate or pad to max_action_tokens
if len(tokens) > self.max_action_tokens:
@@ -507,49 +497,44 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
"Consider increasing the `max_action_tokens` in your model config if this happens frequently."
)
tokens = tokens[: self.max_action_tokens]
code_mask = code_mask[: self.max_action_tokens]
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
else:
pad_len = self.max_action_tokens - len(tokens)
mask = torch.cat(
[
torch.ones(len(tokens), dtype=torch.bool, device=action.device),
torch.zeros(pad_len, dtype=torch.bool, device=action.device),
torch.zeros(
self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device
),
]
)
code_mask = torch.nn.functional.pad(code_mask, (0, pad_len), value=False)
# Pad tokens with zeros
tokens = torch.nn.functional.pad(tokens, (0, pad_len), value=0)
tokens = torch.nn.functional.pad(tokens, (0, self.max_action_tokens - len(tokens)), value=0)
tokens_list.append(tokens)
masks_list.append(mask)
code_masks_list.append(code_mask)
# Stack into batched tensors
tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens)
masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens)
code_masks_batch = torch.stack(code_masks_list, dim=0) # (B, max_action_tokens)
# Remove batch dimension if input was single sample
if single_sample:
tokens_batch = tokens_batch.squeeze(0)
masks_batch = masks_batch.squeeze(0)
code_masks_batch = code_masks_batch.squeeze(0)
# Move to the same device as the input
if device is not None:
tokens_batch = tokens_batch.to(device)
masks_batch = masks_batch.to(device)
code_masks_batch = code_masks_batch.to(device)
return tokens_batch, masks_batch, code_masks_batch
return tokens_batch, masks_batch
def action(self, action: torch.Tensor) -> torch.Tensor:
"""
This method is not used since we override __call__.
Required by ActionProcessorStep ABC.
"""
tokens, _, _ = self._tokenize_action(action)
tokens, _ = self._tokenize_action(action)
return tokens
def get_config(self) -> dict[str, Any]:

View File

@@ -21,8 +21,6 @@ from lerobot.utils.import_utils import make_device_from_device_class
from .config import RobotConfig
from .robot import Robot
logger = logging.getLogger(__name__)
def make_robot_from_config(config: RobotConfig) -> Robot:
# TODO(Steven): Consider just using the make_device_from_device_class for all types
@@ -120,7 +118,7 @@ def ensure_safe_goal_position(
}
if warnings_dict:
logger.warning(
logging.warning(
"Relative goal position magnitude had to be clamped to be safe.\n"
f"{pformat(warnings_dict, indent=4)}"
)

View File

@@ -1,345 +0,0 @@
#!/usr/bin/env python3
"""Build a single combined LeRobotDataset from RoboCasa's 16 composite_seen tasks.
RoboCasa 1.0 already ships in LeRobot format (parquet + mp4), distributed as
``lerobot.tar`` archives from Box. This script:
1. Downloads each composite_seen task's ``target/human`` archive via RoboCasa's
official ``download_datasets`` helper (idempotent — skipped if already on
disk).
2. Opens each extracted directory as a ``LeRobotDataset``.
3. Merges all 16 into one unified dataset via ``merge_datasets`` (a thin wrapper
over ``aggregate_datasets`` that revalidates fps / robot_type / features,
unifies task indices, concatenates videos and parquet, and recomputes stats).
4. Optionally pushes the merged dataset to the Hub.
The result is one ~8,000-trajectory dataset where each episode carries its
source task as the ``task`` field — ready for downstream annotation
(subtasks / memory / VQA / tool calls) without per-task bookkeeping.
Usage::
uv run python -m lerobot.scripts.build_robocasa_composite_seen \\
--output-dir=/data/lerobot/robocasa_composite_seen \\
--hub-repo-id=${HF_USER}/robocasa_composite_seen \\
--push-to-hub
Prereqs: ``robocasa`` and ``robosuite`` installed (see
``docs/source/benchmarks/robocasa.mdx`` for the editable-install dance — they
are not on PyPI and RoboCasa's own ``setup.py`` pins an old LeRobot version).
The 16 composite_seen tasks are the multi-step subset of the official
RoboCasa365 target benchmark — exactly the slice used to compute the
``Composite-Seen`` column of the leaderboard.
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
from lerobot.datasets.dataset_tools import merge_datasets
from lerobot.datasets.lerobot_dataset import LeRobotDataset
logger = logging.getLogger(__name__)
# Canonical 16 composite_seen tasks (RoboCasa365 target benchmark).
# Order matches the leaderboard docs.
COMPOSITE_SEEN_TASKS: list[str] = [
"DeliverStraw",
"GetToastedBread",
"KettleBoiling",
"LoadDishwasher",
"PackIdenticalLunches",
"PreSoakPan",
"PrepareCoffee",
"RinseSinkBasin",
"ScrubCuttingBoard",
"SearingMeat",
"SetUpCuttingStation",
"StackBowlsCabinet",
"SteamInMicrowave",
"StirVegetables",
"StoreLeftoversInBowl",
"WashLettuce",
]
def _require_robocasa() -> None:
"""Fail fast with an actionable message if robocasa is missing.
RoboCasa is not on PyPI and is not a LeRobot extra — see the installation
notes in ``docs/source/benchmarks/robocasa.mdx``.
"""
try:
import robocasa # noqa: F401, PLC0415
from robocasa.scripts import download_datasets as _dl # noqa: F401, PLC0415
from robocasa.utils import dataset_registry as _reg # noqa: F401, PLC0415
except ImportError as exc:
sys.exit(
"[build_robocasa_composite_seen] robocasa is not importable.\n"
"Install it (and robosuite) per the LeRobot RoboCasa docs:\n"
" git clone https://github.com/robocasa/robocasa.git ~/robocasa\n"
" git clone https://github.com/ARISE-Initiative/robosuite.git ~/robosuite\n"
" pip install -e ~/robocasa --no-deps\n"
" pip install -e ~/robosuite\n"
f"(original error: {exc})"
)
def _resolve_task_root(task: str) -> Path:
"""Resolve the local extracted ``LeRobotDataset`` root for a target/human task.
Uses RoboCasa's own ``dataset_registry`` so we follow whatever directory
layout RoboCasa picks (currently ``v1.0/target/composite/<task>/<date>/``
under ``robocasa.macros.DATASET_BASE_DIR``). Falls back to discovering the
extracted directory if the helper's signature drifted between releases.
"""
from robocasa.utils import dataset_registry # noqa: PLC0415
# ``get_ds_path`` is the canonical helper. RoboCasa 1.0 signature is
# ``get_ds_path(task, ds_type, return_info=False)`` with ``ds_type`` like
# ``"human_im"`` (image-observation human demos). We try the common
# ``split=`` kwarg first (newer registry); if it's rejected, fall back.
try:
ds_path = dataset_registry.get_ds_path(
task=task,
ds_type="human_im",
return_info=False,
split="target",
)
except TypeError:
# Older registry — ds_type alone disambiguates target/human.
ds_path = dataset_registry.get_ds_path(
task=task,
ds_type="human_im",
return_info=False,
)
root = Path(ds_path)
# ``get_ds_path`` may return either the extracted dir or the .tar; normalize.
if root.suffix == ".tar":
root = root.parent
return root
def _download_task(task: str, *, overwrite: bool = False) -> Path:
"""Download (or locate) a single target/human task and return its extracted root."""
from robocasa.scripts import download_datasets as dl # noqa: PLC0415
# Try the documented programmatic API. The CLI is
# python -m robocasa.scripts.download_datasets --tasks <T> --source human --split target
# which is a thin wrapper over a function of the same name.
if hasattr(dl, "download_datasets"):
try:
dl.download_datasets(
tasks=[task],
source="human",
split="target",
overwrite=overwrite,
)
except TypeError:
# Older signature — drop the kwargs RoboCasa didn't have yet.
dl.download_datasets(tasks=[task])
else:
# No public function — shell out to the CLI as a last resort. This
# guarantees we use whatever entrypoint RoboCasa's authors maintain.
import subprocess # noqa: PLC0415
cmd = [
sys.executable,
"-m",
"robocasa.scripts.download_datasets",
"--tasks",
task,
"--source",
"human",
"--split",
"target",
]
if overwrite:
cmd.append("--overwrite")
subprocess.run(cmd, check=True)
root = _resolve_task_root(task)
if not root.exists():
raise RuntimeError(
f"Expected {root} after download, but it doesn't exist. "
"RoboCasa may have changed its data layout — verify with "
"`robocasa.utils.dataset_registry.get_ds_path()`."
)
return root
def _open_as_lerobot_dataset(task: str, root: Path) -> LeRobotDataset:
"""Open an extracted RoboCasa target/human task as a ``LeRobotDataset``.
The placeholder ``repo_id`` (``robocasa/<task>_target_human``) is only used
by the aggregator for logging and for the unified task table — the actual
data is loaded from ``root``.
"""
repo_id = f"robocasa/{task}_target_human"
return LeRobotDataset(repo_id=repo_id, root=root)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Aggregate the 16 RoboCasa composite_seen target tasks into one LeRobotDataset.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"--output-dir",
type=Path,
required=True,
help="Local directory for the merged dataset (will be created).",
)
parser.add_argument(
"--hub-repo-id",
type=str,
default=None,
help=(
"Hub repo_id for the merged dataset (e.g. ``yourname/"
"robocasa_composite_seen``). Required for ``--push-to-hub``; also "
"becomes the merged dataset's canonical ``repo_id``."
),
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push the merged dataset to the Hub after building. Requires "
"``--hub-repo-id`` and a prior ``huggingface-cli login``.",
)
parser.add_argument(
"--private",
action="store_true",
help="When pushing, create the Hub repo as private.",
)
parser.add_argument(
"--tasks",
type=str,
default=None,
help="Comma-separated task names to override the default 16 "
"composite_seen list (useful for smoke-testing with 12 tasks).",
)
parser.add_argument(
"--skip-download",
action="store_true",
help="Skip the download step entirely; assume each task is already "
"extracted on disk at the path ``dataset_registry.get_ds_path`` "
"returns.",
)
parser.add_argument(
"--overwrite-download",
action="store_true",
help="Force re-download even when a complete local extraction exists.",
)
parser.add_argument(
"--log-level",
type=str,
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
)
return parser.parse_args()
def main() -> int:
args = parse_args()
logging.basicConfig(
level=getattr(logging, args.log_level),
format="[%(levelname)s] %(message)s",
)
tasks = (
[t.strip() for t in args.tasks.split(",") if t.strip()]
if args.tasks
else list(COMPOSITE_SEEN_TASKS)
)
if not tasks:
sys.exit("No tasks selected.")
if args.push_to_hub and not args.hub_repo_id:
sys.exit("--push-to-hub requires --hub-repo-id.")
output_repo_id = args.hub_repo_id or "local/robocasa_composite_seen"
logger.info(
"Building merged RoboCasa dataset: %d tasks → %s (output dir: %s)",
len(tasks),
output_repo_id,
args.output_dir,
)
_require_robocasa()
# 1. Download (or locate) each task's extracted directory.
task_roots: list[tuple[str, Path]] = []
for i, task in enumerate(tasks, 1):
logger.info("[%d/%d] %s", i, len(tasks), task)
if args.skip_download:
root = _resolve_task_root(task)
if not root.exists():
sys.exit(
f"--skip-download set but extracted directory does not "
f"exist for {task}: {root}"
)
else:
root = _download_task(task, overwrite=args.overwrite_download)
logger.info(" extracted at: %s", root)
task_roots.append((task, root))
# 2. Open each as a LeRobotDataset (validation happens inside aggregator).
datasets: list[LeRobotDataset] = []
for task, root in task_roots:
logger.info("Opening %s", task)
ds = _open_as_lerobot_dataset(task, root)
logger.info(
" %s: %d episodes, %d frames, %d FPS",
task,
ds.num_episodes,
ds.num_frames,
ds.fps,
)
datasets.append(ds)
# 3. Merge — re-validates features/fps/robot_type, unifies tasks, concats
# videos + parquet, recomputes stats.
logger.info("Merging %d datasets into %s", len(datasets), output_repo_id)
merged = merge_datasets(
datasets=datasets,
output_repo_id=output_repo_id,
output_dir=args.output_dir,
)
logger.info(
"Merged: %d episodes, %d frames across %d unique task strings",
merged.num_episodes,
merged.num_frames,
len(merged.meta.tasks) if merged.meta.tasks is not None else 0,
)
# 4. Push to Hub.
if args.push_to_hub:
logger.info("Pushing %s to the Hub (private=%s)", args.hub_repo_id, args.private)
# ``upload_large_folder=True`` is the right mode for tens-of-GB
# datasets — uses multipart uploads + resumable transfers.
merged.push_to_hub(
private=args.private,
upload_large_folder=True,
tags=["lerobot", "robocasa", "composite_seen", "manipulation"],
)
logger.info(
"Push complete: https://huggingface.co/datasets/%s",
args.hub_repo_id,
)
else:
logger.info(
"Skipping Hub push (no --push-to-hub). Merged dataset is at %s.",
args.output_dir,
)
return 0
if __name__ == "__main__":
raise SystemExit(main())

File diff suppressed because it is too large Load Diff

View File

@@ -20,7 +20,6 @@ Requires: pip install 'lerobot[training]' (includes dataset + accelerate + wand
import dataclasses
import logging
import os
import time
from contextlib import nullcontext
from pprint import pformat
@@ -44,7 +43,7 @@ from lerobot.common.train_utils import (
from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets import EpisodeAwareSampler, WeightedEpisodeAwareSampler, make_dataset
from lerobot.datasets import EpisodeAwareSampler, make_dataset
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
@@ -162,196 +161,6 @@ def update_policy(
return train_metrics, output_dict
def _print_debug_text_predictions(
policy: Any, batch: dict[str, Any], step: int, n_samples: int = 5
) -> None:
"""Forward the current batch and print head-argmax vs label per supervised position.
Opt-in via ``LEROBOT_DEBUG_PREDS_EVERY=<step_interval>``. Only the
policy types that expose ``debug_text_predictions`` participate
(currently PI052); others are silently skipped. Pretty-prints up to
``n_samples`` samples from the current batch, showing the prompt,
every supervised position's (label, prediction, ✓/✗), and a
per-sample token-accuracy summary — the cheapest "is text training
actually learning anything" signal.
"""
# Accelerator/DDP wraps the policy in a ``module`` attribute and
# doesn't proxy custom methods through, so a naive
# ``hasattr(policy, "debug_text_predictions")`` returns False on the
# wrapper — and the helper would silently no-op. Walk through any
# ``.module`` indirection (DDP, FSDP, ``accelerator.prepare`` wrappers)
# to reach the raw policy that actually defines the method.
inner = policy
while hasattr(inner, "module") and not hasattr(inner, "debug_text_predictions"):
inner = inner.module
if not hasattr(inner, "debug_text_predictions"):
logging.warning(
"LEROBOT_DEBUG_PREDS_EVERY set but policy %s has no "
"debug_text_predictions method — skipping dump.",
type(inner).__name__,
)
return
try:
debug = inner.debug_text_predictions(batch, max_samples=n_samples)
except Exception as exc: # noqa: BLE001
logging.warning("debug_text_predictions failed: %s", exc, exc_info=True)
return
if not debug:
logging.warning(
"debug_text_predictions returned no supervised samples — "
"current batch has no text labels."
)
return
policy = inner # used below for select_message-style decoding parity
# Build a tokenizer for decoding — match training side exactly.
try:
from transformers import AutoTokenizer # noqa: PLC0415
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: PLC0415
register_paligemma_loc_tokens,
)
tok_name = (
getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224"
)
tokenizer = register_paligemma_loc_tokens(AutoTokenizer.from_pretrained(tok_name))
except Exception as exc: # noqa: BLE001
logging.warning("debug preds: tokenizer load failed: %s", exc)
return
ids = debug["input_ids"]
labels = debug["labels"]
preds = debug["predictions"]
attn = debug["attention_mask"]
inference = debug.get("inference") or []
n = ids.shape[0]
print(
f"\n========== STEP {step} DEBUG PREDICTIONS ({n} samples) ==========",
flush=True,
)
for s in range(n):
a = attn[s].tolist()
real = sum(a)
sid = ids[s].tolist()
sl = labels[s].tolist()
sp = preds[s].tolist()
prompt = tokenizer.decode(sid[:real], skip_special_tokens=False)
print(f"\n --- sample {s + 1}/{n} ---", flush=True)
print(f" prompt: {prompt!r}", flush=True)
# Ground-truth target (the contiguous supervised label span).
sup_ids = [int(sid[i]) for i in range(real) if sl[i] != -100]
if sup_ids:
print(
f" target (ground truth) : {tokenizer.decode(sup_ids, skip_special_tokens=False)!r}",
flush=True,
)
# Training-side teacher-forced argmax on the same prompt+target.
n_sup = n_ok = 0
first_sup_pred: int | None = None
teacher_chars: list[int] = []
for i in range(1, real):
label = sl[i]
if label == -100:
continue
n_sup += 1
pred = int(sp[i - 1])
if first_sup_pred is None:
first_sup_pred = pred
teacher_chars.append(pred)
if label == pred:
n_ok += 1
teacher_text = (
tokenizer.decode(teacher_chars, skip_special_tokens=False) if teacher_chars else ""
)
acc = n_ok / max(n_sup, 1)
print(
f" training argmax (teacher-fed) : {teacher_text!r} acc={n_ok}/{n_sup}={acc:.1%}",
flush=True,
)
# Inference-side autoregressive output from the same prompt prefix.
inf_entry = inference[s] if s < len(inference) else None
if inf_entry:
inf_decoded = inf_entry.get("decoded", "")
print(f" inference (autoregressive) : {inf_decoded!r}", flush=True)
# First-token parity: training-side argmax at the prompt-end
# position MUST equal inference's first generated token —
# both compute argmax(lm_head(h_last_prompt)) on identical
# context. Any divergence signals a training↔inference bug.
if first_sup_pred is not None and inf_decoded and not inf_decoded.startswith("<inference"):
inf_ids = tokenizer(inf_decoded, add_special_tokens=False)["input_ids"]
if inf_ids:
inf_first = int(inf_ids[0])
match = inf_first == first_sup_pred
print(
f" first-token parity : "
f"train={first_sup_pred} ({tokenizer.decode([first_sup_pred])!r}) "
f"vs infer={inf_first} ({tokenizer.decode([inf_first])!r}) "
f"{'✓ MATCH' if match else '✗ DIVERGED — training/inference mismatch'}",
flush=True,
)
print("=" * 60 + "\n", flush=True)
def _build_vqa_oversample_weights(dataset: Any, target_fraction: float) -> "torch.Tensor | None":
"""Build per-frame sampling weights that oversample VQA-annotated frames.
Scans the dataset's ``language_events`` column for frames carrying a
``vqa``-style annotation and returns a weight tensor (length == total
dataset frames) such that, under multinomial sampling, VQA frames make up
roughly ``target_fraction`` of the training stream.
Returns ``None`` (⇒ fall back to uniform episode-aware sampling) when VQA
frames cannot be detected or there are none.
"""
if not 0.0 < target_fraction < 1.0:
logging.warning(
"vqa_target_fraction must be in (0, 1); got %s — VQA oversampling disabled.",
target_fraction,
)
return None
hf = getattr(dataset, "hf_dataset", None)
if hf is None or "language_events" not in getattr(hf, "column_names", []):
logging.warning(
"Dataset has no `language_events` column — VQA oversampling disabled."
)
return None
events_col = hf["language_events"]
n_frames = len(events_col)
is_vqa = torch.zeros(n_frames, dtype=torch.bool)
for i, rows in enumerate(events_col):
if rows and any((row or {}).get("style") == "vqa" for row in rows):
is_vqa[i] = True
n_vqa = int(is_vqa.sum())
if n_vqa == 0:
logging.warning("No `vqa` annotations found in the dataset — VQA oversampling disabled.")
return None
n_other = n_frames - n_vqa
# Solve target = (n_vqa·w) / (n_vqa·w + n_other) for the VQA weight w.
# Clamp to ≥ 1 so VQA frames are never *down*-weighted below uniform.
weight = (target_fraction * n_other) / ((1.0 - target_fraction) * max(n_vqa, 1))
weight = max(weight, 1.0)
weights = torch.ones(n_frames, dtype=torch.double)
weights[is_vqa] = weight
logging.info(
"VQA oversampling: %d/%d frames carry a `vqa` annotation (%.2f%%); "
"weighting them x%.2f to target ~%.0f%% of the training stream.",
n_vqa,
n_frames,
100.0 * n_vqa / n_frames,
weight,
100.0 * target_fraction,
)
return weights
@parser.wrap()
def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
"""
@@ -483,17 +292,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
active_cfg = cfg.trainable_config
processor_pretrained_path = active_cfg.pretrained_path
# pi052: even when loading pretrained weights, build the processors
# from the current pi052 config so the recipe text-label and FAST
# action-label steps are generated and not silently swapped for the
# checkpoint's older processor stack.
if cfg.policy.type == "pi052" and processor_pretrained_path is not None and not cfg.resume:
logging.warning(
"pi052 is loading pretrained weights from %s, but building processors from the current "
"pi052 config so recipe text labels and FAST action labels are generated.",
processor_pretrained_path,
)
processor_pretrained_path = None
if (
getattr(active_cfg, "use_relative_actions", False)
and processor_pretrained_path is not None
@@ -513,14 +311,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if cfg.is_reward_model_training:
processor_kwargs["dataset_meta"] = dataset.meta
# For pi052 (and any future policy that auto-fits part of its
# preprocessing per-dataset), pass the dataset repo id so the
# processor factory can locate/refresh dataset-specific artifacts
# (e.g. fitted FAST tokenizers per Pertsch et al. 2025 [64],
# π0.5 §III.C).
if cfg.policy.type == "pi052":
processor_kwargs["dataset_repo_id"] = cfg.dataset.repo_id
if not cfg.is_reward_model_training and processor_pretrained_path is not None:
processor_kwargs["preprocessor_overrides"] = {
"device_processor": {"device": device.type},
@@ -601,29 +391,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# create dataloader for offline training
if hasattr(active_cfg, "drop_n_last_frames"):
shuffle = False
from_indices = dataset.meta.episodes["dataset_from_index"]
to_indices = dataset.meta.episodes["dataset_to_index"]
# When `vqa_target_fraction` is set, oversample VQA-annotated
# frames via a weighted sampler; otherwise plain episode-aware.
vqa_weights = None
if cfg.vqa_target_fraction is not None and not cfg.dataset.streaming:
vqa_weights = _build_vqa_oversample_weights(dataset, cfg.vqa_target_fraction)
if vqa_weights is not None:
sampler = WeightedEpisodeAwareSampler(
from_indices,
to_indices,
vqa_weights,
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=active_cfg.drop_n_last_frames,
)
else:
sampler = EpisodeAwareSampler(
from_indices,
to_indices,
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=active_cfg.drop_n_last_frames,
shuffle=True,
)
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=active_cfg.drop_n_last_frames,
shuffle=True,
)
else:
shuffle = True
sampler = None
@@ -654,54 +428,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
policy.train()
# ------------------------------------------------------------------
# EMA setup
# ------------------------------------------------------------------
# Shadow copy of the trainable params for late-training averaging
# (Chi et al. 2023 Diffusion Policy §V.D; openpi JAX trainer ships
# this with decay=0.999 for pi05_libero; openpi PyTorch port and
# LeRobot main both skip it). Off by default; opt in with
# ``--ema.enable=true``. Implemented via ema-pytorch
# (https://github.com/lucidrains/ema-pytorch) — the standard PyTorch
# EMA library, also used by lucidrains' diffusion repos.
ema = None
if cfg.ema.enable and is_main_process:
from ema_pytorch import EMA # noqa: PLC0415
ema = EMA(
accelerator.unwrap_model(policy),
beta=cfg.ema.decay,
update_after_step=cfg.ema.warmup_steps,
update_every=1, # update on every ema.update() call
# Don't register the live model as an ema submodule — accelerator
# already owns its lifecycle, and double-registration would
# double-count its params in ``ema.state_dict()``.
include_online_model=False,
)
ema.to(accelerator.device)
logging.info(
"EMA enabled (ema-pytorch): beta=%g, update_after_step=%d, "
"use_for_eval=%s, use_for_wandb_examples=%s",
cfg.ema.decay,
cfg.ema.warmup_steps,
cfg.ema.use_for_eval,
cfg.ema.use_for_wandb_examples,
)
# Resume the EMA shadow if a previous run wrote one.
if cfg.checkpoint_path is not None:
ema_path = cfg.checkpoint_path / "training_state" / "ema_state.pt"
if ema_path.exists():
logging.info("Resuming EMA shadow from %s", ema_path)
try:
ema.load_state_dict(torch.load(ema_path, map_location=accelerator.device))
except Exception as exc: # noqa: BLE001
logging.warning(
"Failed to load EMA shadow (%s) — restarting EMA from "
"current live weights",
exc,
)
train_metrics = {
"loss": AverageMeter("loss", ":.3f"),
"grad_norm": AverageMeter("grdn", ":.3f"),
@@ -754,14 +480,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
sample_weighter=sample_weighter,
)
# EMA update: pull one step of the live weights into the shadow.
# Runs only on the main process (the shadow lives there); other
# ranks rely on the live model staying in sync via accelerator.
# ``ema-pytorch`` holds an internal reference to the online model
# (set at construction), so ``ema.update()`` takes no args.
if ema is not None:
ema.update()
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.
step += 1
@@ -772,27 +490,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
# Optional periodic head-prediction dump for the LM head:
# ``LEROBOT_DEBUG_PREDS_EVERY=1000`` prints 5 samples + per-token
# (label, argmax, ✓/✗) every 1000 steps. Cheap diagnostic to see
# whether the text head is actually learning what we expect, vs
# collapsing to a fixed token. Refilling the recipe-sample dump
# budget at the same cadence also redumps the raw input shapes.
_debug_preds_every = int(os.environ.get("LEROBOT_DEBUG_PREDS_EVERY", "0"))
if (
_debug_preds_every > 0
and step % _debug_preds_every == 0
and is_main_process
):
try:
from lerobot.policies.pi052 import text_processor_pi052 as _tp # noqa: PLC0415
_tp._DUMPED_SO_FAR = 0
_tp._DUMP_BUDGET = max(_tp._DUMP_BUDGET, 5)
except Exception: # noqa: BLE001
pass
_print_debug_text_predictions(policy, batch, step, n_samples=5)
if is_log_step:
logging.info(train_tracker)
if wandb_logger:
@@ -803,49 +500,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if sample_weighter is not None:
weighter_stats = sample_weighter.get_stats()
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
# EMA observability: ``ema.step`` is the count of
# ``ema.update()`` calls (= optimizer steps once EMA is
# enabled); ``ema.initted`` flips to True once we've
# crossed ``update_after_step``.
if ema is not None:
wandb_log_dict["ema/step"] = int(ema.step.item())
wandb_log_dict["ema/initted"] = float(ema.initted.item())
wandb_log_dict["ema/beta"] = float(cfg.ema.decay)
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
# Periodic training-example dump to wandb (camera images + text
# fields + action endpoints). Opt-in via ``--wandb.log_examples_freq``;
# independent of ``--log_freq`` so you can keep scalar logs frequent
# and the heavier visual dump rare (e.g. every 5000 steps).
if (
wandb_logger is not None
and cfg.wandb.log_examples_freq > 0
and step % cfg.wandb.log_examples_freq == 0
and is_main_process
):
try:
# Optionally use the EMA shadow model directly for the
# predicted-action columns (matches what eval / deployment
# would see). ``ema-pytorch`` exposes the shadow as a
# full ``nn.Module`` at ``ema.ema_model``, so we just
# pass that instead of swap-and-restore.
target_policy = (
ema.ema_model
if (ema is not None and cfg.ema.use_for_wandb_examples)
else accelerator.unwrap_model(policy)
)
wandb_logger.log_training_examples(
batch=batch,
step=step,
camera_keys=list(dataset.meta.camera_keys),
n_samples=cfg.wandb.log_examples_n,
policy=target_policy,
predict_actions=cfg.wandb.log_examples_predict_actions,
)
except Exception as exc: # noqa: BLE001
logging.warning("wandb log_training_examples failed: %s", exc)
if cfg.save_checkpoint and is_saving_step:
if is_main_process:
logging.info(f"Checkpoint policy after step {step}")
@@ -861,18 +518,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
postprocessor=postprocessor,
)
update_last_checkpoint(checkpoint_dir)
# Save the EMA shadow alongside the training state so a
# resumed run picks up exactly where the live EMA left off.
# ``ema-pytorch.state_dict()`` returns the full shadow
# nn.Module's state dict + step/initted buffers; saved as
# .pt (the rest of training_state mixes formats already).
if ema is not None:
try:
ema_path = checkpoint_dir / "training_state" / "ema_state.pt"
ema_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(ema.state_dict(), ema_path)
except Exception as exc: # noqa: BLE001
logging.warning("Failed to save EMA shadow: %s", exc)
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
@@ -882,20 +527,10 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if is_main_process:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
# Use the EMA shadow model for eval when enabled —
# standard practice for diffusion-style policies (~13%
# lift on closed-loop success). ``ema.ema_model`` is a
# full nn.Module clone, so we just pass it through; no
# swap/restore on the live policy needed.
eval_target_policy = (
ema.ema_model
if (ema is not None and cfg.ema.use_for_eval)
else accelerator.unwrap_model(policy)
)
with torch.no_grad(), accelerator.autocast():
eval_info = eval_policy_all(
envs=eval_env, # dict[suite][task_id] -> vec_env
policy=eval_target_policy,
policy=accelerator.unwrap_model(policy),
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,

View File

@@ -1,29 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LeRobot tool implementations.
Storage of the tool catalog (``meta/info.json["tools"]``) and the
``SAY_TOOL_SCHEMA`` constant live in PR 1
(``lerobot.datasets.language``). This package holds the *runnable*
implementations one file per tool, plus the registry that maps tool
names to classes.
See ``docs/source/tools.mdx`` for the authoring guide.
"""
from .base import Tool
from .registry import TOOL_REGISTRY, get_tools
from .say import SayTool
__all__ = ["Tool", "TOOL_REGISTRY", "get_tools", "SayTool"]

View File

@@ -1,58 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tool protocol — the contract every runnable tool implementation honors.
Tools are the executable side of the OpenAI-style function-calling
abstraction the v3.1 language schema (PR 1) carries on assistant
messages: the schema describes *what can be called*, the tool
implementation describes *how to call it*.
Implementations live one-per-file under :mod:`lerobot.tools` (e.g.
``say.py`` for ``SayTool``) and are registered in
:mod:`lerobot.tools.registry`. The runtime instantiates them lazily so
heavy dependencies (torch models, audio backends, network clients,
hardware drivers) only load when the dataset actually declares the tool.
"""
from __future__ import annotations
from typing import Any, Protocol, runtime_checkable
@runtime_checkable
class Tool(Protocol):
"""Minimum surface every tool must expose."""
#: Name matching ``schema["function"]["name"]``. The runtime dispatcher
#: routes incoming ``tool_calls`` to the implementation by this key.
name: str
#: OpenAI-style function-call schema. Same dict the dataset stores in
#: ``meta/info.json["tools"]`` and the chat template renders into the
#: prompt.
schema: dict[str, Any]
def call(self, arguments: dict[str, Any]) -> Any:
"""Execute the tool with the model-provided arguments.
``arguments`` is the parsed dict from
``tool_calls[i]["function"]["arguments"]`` (already JSON-decoded
when the model emits a JSON-string by the chat-template
convention). Implementations validate the dict against their own
schema; the runtime only routes by name.
Return value is implementation-defined — typically a tensor
(TTS audio), a Path (saved file), a dict (structured result), or
``None`` (side-effect-only call).
"""

View File

@@ -1,70 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tool registry — name → implementation class.
Adding a new tool:
1. Drop a file under ``src/lerobot/tools/`` that defines a class
conforming to :class:`lerobot.tools.base.Tool` (must expose ``name``,
``schema``, ``call(arguments)``).
2. Register the class here under :data:`TOOL_REGISTRY`.
3. (Optional) Pre-populate ``meta/info.json["tools"]`` on your dataset
to advertise the schema to the chat-template + policy. The PR 2
annotation pipeline preserves anything you put there.
See ``docs/source/tools.mdx`` for the full authoring guide.
"""
from __future__ import annotations
from typing import Any
from .base import Tool
from .say import SayTool
#: Map from ``function.name`` to a class implementing :class:`Tool`.
#: The runtime instantiates entries lazily — registering a tool here is
#: essentially free (no model load happens until ``call`` runs).
TOOL_REGISTRY: dict[str, type] = {
"say": SayTool,
}
def get_tools(meta: Any, **kwargs: Any) -> dict[str, Tool]:
"""Build name → tool-instance dict from a dataset's declared catalog.
``meta`` is anything with a ``.tools`` attribute returning the
OpenAI-style schema list — typically a
:class:`lerobot.datasets.dataset_metadata.LeRobotDatasetMetadata`.
Each entry whose ``function.name`` is registered here is
instantiated with the schema dict; tools whose name is unknown to
the registry are skipped (the schema still rides through the chat
template, the model just can't actually invoke that tool at
inference).
Extra keyword arguments are forwarded to every constructor — useful
for runtime defaults like ``output_dir=Path("./tts_log")``.
"""
declared = list(meta.tools)
instances: dict[str, Tool] = {}
for schema in declared:
try:
name = schema["function"]["name"]
except (KeyError, TypeError):
continue
cls = TOOL_REGISTRY.get(name)
if cls is None:
continue
instances[name] = cls(schema=schema, **kwargs)
return instances

View File

@@ -1,169 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""``SayTool`` — text-to-speech tool wrapping Kyutai's pocket-tts.
The first concrete tool implementation. PI052 and downstream runtime
dispatchers consume this when the model emits an assistant message
with ``tool_calls=[{function: {name: "say", arguments: {text: ...}}}]``.
Why pocket-tts:
- runs on CPU (no GPU dependency); ~6× real-time on a MacBook Air M4
- ~100M parameters, ~200ms first-chunk latency
- streamable, voice-cloneable
- pip-installable, MIT-style permissive license
The pocket-tts model is loaded **lazily** the first time ``call(...)``
runs (or eagerly via ``preload()``). Loading takes a few seconds and
several hundred MB of RAM, so we don't pay the cost when the tool is
merely *registered* — only when it's *invoked*.
Optional dependency. Install with::
pip install lerobot[tools]
# or directly:
pip install pocket-tts
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from lerobot.datasets.language import SAY_TOOL_SCHEMA
logger = logging.getLogger(__name__)
@dataclass
class SayTool:
"""Speak a short utterance via Kyutai's pocket-tts.
Parameters
----------
schema:
Optional schema override; defaults to the canonical
``SAY_TOOL_SCHEMA`` from PR 1. Custom voices or extended
argument shapes can pass in a modified schema, but the
implementation only reads ``arguments["text"]``.
voice:
One of the pocket-tts catalog voices (``alba``, ``marius``,
``javert``, ``jean``, ``fantine``, ``cosette``, ``eponine``,
``azelma``) or a path to a ``.wav`` / ``.safetensors`` voice
file for cloning. See the pocket-tts model card for licensing.
output_dir:
If set, every ``call(...)`` writes a ``<timestamp>.wav`` audio
file there in addition to returning the PCM tensor.
``None`` (default) skips disk writes — useful for live
playback paths that hand the tensor directly to a sounddevice
/ WebAudio sink.
"""
schema: dict[str, Any] = field(default_factory=lambda: dict(SAY_TOOL_SCHEMA))
voice: str = "alba"
output_dir: Path | None = None
name: str = field(init=False, default="say")
_model: Any = field(init=False, default=None, repr=False)
_voice_state: Any = field(init=False, default=None, repr=False)
_sample_rate: int = field(init=False, default=24000, repr=False)
# ------------------------------------------------------------------
# Lazy model load
# ------------------------------------------------------------------
def preload(self) -> None:
"""Load the pocket-tts model + voice state into memory.
Optional — ``call(...)`` triggers this automatically on first
invocation. Useful when you want the multi-second load to
happen at startup rather than on the first ``say`` the policy
emits.
"""
if self._model is not None and self._voice_state is not None:
return
try:
from pocket_tts import TTSModel # noqa: PLC0415 (optional dep)
except ImportError as exc: # pragma: no cover (env-dependent)
raise ImportError(
"SayTool requires pocket-tts. Install with `pip install "
"lerobot[tools]` or `pip install pocket-tts`."
) from exc
logger.info("SayTool: loading pocket-tts model + voice=%r", self.voice)
self._model = TTSModel.load_model()
self._voice_state = self._model.get_state_for_audio_prompt(self.voice)
self._sample_rate = int(getattr(self._model, "sample_rate", 24000))
# ------------------------------------------------------------------
# Tool protocol
# ------------------------------------------------------------------
def call(self, arguments: dict[str, Any]) -> Any:
"""Speak ``arguments["text"]`` and return the PCM tensor.
Optionally also writes ``<output_dir>/<timestamp>.wav`` when
``self.output_dir`` is set. The returned tensor is a 1-D
``torch.Tensor`` of float32 PCM samples at
``self.sample_rate`` Hz — directly playable by
``sounddevice.play(audio.numpy(), self.sample_rate)`` or
encodable by ``scipy.io.wavfile.write``.
"""
text = arguments.get("text")
if not isinstance(text, str) or not text.strip():
raise ValueError(
f"SayTool.call expects arguments={{'text': str}}, got {arguments!r}"
)
self.preload()
audio = self._model.generate_audio(self._voice_state, text)
if self.output_dir is not None:
self._write_wav(audio, text)
return audio
@property
def sample_rate(self) -> int:
"""PCM sample rate of the returned tensor (Hz)."""
return self._sample_rate
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _write_wav(self, audio: Any, text: str) -> Path:
"""Write a ``.wav`` next to ``output_dir`` for offline inspection."""
import time as _time # noqa: PLC0415
try:
import scipy.io.wavfile # noqa: PLC0415
except ImportError as exc: # pragma: no cover
raise ImportError(
"SayTool.output_dir requires scipy. `pip install scipy`."
) from exc
out_dir = Path(self.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
# One file per call; suffix with a millisecond timestamp + a
# short text snippet so a directory listing is informative.
snippet = "".join(c if c.isalnum() else "_" for c in text[:32]).strip("_")
ts_ms = int(_time.time() * 1000)
path = out_dir / f"say_{ts_ms}_{snippet}.wav"
# ``audio`` is a torch tensor; pocket-tts uses CPU, so a plain
# ``.numpy()`` is safe.
scipy.io.wavfile.write(path, self.sample_rate, audio.numpy())
return path

View File

@@ -22,7 +22,7 @@ from torch.utils.data._utils.collate import default_collate
from lerobot.datasets.language import LANGUAGE_COLUMNS
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices", *LANGUAGE_COLUMNS}
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices"}
def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None:

View File

@@ -34,7 +34,6 @@ ACTION = "action"
ACTION_PREFIX = ACTION + "."
ACTION_TOKENS = ACTION + ".tokens"
ACTION_TOKEN_MASK = ACTION + ".token_mask"
ACTION_CODE_TOKEN_MASK = ACTION + ".code_token_mask"
REWARD = "next.reward"
TRUNCATED = "next.truncated"
DONE = "next.done"

View File

@@ -40,6 +40,7 @@ from lerobot.datasets.language_render import render_sample
from ._helpers import make_canned_responder
def _build_pr1_style_blend_recipe() -> TrainingRecipe:
"""Inline blend recipe that consumes every style this pipeline produces.
@@ -144,29 +145,22 @@ def test_pr1_canonical_recipe_renders_nonempty_from_pipeline_output(
recipe = _build_pr1_style_blend_recipe()
rendered_any = False
for sample_idx, (ts, persistent, events) in enumerate(
zip(timestamps, persistent_lists, events_lists, strict=True)
):
for ts, persistent, events in zip(timestamps, persistent_lists, events_lists, strict=True):
result = render_sample(
recipe=recipe,
persistent=persistent,
events=events,
t=float(ts),
sample_idx=sample_idx,
sample_idx=0,
dataset_ctx={"task": "Pour water from the bottle into the cup."},
)
if result is None:
continue
if result["messages"]:
rendered_any = True
# A valid render supervises something: a text-CE target turn
# OR a flow-only ``low_level``-stream turn (action loss).
assert (
result["target_message_indices"]
or "low_level" in result["message_streams"]
)
assert result["target_message_indices"]
break
assert rendered_any, "recipe rendered no messages from pipeline output"
assert rendered_any, "PR 1 recipe rendered no messages from pipeline output"
# Sanity: speech atom appears in events column intact
flat_events = [r for ev in events_lists for r in ev]

View File

@@ -29,15 +29,6 @@ def test_message_recipe_validates_unknown_binding():
)
def test_canonical_recipe_loads():
"""The canonical PI052 blend YAML loads + validates."""
recipe = TrainingRecipe.from_yaml(
Path("src/lerobot/configs/recipes/subtask_mem_vqa_speech.yaml")
)
assert recipe.blend is not None
assert sum(c.weight for c in recipe.blend.values()) == pytest.approx(1.0)
def test_message_turn_requires_a_stream():
"""Every turn must declare a stream — None is rejected at construction.

View File

@@ -343,84 +343,6 @@ def test_resolve_task_explicit_override_beats_rephrasings():
assert rendered["messages"][0]["content"] == "explicit override wins"
def test_flow_only_low_level_recipe_renders_without_target():
"""Regression: a flow-only ``low_level`` recipe has no ``target`` turn —
its supervision is the action-expert flow loss, not text-CE. It must
still render (not ``None``), otherwise every blend draw of it is dropped
and the action expert never receives a flow loss."""
recipe = TrainingRecipe(
messages=[
MessageTurn(
role="user",
content="${subtask}",
stream="low_level",
if_present="subtask",
),
],
bindings={"subtask": "active_at(t, style=subtask)"},
)
rendered = render_sample(
recipe=recipe,
persistent=PERSISTENT,
events=[],
t=0.5,
sample_idx=0,
task="clean kitchen",
)
assert rendered is not None
assert rendered["messages"] == [{"role": "user", "content": "subtask 0"}]
assert rendered["message_streams"] == ["low_level"]
assert rendered["target_message_indices"] == []
def test_vqa_frame_is_consumed_over_the_weighted_blend():
"""A frame carrying a VQA annotation renders the ``ask_vqa*`` sub-recipe
even when its blend weight is tiny — VQA annotations are sparse and must
never be wasted on a subtask/action draw."""
recipe = TrainingRecipe(
blend={
"high_level_subtask": TrainingRecipe(
weight=0.99,
messages=[
MessageTurn(role="user", content="${task}", stream="high_level"),
MessageTurn(role="assistant", content="a subtask", stream="high_level", target=True),
],
),
"ask_vqa_top": TrainingRecipe(
weight=0.01,
bindings={
"vqa_query": "emitted_at(t, style=vqa, role=user, camera=observation.images.top)",
"vqa": "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)",
},
messages=[
MessageTurn(
role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query"
),
MessageTurn(
role="assistant",
content="${vqa}",
stream="high_level",
target=True,
if_present="vqa",
),
],
),
}
)
# A frame WITH a vqa event renders VQA on every sample_idx, despite the
# ask_vqa weight being only 0.01.
for sample_idx in range(20):
rendered = render_sample(
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_1, t=1.0, sample_idx=sample_idx, task="x"
)
assert rendered["messages"][-1]["content"] == '{"count": 2}', sample_idx
# A frame WITHOUT a vqa event falls back to the normal weighted blend.
rendered = render_sample(recipe=recipe, persistent=PERSISTENT, events=[], t=1.0, sample_idx=0, task="x")
assert rendered["messages"][-1]["content"] == "a subtask"
def test_emitted_at_persistent_tolerates_small_timestamp_drift():
"""Persistent ``emitted_at`` should match within EMITTED_AT_TOLERANCE_S
so callers that derive ``t`` arithmetically (``frame_idx / fps``) still

View File

@@ -25,7 +25,7 @@ from datasets import Dataset # noqa: E402
from lerobot.datasets.io_utils import (
hf_transform_to_torch,
)
from lerobot.datasets.sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
from lerobot.datasets.sampler import EpisodeAwareSampler
def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]:
@@ -137,49 +137,3 @@ def test_partial_episode_drop_warns(caplog):
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
assert sampler.indices == [2, 3, 4, 5]
assert "Episode 0" in caplog.text
# --- WeightedEpisodeAwareSampler --------------------------------------------
def test_weighted_sampler_respects_episode_drop_and_length():
"""The episode-boundary frame filtering is applied before weighting,
and one epoch still yields ``len(indices)`` samples."""
# One episode, 10 frames; drop the last 2.
sampler = WeightedEpisodeAwareSampler([0], [10], frame_weights=torch.ones(10), drop_n_last_frames=2)
assert sampler.indices == list(range(8))
assert len(sampler) == 8
draws = list(sampler)
assert len(draws) == 8
# Dropped frames 8 and 9 must never be sampled.
assert all(d in set(range(8)) for d in draws)
def test_weighted_sampler_oversamples_high_weight_frames():
"""A heavily-weighted frame dominates the draws."""
torch.manual_seed(0)
# 100 frames, frame 7 is weighted 1000x.
weights = torch.ones(100)
weights[7] = 1000.0
sampler = WeightedEpisodeAwareSampler([0], [100], frame_weights=weights)
counts = {}
for _ in range(20): # 20 epochs
for d in sampler:
counts[d] = counts.get(d, 0) + 1
total = sum(counts.values())
# Frame 7 should be the overwhelming majority of the 2000 draws.
assert counts.get(7, 0) / total > 0.9
def test_weighted_sampler_zero_weights_fall_back_to_uniform():
"""If every surviving frame has zero weight, sampling is uniform
rather than crashing."""
sampler = WeightedEpisodeAwareSampler([0], [6], frame_weights=torch.zeros(6))
draws = set(sampler)
assert draws.issubset(set(range(6)))
assert len(list(sampler)) == 6
def test_weighted_sampler_rejects_short_weight_vector():
with pytest.raises(ValueError, match="frame_weights"):
WeightedEpisodeAwareSampler([0], [10], frame_weights=torch.ones(5))

View File

@@ -1,149 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Attention-masking tests for the PI052 (π0.5 v2) text head.
Regression coverage for the text-CE collapse bug: PaliGemma's
``embed_prefix`` flags every language token ``att=0``, which
``make_att_2d_masks`` turns into one fully *bidirectional* block. Under
that mask the text cross-entropy degenerates into a copy task — a
supervised target token attends to the tokens it is trained to predict —
and the LM head never learns causal generation, so ``select_message``
collapses at inference.
``_mark_target_span_causal`` sets ``att=1`` on the supervised target
language positions so each target token attends causally among the
targets while staying bidirectional to images + the user prompt. These
tests pin that behaviour for the PaliGemma prefix layout.
"""
import pytest
import torch
# modeling_pi052 / modeling_pi05 import transformers transitively.
pytest.importorskip("transformers")
from lerobot.policies.pi05.modeling_pi05 import make_att_2d_masks # noqa: E402
from lerobot.policies.pi052.modeling_pi052 import _mark_target_span_causal, _shifted_ce # noqa: E402
# ---------------------------------------------------------------------------
# A synthetic PI052 prefix layout: [images, prompt-lang, target-lang]
#
# indices 0-1 : 2 image tokens (att = 0)
# indices 2-4 : 3 user-prompt lang (att = 0)
# indices 5-8 : 4 supervised target lang(att = 0 from embed_prefix)
#
# ``text_labels`` covers the 7 language tokens; -100 on the prompt span,
# real ids on the 4-token target span. PaliGemma's prefix has no state
# token (unlike SmolVLA), so the lang span ends at the prefix end.
# ---------------------------------------------------------------------------
N_IMAGE = 2
N_PROMPT = 3
N_TARGET = 4
LANG_START = N_IMAGE
LANG_END = N_IMAGE + N_PROMPT + N_TARGET # = prefix length
PREFIX_LEN = LANG_END
def _embed_prefix_att_masks() -> torch.Tensor:
"""Mimic PaliGemma ``embed_prefix``: images + lang all att=0."""
return torch.zeros(1, PREFIX_LEN, dtype=torch.bool)
def _text_labels() -> torch.Tensor:
"""-100 over the prompt span, real ids over the target span."""
labels = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long)
labels[0, N_PROMPT:] = torch.arange(10, 10 + N_TARGET)
return labels
def _attends(prefix_att_masks: torch.Tensor) -> torch.Tensor:
"""2D boolean attendance matrix; ``[i, j]`` True ⇒ i attends to j."""
pad = torch.ones(1, PREFIX_LEN, dtype=torch.bool)
return make_att_2d_masks(pad, prefix_att_masks)[0]
def test_mark_sets_att_on_targets_only():
"""Only the supervised target language positions flip to att=1."""
marked = _mark_target_span_causal(
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
)
expected = [False] * PREFIX_LEN
for i in range(LANG_START + N_PROMPT, LANG_END): # target span
expected[i] = True
assert marked[0].tolist() == expected
def test_target_tokens_attend_causally_among_themselves():
"""A target token must NOT attend to later targets, but must attend
to earlier ones — genuine causal next-token prediction."""
marked = _mark_target_span_causal(
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
)
attends = _attends(marked)
tgt = range(LANG_START + N_PROMPT, LANG_END)
for i in tgt:
for j in tgt:
if j > i:
assert not attends[i, j], f"target {i} must not see future target {j}"
else:
assert attends[i, j], f"target {i} must see earlier/self target {j}"
def test_target_tokens_attend_prompt_and_images_bidirectionally():
"""Targets keep full visibility of images + the user prompt."""
marked = _mark_target_span_causal(
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
)
attends = _attends(marked)
context = list(range(0, LANG_START + N_PROMPT)) # images + prompt
for i in range(LANG_START + N_PROMPT, LANG_END):
for j in context:
assert attends[i, j], f"target {i} must attend context {j}"
def test_non_target_subtask_stays_bidirectional():
"""A flow-only / non-target language span (all -100 labels) leaves the
mask untouched — the action expert reads it bidirectionally."""
all_ignored = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long)
marked = _mark_target_span_causal(
_embed_prefix_att_masks(), all_ignored, LANG_START, LANG_END
)
assert torch.equal(marked, _embed_prefix_att_masks())
def test_unmarked_mask_is_bidirectional_the_bug():
"""Documents the bug the fix prevents: without ``_mark_target_span_causal``
a target token attends *bidirectionally* to later targets — the
text-CE can copy the answer it is trained to predict."""
attends = _attends(_embed_prefix_att_masks())
first_tgt = LANG_START + N_PROMPT
last_tgt = LANG_END - 1
assert attends[first_tgt, last_tgt], (
"raw embed_prefix mask is bidirectional over language — the first "
"target token can see the last, which is the collapse bug"
)
def test_shifted_ce_returns_zero_when_no_text_positions_are_supervised():
logits = torch.randn(2, 4, 8, requires_grad=True)
labels = torch.full((2, 4), -100, dtype=torch.long)
loss = _shifted_ce(logits, labels)
assert loss.item() == 0
loss.backward()
assert logits.grad is not None

View File

@@ -1,87 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Regression tests for PI052 FAST action-code supervision."""
import pytest
import torch
from torch.nn import functional as F
pytest.importorskip("transformers")
from lerobot.policies.pi052.modeling_pi052 import _fast_ce # noqa: E402
def test_fast_ce_supervises_only_discrete_action_codes():
"""Wrapper tokens can be wrong without affecting the FAST action-code loss."""
vocab_size = 8
action_tokens = torch.tensor([[1, 2, 3, 4, 5, 0]])
action_code_mask = torch.tensor([[False, False, True, True, False, False]])
logits = torch.zeros(1, action_tokens.shape[1], vocab_size)
# Deliberately bad wrapper-token predictions. These should be ignored.
logits[0, 0, 7] = 10.0 # target would be token 2
logits[0, 3, 7] = 10.0 # target would be delimiter token 5
# Correct action-code predictions: hidden t predicts target t + 1.
logits[0, 1, 3] = 10.0
logits[0, 2, 4] = 10.0
loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions_t=None)
expected = F.cross_entropy(
torch.stack([logits[0, 1], logits[0, 2]]),
torch.tensor([3, 4]),
reduction="mean",
)
assert torch.allclose(loss, expected)
def test_fast_ce_masks_non_action_samples():
"""Recipe samples with predict_actions=False do not contribute FAST loss."""
vocab_size = 8
action_tokens = torch.tensor([[1, 2, 3, 4], [1, 2, 5, 6]])
action_code_mask = torch.tensor(
[[False, False, True, True], [False, False, True, True]]
)
predict_actions = torch.tensor([True, False])
logits = torch.zeros(2, action_tokens.shape[1], vocab_size)
logits[0, 1, 3] = 10.0
logits[0, 2, 4] = 10.0
# Bad predictions in the masked sample should not matter.
logits[1, 1, 7] = 10.0
logits[1, 2, 7] = 10.0
loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions)
expected = F.cross_entropy(
torch.stack([logits[0, 1], logits[0, 2]]),
torch.tensor([3, 4]),
reduction="mean",
)
assert torch.allclose(loss, expected)
def test_fast_ce_returns_zero_when_no_action_code_positions_are_valid():
logits = torch.randn(2, 4, 8, requires_grad=True)
action_tokens = torch.tensor([[1, 2, 3, 4], [1, 2, 5, 6]])
action_code_mask = torch.zeros_like(action_tokens, dtype=torch.bool)
loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions_t=None)
assert loss.item() == 0
loss.backward()
assert logits.grad is not None

View File

@@ -1,155 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Numerical-parity tests for the SDPA attention port.
``pi05`` / ``pi052`` replaced the per-layer call from
``modeling_gemma.eager_attention_forward`` with
``sdpa_attention_forward`` (PyTorch SDPA + GQA repeat). The forward
output must be bit-equivalent (within bf16 tolerance) on the masks
this model actually uses — block-bidirectional with an arbitrary
additive bias — otherwise we silently change training behaviour.
"""
from types import SimpleNamespace
import pytest
import torch
pytest.importorskip("transformers")
from transformers.models.gemma import modeling_gemma # noqa: E402
from lerobot.policies.pi05.modeling_pi05 import ( # noqa: E402
make_att_2d_masks,
sdpa_attention_forward,
)
from lerobot.utils.constants import OPENPI_ATTENTION_MASK_VALUE # noqa: E402
def _mock_self_attn(num_kv_groups: int, training: bool = False):
"""Bare module surface that both forwards read."""
return SimpleNamespace(
num_key_value_groups=num_kv_groups,
training=training,
)
def _build_inputs(
bsize: int,
num_heads: int,
num_kv_heads: int,
seq_len: int,
head_dim: int,
dtype: torch.dtype,
seed: int = 0,
):
g = torch.Generator(device="cpu").manual_seed(seed)
q = torch.randn(bsize, num_heads, seq_len, head_dim, dtype=dtype, generator=g)
k = torch.randn(bsize, num_kv_heads, seq_len, head_dim, dtype=dtype, generator=g)
v = torch.randn(bsize, num_kv_heads, seq_len, head_dim, dtype=dtype, generator=g)
return q, k, v
def _block_bidirectional_mask(
bsize: int, seq_len: int, block_sizes: list[int], dtype: torch.dtype
) -> torch.Tensor:
"""Mimic ``_prepare_attention_masks_4d`` on a block layout that
matches ``[images, language, suffix]`` from ``embed_prefix`` +
``embed_suffix``: every block bidirectional internally, later
blocks visible to earlier ones via the cumulative-block rule.
"""
assert sum(block_sizes) == seq_len
att_marks = []
for i, n in enumerate(block_sizes):
att_marks += [1 if i > 0 else 0] + [0] * (n - 1)
pad = torch.ones(bsize, seq_len, dtype=torch.bool)
att = torch.tensor(att_marks, dtype=torch.bool)[None].expand(bsize, seq_len)
att_2d = make_att_2d_masks(pad, att)
bias = torch.where(
att_2d[:, None, :, :],
torch.zeros((), dtype=dtype),
torch.tensor(OPENPI_ATTENTION_MASK_VALUE, dtype=dtype),
)
return bias
@pytest.mark.parametrize(
"num_heads,num_kv_heads,head_dim",
[
(8, 1, 256), # gemma_2b / paligemma config
(8, 8, 64), # MHA control (no GQA repeat)
],
)
def test_sdpa_parity_with_eager_block_bidirectional(num_heads, num_kv_heads, head_dim):
"""SDPA forward output matches the eager softmax(QK^T)@V on the
block-bidirectional mask layout pi05 actually uses."""
bsize, seq_len = 2, 13
block_sizes = [4, 5, 4] # images, language, suffix-style blocks
dtype = torch.float32 # cpu math kernel — keep fp32 for tight tol
scaling = head_dim ** -0.5
q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, dtype)
mask = _block_bidirectional_mask(bsize, seq_len, block_sizes, dtype)
module = _mock_self_attn(num_heads // num_kv_heads)
out_eager, _ = modeling_gemma.eager_attention_forward(
module, q, k, v, mask, scaling
)
out_sdpa, _ = sdpa_attention_forward(
module, q, k, v, mask, scaling
)
assert out_eager.shape == out_sdpa.shape
torch.testing.assert_close(out_sdpa, out_eager, atol=1e-5, rtol=1e-4)
def test_sdpa_parity_bf16():
"""bf16 path — looser tolerance, must still match eager."""
bsize, num_heads, num_kv_heads, seq_len, head_dim = 2, 8, 1, 17, 256
scaling = head_dim ** -0.5
q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, torch.bfloat16)
mask = _block_bidirectional_mask(bsize, seq_len, [5, 6, 6], torch.bfloat16)
module = _mock_self_attn(num_heads // num_kv_heads)
out_eager, _ = modeling_gemma.eager_attention_forward(
module, q, k, v, mask, scaling
)
out_sdpa, _ = sdpa_attention_forward(
module, q, k, v, mask, scaling
)
torch.testing.assert_close(out_sdpa, out_eager, atol=2e-2, rtol=2e-2)
def test_sdpa_parity_backward():
"""Gradients flow through SDPA and match the eager path within
bf16 tolerance — critical for any training-side parity claim."""
bsize, num_heads, num_kv_heads, seq_len, head_dim = 1, 4, 2, 9, 32
scaling = head_dim ** -0.5
q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, torch.float32)
q.requires_grad_(True); k.requires_grad_(True); v.requires_grad_(True)
mask = _block_bidirectional_mask(bsize, seq_len, [3, 3, 3], torch.float32)
module = _mock_self_attn(num_heads // num_kv_heads)
out_e, _ = modeling_gemma.eager_attention_forward(module, q, k, v, mask, scaling)
g_q_e, g_k_e, g_v_e = torch.autograd.grad(out_e.sum(), [q, k, v])
out_s, _ = sdpa_attention_forward(module, q, k, v, mask, scaling)
g_q_s, g_k_s, g_v_s = torch.autograd.grad(out_s.sum(), [q, k, v])
torch.testing.assert_close(g_q_s, g_q_e, atol=1e-5, rtol=1e-4)
torch.testing.assert_close(g_k_s, g_k_e, atol=1e-5, rtol=1e-4)
torch.testing.assert_close(g_v_s, g_v_e, atol=1e-5, rtol=1e-4)

View File

@@ -1,196 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for PI052's text tokenizer.
Covers ``say`` tool-call flattening (PaliGemma's flat prompt has no
structured tool calls, so a ``say`` call must be serialized into a
``<say>...</say>`` text marker) and EOS-termination supervision (the
supervised target span must end with an EOS token so the LM head learns
to stop instead of rambling to ``max_length`` at inference).
"""
import torch
from lerobot.policies.pi052.text_processor_pi052 import (
PI052TextTokenizerStep,
_flatten_say_tool_calls,
_format_messages,
)
from lerobot.types import TransitionKey
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
def _say_call(text):
return {"type": "function", "function": {"name": "say", "arguments": {"text": text}}}
def test_flatten_appends_say_marker_and_drops_tool_calls():
msg = {"role": "assistant", "content": "Heading to the cube.", "tool_calls": [_say_call("On it!")]}
out = _flatten_say_tool_calls(msg)
assert "tool_calls" not in out
assert out["content"] == "Heading to the cube.\n<say>On it!</say>"
def test_flatten_marker_only_when_content_empty_or_none():
out = _flatten_say_tool_calls({"role": "assistant", "tool_calls": [_say_call("hi")]})
assert out["content"] == "<say>hi</say>"
def test_flatten_accepts_json_string_arguments():
call = {"type": "function", "function": {"name": "say", "arguments": '{"text": "hello there"}'}}
out = _flatten_say_tool_calls({"role": "assistant", "content": "p", "tool_calls": [call]})
assert out["content"] == "p\n<say>hello there</say>"
def test_flatten_leaves_messages_without_tool_calls_untouched():
msg = {"role": "assistant", "content": "just a plan"}
assert _flatten_say_tool_calls(msg) == msg
def test_flatten_drops_non_say_tool_calls_but_keeps_content():
weather = {"type": "function", "function": {"name": "check_weather", "arguments": {}}}
out = _flatten_say_tool_calls(
{"role": "assistant", "content": "plan only", "tool_calls": [weather]}
)
assert out["content"] == "plan only"
assert "tool_calls" not in out
# ---------------------------------------------------------------------------
# EOS-termination supervision
# ---------------------------------------------------------------------------
def test_format_messages_appends_eos_to_target_turns_only():
msgs = [
{"role": "user", "content": "pick cube"},
{"role": "assistant", "content": "move to cube"},
]
prompt, spans = _format_messages(msgs, target_indices=[1], eos_token="<eos>")
# EOS is appended to the supervised target (assistant) turn only.
assert prompt == "User: pick cube\nAssistant: move to cube<eos>\n"
# The user span is unchanged; the target span covers content + EOS.
assert prompt[spans[0][0] : spans[0][1]] == "pick cube"
assert prompt[spans[1][0] : spans[1][1]] == "move to cube<eos>"
def test_format_messages_without_eos_args_is_unchanged():
"""Inference callers omit target_indices / eos_token — no EOS baked in."""
prompt, spans = _format_messages([{"role": "user", "content": "hi"}])
assert prompt == "User: hi\n"
assert prompt[spans[0][0] : spans[0][1]] == "hi"
def _eos_char_id() -> int:
"""Token id _CharTokenizer assigns to its 1-char EOS."""
return ord("\x1f") % 251 + 1
def test_pi052_text_tokenizer_supervises_eos_at_target_end():
"""The appended EOS is the last supervised label on a target turn —
that's the signal that teaches the LM head to stop. The trailing
newline right after it stays unsupervised (-100)."""
step = PI052TextTokenizerStep(max_length=64)
step._tokenizer = _CharTokenizer()
transition = {
TransitionKey.OBSERVATION: {},
TransitionKey.COMPLEMENTARY_DATA: {
"messages": [
{"role": "user", "content": "pick cube"},
{"role": "assistant", "content": "move to cube"},
],
"target_message_indices": [1],
"message_streams": ["high_level", "high_level"],
"index": torch.tensor(10),
},
}
out = step(transition)
ids = out[TransitionKey.OBSERVATION][OBS_LANGUAGE_TOKENS][0]
labels = out[TransitionKey.COMPLEMENTARY_DATA]["text_labels"][0]
supervised = (labels != -100).nonzero().flatten().tolist()
assert supervised, "target turn produced no supervised labels"
last = supervised[-1]
# The last supervised token is the appended EOS.
assert int(ids[last]) == _eos_char_id()
assert int(labels[last]) == _eos_char_id()
# The token right after the EOS (the trailing newline) is NOT supervised.
assert int(labels[last + 1]) == -100
class _CharTokenizer:
pad_token_id = 0
eos_token = "\x1f" # unit separator — a 1-char "EOS" for testing
def __call__(
self,
text,
max_length,
padding,
truncation,
return_tensors,
return_offsets_mapping,
padding_side,
):
ids = [ord(c) % 251 + 1 for c in text[:max_length]]
offsets = [(i, i + 1) for i in range(len(ids))]
attention = [1] * len(ids)
if padding == "max_length" and len(ids) < max_length:
pad = max_length - len(ids)
ids += [self.pad_token_id] * pad
offsets += [(0, 0)] * pad
attention += [0] * pad
return {
"input_ids": torch.tensor([ids], dtype=torch.long),
"attention_mask": torch.tensor([attention], dtype=torch.long),
"offset_mapping": torch.tensor([offsets], dtype=torch.long),
}
def decode(self, token_ids, skip_special_tokens=False):
return "".join(chr(max(int(i) - 1, 0)) for i in token_ids if int(i) != self.pad_token_id)
def test_pi052_text_tokenizer_handles_batched_rendered_messages():
step = PI052TextTokenizerStep(max_length=64)
step._tokenizer = _CharTokenizer()
transition = {
TransitionKey.OBSERVATION: {},
TransitionKey.COMPLEMENTARY_DATA: {
"messages": [
[
{"role": "user", "content": "pick cube"},
{"role": "assistant", "content": "move to cube"},
],
[{"role": "user", "content": "open drawer"}],
],
"target_message_indices": [[1], []],
"message_streams": [["high_level", "high_level"], ["low_level"]],
"index": torch.tensor([10, 11]),
},
}
out = step(transition)
obs = out[TransitionKey.OBSERVATION]
comp = out[TransitionKey.COMPLEMENTARY_DATA]
assert obs[OBS_LANGUAGE_TOKENS].shape == (2, 64)
assert obs[OBS_LANGUAGE_ATTENTION_MASK].shape == (2, 64)
assert comp["text_labels"].shape == (2, 64)
assert comp["predict_actions"].tolist() == [False, True]
assert (comp["text_labels"][0] != -100).any()
assert not (comp["text_labels"][1] != -100).any()

View File

@@ -1,187 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training-side conversion of VQA answers to PaliGemma ``<loc>`` text.
PI052 trains spatial VQA answers (``bbox`` / ``keypoint``) in
PaliGemma's native ``<locNNNN>`` detection vocabulary so the LM head
reuses the detection prior instead of fighting it (the ``<loc>``-salad
bug). The dataset stores Qwen2.5-VL's grounding output — **01000
normalized** coordinates, *not* pixels. (Verified empirically on the
published datasets: x and y both span 0..1000 with ~30% of values
exceeding the camera's pixel dimensions.) The conversion is therefore
camera-resolution-independent. The dataset stays backbone-agnostic
JSON; the conversion lives in PI052's tokenizer. These tests pin the
JSON → ``<loc>`` rewrite.
"""
import pytest
pytest.importorskip("transformers")
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: E402
_loc_token,
_messages_vqa_to_loc,
_vqa_answer_to_loc,
register_paligemma_loc_tokens,
)
class _FakeTokenizer:
"""Tracks ``add_tokens`` calls; mimics the bits ``register_paligemma_loc_tokens`` reads."""
def __init__(self, prepopulate: bool = False):
self.added_tokens_encoder: dict[str, int] = {}
self.calls: list[list[str]] = []
if prepopulate:
self.added_tokens_encoder["<loc0000>"] = 256000
def add_tokens(self, tokens: list[str]) -> int:
self.calls.append(list(tokens))
for t in tokens:
self.added_tokens_encoder.setdefault(t, len(self.added_tokens_encoder) + 256000)
return len(tokens)
def test_register_loc_tokens_adds_full_1024_range():
tok = _FakeTokenizer()
out = register_paligemma_loc_tokens(tok)
assert out is tok # returns same instance
assert len(tok.calls) == 1
added = tok.calls[0]
assert len(added) == 1024
assert added[0] == "<loc0000>"
assert added[-1] == "<loc1023>"
# Spot check a few in the middle.
assert added[162] == "<loc0162>"
assert added[759] == "<loc0759>"
def test_register_loc_tokens_is_idempotent():
"""If the loc tokens are already present we skip re-adding them."""
tok = _FakeTokenizer(prepopulate=True)
register_paligemma_loc_tokens(tok)
register_paligemma_loc_tokens(tok)
assert tok.calls == [] # never called add_tokens
def test_loc_token_normalizes_and_clamps():
# Default scale is the 01000 Qwen convention.
assert _loc_token(0) == "<loc0000>"
assert _loc_token(1000) == "<loc1023>"
assert _loc_token(500) == f"<loc{round(500 / 1000 * 1023):04d}>"
# out-of-range coordinates clamp into [0, 1023]
assert _loc_token(9999) == "<loc1023>"
assert _loc_token(-5) == "<loc0000>"
def test_vqa_answer_to_loc_keypoint_normalized():
# Label-first: avoids the "Assistant: → <loc>" attractor at training.
answer = {"label": "blue cube", "point_format": "xy", "point": [500, 500]}
assert _vqa_answer_to_loc(answer) == "blue cube <loc0512><loc0512>"
def test_vqa_answer_to_loc_bbox_normalized():
answer = {
"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [0, 0, 1000, 1000]}]
}
assert _vqa_answer_to_loc(answer) == "cube <loc0000><loc0000><loc1023><loc1023>"
def test_vqa_answer_to_loc_multiple_detections_separator():
answer = {
"detections": [
{"label": "blue", "bbox_format": "xyxy", "bbox": [0, 0, 500, 500]},
{"label": "yellow", "bbox_format": "xyxy", "bbox": [500, 500, 1000, 1000]},
]
}
out = _vqa_answer_to_loc(answer)
# Each segment is "label <locs>", joined by " ; "
assert out == (
"blue <loc0000><loc0000><loc0512><loc0512> ; "
"yellow <loc0512><loc0512><loc1023><loc1023>"
)
def test_vqa_answer_to_loc_returns_none_for_non_spatial():
assert _vqa_answer_to_loc({"label": "cubes", "count": 2}) is None
assert _vqa_answer_to_loc({"weird": "payload"}) is None
def test_messages_vqa_to_loc_rewrites_target_turn():
messages = [
{"role": "user", "content": [{"type": "text", "text": "where is the cube?"}]},
{
"role": "assistant",
"content": '{"label": "cube", "point_format": "xy", "point": [500, 500]}',
},
]
out = _messages_vqa_to_loc(messages, target_indices=[1])
assert out[1]["content"] == "cube <loc0512><loc0512>"
# input messages are not mutated
assert messages[1]["content"].startswith("{")
def test_messages_vqa_to_loc_leaves_plain_text_targets_untouched():
messages = [
{"role": "user", "content": "pick the cube"},
{"role": "assistant", "content": "pick up the cube"},
]
out = _messages_vqa_to_loc(messages, target_indices=[1])
assert out[1]["content"] == "pick up the cube"
def test_messages_vqa_to_loc_noop_without_target_indices():
messages = [
{"role": "assistant", "content": '{"label": "c", "point_format": "xy", "point": [1, 2]}'}
]
assert _messages_vqa_to_loc(messages, []) is messages
# ---------------------------------------------------------------------------
# Round-trip: training-side JSON -> <loc> -> runtime-side parse back
#
# Pins that the conversion preserves coordinate *order* (JSON is x-first,
# PaliGemma <loc> is y-first) and the 01000 → [0, 1023] scaling. The
# only loss is quantization to the 1024-bucket <loc> grid, so a coord
# survives within half a bucket (~1000/2046 ≈ 0.49 on the 01000 scale).
# ---------------------------------------------------------------------------
def test_loc_round_trip_keypoint_preserves_normalized_coords():
from lerobot.policies.pi052.inference.vqa import parse_vqa_answer
answer = {"label": "blue cube", "point_format": "xy", "point": [640, 480]}
loc = _vqa_answer_to_loc(answer)
parsed = parse_vqa_answer(loc)
nx, ny = parsed["payload"]["point"]
# parse_vqa_answer returns [0, 1] normalized; rescale back to 01000.
assert abs(nx * 1000.0 - 640) <= 1000.0 / 2046 + 1e-6
assert abs(ny * 1000.0 - 480) <= 1000.0 / 2046 + 1e-6
assert parsed["payload"]["label"] == "blue cube"
def test_loc_round_trip_bbox_preserves_order_and_scale():
from lerobot.policies.pi052.inference.vqa import parse_vqa_answer
answer = {
"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [100, 200, 800, 900]}]
}
loc = _vqa_answer_to_loc(answer)
parsed = parse_vqa_answer(loc)
x1, y1, x2, y2 = parsed["payload"]["detections"][0]["bbox"]
for got, want in ((x1, 100), (y1, 200), (x2, 800), (y2, 900)):
assert abs(got * 1000.0 - want) <= 1000.0 / 2046 + 1e-6

View File

@@ -58,70 +58,3 @@ def test_render_messages_step_renders_and_drops_raw_language():
assert data["messages"][-1]["content"] == "reach carefully"
assert data["message_streams"] == ["high_level", "low_level"]
assert data["target_message_indices"] == [1]
def test_render_messages_step_falls_back_to_low_level_task_when_recipe_misses():
recipe = TrainingRecipe(
messages=[
MessageTurn(
role="assistant",
content="${subtask}",
stream="high_level",
target=True,
if_present="subtask",
),
]
)
transition = create_transition(
complementary_data={
"task": "pick the cube",
"timestamp": torch.tensor(0.0),
"index": torch.tensor(7),
"language_persistent": [],
"language_events": [{"style": "unmatched", "timestamp": 0.0}],
}
)
out = RenderMessagesStep(recipe)(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert data["messages"] == [{"role": "user", "content": "pick the cube"}]
assert data["message_streams"] == ["low_level"]
assert data["target_message_indices"] == []
def test_render_messages_step_falls_back_per_sample_in_batched_language():
recipe = TrainingRecipe(
messages=[
MessageTurn(
role="assistant",
content="${subtask}",
stream="high_level",
target=True,
if_present="subtask",
),
]
)
transition = create_transition(
action=torch.arange(4).reshape(2, 2),
complementary_data={
"task": ["pick the cube", "open the drawer"],
"timestamp": torch.tensor([0.0, 1.0]),
"index": torch.tensor([7, 8]),
"language_persistent": [[], []],
"language_events": [
[{"style": "unmatched", "timestamp": 0.0}],
[{"style": "unmatched", "timestamp": 1.0}],
],
},
)
out = RenderMessagesStep(recipe)(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert data["messages"] == [
[{"role": "user", "content": "pick the cube"}],
[{"role": "user", "content": "open the drawer"}],
]
assert data["message_streams"] == [["low_level"], ["low_level"]]
assert data["target_message_indices"] == [[], []]

81
uv.lock generated
View File

@@ -416,15 +416,6 @@ dependencies = [
]
sdist = { url = "https://files.pythonhosted.org/packages/5c/37/0211f82891a9f14efcfd2b2096f8d9e4351398ad637fdd1ee59cfc580b0e/bddl-1.0.1.tar.gz", hash = "sha256:1fa4e6e5050b93888ff6fd8455c39bfb29d3864ce06b4c37c0f781f513a2ae26", size = 164809, upload-time = "2022-03-08T01:48:23.564Z" }
[[package]]
name = "beartype"
version = "0.22.9"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c7/94/1009e248bbfbab11397abca7193bea6626806be9a327d399810d523a07cb/beartype-0.22.9.tar.gz", hash = "sha256:8f82b54aa723a2848a56008d18875f91c1db02c32ef6a62319a002e3e25a975f", size = 1608866, upload-time = "2025-12-13T06:50:30.72Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl", hash = "sha256:d16c9bbc61ea14637596c5f6fbff2ee99cbe3573e46a716401734ef50c3060c2", size = 1333658, upload-time = "2025-12-13T06:50:28.266Z" },
]
[[package]]
name = "beautifulsoup4"
version = "4.14.3"
@@ -3203,7 +3194,6 @@ all = [
{ name = "ruff" },
{ name = "scikit-image" },
{ name = "scipy" },
{ name = "sentencepiece" },
{ name = "teleop" },
{ name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" },
{ name = "torchdiffeq" },
@@ -3425,7 +3415,6 @@ phone = [
]
pi = [
{ name = "scipy" },
{ name = "sentencepiece" },
{ name = "transformers" },
]
placo-dep = [
@@ -3477,9 +3466,6 @@ sarm = [
scipy-dep = [
{ name = "scipy" },
]
sentencepiece-dep = [
{ name = "sentencepiece" },
]
smolvla = [
{ name = "accelerate" },
{ name = "num2words" },
@@ -3491,10 +3477,6 @@ test = [
{ name = "pytest-cov" },
{ name = "pytest-timeout" },
]
tools = [
{ name = "pocket-tts" },
{ name = "scipy" },
]
training = [
{ name = "accelerate" },
{ name = "av" },
@@ -3652,7 +3634,6 @@ requires-dist = [
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'phone'" },
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'pi'" },
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'wallx'" },
{ name = "lerobot", extras = ["sentencepiece-dep"], marker = "extra == 'pi'" },
{ name = "lerobot", extras = ["smolvla"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["test"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["training"], marker = "extra == 'all'" },
@@ -3694,7 +3675,6 @@ requires-dist = [
{ name = "peft", marker = "extra == 'peft-dep'", specifier = ">=0.18.0,<1.0.0" },
{ name = "pillow", specifier = ">=10.0.0,<13.0.0" },
{ name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.16" },
{ name = "pocket-tts", marker = "extra == 'tools'", specifier = ">=1.0.0,<3.0.0" },
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.7.0,<5.0.0" },
{ name = "protobuf", marker = "extra == 'grpcio-dep'", specifier = ">=6.31.1,<6.32.0" },
{ name = "pyarrow", marker = "extra == 'dataset'", specifier = ">=21.0.0,<30.0.0" },
@@ -3719,8 +3699,6 @@ requires-dist = [
{ name = "scikit-image", marker = "extra == 'video-benchmark'", specifier = ">=0.23.2,<0.26.0" },
{ name = "scipy", marker = "extra == 'all'", specifier = ">=1.14.0,<2.0.0" },
{ name = "scipy", marker = "extra == 'scipy-dep'", specifier = ">=1.14.0,<2.0.0" },
{ name = "scipy", marker = "extra == 'tools'", specifier = ">=1.11.0,<2.0.0" },
{ name = "sentencepiece", marker = "extra == 'sentencepiece-dep'", specifier = ">=0.2.0,<0.3.0" },
{ name = "setuptools", specifier = ">=71.0.0,<81.0.0" },
{ name = "teleop", marker = "extra == 'phone'", specifier = ">=0.1.0,<0.2.0" },
{ name = "termcolor", specifier = ">=2.4.0,<4.0.0" },
@@ -3738,7 +3716,7 @@ requires-dist = [
{ name = "vllm", marker = "sys_platform == 'linux' and extra == 'annotations'", specifier = ">=0.6.0,<1.0.0" },
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "sentencepiece-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "eo1", "hilserl", "async", "peft", "annotations", "tools", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "eo1", "hilserl", "async", "peft", "annotations", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
[[package]]
name = "librt"
@@ -5256,33 +5234,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
]
[[package]]
name = "pocket-tts"
version = "2.1.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "beartype" },
{ name = "einops" },
{ name = "fastapi" },
{ name = "huggingface-hub" },
{ name = "numpy" },
{ name = "pydantic" },
{ name = "python-multipart" },
{ name = "requests" },
{ name = "safetensors" },
{ name = "scipy" },
{ name = "sentencepiece" },
{ name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux'" },
{ name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux'" },
{ name = "typer" },
{ name = "typing-extensions" },
{ name = "uvicorn" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f9/2c/7445f57163bb40e2b2fab4df70d18a4216c4965cdf74196344d95859fc07/pocket_tts-2.1.0.tar.gz", hash = "sha256:6f244f445413400f686506f5ccfb75048547caab7b455b927f4a854c551c60a8", size = 642108, upload-time = "2026-05-04T14:00:29.207Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cf/63/d16958d388efee3f0fc7287e1418ed652ddbc2b61ff4f581f0ad0abce625/pocket_tts-2.1.0-py3-none-any.whl", hash = "sha256:7b8f01d3e52aa7df84887b711994586bdc875e024a8b40a15f757feeeb29f752", size = 68096, upload-time = "2026-05-04T14:00:27.547Z" },
]
[[package]]
name = "pre-commit"
version = "4.6.0"
@@ -6868,46 +6819,16 @@ version = "0.2.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/15/15/2e7a025fc62d764b151ae6d0f2a92f8081755ebe8d4a64099accc6f77ba6/sentencepiece-0.2.1.tar.gz", hash = "sha256:8138cec27c2f2282f4a34d9a016e3374cd40e5c6e9cb335063db66a0a3b71fad", size = 3228515, upload-time = "2025-08-12T07:00:51.718Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/4a/be/32ce495aa1d0e0c323dcb1ba87096037358edee539cac5baf8755a6bd396/sentencepiece-0.2.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:57cae326c8727de58c85977b175af132a7138d84c764635d7e71bbee7e774133", size = 1943152, upload-time = "2025-08-12T06:59:40.048Z" },
{ url = "https://files.pythonhosted.org/packages/88/7e/ff23008899a58678e98c6ff592bf4d368eee5a71af96d0df6b38a039dd4f/sentencepiece-0.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:56dd39a3c4d6493db3cdca7e8cc68c6b633f0d4195495cbadfcf5af8a22d05a6", size = 1325651, upload-time = "2025-08-12T06:59:41.536Z" },
{ url = "https://files.pythonhosted.org/packages/19/84/42eb3ce4796777a1b5d3699dfd4dca85113e68b637f194a6c8d786f16a04/sentencepiece-0.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d9381351182ff9888cc80e41c632e7e274b106f450de33d67a9e8f6043da6f76", size = 1253645, upload-time = "2025-08-12T06:59:42.903Z" },
{ url = "https://files.pythonhosted.org/packages/89/fa/d3d5ebcba3cb9e6d3775a096251860c41a6bc53a1b9461151df83fe93255/sentencepiece-0.2.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:99f955df238021bf11f0fc37cdb54fd5e5b5f7fd30ecc3d93fb48b6815437167", size = 1316273, upload-time = "2025-08-12T06:59:44.476Z" },
{ url = "https://files.pythonhosted.org/packages/04/88/14f2f4a2b922d8b39be45bf63d79e6cd3a9b2f248b2fcb98a69b12af12f5/sentencepiece-0.2.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0cdfecef430d985f1c2bcbfff3defd1d95dae876fbd0173376012d2d7d24044b", size = 1387881, upload-time = "2025-08-12T06:59:46.09Z" },
{ url = "https://files.pythonhosted.org/packages/fd/b8/903e5ccb77b4ef140605d5d71b4f9e0ad95d456d6184688073ed11712809/sentencepiece-0.2.1-cp312-cp312-win32.whl", hash = "sha256:a483fd29a34c3e34c39ac5556b0a90942bec253d260235729e50976f5dba1068", size = 999540, upload-time = "2025-08-12T06:59:48.023Z" },
{ url = "https://files.pythonhosted.org/packages/2d/81/92df5673c067148c2545b1bfe49adfd775bcc3a169a047f5a0e6575ddaca/sentencepiece-0.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:4cdc7c36234fda305e85c32949c5211faaf8dd886096c7cea289ddc12a2d02de", size = 1054671, upload-time = "2025-08-12T06:59:49.895Z" },
{ url = "https://files.pythonhosted.org/packages/fe/02/c5e3bc518655d714622bec87d83db9cdba1cd0619a4a04e2109751c4f47f/sentencepiece-0.2.1-cp312-cp312-win_arm64.whl", hash = "sha256:daeb5e9e9fcad012324807856113708614d534f596d5008638eb9b40112cd9e4", size = 1033923, upload-time = "2025-08-12T06:59:51.952Z" },
{ url = "https://files.pythonhosted.org/packages/ba/4a/85fbe1706d4d04a7e826b53f327c4b80f849cf1c7b7c5e31a20a97d8f28b/sentencepiece-0.2.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dcd8161eee7b41aae57ded06272905dbd680a0a04b91edd0f64790c796b2f706", size = 1943150, upload-time = "2025-08-12T06:59:53.588Z" },
{ url = "https://files.pythonhosted.org/packages/c2/83/4cfb393e287509fc2155480b9d184706ef8d9fa8cbf5505d02a5792bf220/sentencepiece-0.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c6c8f42949f419ff8c7e9960dbadcfbc982d7b5efc2f6748210d3dd53a7de062", size = 1325651, upload-time = "2025-08-12T06:59:55.073Z" },
{ url = "https://files.pythonhosted.org/packages/8d/de/5a007fb53b1ab0aafc69d11a5a3dd72a289d5a3e78dcf2c3a3d9b14ffe93/sentencepiece-0.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:097f3394e99456e9e4efba1737c3749d7e23563dd1588ce71a3d007f25475fff", size = 1253641, upload-time = "2025-08-12T06:59:56.562Z" },
{ url = "https://files.pythonhosted.org/packages/2c/d2/f552be5928105588f4f4d66ee37dd4c61460d8097e62d0e2e0eec41bc61d/sentencepiece-0.2.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d7b670879c370d350557edabadbad1f6561a9e6968126e6debca4029e5547820", size = 1316271, upload-time = "2025-08-12T06:59:58.109Z" },
{ url = "https://files.pythonhosted.org/packages/96/df/0cfe748ace5485be740fed9476dee7877f109da32ed0d280312c94ec259f/sentencepiece-0.2.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c7f0fd2f2693309e6628aeeb2e2faf6edd221134dfccac3308ca0de01f8dab47", size = 1387882, upload-time = "2025-08-12T07:00:00.701Z" },
{ url = "https://files.pythonhosted.org/packages/ac/dd/f7774d42a881ced8e1739f393ab1e82ece39fc9abd4779e28050c2e975b5/sentencepiece-0.2.1-cp313-cp313-win32.whl", hash = "sha256:92b3816aa2339355fda2c8c4e021a5de92180b00aaccaf5e2808972e77a4b22f", size = 999541, upload-time = "2025-08-12T07:00:02.709Z" },
{ url = "https://files.pythonhosted.org/packages/dd/e9/932b9eae6fd7019548321eee1ab8d5e3b3d1294df9d9a0c9ac517c7b636d/sentencepiece-0.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:10ed3dab2044c47f7a2e7b4969b0c430420cdd45735d78c8f853191fa0e3148b", size = 1054669, upload-time = "2025-08-12T07:00:04.915Z" },
{ url = "https://files.pythonhosted.org/packages/c9/3a/76488a00ea7d6931689cda28726a1447d66bf1a4837943489314593d5596/sentencepiece-0.2.1-cp313-cp313-win_arm64.whl", hash = "sha256:ac650534e2251083c5f75dde4ff28896ce7c8904133dc8fef42780f4d5588fcd", size = 1033922, upload-time = "2025-08-12T07:00:06.496Z" },
{ url = "https://files.pythonhosted.org/packages/4a/b6/08fe2ce819e02ccb0296f4843e3f195764ce9829cbda61b7513f29b95718/sentencepiece-0.2.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:8dd4b477a7b069648d19363aad0cab9bad2f4e83b2d179be668efa672500dc94", size = 1946052, upload-time = "2025-08-12T07:00:08.136Z" },
{ url = "https://files.pythonhosted.org/packages/ab/d9/1ea0e740591ff4c6fc2b6eb1d7510d02f3fb885093f19b2f3abd1363b402/sentencepiece-0.2.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0c0f672da370cc490e4c59d89e12289778310a0e71d176c541e4834759e1ae07", size = 1327408, upload-time = "2025-08-12T07:00:09.572Z" },
{ url = "https://files.pythonhosted.org/packages/99/7e/1fb26e8a21613f6200e1ab88824d5d203714162cf2883248b517deb500b7/sentencepiece-0.2.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:ad8493bea8432dae8d6830365352350f3b4144415a1d09c4c8cb8d30cf3b6c3c", size = 1254857, upload-time = "2025-08-12T07:00:11.021Z" },
{ url = "https://files.pythonhosted.org/packages/bc/85/c72fd1f3c7a6010544d6ae07f8ddb38b5e2a7e33bd4318f87266c0bbafbf/sentencepiece-0.2.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b81a24733726e3678d2db63619acc5a8dccd074f7aa7a54ecd5ca33ca6d2d596", size = 1315722, upload-time = "2025-08-12T07:00:12.989Z" },
{ url = "https://files.pythonhosted.org/packages/4a/e8/661e5bd82a8aa641fd6c1020bd0e890ef73230a2b7215ddf9c8cd8e941c2/sentencepiece-0.2.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0a81799d0a68d618e89063fb423c3001a034c893069135ffe51fee439ae474d6", size = 1387452, upload-time = "2025-08-12T07:00:15.088Z" },
{ url = "https://files.pythonhosted.org/packages/99/5e/ae66c361023a470afcbc1fbb8da722c72ea678a2fcd9a18f1a12598c7501/sentencepiece-0.2.1-cp313-cp313t-win32.whl", hash = "sha256:89a3ea015517c42c0341d0d962f3e6aaf2cf10d71b1932d475c44ba48d00aa2b", size = 1002501, upload-time = "2025-08-12T07:00:16.966Z" },
{ url = "https://files.pythonhosted.org/packages/c1/03/d332828c4ff764e16c1b56c2c8f9a33488bbe796b53fb6b9c4205ddbf167/sentencepiece-0.2.1-cp313-cp313t-win_amd64.whl", hash = "sha256:33f068c9382dc2e7c228eedfd8163b52baa86bb92f50d0488bf2b7da7032e484", size = 1057555, upload-time = "2025-08-12T07:00:18.573Z" },
{ url = "https://files.pythonhosted.org/packages/88/14/5aee0bf0864df9bd82bd59e7711362908e4935e3f9cdc1f57246b5d5c9b9/sentencepiece-0.2.1-cp313-cp313t-win_arm64.whl", hash = "sha256:b3616ad246f360e52c85781e47682d31abfb6554c779e42b65333d4b5f44ecc0", size = 1036042, upload-time = "2025-08-12T07:00:20.209Z" },
{ url = "https://files.pythonhosted.org/packages/24/9c/89eb8b2052f720a612478baf11c8227dcf1dc28cd4ea4c0c19506b5af2a2/sentencepiece-0.2.1-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:5d0350b686c320068702116276cfb26c066dc7e65cfef173980b11bb4d606719", size = 1943147, upload-time = "2025-08-12T07:00:21.809Z" },
{ url = "https://files.pythonhosted.org/packages/82/0b/a1432bc87f97c2ace36386ca23e8bd3b91fb40581b5e6148d24b24186419/sentencepiece-0.2.1-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:c7f54a31cde6fa5cb030370566f68152a742f433f8d2be458463d06c208aef33", size = 1325624, upload-time = "2025-08-12T07:00:23.289Z" },
{ url = "https://files.pythonhosted.org/packages/ea/99/bbe054ebb5a5039457c590e0a4156ed073fb0fe9ce4f7523404dd5b37463/sentencepiece-0.2.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:c83b85ab2d6576607f31df77ff86f28182be4a8de6d175d2c33ca609925f5da1", size = 1253670, upload-time = "2025-08-12T07:00:24.69Z" },
{ url = "https://files.pythonhosted.org/packages/19/ad/d5c7075f701bd97971d7c2ac2904f227566f51ef0838dfbdfdccb58cd212/sentencepiece-0.2.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1855f57db07b51fb51ed6c9c452f570624d2b169b36f0f79ef71a6e6c618cd8b", size = 1316247, upload-time = "2025-08-12T07:00:26.435Z" },
{ url = "https://files.pythonhosted.org/packages/fb/03/35fbe5f3d9a7435eebd0b473e09584bd3cc354ce118b960445b060d33781/sentencepiece-0.2.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:01e6912125cb45d3792f530a4d38f8e21bf884d6b4d4ade1b2de5cf7a8d2a52b", size = 1387894, upload-time = "2025-08-12T07:00:28.339Z" },
{ url = "https://files.pythonhosted.org/packages/dc/aa/956ef729aafb6c8f9c443104c9636489093bb5c61d6b90fc27aa1a865574/sentencepiece-0.2.1-cp314-cp314-win32.whl", hash = "sha256:c415c9de1447e0a74ae3fdb2e52f967cb544113a3a5ce3a194df185cbc1f962f", size = 1096698, upload-time = "2025-08-12T07:00:29.764Z" },
{ url = "https://files.pythonhosted.org/packages/b8/cb/fe400d8836952cc535c81a0ce47dc6875160e5fedb71d2d9ff0e9894c2a6/sentencepiece-0.2.1-cp314-cp314-win_amd64.whl", hash = "sha256:881b2e44b14fc19feade3cbed314be37de639fc415375cefaa5bc81a4be137fd", size = 1155115, upload-time = "2025-08-12T07:00:32.865Z" },
{ url = "https://files.pythonhosted.org/packages/32/89/047921cf70f36c7b6b6390876b2399b3633ab73b8d0cb857e5a964238941/sentencepiece-0.2.1-cp314-cp314-win_arm64.whl", hash = "sha256:2005242a16d2dc3ac5fe18aa7667549134d37854823df4c4db244752453b78a8", size = 1133890, upload-time = "2025-08-12T07:00:34.763Z" },
{ url = "https://files.pythonhosted.org/packages/a1/11/5b414b9fae6255b5fb1e22e2ed3dc3a72d3a694e5703910e640ac78346bb/sentencepiece-0.2.1-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:a19adcec27c524cb7069a1c741060add95f942d1cbf7ad0d104dffa0a7d28a2b", size = 1946081, upload-time = "2025-08-12T07:00:36.97Z" },
{ url = "https://files.pythonhosted.org/packages/77/eb/7a5682bb25824db8545f8e5662e7f3e32d72a508fdce086029d89695106b/sentencepiece-0.2.1-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:e37e4b4c4a11662b5db521def4e44d4d30ae69a1743241412a93ae40fdcab4bb", size = 1327406, upload-time = "2025-08-12T07:00:38.669Z" },
{ url = "https://files.pythonhosted.org/packages/03/b0/811dae8fb9f2784e138785d481469788f2e0d0c109c5737372454415f55f/sentencepiece-0.2.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:477c81505db072b3ab627e7eab972ea1025331bd3a92bacbf798df2b75ea86ec", size = 1254846, upload-time = "2025-08-12T07:00:40.611Z" },
{ url = "https://files.pythonhosted.org/packages/ef/23/195b2e7ec85ebb6a547969f60b723c7aca5a75800ece6cc3f41da872d14e/sentencepiece-0.2.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:010f025a544ef770bb395091d57cb94deb9652d8972e0d09f71d85d5a0816c8c", size = 1315721, upload-time = "2025-08-12T07:00:42.914Z" },
{ url = "https://files.pythonhosted.org/packages/7e/aa/553dbe4178b5f23eb28e59393dddd64186178b56b81d9b8d5c3ff1c28395/sentencepiece-0.2.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:733e59ff1794d26db706cd41fc2d7ca5f6c64a820709cb801dc0ea31780d64ab", size = 1387458, upload-time = "2025-08-12T07:00:44.56Z" },
{ url = "https://files.pythonhosted.org/packages/66/7c/08ff0012507297a4dd74a5420fdc0eb9e3e80f4e88cab1538d7f28db303d/sentencepiece-0.2.1-cp314-cp314t-win32.whl", hash = "sha256:d3233770f78e637dc8b1fda2cd7c3b99ec77e7505041934188a4e7fe751de3b0", size = 1099765, upload-time = "2025-08-12T07:00:46.058Z" },
{ url = "https://files.pythonhosted.org/packages/91/d5/2a69e1ce15881beb9ddfc7e3f998322f5cedcd5e4d244cb74dade9441663/sentencepiece-0.2.1-cp314-cp314t-win_amd64.whl", hash = "sha256:5e4366c97b68218fd30ea72d70c525e6e78a6c0a88650f57ac4c43c63b234a9d", size = 1157807, upload-time = "2025-08-12T07:00:47.673Z" },
{ url = "https://files.pythonhosted.org/packages/f3/16/54f611fcfc2d1c46cbe3ec4169780b2cfa7cf63708ef2b71611136db7513/sentencepiece-0.2.1-cp314-cp314t-win_arm64.whl", hash = "sha256:105e36e75cbac1292642045458e8da677b2342dcd33df503e640f0b457cb6751", size = 1136264, upload-time = "2025-08-12T07:00:49.485Z" },
]
[[package]]