mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
Compare commits
2 Commits
feat/smolv
...
feat/langu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1e7c0d6aa1 | ||
|
|
920c6ef5a2 |
@@ -106,8 +106,7 @@ inside the script — every flag in there maps directly onto a CLI flag of
|
||||
## Style-to-recipe consumer mapping
|
||||
|
||||
The pipeline's outputs are designed to be consumed by recipes (see
|
||||
[Language Columns and Recipes](./language_and_recipes)) — for the
|
||||
canonical PI052 blend `src/lerobot/configs/recipes/subtask_mem_vqa_speech.yaml`:
|
||||
[Language Columns and Recipes](./language_and_recipes)) — typically:
|
||||
|
||||
- low-level / high-level / memory-update branches consume
|
||||
`subtask`/`plan`/`memory` from `language_persistent`.
|
||||
|
||||
@@ -141,11 +141,6 @@ sample["target_message_indices"]
|
||||
|
||||
The renderer does not apply a tokenizer chat template. Policy processors decide how to serialize the messages for their backbone, which keeps the same dataset usable across SmolVLA, Pi0.5, and any future VLM that expects OpenAI-style chat messages.
|
||||
|
||||
## Blends
|
||||
|
||||
Blend recipes select one weighted sub-recipe deterministically from the sample index.
|
||||
`recipes/subtasks_vqa.yaml` trains the core blend — high-level subtask prediction, low-level execution, and VQA. `recipes/subtask_mem_vqa_speech.yaml` is the fuller variant that also adds memory updates and spoken interjection responses.
|
||||
|
||||
## Graceful absence
|
||||
|
||||
If both language columns are missing, `None`, or empty, `RenderMessagesStep` is a no-op.
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6 MoE).
|
||||
|
||||
Spawns one ``h200x4`` job that:
|
||||
Spawns one ``h200x2`` job that:
|
||||
|
||||
1. installs this branch of ``lerobot`` plus the annotation extras,
|
||||
2. boots four vllm servers (one per H200) with Qwen3.6-35B-A3B-FP8,
|
||||
3. runs the plan + vqa modules across the dataset in free-form
|
||||
mode — phase 0 (canonical vocabulary discovery) is disabled so
|
||||
every episode's subtasks + memory are generated independently;
|
||||
interjections is also disabled, which short-circuits the
|
||||
plan_update phase that depends on it,
|
||||
2. boots two vllm servers (one per GPU) with Qwen3.6-35B-A3B-FP8,
|
||||
3. runs the plan / interjections / vqa modules across the dataset
|
||||
in free-form mode (phase 0 canonical-vocabulary discovery is
|
||||
disabled — each episode generates its own subtasks + memory),
|
||||
4. uploads the annotated dataset to ``--dest_repo_id`` (when set)
|
||||
or back to ``--repo_id``.
|
||||
|
||||
Re-enable phase 0 with ``--vocabulary.enabled=true`` (optionally
|
||||
``--vocabulary.sample_episodes=N``) when the dataset is homogeneous
|
||||
enough to share one subtask + memory vocabulary across all episodes.
|
||||
|
||||
Usage:
|
||||
|
||||
HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py
|
||||
@@ -38,84 +40,50 @@ CMD = (
|
||||
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
|
||||
"export VLLM_VIDEO_BACKEND=pyav && "
|
||||
"lerobot-annotate "
|
||||
"--repo_id=pepijn223/robocasa_smoke_2atomic_v3 "
|
||||
"--dest_repo_id=pepijn223/robocasa_smoke_2atomic_v3_annotated "
|
||||
"--repo_id=imstevenpmwork/super_poulain_draft "
|
||||
"--dest_repo_id=pepijn223/super_poulain_vocab "
|
||||
"--push_to_hub=true "
|
||||
"--vlm.backend=openai "
|
||||
"--vlm.model_id=Qwen/Qwen3.6-35B-A3B-FP8 "
|
||||
"--vlm.parallel_servers=4 "
|
||||
"--vlm.num_gpus=4 "
|
||||
"--vlm.parallel_servers=2 "
|
||||
"--vlm.num_gpus=2 "
|
||||
'--vlm.serve_command="vllm serve Qwen/Qwen3.6-35B-A3B-FP8 '
|
||||
# 4× the context (32768 → 131072) so long episodes at 1 Hz fit even
|
||||
# at full Qwen vision resolution: 90 frames @ ~700 vision tokens/frame
|
||||
# ≈ 63 k tokens, comfortably under 131 k. On 1× H200 (144 GB) the
|
||||
# 35B-FP8 model leaves plenty of room for the bigger KV cache.
|
||||
"--tensor-parallel-size 1 --max-model-len 131072 "
|
||||
'--gpu-memory-utilization 0.85 --uvicorn-log-level warning --port {port}" '
|
||||
"--tensor-parallel-size 1 --max-model-len 32768 "
|
||||
'--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" '
|
||||
"--vlm.serve_ready_timeout_s=1800 "
|
||||
"--vlm.client_concurrency=256 "
|
||||
"--vlm.client_concurrency=128 "
|
||||
"--vlm.max_new_tokens=512 "
|
||||
# Low temperature for VQA: bbox + keypoint are coordinate-regression
|
||||
# tasks where sampling noise directly degrades localization
|
||||
# (overlapping boxes, drifted points). 0.2 keeps the model decisive
|
||||
# while still letting question/label phrasing vary across frames.
|
||||
"--vlm.temperature=0.2 "
|
||||
"--executor.episode_parallelism=64 "
|
||||
"--vlm.temperature=0.7 "
|
||||
"--executor.episode_parallelism=16 "
|
||||
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}' "
|
||||
# Whole-scene agentview is the right choice for subtask reasoning +
|
||||
# VQA on robocasa: the wrist (``robot0_eye_in_hand``) usually only
|
||||
# sees the gripper + nearby object, which hurts "what is happening
|
||||
# in this episode" decomposition. Override per-dataset if your
|
||||
# cameras are named differently (inspect ``meta/info.json``).
|
||||
"--vlm.camera_key=observation.images.robot0_agentview_left "
|
||||
# Phase 0 — canonical vocabulary discovery DISABLED. This dataset's
|
||||
# episodes span heterogeneous tasks/scenes, so a single shared
|
||||
# subtask + memory vocabulary would be too narrow — each episode
|
||||
# generates its subtasks + memory free-form instead.
|
||||
"--vlm.camera_key=observation.images.wrist "
|
||||
# Phase 0 — canonical vocabulary discovery DISABLED by default.
|
||||
# Heterogeneous datasets (different tasks/scenes across episodes)
|
||||
# don't share a single small subtask + memory vocabulary, so each
|
||||
# episode generates its subtasks + memory free-form. Flip to
|
||||
# ``--vocabulary.enabled=true`` (optionally ``--vocabulary.sample_episodes=N``)
|
||||
# for homogeneous datasets where a shared canonical vocabulary
|
||||
# helps the downstream policy.
|
||||
"--vocabulary.enabled=false "
|
||||
# Phase 1 — plan module (subtasks + plan + memory + task_aug).
|
||||
"--plan.enabled=true "
|
||||
"--plan.frames_per_second=1.0 "
|
||||
"--plan.use_video_url=true "
|
||||
"--plan.use_video_url_fps=1.0 "
|
||||
# Force coarse, composite subtasks (``pick up X`` = approach + grasp
|
||||
# + lift in one span, not three). 3 s is large enough to host a
|
||||
# full grasp-or-place composite at typical 20 fps robocasa speeds;
|
||||
# any candidate span shorter than this gets merged into a neighbour
|
||||
# by the prompt's authoring rules (see module_1_subtasks.txt).
|
||||
"--plan.min_subtask_seconds=3.0 "
|
||||
# Cap so the VLM can't drift into micro-segmentation. Combined with
|
||||
# the composite-action rules in the prompt, this targets ~3-6
|
||||
# meaningful spans per episode for typical pick-and-place demos.
|
||||
"--plan.plan_max_steps=9 "
|
||||
# ``off`` keeps the dataset's canonical ``record.episode_task`` as-is
|
||||
# — no per-episode VLM "what is this video about" call. Switch to
|
||||
# ``if_short`` (default) only if some episodes have placeholder /
|
||||
# missing canonical tasks; ``always`` overrides every episode's task.
|
||||
"--plan.derive_task_from_video=off "
|
||||
# 0 disables the task_aug pass entirely (see PlanConfig.n_task_rephrasings
|
||||
# docstring) — no per-episode paraphrase generation, no task_aug rows.
|
||||
"--plan.n_task_rephrasings=0 "
|
||||
# Phase 2 — interjections OFF (also skips phase 3 plan_update,
|
||||
# see executor.py:_run_plan_update_phase guard).
|
||||
"--interjections.enabled=false "
|
||||
# Phase 4 — general VQA. K=1 keeps each VQA answer on its own
|
||||
# emission frame (no temporal smear); see VqaConfig.K docstring.
|
||||
# 3 Hz cadence: at 20 fps source, that's a VQA tick every ~7 frames.
|
||||
# NOTE: VQA emits per-camera, so for robocasa (3 cameras) each tick
|
||||
# produces 3 (user, assistant) row pairs — total call volume ~= 3 *
|
||||
# 3 Hz * mean_episode_seconds * n_episodes.
|
||||
"--vqa.enabled=true "
|
||||
"--vqa.K=1 "
|
||||
"--vqa.vqa_emission_hz=3.0"
|
||||
"--plan.derive_task_from_video=always "
|
||||
"--plan.n_task_rephrasings=30 "
|
||||
# Phase 2 — interjections + speech.
|
||||
"--interjections.max_interjections_per_episode=6 "
|
||||
# Phase 4 — general VQA.
|
||||
"--vqa.K=3 "
|
||||
"--vqa.vqa_emission_hz=1.0"
|
||||
)
|
||||
|
||||
job = run_job(
|
||||
image="vllm/vllm-openai:latest",
|
||||
command=["bash", "-c", CMD],
|
||||
flavor="h200x4",
|
||||
flavor="h200x2",
|
||||
secrets={"HF_TOKEN": token},
|
||||
timeout="24h",
|
||||
timeout="2h",
|
||||
)
|
||||
print(f"Job URL: {job.url}")
|
||||
print(f"Job ID: {job.id}")
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench-pi052-kernels
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=01:30:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_kernels_%j.out
|
||||
|
||||
# HF kernels exploration via Liger's apply_liger_kernel_to_paligemma.
|
||||
# Baseline (SDPA, no kernels) vs. per-subkernel ablations vs. all-on.
|
||||
# Same harness as bench_pi052_step.py — only the --kernels flag varies
|
||||
# across runs so any delta is attributable to the patched op(s).
|
||||
#
|
||||
# Subkernels exercised: rope, rms_norm, geglu, layer_norm.
|
||||
# Skipped: cross_entropy / fused_linear_cross_entropy — pi052 calls
|
||||
# F.cross_entropy directly and bypasses PaliGemma's forward, so those
|
||||
# patches wouldn't fire without model-code changes (separate PR).
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||
|
||||
# /fsx triton cache is shared across nodes with different glibc versions
|
||||
# — kernels built on one node trip GLIBC_2.34-not-found on another. Use
|
||||
# a node-local cache per job to side-step that.
|
||||
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
|
||||
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
|
||||
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
|
||||
|
||||
echo "=== Node: $(hostname) ==="
|
||||
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
|
||||
ldd --version | head -1
|
||||
|
||||
# Liger isn't in our standard env yet — install on the compute node so
|
||||
# the slurm log captures the exact version that produced the numbers.
|
||||
python -m pip install -q --upgrade 'liger-kernel'
|
||||
python - <<'PY' || true
|
||||
from importlib.metadata import version, PackageNotFoundError
|
||||
try:
|
||||
print("liger-kernel", version("liger-kernel"))
|
||||
except PackageNotFoundError:
|
||||
print("liger-kernel: not importable")
|
||||
import liger_kernel.transformers as t
|
||||
print("apply_liger_kernel_to_paligemma:", hasattr(t, "apply_liger_kernel_to_paligemma"))
|
||||
PY
|
||||
|
||||
run() {
|
||||
echo
|
||||
echo "--- $* ---"
|
||||
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||
}
|
||||
|
||||
# -- Baseline (no kernels) at the BS we actually train at. --
|
||||
run --attn sdpa --batch-size 8 --kernels none
|
||||
run --attn sdpa --batch-size 16 --kernels none
|
||||
|
||||
# -- Per-subkernel ablations at BS=16 to isolate each contributor. --
|
||||
run --attn sdpa --batch-size 16 --kernels rms_norm
|
||||
run --attn sdpa --batch-size 16 --kernels geglu
|
||||
run --attn sdpa --batch-size 16 --kernels layer_norm
|
||||
run --attn sdpa --batch-size 16 --kernels rope
|
||||
|
||||
# -- All-on, both BS to compare against the matched baselines above. --
|
||||
run --attn sdpa --batch-size 8 --kernels all
|
||||
run --attn sdpa --batch-size 16 --kernels all
|
||||
|
||||
# -- Headroom check: does kernels-all let BS=24 fit (baseline OOMs near here)? --
|
||||
run --attn sdpa --batch-size 24 --kernels none
|
||||
run --attn sdpa --batch-size 24 --kernels all
|
||||
@@ -1,338 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Benchmark ``PI052Policy.forward + backward`` on a single GPU.
|
||||
|
||||
Compares the new SDPA attention path against the eager baseline by
|
||||
monkeypatching ``sdpa_attention_forward`` before the first model
|
||||
forward — so both runs share identical Q/K/V plumbing and only the
|
||||
attention kernel differs. Reports steps/sec and peak GPU memory.
|
||||
|
||||
SLURM-only:
|
||||
|
||||
sbatch examples/benchmark/bench_pi052_step.slurm
|
||||
|
||||
Or one-off:
|
||||
|
||||
srun --partition=hopper-prod --qos=high --gpus=1 --time=15 \\
|
||||
python examples/benchmark/bench_pi052_step.py --attn sdpa --batch-size 8
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _maybe_patch_eager() -> None:
|
||||
"""Swap ``sdpa_attention_forward`` for the original eager forward.
|
||||
|
||||
Must be called BEFORE PI052Policy is instantiated — the layer
|
||||
compute functions resolve the symbol at call time (module-level
|
||||
lookup), so this patch covers both pi05 and pi052 KI paths."""
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
from lerobot.policies.pi05 import modeling_pi05
|
||||
|
||||
modeling_pi05.sdpa_attention_forward = modeling_gemma.eager_attention_forward
|
||||
|
||||
|
||||
_LIGER_SUBKERNELS = ("rope", "rms_norm", "geglu", "layer_norm")
|
||||
|
||||
|
||||
def _maybe_patch_liger(spec: str) -> dict:
|
||||
"""Globally patch PaliGemma/Gemma/Siglip modules with Liger Triton kernels.
|
||||
|
||||
Must be called BEFORE PI052Policy is instantiated — Liger replaces
|
||||
classes inside ``transformers.models.{gemma,gemma2,siglip,paligemma}``,
|
||||
so any model built after the call picks up the fused forwards.
|
||||
|
||||
``spec`` is a comma-separated subset of {rope, rms_norm, geglu,
|
||||
layer_norm} (also ``all`` and ``none``). ``cross_entropy`` and
|
||||
``fused_linear_cross_entropy`` are intentionally skipped — pi052's
|
||||
losses use ``F.cross_entropy`` directly (not ``nn.CrossEntropyLoss``)
|
||||
and never traverse ``PaliGemmaForConditionalGeneration.forward``,
|
||||
so neither patch would fire without invasive model-code changes.
|
||||
"""
|
||||
enabled = dict.fromkeys(_LIGER_SUBKERNELS, False)
|
||||
if spec in ("", "none"):
|
||||
return enabled
|
||||
tokens = [t.strip() for t in spec.split(",") if t.strip()]
|
||||
if tokens == ["all"]:
|
||||
enabled = dict.fromkeys(_LIGER_SUBKERNELS, True)
|
||||
else:
|
||||
for t in tokens:
|
||||
if t not in enabled:
|
||||
raise SystemExit(f"Unknown liger subkernel: {t!r}. Choose from {_LIGER_SUBKERNELS} or 'all'.")
|
||||
enabled[t] = True
|
||||
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_paligemma
|
||||
|
||||
apply_liger_kernel_to_paligemma(
|
||||
rope=enabled["rope"],
|
||||
rms_norm=enabled["rms_norm"],
|
||||
geglu=enabled["geglu"],
|
||||
layer_norm=enabled["layer_norm"],
|
||||
cross_entropy=False,
|
||||
fused_linear_cross_entropy=False,
|
||||
)
|
||||
return enabled
|
||||
|
||||
|
||||
def _maybe_patch_flex() -> None:
|
||||
"""Swap ``sdpa_attention_forward`` for a FlexAttention-backed forward.
|
||||
|
||||
Experimental: builds a per-call ``score_mod`` from the additive
|
||||
mask and dispatches to a compiled ``flex_attention`` kernel.
|
||||
|
||||
Known issue on torch 2.7.1: dynamo errors out with
|
||||
``FlexAttentionHigherOrderVariable() has no type`` when the
|
||||
``score_mod`` closure captures a per-call bias tensor. A proper
|
||||
port needs ``create_block_mask(mask_mod, ...)`` plumbed at the
|
||||
PI05Pytorch.forward level so a BlockMask object can be passed
|
||||
down to the layer compute, not a per-call closure. Left as
|
||||
future work; keep this stub for benchmark experimentation."""
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import flex_attention
|
||||
|
||||
from lerobot.policies.pi05 import modeling_pi05
|
||||
|
||||
compiled_flex = torch.compile(flex_attention, dynamic=True)
|
||||
|
||||
def flex_forward(module, query, key, value, attention_mask, scaling, dropout=0.0):
|
||||
n_rep = module.num_key_value_groups
|
||||
if n_rep > 1:
|
||||
key = key.repeat_interleave(n_rep, dim=1)
|
||||
value = value.repeat_interleave(n_rep, dim=1)
|
||||
|
||||
bias = attention_mask # (B, 1, Lq, Lk) additive
|
||||
|
||||
def score_mod(score, b, h, q_idx, kv_idx):
|
||||
return score + bias[b, 0, q_idx, kv_idx]
|
||||
|
||||
attn_output = compiled_flex(query, key, value, score_mod=score_mod, scale=scaling)
|
||||
return attn_output.transpose(1, 2).contiguous(), None
|
||||
|
||||
modeling_pi05.sdpa_attention_forward = flex_forward
|
||||
|
||||
|
||||
def _build_policy(args, device: torch.device):
|
||||
"""Random-init PI052Policy at production-relevant shapes."""
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.pi052.configuration_pi052 import PI052Config
|
||||
from lerobot.policies.pi052.modeling_pi052 import PI052Policy
|
||||
|
||||
# Production has ``unfreeze_lm_head=True`` + ``text_loss_weight>0``,
|
||||
# which flips ``train_expert_only=False`` in __post_init__ and
|
||||
# makes the whole PaliGemma + Gemma-expert stack trainable. We
|
||||
# mirror that here so the optimizer-state count reflects reality;
|
||||
# the loss path still goes through ``PI05Policy.forward`` because
|
||||
# ``text_labels`` / FAST tokens are absent from the synthetic batch
|
||||
# (see ``PI052Policy.forward`` early-return).
|
||||
config = PI052Config(
|
||||
max_action_dim=args.action_dim,
|
||||
max_state_dim=args.state_dim,
|
||||
dtype=args.dtype,
|
||||
knowledge_insulation=args.knowledge_insulation,
|
||||
text_loss_weight=1e-3 if args.train_full else 0.0,
|
||||
flow_loss_weight=1.0,
|
||||
enable_fast_action_loss=False,
|
||||
unfreeze_lm_head=args.train_full,
|
||||
tokenizer_max_length=args.lang_tokens,
|
||||
device="cuda",
|
||||
compile_model=args.compile_model,
|
||||
compile_mode=args.compile_mode,
|
||||
)
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(args.state_dim,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(args.action_dim,)),
|
||||
}
|
||||
policy = PI052Policy(config)
|
||||
policy.to(device)
|
||||
if args.gradient_checkpointing:
|
||||
policy.model.gradient_checkpointing_enable()
|
||||
policy.train()
|
||||
return policy, config
|
||||
|
||||
|
||||
def _build_batch(args, config, device: torch.device) -> dict:
|
||||
"""Synthetic batch matching the training-loop input contract."""
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
)
|
||||
|
||||
B = args.batch_size
|
||||
L = args.lang_tokens
|
||||
return {
|
||||
OBS_LANGUAGE_TOKENS: torch.randint(0, 250000, (B, L), device=device),
|
||||
OBS_LANGUAGE_ATTENTION_MASK: torch.ones(B, L, dtype=torch.bool, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(B, 3, 224, 224, device=device),
|
||||
"observation.images.base_0_rgb_padding_mask": torch.ones(B, dtype=torch.bool, device=device),
|
||||
"observation.state": torch.randn(B, args.state_dim, device=device),
|
||||
ACTION: torch.randn(B, config.chunk_size, args.action_dim, device=device),
|
||||
"action_is_pad": torch.zeros(B, config.chunk_size, dtype=torch.bool, device=device),
|
||||
"task": ["bench task"] * B,
|
||||
}
|
||||
|
||||
|
||||
def _step(policy, batch, optimizer=None) -> torch.Tensor:
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
if optimizer is not None:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
else:
|
||||
for p in policy.parameters():
|
||||
if p.grad is not None:
|
||||
p.grad = None
|
||||
return loss.detach()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--attn", choices=["sdpa", "eager", "flex"], default="sdpa")
|
||||
parser.add_argument(
|
||||
"--kernels",
|
||||
default="none",
|
||||
help=(
|
||||
"Liger sub-kernels to enable, comma-separated. Choose from "
|
||||
f"{_LIGER_SUBKERNELS} or use 'all' / 'none' (default). Applied "
|
||||
"via apply_liger_kernel_to_paligemma() BEFORE model build."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compile",
|
||||
dest="compile_model",
|
||||
action="store_true",
|
||||
help="Set policy.config.compile_model=True (torch.compile the forward).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compile-mode",
|
||||
default="default",
|
||||
help="torch.compile mode (default | reduce-overhead | max-autotune).",
|
||||
)
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument("--warmup", type=int, default=8)
|
||||
parser.add_argument("--steps", type=int, default=40)
|
||||
parser.add_argument("--lang-tokens", type=int, default=512)
|
||||
parser.add_argument("--dtype", choices=["bfloat16", "float32"], default="bfloat16")
|
||||
parser.add_argument("--action-dim", type=int, default=14)
|
||||
parser.add_argument("--state-dim", type=int, default=14)
|
||||
parser.add_argument("--knowledge-insulation", action="store_true", default=True)
|
||||
parser.add_argument(
|
||||
"--gradient-checkpointing",
|
||||
dest="gradient_checkpointing",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optimizer",
|
||||
choices=["none", "adamw", "adamw_fused"],
|
||||
default="adamw_fused",
|
||||
help=(
|
||||
"Whether to include an AdamW step in the timed iteration. "
|
||||
"'none' mirrors the fwd+bwd-only original bench; 'adamw' / "
|
||||
"'adamw_fused' add the realistic ~2x param-bytes optimizer "
|
||||
"state and ``optimizer.step()`` cost."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train-full",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help=(
|
||||
"Mirror production: unfreeze the PaliGemma backbone (full "
|
||||
"~3B trainable params) instead of training only the 300M "
|
||||
"action expert."
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise SystemExit("Benchmark requires CUDA; submit via slurm (srun/sbatch).")
|
||||
|
||||
if args.attn == "eager":
|
||||
_maybe_patch_eager()
|
||||
elif args.attn == "flex":
|
||||
_maybe_patch_flex()
|
||||
|
||||
liger_flags = _maybe_patch_liger(args.kernels)
|
||||
|
||||
device = torch.device("cuda")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
policy, config = _build_policy(args, device)
|
||||
batch = _build_batch(args, config, device)
|
||||
|
||||
optimizer = None
|
||||
trainable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
if args.optimizer != "none":
|
||||
trainable = [p for p in policy.parameters() if p.requires_grad]
|
||||
optimizer = torch.optim.AdamW(
|
||||
trainable, lr=5e-5, fused=(args.optimizer == "adamw_fused")
|
||||
)
|
||||
|
||||
for _ in range(args.warmup):
|
||||
_step(policy, batch, optimizer)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
starter = torch.cuda.Event(enable_timing=True)
|
||||
ender = torch.cuda.Event(enable_timing=True)
|
||||
starter.record()
|
||||
for _ in range(args.steps):
|
||||
_step(policy, batch, optimizer)
|
||||
ender.record()
|
||||
torch.cuda.synchronize()
|
||||
total_ms = starter.elapsed_time(ender)
|
||||
step_ms = total_ms / args.steps
|
||||
peak_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
optim_gb = 0.0
|
||||
if optimizer is not None:
|
||||
for st in optimizer.state.values():
|
||||
for v in st.values():
|
||||
if torch.is_tensor(v):
|
||||
optim_gb += v.numel() * v.element_size() / (1024**3)
|
||||
|
||||
liger_on = ",".join(k for k, v in liger_flags.items() if v) or "none"
|
||||
name = (
|
||||
f"{args.attn:>5} | BS={args.batch_size} | L={args.lang_tokens} | "
|
||||
f"KI={args.knowledge_insulation} | GC={args.gradient_checkpointing} | "
|
||||
f"compile={args.compile_model} | liger={liger_on} | opt={args.optimizer} | dtype={args.dtype}"
|
||||
)
|
||||
print(
|
||||
f"{name}\n step_ms={step_ms:.1f} steps/sec={1000.0 / step_ms:.3f} "
|
||||
f"peak_mem={peak_gb:.2f} GiB optim_state={optim_gb:.2f} GiB "
|
||||
f"trainable_params={trainable_params / 1e9:.2f}B"
|
||||
)
|
||||
|
||||
del policy, batch
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -1,36 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench-pi052-attn
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=00:30:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_%j.out
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||
|
||||
echo "=== Node: $(hostname) ==="
|
||||
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
|
||||
|
||||
python -c "import torch; print('torch', torch.__version__, 'cuda', torch.version.cuda)"
|
||||
|
||||
run() {
|
||||
echo
|
||||
echo "--- $* ---"
|
||||
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||
}
|
||||
|
||||
# Attention parity benchmark — same shapes, different attention kernel.
|
||||
run --attn eager --batch-size 8
|
||||
run --attn sdpa --batch-size 8
|
||||
|
||||
# Headroom benchmark — does SDPA's memory cut allow a bigger micro-batch?
|
||||
run --attn sdpa --batch-size 12
|
||||
run --attn sdpa --batch-size 16
|
||||
run --attn sdpa --batch-size 24
|
||||
@@ -1,39 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench-pi052-v2
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=00:45:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v2_%j.out
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||
|
||||
echo "=== Node: $(hostname) ==="
|
||||
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
|
||||
|
||||
run() {
|
||||
echo
|
||||
echo "--- $* ---"
|
||||
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||
}
|
||||
|
||||
# A: GC ON — see if the selective-AC change (one less recompute level)
|
||||
# narrows the eager vs SDPA gap at BS=8.
|
||||
run --attn eager --batch-size 8
|
||||
run --attn sdpa --batch-size 8
|
||||
|
||||
# B: GC OFF — isolate the raw attention-kernel cost & memory delta.
|
||||
run --attn eager --batch-size 4 --no-gradient-checkpointing
|
||||
run --attn sdpa --batch-size 4 --no-gradient-checkpointing
|
||||
|
||||
# C: SDPA + GC headroom sweep — where does it OOM?
|
||||
run --attn sdpa --batch-size 16
|
||||
run --attn sdpa --batch-size 24
|
||||
run --attn sdpa --batch-size 32
|
||||
@@ -1,36 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench-pi052-v3
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=00:45:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v3_%j.out
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||
|
||||
echo "=== Node: $(hostname) ==="
|
||||
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
|
||||
|
||||
run() {
|
||||
echo
|
||||
echo "--- $* ---"
|
||||
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||
}
|
||||
|
||||
# Compile sweep: does torch.compile + SDPA give a non-trivial boost on
|
||||
# top of the bare SDPA path?
|
||||
run --attn sdpa --batch-size 8 --compile
|
||||
run --attn sdpa --batch-size 16 --compile
|
||||
|
||||
# FlexAttention sweep (experimental): score_mod adds the additive bias
|
||||
# in-kernel; expect a long first-step compile, then SDPA-or-better steady
|
||||
# state.
|
||||
run --attn flex --batch-size 8
|
||||
run --attn flex --batch-size 16
|
||||
@@ -1,41 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench-pi052-v4
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=01:00:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v4_%j.out
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||
|
||||
# /fsx triton cache is shared across nodes with different glibc versions
|
||||
# — kernels built on one node trip GLIBC_2.34-not-found on another. Use
|
||||
# a node-local cache per job to side-step that.
|
||||
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
|
||||
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
|
||||
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
|
||||
|
||||
echo "=== Node: $(hostname) ==="
|
||||
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader
|
||||
ldd --version | head -1
|
||||
|
||||
run() {
|
||||
echo
|
||||
echo "--- $* ---"
|
||||
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||
}
|
||||
|
||||
# compile path on top of SDPA + selective AC
|
||||
run --attn sdpa --batch-size 8 --compile
|
||||
run --attn sdpa --batch-size 16 --compile
|
||||
|
||||
# FlexAttention experimental
|
||||
run --attn flex --batch-size 8
|
||||
run --attn flex --batch-size 16
|
||||
@@ -1,33 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench-pi052-v5
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=00:45:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v5_%j.out
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
|
||||
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
|
||||
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
|
||||
|
||||
echo "=== Node: $(hostname) ==="
|
||||
|
||||
run() {
|
||||
echo
|
||||
echo "--- $* ---"
|
||||
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||
}
|
||||
|
||||
# compile_mode=default (graph-only, no autotune) is the right knob with
|
||||
# gradient checkpointing — max-autotune in v4 was 2x slower than no-compile.
|
||||
run --attn sdpa --batch-size 8 --compile --compile-mode default
|
||||
run --attn sdpa --batch-size 16 --compile --compile-mode default
|
||||
run --attn sdpa --batch-size 8 --compile --compile-mode reduce-overhead
|
||||
@@ -1,31 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench-pi052-v6-bs32
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=00:30:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v6_%j.out
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
|
||||
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
|
||||
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
|
||||
|
||||
echo "=== Node: $(hostname) ==="
|
||||
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
|
||||
|
||||
run() {
|
||||
echo
|
||||
echo "--- $* ---"
|
||||
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||
}
|
||||
|
||||
# BS=32 with the production settings (SDPA + compile=default).
|
||||
run --attn sdpa --batch-size 32 --compile --compile-mode default
|
||||
@@ -1,39 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench-pi052-v7-opt
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=00:45:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v7_%j.out
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
|
||||
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
|
||||
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
|
||||
|
||||
echo "=== Node: $(hostname) ==="
|
||||
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
|
||||
|
||||
run() {
|
||||
echo
|
||||
echo "--- $* ---"
|
||||
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||
}
|
||||
|
||||
# Realistic full-step memory: fwd + bwd + AdamW step. The original
|
||||
# sweep was fwd+bwd-only and undercounted memory by the optimizer-
|
||||
# state size (~2x param bytes for AdamW). This run confirms BS=16
|
||||
# and BS=32 still fit with the optimizer in residency.
|
||||
run --attn sdpa --batch-size 16 --compile --compile-mode default --optimizer adamw_fused
|
||||
run --attn sdpa --batch-size 32 --compile --compile-mode default --optimizer adamw_fused
|
||||
|
||||
# Without compile, in case the production cluster has compile issues.
|
||||
run --attn sdpa --batch-size 16 --optimizer adamw_fused
|
||||
run --attn sdpa --batch-size 32 --optimizer adamw_fused
|
||||
@@ -1,36 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=bench-pi052-v8-bs40-dtype
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --time=00:45:00
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --gpus-per-task=1
|
||||
#SBATCH --output=/fsx/pepijn/logs/bench_pi052_v8_%j.out
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
|
||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
||||
export TRITON_CACHE_DIR="/tmp/triton_${SLURM_JOB_ID}"
|
||||
export TORCHINDUCTOR_CACHE_DIR="/tmp/torchinductor_${SLURM_JOB_ID}"
|
||||
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"
|
||||
|
||||
echo "=== Node: $(hostname) ==="
|
||||
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
|
||||
|
||||
run() {
|
||||
echo
|
||||
echo "--- $* ---"
|
||||
python examples/benchmark/bench_pi052_step.py "$@" || true
|
||||
}
|
||||
|
||||
# Confirm BS=40 fits on a single H100 with the optimizer in residency.
|
||||
run --attn sdpa --batch-size 40 --compile --compile-mode default --optimizer adamw_fused
|
||||
|
||||
# Dtype A/B at modest batch — fp32 needs ~2x the memory of bf16, so we
|
||||
# drop to BS=4 to keep both runs comparable instead of OOMing fp32.
|
||||
run --attn sdpa --batch-size 4 --optimizer adamw_fused --dtype bfloat16
|
||||
run --attn sdpa --batch-size 4 --optimizer adamw_fused --dtype float32
|
||||
@@ -1,29 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
fsdp_config:
|
||||
fsdp_activation_checkpointing: false
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: false
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_transformer_layer_cls_to_wrap: GemmaDecoderLayer,SiglipEncoderLayer
|
||||
fsdp_use_orig_params: true
|
||||
fsdp_version: 2
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
File diff suppressed because it is too large
Load Diff
@@ -85,11 +85,6 @@ dependencies = [
|
||||
"termcolor>=2.4.0,<4.0.0",
|
||||
"tqdm>=4.66.0,<5.0.0",
|
||||
|
||||
# Training utilities
|
||||
# EMA of policy parameters (Diffusion Policy / pi05 style). Tiny
|
||||
# pure-python dependency — preferred over a hand-rolled implementation.
|
||||
"ema-pytorch>=0.7.7,<1.0.0",
|
||||
|
||||
# Build tools (required by opencv-python-headless on some platforms)
|
||||
"cmake>=3.29.0.1,<4.2.0",
|
||||
"setuptools>=71.0.0,<81.0.0",
|
||||
@@ -147,7 +142,6 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
# (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available.
|
||||
placo-dep = ["placo>=0.9.6,<0.9.16"]
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
sentencepiece-dep = ["sentencepiece>=0.2.0,<0.3.0"] # FAST action tokenizer backend (pi052, pi0_fast)
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
@@ -203,7 +197,7 @@ wallx = [
|
||||
"torchdiffeq>=0.2.4,<0.3.0",
|
||||
"lerobot[qwen-vl-utils-dep]",
|
||||
]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]", "lerobot[sentencepiece-dep]"]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
|
||||
multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
|
||||
groot = [
|
||||
@@ -237,14 +231,6 @@ annotations = [
|
||||
"vllm>=0.6.0,<1.0.0; sys_platform == 'linux'",
|
||||
]
|
||||
|
||||
# Tool implementations under src/lerobot/tools/. Each tool's dependencies
|
||||
# are isolated so adding a new tool doesn't bloat the base install.
|
||||
# Currently only `say` (Kyutai pocket-tts; CPU-only, ~100M params).
|
||||
tools = [
|
||||
"pocket-tts>=1.0.0,<3.0.0",
|
||||
"scipy>=1.11.0,<2.0.0", # SayTool.output_dir uses scipy.io.wavfile
|
||||
]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"]
|
||||
notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"]
|
||||
@@ -337,8 +323,6 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-annotate="lerobot.scripts.lerobot_annotate:main"
|
||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||
# Interactive hierarchical-VLA runtime for PI052 (PaliGemma backbone).
|
||||
lerobot-pi052-runtime="lerobot.scripts.lerobot_pi052_runtime:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Build a tiny RoboCasa smoke dataset (2 short atomic tasks, all episodes) for
|
||||
# fast end-to-end training validation before the real run.
|
||||
#
|
||||
# Defaults: target/human, OpenStandMixerHead + NavigateKitchen (~1k episodes,
|
||||
# ~131k frames, ~109 min @ 20 fps), 2 SLURM workers on hopper-cpu.
|
||||
#
|
||||
# Override via env: TASKS, REPO_ID, WORK_DIR, WORKERS, CPUS, PARTITION, LOCAL=1.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||
source ~/miniconda3/etc/profile.d/conda.sh
|
||||
conda activate lerobot
|
||||
|
||||
REPO_ID="${REPO_ID:-${HF_USER:?HF_USER is unset}/robocasa_smoke_2atomic_v3}"
|
||||
WORK_DIR="${WORK_DIR:-/fsx/${USER}/robocasa/datasets/v1.0}"
|
||||
ROBOCASA_ROOT="${ROBOCASA_ROOT:-/fsx/${USER}/robocasa}"
|
||||
LOGS_DIR="${LOGS_DIR:-/fsx/${USER}/logs/robocasa}"
|
||||
TASKS="${TASKS:-OpenStandMixerHead NavigateKitchen}"
|
||||
WORKERS="${WORKERS:-2}"
|
||||
CPUS="${CPUS:-8}"
|
||||
PARTITION="${PARTITION:-hopper-cpu}"
|
||||
LOCAL="${LOCAL:-0}"
|
||||
|
||||
ARGS=(
|
||||
examples/port_datasets/slurm_build_robocasa_composite_seen.py
|
||||
--repo-id="$REPO_ID"
|
||||
--work-dir="$WORK_DIR"
|
||||
--robocasa-root="$ROBOCASA_ROOT"
|
||||
--split=target --source=human
|
||||
--tasks $TASKS
|
||||
--workers="$WORKERS"
|
||||
--cpus-per-task="$CPUS"
|
||||
--partition="$PARTITION"
|
||||
--mem-per-cpu=4G
|
||||
--time=04:00:00
|
||||
--logs-dir="$LOGS_DIR"
|
||||
--job-name=port_robocasa_smoke
|
||||
)
|
||||
if [[ "$LOCAL" == "1" ]]; then
|
||||
ARGS+=(--slurm=0)
|
||||
fi
|
||||
|
||||
echo "Smoke dataset: $REPO_ID"
|
||||
echo "Tasks: $TASKS"
|
||||
python "${ARGS[@]}"
|
||||
@@ -134,11 +134,8 @@ class Executor:
|
||||
written = self.writer.write_all(records, staging_dir, root)
|
||||
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
|
||||
|
||||
# Keep meta/info.json aligned with the parquet schema we just wrote
|
||||
# (language columns advertised; canonical ``say`` tool registered for
|
||||
# PI052 / Pi0.5 / dataset-visualizer consumers via
|
||||
# ``LeRobotDatasetMetadata.tools``). Idempotent and additive: existing
|
||||
# user metadata is preserved.
|
||||
# Keep meta/info.json aligned with the parquet schema we just wrote.
|
||||
# Idempotent and additive: existing user metadata is preserved.
|
||||
self._ensure_annotation_metadata_in_info(root)
|
||||
|
||||
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
|
||||
|
||||
@@ -23,7 +23,7 @@ one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
|
||||
generated against that camera's frame and stamped with the matching
|
||||
``camera`` field on the emitted rows. The resolver disambiguates via
|
||||
``camera=...``; recipes that consume VQA do so through one sub-recipe
|
||||
per camera (see ``recipes/subtasks_vqa.yaml``).
|
||||
per camera (see ``recipes/pi05_hirobot.yaml``).
|
||||
|
||||
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
|
||||
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
|
||||
|
||||
@@ -5,40 +5,15 @@ pixel coordinates, keypoints, counts, attributes, and spatial relations.
|
||||
|
||||
The frame shows a robot working on: "{episode_task}".
|
||||
|
||||
QUALITY BAR — read before answering:
|
||||
|
||||
- Only label objects you are highly confident about. If you are not
|
||||
sure what an object is, do NOT include it. A short, certain answer
|
||||
beats a long, speculative one.
|
||||
- For coordinate-grounded answers (bbox, keypoint) only emit a label
|
||||
when you can localize the object *tightly and precisely*. If the
|
||||
object is occluded, ambiguous, off-frame, or you can't pin its
|
||||
extent, return an empty detections list / pick a different object
|
||||
rather than guessing.
|
||||
- Prefer task-relevant objects (the thing the robot is manipulating
|
||||
or interacting with) over background clutter.
|
||||
|
||||
Question types and the EXACT answer JSON shape required for each:
|
||||
|
||||
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
|
||||
"bbox": [x1, y1, x2, y2]}}, ...]}}
|
||||
Pixel coordinates (x_min, y_min, x_max, y_max). Emit
|
||||
AT MOST 3 detections, and *only* the highest-confidence
|
||||
ones — 1 tight, certain detection is preferred over 3
|
||||
loose ones. Each box must be tight (no >10% padding
|
||||
around the object) and the label must be specific
|
||||
("red mug" not "object"). Return an empty list if no
|
||||
object meets the bar.
|
||||
bbox is in pixel coordinates (x_min, y_min, x_max, y_max).
|
||||
ECoT example: "a white cup [124, 25, 176, 113]".
|
||||
|
||||
keypoint => {{"label": "<point>", "point_format": "xy",
|
||||
"point": [x, y]}}
|
||||
Pick ONE high-confidence, precisely-localizable point
|
||||
(e.g. a graspable handle, a button center, the gripper
|
||||
tip). The point must land within a few pixels of the
|
||||
feature. Do not emit a coarse "somewhere on the object"
|
||||
point — pick a different question type if no such
|
||||
point exists in this frame.
|
||||
|
||||
count => {{"label": "<obj>", "count": <int>,
|
||||
"note": "<optional short note>"}}
|
||||
|
||||
@@ -205,149 +205,3 @@ class WandBLogger:
|
||||
|
||||
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")
|
||||
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
|
||||
|
||||
def log_training_examples(
|
||||
self,
|
||||
batch: dict,
|
||||
step: int,
|
||||
*,
|
||||
camera_keys: list[str],
|
||||
n_samples: int = 4,
|
||||
policy=None,
|
||||
predict_actions: bool = False,
|
||||
mode: str = "train",
|
||||
) -> None:
|
||||
"""Push a ``wandb.Table`` of training-example rows for the current batch.
|
||||
|
||||
Each row is one batch element with:
|
||||
* one ``wandb.Image`` column per camera in ``camera_keys`` (CHW or
|
||||
HWC, uint8 or float in [0,1] — auto-detected),
|
||||
* any text fields present in the batch (``task`` / ``subtask`` /
|
||||
``memory`` / ``instruction``),
|
||||
* ground-truth action first/last frame (the action chunk's
|
||||
endpoints — gives a quick sense of trajectory direction),
|
||||
* if ``predict_actions=True`` and ``policy`` is supplied, the model's
|
||||
``predict_action_chunk`` first/last frame alongside.
|
||||
|
||||
This is opt-in via ``--wandb.log_examples_freq=N`` on the CLI; the
|
||||
training loop calls it once every N steps. Cheap to keep on: with
|
||||
N=4 samples and 3 cameras you upload 12 small PNGs per dump and (if
|
||||
enabled) run one extra inference forward pass.
|
||||
"""
|
||||
import logging # noqa: PLC0415
|
||||
import numpy as np # noqa: PLC0415
|
||||
import torch # noqa: PLC0415
|
||||
|
||||
if mode not in {"train", "eval"}:
|
||||
raise ValueError(mode)
|
||||
|
||||
# Batch size — first tensor-like value wins.
|
||||
bsz = next(
|
||||
(int(v.shape[0]) for v in batch.values() if hasattr(v, "shape") and v.ndim > 0),
|
||||
None,
|
||||
)
|
||||
if not bsz:
|
||||
return
|
||||
n = min(int(n_samples), bsz)
|
||||
|
||||
# Optional predicted-action forward pass on the first n samples.
|
||||
pred_actions: np.ndarray | None = None
|
||||
if predict_actions and policy is not None:
|
||||
was_training = policy.training
|
||||
try:
|
||||
policy.eval()
|
||||
sub_batch = {}
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
sub_batch[k] = v[:n]
|
||||
elif isinstance(v, (list, tuple)):
|
||||
sub_batch[k] = list(v[:n])
|
||||
else:
|
||||
sub_batch[k] = v
|
||||
with torch.no_grad():
|
||||
pred = policy.predict_action_chunk(sub_batch)
|
||||
pred_actions = pred.detach().cpu().float().numpy()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning(
|
||||
"log_training_examples: predict_action_chunk failed (%s) — "
|
||||
"skipping predicted-action columns",
|
||||
exc,
|
||||
)
|
||||
pred_actions = None
|
||||
finally:
|
||||
if was_training:
|
||||
policy.train()
|
||||
|
||||
present_cameras = [c for c in camera_keys if c in batch]
|
||||
text_keys = [k for k in ("task", "subtask", "memory", "instruction") if k in batch]
|
||||
|
||||
columns = ["sample"]
|
||||
columns.extend(c.removeprefix("observation.images.") or c for c in present_cameras)
|
||||
columns.extend(text_keys)
|
||||
columns.append("gt_action_first")
|
||||
columns.append("gt_action_last")
|
||||
if pred_actions is not None:
|
||||
columns.append("pred_action_first")
|
||||
columns.append("pred_action_last")
|
||||
|
||||
table = self._wandb.Table(columns=columns)
|
||||
|
||||
def _to_uint8_hwc(t: torch.Tensor) -> np.ndarray:
|
||||
# Strip an outer time dim if present: (T, C, H, W) -> first frame.
|
||||
if t.ndim == 4:
|
||||
t = t[0]
|
||||
# CHW -> HWC.
|
||||
if t.ndim == 3 and t.shape[0] in (1, 3, 4) and t.shape[-1] not in (1, 3, 4):
|
||||
t = t.permute(1, 2, 0)
|
||||
arr = t.detach().cpu().float().numpy()
|
||||
if arr.size and float(arr.max()) <= 1.5:
|
||||
arr = arr * 255.0
|
||||
return np.clip(arr, 0, 255).astype(np.uint8)
|
||||
|
||||
def _action_endpoints(a: torch.Tensor) -> tuple[str, str]:
|
||||
arr = a.detach().cpu().float().numpy()
|
||||
if arr.ndim == 2: # (T, D)
|
||||
return (
|
||||
str(np.round(arr[0], 3).tolist()),
|
||||
str(np.round(arr[-1], 3).tolist()),
|
||||
)
|
||||
if arr.ndim == 1:
|
||||
rounded = np.round(arr, 3).tolist()
|
||||
return (str(rounded), str(rounded))
|
||||
return (str(arr.tolist()), str(arr.tolist()))
|
||||
|
||||
for i in range(n):
|
||||
row: list = [i]
|
||||
for cam in present_cameras:
|
||||
try:
|
||||
row.append(self._wandb.Image(_to_uint8_hwc(batch[cam][i])))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning(
|
||||
"log_training_examples: camera %s sample %d failed (%s)",
|
||||
cam,
|
||||
i,
|
||||
exc,
|
||||
)
|
||||
row.append(None)
|
||||
for tk in text_keys:
|
||||
v = batch[tk]
|
||||
if isinstance(v, (list, tuple)):
|
||||
row.append(str(v[i]) if i < len(v) else "")
|
||||
else:
|
||||
row.append(str(v))
|
||||
action = batch.get("action")
|
||||
if isinstance(action, torch.Tensor) and action.ndim >= 1:
|
||||
first, last = _action_endpoints(action[i])
|
||||
row.append(first)
|
||||
row.append(last)
|
||||
else:
|
||||
row.append("")
|
||||
row.append("")
|
||||
if pred_actions is not None:
|
||||
p = torch.from_numpy(pred_actions[i])
|
||||
pfirst, plast = _action_endpoints(p)
|
||||
row.append(pfirst)
|
||||
row.append(plast)
|
||||
table.add_data(*row)
|
||||
|
||||
self._wandb.log({f"{mode}/examples": table}, step=step)
|
||||
|
||||
@@ -62,72 +62,6 @@ class WandBConfig:
|
||||
run_id: str | None = None
|
||||
mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online'
|
||||
add_tags: bool = True # If True, save configuration as tags in the WandB run.
|
||||
# Periodic training-example dump (independent of ``log_freq``). When > 0,
|
||||
# every ``log_examples_freq`` steps the trainer pushes a ``wandb.Table``
|
||||
# with one row per sampled batch element containing each camera view
|
||||
# (rendered as ``wandb.Image``), any text fields present in the batch
|
||||
# (``task`` / ``subtask`` / ``memory`` / ``instruction``), and the
|
||||
# ground-truth action chunk's first + last frames. Defaults to 5000 — set
|
||||
# to 0 to disable. Only fires when ``enable=True``, so runs without wandb
|
||||
# are unaffected.
|
||||
log_examples_freq: int = 5000
|
||||
# Number of batch elements to include in each example dump.
|
||||
log_examples_n: int = 4
|
||||
# If True (default), also run ``policy.predict_action_chunk`` on the logged
|
||||
# samples (in eval mode, no_grad) and add predicted vs ground-truth action
|
||||
# columns to the table. Costs one extra forward pass per dump — negligible
|
||||
# at the 5k-step default cadence. Set to ``False`` if your policy doesn't
|
||||
# implement ``predict_action_chunk`` or you want to skip the extra forward.
|
||||
log_examples_predict_actions: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class EMAConfig:
|
||||
"""Exponential Moving Average of trainable policy parameters.
|
||||
|
||||
Diffusion / flow-matching policies (Diffusion Policy, π0/π0.5,
|
||||
pi052) benefit substantially from averaging late-training
|
||||
parameter oscillations — see Chi et al. 2023 §V.D. The official
|
||||
JAX openpi trainer ships EMA with ``ema_decay=0.99`` (default) and
|
||||
``0.999`` for its pi05_libero config; the openpi PyTorch port
|
||||
explicitly lists EMA as unsupported, and LeRobot main inherited
|
||||
that gap. Enabling this flag plugs ema-pytorch
|
||||
(https://github.com/lucidrains/ema-pytorch) into the LeRobot
|
||||
training loop with a shadow ``nn.Module`` clone of the policy.
|
||||
|
||||
Cost: 1× model params in fp32 shadow (~13 GB for pi052's 3.3B
|
||||
params) + one elementwise update per training step (~1% step time).
|
||||
|
||||
On by default — matches openpi (JAX) which ships EMA on for every
|
||||
config, and closes the gap with the openpi PyTorch port which
|
||||
explicitly lists EMA as unsupported. Set ``--ema.enable=false`` to
|
||||
disable for short runs / memory-constrained training where the
|
||||
extra fp32 shadow copy is the bottleneck.
|
||||
"""
|
||||
|
||||
enable: bool = True
|
||||
# Target EMA decay β in θ_ema ← β·θ_ema + (1-β)·θ_live (passed to
|
||||
# ema-pytorch as ``beta``).
|
||||
# 0.999 — last ~1000 steps; pi05_libero default in openpi
|
||||
# 0.99 — last ~100 steps; openpi top-level default
|
||||
# 0.75 — very fast EMA (Diffusion Policy original setting)
|
||||
# 0.9999 — very slow EMA (long classification runs)
|
||||
decay: float = 0.999
|
||||
# Skip the first N calls to ``ema.update()``; during this window
|
||||
# the shadow is just a hard copy of the live weights (no averaging).
|
||||
# Lets early-training rapid changes settle before averaging begins.
|
||||
# Maps to ema-pytorch's ``update_after_step`` (NOT a smooth decay
|
||||
# ramp like older lerobot EMA implementations).
|
||||
warmup_steps: int = 0
|
||||
# When True, the periodic eval block uses the EMA shadow model
|
||||
# directly (``ema.ema_model``) instead of the live policy. Standard
|
||||
# practice for diffusion-style policies — eval scores are usually
|
||||
# 1–3% higher than the live policy at the same step.
|
||||
use_for_eval: bool = True
|
||||
# When True, the periodic wandb training-example dump uses the EMA
|
||||
# shadow for the optional predicted-action columns (so what you see
|
||||
# in W&B matches eval behavior).
|
||||
use_for_wandb_examples: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -147,16 +147,7 @@ class TrainingRecipe:
|
||||
return cls.from_dict(data)
|
||||
|
||||
def _validate_message_recipe(self) -> None:
|
||||
"""Ensure every templated binding is known and the recipe supervises something.
|
||||
|
||||
A recipe is valid if it has at least one of:
|
||||
|
||||
* a ``target: true`` assistant turn (drives text-CE supervision), or
|
||||
* a ``stream: low_level`` turn (drives flow / action supervision via
|
||||
``predict_actions=True``, even when no assistant turn is targeted —
|
||||
e.g. π0.5-style ``low_level_execution`` where the action expert
|
||||
conditions on a user-only ``${subtask}`` prompt).
|
||||
"""
|
||||
"""Ensure every templated binding is known and at least one turn is a target."""
|
||||
assert self.messages is not None
|
||||
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
|
||||
|
||||
@@ -165,14 +156,8 @@ class TrainingRecipe:
|
||||
if missing:
|
||||
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
|
||||
|
||||
has_target = any(turn.target for turn in self.messages)
|
||||
has_low_level = any(turn.stream == "low_level" for turn in self.messages)
|
||||
if not (has_target or has_low_level):
|
||||
raise ValueError(
|
||||
"Message recipes must contain at least one supervised turn — "
|
||||
"either ``target: true`` (text CE) or ``stream: low_level`` "
|
||||
"(flow/action loss)."
|
||||
)
|
||||
if not any(turn.target for turn in self.messages):
|
||||
raise ValueError("Message recipes must contain at least one target turn.")
|
||||
|
||||
def _validate_blend_recipe(self) -> None:
|
||||
"""Ensure each blend component is a non-empty, weighted message recipe."""
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
# subtask_mem_vqa_speech — Hi-Robot blend + memory + spoken responses.
|
||||
#
|
||||
# Superset of subtasks_vqa.yaml. Keeps the core subtask + action + VQA
|
||||
# training, and adds two text-supervised tasks:
|
||||
#
|
||||
# high_level_subtask — predict the subtask from the task.
|
||||
# low_level_execution — flow loss with [images, subtask, state].
|
||||
# memory_update — compress progress into a memory note.
|
||||
# user_interjection_response — reply to a user interjection with a
|
||||
# spoken `say` tool call (no plan, no
|
||||
# subtask text — just the spoken reply).
|
||||
# ask_vqa_{top,wrist} — camera-grounded VQA.
|
||||
#
|
||||
# Plan is intentionally left out — memory is the only persistent
|
||||
# high-level state here, keeping the prompt short.
|
||||
#
|
||||
# Requires the dataset to carry `memory`, `interjection` and `say`-tool
|
||||
# annotations (the annotation pipeline's memory + interjection modules)
|
||||
# in addition to `subtask` and `vqa`. Sub-recipes whose `if_present`
|
||||
# bindings are missing simply don't render for that sample, so a
|
||||
# dataset without interjections still trains the rest of the blend.
|
||||
#
|
||||
# Tool-call note: the `say` tool call on the interjection-response turn
|
||||
# is flattened to a `<say>...</say>` text marker by the tokenizer step
|
||||
# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the
|
||||
# marker the runtime parses back (`_split_plan_and_say`).
|
||||
|
||||
blend:
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.30
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.55
|
||||
messages:
|
||||
# The action expert is conditioned on the SUBTASK — at inference
|
||||
# `HighLevelSubtaskFwd` generates it via the LM head and feeds it
|
||||
# here. `stream: low_level` flips `predict_actions=True` so the
|
||||
# flow loss fires; no text-CE target (subtask prediction is owned
|
||||
# by `high_level_subtask`).
|
||||
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
|
||||
|
||||
memory_update:
|
||||
# At inference, `MemoryUpdateFwd` is triggered only on
|
||||
# `subtask_change` events (sparse). Training densely with
|
||||
# `active_at` — i.e. on every frame inside a subtask interval,
|
||||
# not just the boundary frame — supervises the same
|
||||
# (prior_memory, completed_subtask) → current_memory mapping
|
||||
# against varied observations within the interval. The model
|
||||
# learns a stateless transformation; the *when* to emit lives in
|
||||
# the inference trigger, not the model. Annotations only exist
|
||||
# for ~1% of frames as boundary events, so `emitted_at` would
|
||||
# waste 99% of the blend draws (and silently leak them into a
|
||||
# task-conditioned fallback); `active_at` lifts the renderable
|
||||
# rate to ~87% on this dataset.
|
||||
weight: 0.15
|
||||
bindings:
|
||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||
current_memory: "active_at(t, style=memory)"
|
||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||
@@ -1,99 +0,0 @@
|
||||
# subtask_mem_vqa_robocasa — Hi-Robot blend tuned for RoboCasa cameras.
|
||||
#
|
||||
# Same supervision as ``subtask_mem.yaml`` (subtask + memory) plus
|
||||
# camera-grounded VQA across the three RoboCasa camera keys produced
|
||||
# by ``slurm_build_robocasa_composite_seen.py``:
|
||||
#
|
||||
# observation.images.robot0_agentview_left (left scene view)
|
||||
# observation.images.robot0_agentview_right (right scene view)
|
||||
# observation.images.robot0_eye_in_hand (wrist)
|
||||
#
|
||||
# The annotation pipeline (``examples/annotations/run_hf_job.py``) emits
|
||||
# VQA per camera, so each anchor frame produces three (user, assistant)
|
||||
# rows tagged with their source camera. Each VQA sub-recipe consumes
|
||||
# the rows for one camera via ``camera=...`` resolver bindings.
|
||||
#
|
||||
# Spatial VQA targets (bbox / point) are rewritten from JSON to
|
||||
# PaliGemma ``<locDDDD>`` tokens by ``_messages_vqa_to_loc`` —
|
||||
# ``register_paligemma_loc_tokens`` already collapses them to single
|
||||
# detection-vocab ids so the LM head learns the pretrained pointing /
|
||||
# detection prior, not a 7-piece BPE salad.
|
||||
#
|
||||
# Interjections / spoken responses are intentionally absent — the
|
||||
# annotation job runs with ``--interjections.enabled=false``.
|
||||
|
||||
blend:
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.25
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.45
|
||||
messages:
|
||||
# Action expert is conditioned on the SUBTASK; at inference the
|
||||
# high-level loop generates it via the LM head and feeds it here.
|
||||
# ``stream: low_level`` flips ``predict_actions=True`` so the flow
|
||||
# loss fires; subtask CE is owned by ``high_level_subtask``.
|
||||
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
|
||||
|
||||
memory_update:
|
||||
# Trained densely with ``active_at`` — every frame inside a subtask
|
||||
# interval — so the (prior_memory, completed_subtask) → current_memory
|
||||
# mapping is supervised against varied observations. The *when* to
|
||||
# emit lives in the inference trigger (subtask_change), not the
|
||||
# model. See ``subtask_mem.yaml`` for the long version of this note.
|
||||
weight: 0.15
|
||||
bindings:
|
||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||
current_memory: "active_at(t, style=memory)"
|
||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||
|
||||
ask_vqa_agentview_left:
|
||||
weight: 0.05
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_left)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_left)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.robot0_agentview_left}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_agentview_right:
|
||||
weight: 0.05
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_right)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_right)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.robot0_agentview_right}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_wrist:
|
||||
weight: 0.05
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_eye_in_hand)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_eye_in_hand)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.robot0_eye_in_hand}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
@@ -1,114 +0,0 @@
|
||||
# subtask_mem_vqa_speech — Hi-Robot blend + memory + spoken responses.
|
||||
#
|
||||
# Superset of subtasks_vqa.yaml. Keeps the core subtask + action + VQA
|
||||
# training, and adds two text-supervised tasks:
|
||||
#
|
||||
# high_level_subtask — predict the subtask from the task.
|
||||
# low_level_execution — flow loss with [images, subtask, state].
|
||||
# memory_update — compress progress into a memory note.
|
||||
# user_interjection_response — reply to a user interjection with a
|
||||
# spoken `say` tool call (no plan, no
|
||||
# subtask text — just the spoken reply).
|
||||
# ask_vqa_{top,wrist} — camera-grounded VQA.
|
||||
#
|
||||
# Plan is intentionally left out — memory is the only persistent
|
||||
# high-level state here, keeping the prompt short.
|
||||
#
|
||||
# Requires the dataset to carry `memory`, `interjection` and `say`-tool
|
||||
# annotations (the annotation pipeline's memory + interjection modules)
|
||||
# in addition to `subtask` and `vqa`. Sub-recipes whose `if_present`
|
||||
# bindings are missing simply don't render for that sample, so a
|
||||
# dataset without interjections still trains the rest of the blend.
|
||||
#
|
||||
# Tool-call note: the `say` tool call on the interjection-response turn
|
||||
# is flattened to a `<say>...</say>` text marker by the tokenizer step
|
||||
# (`_flatten_say_tool_calls`) so the LM head learns to emit exactly the
|
||||
# marker the runtime parses back (`_split_plan_and_say`).
|
||||
|
||||
blend:
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.25
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.40
|
||||
messages:
|
||||
# The action expert is conditioned on the SUBTASK — at inference
|
||||
# `HighLevelSubtaskFwd` generates it via the LM head and feeds it
|
||||
# here. `stream: low_level` flips `predict_actions=True` so the
|
||||
# flow loss fires; no text-CE target (subtask prediction is owned
|
||||
# by `high_level_subtask`).
|
||||
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
|
||||
|
||||
memory_update:
|
||||
# At inference, `MemoryUpdateFwd` is triggered only on
|
||||
# `subtask_change` events (sparse). Training densely with
|
||||
# `active_at` — i.e. on every frame inside a subtask interval,
|
||||
# not just the boundary frame — supervises the same
|
||||
# (prior_memory, completed_subtask) → current_memory mapping
|
||||
# against varied observations within the interval. The model
|
||||
# learns a stateless transformation; the *when* to emit lives in
|
||||
# the inference trigger, not the model. Annotations only exist
|
||||
# for ~1% of frames as boundary events, so `emitted_at` would
|
||||
# waste 99% of the blend draws (and silently leak them into the
|
||||
# task-conditioned fallback); `active_at` lifts the renderable
|
||||
# rate to ~87% on Hi-Robot-style datasets.
|
||||
weight: 0.10
|
||||
bindings:
|
||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||
current_memory: "active_at(t, style=memory)"
|
||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||
|
||||
user_interjection_response:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
interjection: "emitted_at(t, style=interjection)"
|
||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
|
||||
# Spoken reply only: the assistant turn carries no text content,
|
||||
# just a `say` tool call (`tool_calls_from: speech`). The chat
|
||||
# tokenizer flattens it to a `<say>...</say>` marker, so the
|
||||
# supervised target trains the model to respond to an
|
||||
# interjection with a spoken acknowledgement.
|
||||
- {role: assistant, stream: high_level, target: true, if_present: speech, tool_calls_from: speech}
|
||||
|
||||
# VQA is view-dependent — each camera gets its own sub-recipe so the
|
||||
# resolver disambiguates via `camera=...`. Camera keys match
|
||||
# subtasks_vqa.yaml (`front` + `wrist`); adjust to your dataset.
|
||||
ask_vqa_top:
|
||||
weight: 0.075
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.front}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_wrist:
|
||||
weight: 0.075
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.wrist}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
@@ -1,61 +0,0 @@
|
||||
# subtasks_vqa — Hi-Robot blend for PI052 (PaliGemma backbone).
|
||||
#
|
||||
# Trains two things only: subtasks and VQA. Plan and memory are
|
||||
# intentionally left out — keeps the prompt short and the training
|
||||
# surface small. The fuller blend with memory + spoken replies is
|
||||
# ``subtask_mem_vqa_speech.yaml``.
|
||||
#
|
||||
# high_level_subtask — predict the subtask from the task.
|
||||
# low_level_execution — flow loss with [images, subtask, state].
|
||||
# ask_vqa_{top,wrist} — camera-grounded VQA.
|
||||
#
|
||||
# PI052's text tokenizer renders these messages as plain
|
||||
# ``Role: content`` text (PaliGemma is not chat-pretrained).
|
||||
|
||||
blend:
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.40
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.40
|
||||
messages:
|
||||
# The action expert is conditioned on the SUBTASK — at inference
|
||||
# the high-level loop (``HighLevelSubtaskFwd``) generates the
|
||||
# subtask via the LM head and feeds it here. The action expert's
|
||||
# prefix is [images, subtask, state]. ``stream: low_level`` flips
|
||||
# ``predict_actions=True`` so the flow loss fires; no text-CE
|
||||
# target here (subtask prediction is owned by
|
||||
# ``high_level_subtask``).
|
||||
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
|
||||
|
||||
ask_vqa_top:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.front)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.front)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.front}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_wrist:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.wrist}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
@@ -30,7 +30,7 @@ from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||
|
||||
from . import parser
|
||||
from .default import DatasetConfig, EMAConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .rewards import RewardModelConfig
|
||||
|
||||
@@ -111,20 +111,9 @@ class TrainPipelineConfig(HubMixin):
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
ema: EMAConfig = field(default_factory=EMAConfig)
|
||||
peft: PeftConfig | None = None
|
||||
|
||||
# VQA oversampling. When set (a fraction in (0, 1)), the training
|
||||
# dataloader uses a WeightedEpisodeAwareSampler that draws frames
|
||||
# carrying a `vqa` language annotation often enough that they make
|
||||
# up roughly this fraction of the training stream. VQA annotations
|
||||
# are typically sparse, so without this they are underrepresented.
|
||||
# `None` (default) keeps uniform episode-aware sampling.
|
||||
vqa_target_fraction: float | None = None
|
||||
|
||||
# Sample weighting configuration (e.g., for RA-BC training). Old
|
||||
# inline ``use_rabc`` / ``rabc_*`` params are migrated to this
|
||||
# field by ``_migrate_legacy_rabc_keys`` above.
|
||||
# Sample weighting configuration (e.g., for RA-BC training)
|
||||
sample_weighting: SampleWeightingConfig | None = None
|
||||
|
||||
# Rename map for the observation to override the image and state keys
|
||||
|
||||
@@ -35,6 +35,7 @@ from .dataset_tools import (
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from .factory import make_dataset, resolve_delta_timestamps
|
||||
from .image_writer import safe_stop_image_writer
|
||||
from .io_utils import load_episodes, write_stats
|
||||
from .language import (
|
||||
@@ -49,24 +50,11 @@ from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav
|
||||
from .sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
|
||||
from .sampler import EpisodeAwareSampler
|
||||
from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
from .video_utils import VideoEncodingManager
|
||||
|
||||
|
||||
def make_dataset(*args, **kwargs):
|
||||
from .factory import make_dataset as _make_dataset
|
||||
|
||||
return _make_dataset(*args, **kwargs)
|
||||
|
||||
|
||||
def resolve_delta_timestamps(*args, **kwargs):
|
||||
from .factory import resolve_delta_timestamps as _resolve_delta_timestamps
|
||||
|
||||
return _resolve_delta_timestamps(*args, **kwargs)
|
||||
|
||||
|
||||
# NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.)
|
||||
# and legacy migration constants are intentionally NOT re-exported here.
|
||||
# Import directly: ``from lerobot.datasets.io_utils import ...``
|
||||
@@ -77,7 +65,6 @@ __all__ = [
|
||||
"DEFAULT_QUANTILES",
|
||||
"EVENT_ONLY_STYLES",
|
||||
"EpisodeAwareSampler",
|
||||
"WeightedEpisodeAwareSampler",
|
||||
"LANGUAGE_EVENTS",
|
||||
"LANGUAGE_PERSISTENT",
|
||||
"LeRobotDataset",
|
||||
|
||||
@@ -126,53 +126,10 @@ class DatasetReader:
|
||||
def _load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
features = get_hf_features_from_features(self._meta.features)
|
||||
# Datasets annotated with the PR1 language columns may have been
|
||||
# written without registering those columns in ``meta/info.json``
|
||||
# (e.g. they predate ``CODEBASE_VERSION="v3.1"`` and were
|
||||
# back-filled by ``lerobot-annotate``). Probe a single parquet
|
||||
# shard and graft the column features on so the strict
|
||||
# ``Dataset.from_parquet`` cast doesn't fail with
|
||||
# ``column names don't match``.
|
||||
features = self._extend_features_with_language_columns(features)
|
||||
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
def _extend_features_with_language_columns(
|
||||
self, features: datasets.Features
|
||||
) -> datasets.Features:
|
||||
"""Add ``language_persistent`` / ``language_events`` to ``features``
|
||||
when the underlying parquet shards declare them but the metadata
|
||||
doesn't. No-op when neither column is present or both are
|
||||
already registered.
|
||||
"""
|
||||
# Find any one parquet to peek at; bail if there are none yet
|
||||
# (the dataset will fail later for an unrelated reason and we
|
||||
# want that error to surface as-is).
|
||||
try:
|
||||
sample = next((self.root / "data").glob("*/*.parquet"))
|
||||
except StopIteration:
|
||||
return features
|
||||
|
||||
from pyarrow import parquet as _pq # noqa: PLC0415
|
||||
|
||||
schema_names = set(_pq.read_schema(sample).names)
|
||||
from .language import ( # noqa: PLC0415
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
language_events_column_feature,
|
||||
language_persistent_column_feature,
|
||||
)
|
||||
|
||||
extra: dict[str, object] = {}
|
||||
if LANGUAGE_PERSISTENT in schema_names and LANGUAGE_PERSISTENT not in features:
|
||||
extra[LANGUAGE_PERSISTENT] = language_persistent_column_feature()
|
||||
if LANGUAGE_EVENTS in schema_names and LANGUAGE_EVENTS not in features:
|
||||
extra[LANGUAGE_EVENTS] = language_events_column_feature()
|
||||
if not extra:
|
||||
return features
|
||||
return datasets.Features({**features, **extra})
|
||||
|
||||
def _check_cached_episodes_sufficient(self) -> bool:
|
||||
"""Check if the cached dataset contains all requested episodes and their video files."""
|
||||
if self.hf_dataset is None or len(self.hf_dataset) == 0:
|
||||
|
||||
@@ -70,22 +70,8 @@ def _json_arrow_type() -> pa.DataType:
|
||||
|
||||
|
||||
def _json_feature() -> object:
|
||||
"""Return the HF feature used for tool-call payloads.
|
||||
|
||||
Older ``datasets`` versions do not expose ``datasets.Json``. The
|
||||
annotation pipeline currently emits the canonical ``say`` tool call
|
||||
shape, so use that explicit struct instead of falling back to a string
|
||||
that cannot cast structured parquet values.
|
||||
"""
|
||||
if hasattr(datasets, "Json"):
|
||||
return datasets.Json()
|
||||
return {
|
||||
"type": datasets.Value("string"),
|
||||
"function": {
|
||||
"name": datasets.Value("string"),
|
||||
"arguments": {"text": datasets.Value("string")},
|
||||
},
|
||||
}
|
||||
"""Return the HF ``datasets`` JSON feature, falling back to a string value."""
|
||||
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
|
||||
|
||||
|
||||
def language_persistent_row_arrow_type() -> pa.StructType:
|
||||
|
||||
@@ -170,29 +170,6 @@ def render_sample(
|
||||
"""
|
||||
persistent_rows = _normalize_rows(persistent or [])
|
||||
event_rows = _normalize_rows(events or [])
|
||||
|
||||
# VQA-priority routing. A ``vqa`` annotation is sparse and
|
||||
# view-dependent; the plain weighted blend would (a) waste a draw
|
||||
# whenever it picks an ``ask_vqa*`` sub-recipe for a frame that has
|
||||
# no VQA, and (b) silently drop a VQA-annotated frame whenever it
|
||||
# picks a non-VQA sub-recipe. So: if the blend has ``ask_vqa*``
|
||||
# sub-recipes and *this* frame carries one of their VQA bindings,
|
||||
# render VQA here regardless of the weighted draw. That makes VQA's
|
||||
# recipe-side training share equal the VQA-annotation density (the
|
||||
# maximum reachable without a dataset-level oversampling sampler).
|
||||
if recipe.blend is not None:
|
||||
vqa_rendered = _render_vqa_if_present(
|
||||
recipe,
|
||||
persistent=persistent_rows,
|
||||
events=event_rows,
|
||||
t=t,
|
||||
sample_idx=sample_idx,
|
||||
task=task,
|
||||
dataset_ctx=dataset_ctx,
|
||||
)
|
||||
if vqa_rendered is not None:
|
||||
return vqa_rendered
|
||||
|
||||
selected_recipe = _select_recipe(recipe, sample_idx)
|
||||
bindings = _resolve_bindings(
|
||||
selected_recipe,
|
||||
@@ -206,59 +183,6 @@ def render_sample(
|
||||
return _render_message_recipe(selected_recipe, bindings)
|
||||
|
||||
|
||||
def _render_vqa_if_present(
|
||||
recipe: TrainingRecipe,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
) -> RenderedMessages | None:
|
||||
"""Render an ``ask_vqa*`` sub-recipe iff this frame carries a VQA
|
||||
annotation; otherwise return ``None`` so the caller falls back to the
|
||||
normal weighted blend.
|
||||
|
||||
When several VQA sub-recipes resolve (e.g. a frame annotated for more
|
||||
than one camera), one is chosen deterministically by relative weight.
|
||||
"""
|
||||
assert recipe.blend is not None
|
||||
renderable: list[tuple[float, RenderedMessages]] = []
|
||||
for name, component in recipe.blend.items():
|
||||
if not name.startswith("ask_vqa"):
|
||||
continue
|
||||
bindings = _resolve_bindings(
|
||||
component,
|
||||
persistent=persistent,
|
||||
events=events,
|
||||
t=t,
|
||||
sample_idx=sample_idx,
|
||||
task=task,
|
||||
dataset_ctx=dataset_ctx,
|
||||
)
|
||||
rendered = _render_message_recipe(component, bindings)
|
||||
if rendered is not None:
|
||||
renderable.append((float(component.weight or 0.0), rendered))
|
||||
|
||||
if not renderable:
|
||||
return None
|
||||
if len(renderable) == 1:
|
||||
return renderable[0][1]
|
||||
|
||||
# Multiple cameras have a VQA for this frame — deterministic pick by
|
||||
# relative weight (fall back to a uniform draw if all weights are 0).
|
||||
total = sum(w for w, _ in renderable) or float(len(renderable))
|
||||
digest = hashlib.blake2b(f"vqa:{sample_idx}".encode(), digest_size=8).digest()
|
||||
draw = int.from_bytes(digest, "big") / 2**64 * total
|
||||
cumulative = 0.0
|
||||
for w, rendered in renderable:
|
||||
cumulative += w or (total / len(renderable))
|
||||
if draw < cumulative:
|
||||
return rendered
|
||||
return renderable[-1][1]
|
||||
|
||||
|
||||
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
|
||||
"""Pick a deterministic blend component for ``sample_idx`` (or return ``recipe``)."""
|
||||
if recipe.blend is None:
|
||||
@@ -422,15 +346,7 @@ def _render_message_recipe(
|
||||
if turn.target:
|
||||
target_indices.append(message_idx)
|
||||
|
||||
# A render is meaningful if it supervises *something*: either a
|
||||
# text-CE target turn, or a ``low_level`` stream turn (flow / action
|
||||
# supervision — e.g. the flow-only ``low_level_execution`` recipe,
|
||||
# ``user(${subtask})`` with ``stream: low_level`` and no target).
|
||||
# Without this, a flow-only recipe renders to ``None`` every time
|
||||
# the blend draws it → ``predict_actions`` is never True → the
|
||||
# action expert never receives a flow loss.
|
||||
has_low_level = any(stream == "low_level" for stream in streams)
|
||||
if not target_indices and not has_low_level:
|
||||
if not target_indices:
|
||||
return None
|
||||
|
||||
rendered = {
|
||||
@@ -487,10 +403,8 @@ def _validate_rendered(rendered: RenderedMessages) -> None:
|
||||
|
||||
if len(streams) != len(messages):
|
||||
raise ValueError("message_streams must be aligned with messages.")
|
||||
# Valid iff it supervises something: a text-CE target turn OR a
|
||||
# ``low_level`` stream turn (flow / action supervision).
|
||||
if not target_indices and not any(s == "low_level" for s in streams):
|
||||
raise ValueError("Rendered samples must contain a target message or a low_level-stream message.")
|
||||
if not target_indices:
|
||||
raise ValueError("Rendered samples must contain at least one target message.")
|
||||
for idx in target_indices:
|
||||
if idx < 0 or idx >= len(messages):
|
||||
raise ValueError(f"Target message index {idx} is out of bounds.")
|
||||
|
||||
@@ -84,66 +84,3 @@ class EpisodeAwareSampler:
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.indices)
|
||||
|
||||
|
||||
class WeightedEpisodeAwareSampler(EpisodeAwareSampler):
|
||||
"""``EpisodeAwareSampler`` that draws frames *with replacement* in
|
||||
proportion to per-frame weights.
|
||||
|
||||
Used to oversample frames carrying a sparse annotation (e.g. a VQA
|
||||
question) so the policy sees them more often than their natural
|
||||
dataset density. One epoch still yields ``len(self.indices)``
|
||||
samples — the weights only change the *composition* of the stream,
|
||||
not its length. Each epoch re-draws, so the oversampled subset
|
||||
varies run to run.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_from_indices: list[int],
|
||||
dataset_to_indices: list[int],
|
||||
frame_weights,
|
||||
*,
|
||||
episode_indices_to_use: list | None = None,
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dataset_from_indices: Episode start indices (see ``EpisodeAwareSampler``).
|
||||
dataset_to_indices: Episode end indices.
|
||||
frame_weights: 1-D sequence/tensor of non-negative weights, one per
|
||||
dataset frame (length == total dataset frames). Higher weight ⇒
|
||||
that frame is sampled more often.
|
||||
episode_indices_to_use / drop_n_first_frames / drop_n_last_frames:
|
||||
Same meaning as ``EpisodeAwareSampler`` — the episode-boundary
|
||||
frame filtering is applied first, then weighting is restricted
|
||||
to the surviving frames.
|
||||
"""
|
||||
super().__init__(
|
||||
dataset_from_indices,
|
||||
dataset_to_indices,
|
||||
episode_indices_to_use=episode_indices_to_use,
|
||||
drop_n_first_frames=drop_n_first_frames,
|
||||
drop_n_last_frames=drop_n_last_frames,
|
||||
shuffle=False,
|
||||
)
|
||||
weights = torch.as_tensor(frame_weights, dtype=torch.double).flatten()
|
||||
idx = torch.tensor(self.indices, dtype=torch.long)
|
||||
if weights.numel() <= int(idx.max()):
|
||||
raise ValueError(
|
||||
f"frame_weights has {weights.numel()} entries but the sampler "
|
||||
f"references frame index {int(idx.max())}."
|
||||
)
|
||||
selected = weights[idx]
|
||||
if not torch.isfinite(selected).all() or bool((selected < 0).any()):
|
||||
raise ValueError("frame_weights must be finite and non-negative.")
|
||||
if float(selected.sum()) <= 0.0:
|
||||
# All surviving frames have zero weight — fall back to uniform.
|
||||
selected = torch.ones_like(selected)
|
||||
self._weights = selected
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
picks = torch.multinomial(self._weights, num_samples=len(self.indices), replacement=True)
|
||||
for i in picks.tolist():
|
||||
yield self.indices[i]
|
||||
|
||||
@@ -366,24 +366,17 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
||||
hub_versions = get_repo_versions(repo_id)
|
||||
|
||||
if not hub_versions:
|
||||
msg = (
|
||||
f"Repo {repo_id!r} has no codebase-version tags. The dataset "
|
||||
f"either doesn't exist on the Hub yet, or it was uploaded "
|
||||
f"without a ``v3.x``-style tag. To tag an existing dataset run:\n"
|
||||
f" from huggingface_hub import HfApi\n"
|
||||
f" HfApi().create_tag({repo_id!r}, tag='v3.0', repo_type='dataset', exist_ok=True)"
|
||||
raise RevisionNotFoundError(
|
||||
f"""Your dataset must be tagged with a codebase version.
|
||||
Assuming _version_ is the codebase_version value in the info.json, you can run this:
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
|
||||
```
|
||||
"""
|
||||
)
|
||||
# ``RevisionNotFoundError`` extends ``HfHubHTTPError`` whose
|
||||
# ``__init__`` indexes ``response.headers`` unconditionally on
|
||||
# current ``huggingface_hub`` versions. Constructing it without
|
||||
# a real ``Response`` object crashes with either
|
||||
# ``TypeError: missing 1 required keyword-only argument`` (old
|
||||
# builds) or ``AttributeError: 'NoneType' object has no attribute
|
||||
# 'headers'`` (new builds). Skip that path entirely — this isn't
|
||||
# really an HTTP error, it's a configuration issue — and raise a
|
||||
# plain ``RuntimeError`` so the message actually reaches the
|
||||
# caller.
|
||||
raise RuntimeError(msg)
|
||||
|
||||
if target_version in hub_versions:
|
||||
return f"v{target_version}"
|
||||
|
||||
@@ -104,8 +104,6 @@ class AdamWConfig(OptimizerConfig):
|
||||
eps: float = 1e-8
|
||||
weight_decay: float = 1e-2
|
||||
grad_clip_norm: float = 10.0
|
||||
foreach: bool | None = None
|
||||
fused: bool | None = None
|
||||
|
||||
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
|
||||
@@ -24,7 +24,6 @@ from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as M
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .pi052.configuration_pi052 import PI052Config as PI052Config
|
||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
@@ -48,7 +47,6 @@ __all__ = [
|
||||
"PI0Config",
|
||||
"PI0FastConfig",
|
||||
"PI05Config",
|
||||
"PI052Config",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
|
||||
@@ -61,79 +61,6 @@ from .wall_x.configuration_wall_x import WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig
|
||||
|
||||
|
||||
def _restore_pi052_pretrained_state(
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
pretrained_path: str,
|
||||
) -> None:
|
||||
"""Transplant saved stateful blobs from a pi052 checkpoint into fresh pipelines.
|
||||
|
||||
pi052's preprocessor includes steps whose constructor args don't
|
||||
JSON-roundtrip (``RenderMessagesStep.recipe`` is a Python object,
|
||||
``ActionTokenizerProcessorStep.action_tokenizer_name`` is a
|
||||
fitted-tokenizer path that may not exist at eval time). We rebuild
|
||||
those pipelines fresh from ``config.recipe_path`` and then walk
|
||||
over the saved ``policy_{pre,post}processor.json`` files to find
|
||||
each step's ``state_file`` reference and load the bytes back into
|
||||
the corresponding fresh step. Today that's only the
|
||||
NormalizerProcessorStep / UnnormalizerProcessorStep (the action /
|
||||
state quantile stats), but the loop is generic so any future
|
||||
stateful step picks up its blob automatically.
|
||||
|
||||
Pairing is by ``registry_name`` AND position so a benign reorder
|
||||
on the saved side surfaces a warning rather than silently feeding
|
||||
the wrong tensors into the wrong step.
|
||||
"""
|
||||
import json # noqa: PLC0415
|
||||
import logging # noqa: PLC0415
|
||||
from pathlib import Path # noqa: PLC0415
|
||||
|
||||
from safetensors.torch import load_file # noqa: PLC0415
|
||||
|
||||
base = Path(pretrained_path)
|
||||
if not base.exists():
|
||||
return
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
for pipeline, config_filename in [
|
||||
(preprocessor, f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"),
|
||||
(postprocessor, f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"),
|
||||
]:
|
||||
config_path = base / config_filename
|
||||
if not config_path.exists():
|
||||
continue
|
||||
saved = json.loads(config_path.read_text())
|
||||
|
||||
for idx, (saved_step, fresh_step) in enumerate(
|
||||
zip(saved.get("steps", []), pipeline.steps, strict=False)
|
||||
):
|
||||
state_file = saved_step.get("state_file")
|
||||
if not state_file:
|
||||
continue
|
||||
saved_name = saved_step.get("registry_name")
|
||||
fresh_name = getattr(type(fresh_step), "_registry_name", None)
|
||||
if saved_name and fresh_name and saved_name != fresh_name:
|
||||
log.warning(
|
||||
"PI052 state restore: %s step %d registry name mismatch "
|
||||
"(saved=%s, fresh=%s); skipping %s",
|
||||
config_filename, idx, saved_name, fresh_name, state_file,
|
||||
)
|
||||
continue
|
||||
state_path = base / state_file
|
||||
if not state_path.exists():
|
||||
log.warning(
|
||||
"PI052 state restore: %s missing at %s; %s left at fresh init",
|
||||
state_file, base, fresh_name,
|
||||
)
|
||||
continue
|
||||
fresh_step.load_state_dict(load_file(str(state_path)))
|
||||
log.info(
|
||||
"PI052 state restore: loaded %s into %s (step %d)",
|
||||
state_file, fresh_name, idx,
|
||||
)
|
||||
|
||||
|
||||
def _reconnect_relative_absolute_steps(
|
||||
preprocessor: PolicyProcessorPipeline, postprocessor: PolicyProcessorPipeline
|
||||
) -> None:
|
||||
@@ -200,10 +127,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
return PI05Policy
|
||||
elif name == "pi052":
|
||||
from .pi052.modeling_pi052 import PI052Policy
|
||||
|
||||
return PI052Policy
|
||||
elif name == "gaussian_actor":
|
||||
from .gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
|
||||
@@ -244,8 +167,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05",
|
||||
"pi052", "gaussian_actor", "smolvla", "wall_x".
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
|
||||
"smolvla", "wall_x".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -268,10 +191,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi05":
|
||||
return PI05Config(**kwargs)
|
||||
elif policy_type == "pi052":
|
||||
from .pi052.configuration_pi052 import PI052Config
|
||||
|
||||
return PI052Config(**kwargs)
|
||||
elif policy_type == "gaussian_actor":
|
||||
return GaussianActorConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
@@ -312,12 +231,6 @@ class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
preprocessor_overrides: dict[str, Any] | None
|
||||
postprocessor_overrides: dict[str, Any] | None
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
|
||||
# Optional: HF Hub repo id of the dataset the policy is being
|
||||
# trained on. Used by policies that auto-fit pieces of their
|
||||
# preprocessing (e.g. pi052's FAST action tokenizer per
|
||||
# Pertsch et al. 2025 [64], π0.5 §III.C). When omitted, those
|
||||
# policies fall back to their universal pre-fitted tokenizers.
|
||||
dataset_repo_id: str | None
|
||||
|
||||
|
||||
def make_pre_post_processors(
|
||||
@@ -350,29 +263,6 @@ def make_pre_post_processors(
|
||||
NotImplementedError: If a processor factory is not implemented for the given
|
||||
policy configuration type.
|
||||
"""
|
||||
if pretrained_path and getattr(policy_cfg, "type", None) == "pi052":
|
||||
# pi052 pipelines don't roundtrip through the saved
|
||||
# ``policy_preprocessor.json``: ``RenderMessagesStep`` holds a
|
||||
# Python ``TrainingRecipe`` (not JSON-serializable; saved as
|
||||
# ``{}``) and ``ActionTokenizerProcessorStep`` saves a host-only
|
||||
# FAST tokenizer path. Generic ``from_pretrained`` then dies
|
||||
# with ``RenderMessagesStep.__init__() missing 1 required
|
||||
# positional argument: 'recipe'`` (job 22164494).
|
||||
#
|
||||
# Mirror ``lerobot_pi052_runtime``'s bootstrap: build pipelines
|
||||
# fresh from ``config.recipe_path`` and transplant the saved
|
||||
# stateful blobs (normalizer stats) from the checkpoint dir.
|
||||
from .pi052.processor_pi052 import make_pi052_pre_post_processors
|
||||
|
||||
preprocessor, postprocessor = make_pi052_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_repo_id=kwargs.get("dataset_repo_id"),
|
||||
)
|
||||
_restore_pi052_pretrained_state(preprocessor, postprocessor, pretrained_path)
|
||||
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
if pretrained_path:
|
||||
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
|
||||
if isinstance(policy_cfg, GrootConfig):
|
||||
@@ -467,22 +357,6 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "pi052":
|
||||
# NOTE: PI052Config subclasses PI05Config, so this branch MUST
|
||||
# come before the PI05Config isinstance check below (otherwise
|
||||
# pi052 would silently pick up π0.5's processor).
|
||||
from .pi052.processor_pi052 import make_pi052_pre_post_processors
|
||||
|
||||
processors = make_pi052_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
# ``dataset_repo_id`` flows in via kwargs when FAST CE is
|
||||
# enabled — the train loop sets it from ``--dataset.repo_id``.
|
||||
# When ``None``, ``make_pi052_pre_post_processors`` skips
|
||||
# the auto-fit and uses the universal tokenizer.
|
||||
dataset_repo_id=kwargs.get("dataset_repo_id"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI05Config):
|
||||
from .pi05.processor_pi05 import make_pi05_pre_post_processors
|
||||
|
||||
|
||||
@@ -178,6 +178,7 @@ N_COLOR_CHANNELS = 3
|
||||
|
||||
|
||||
# config
|
||||
@strict
|
||||
class GR00TN15Config(PretrainedConfig):
|
||||
model_type = "gr00t_n1_5"
|
||||
|
||||
|
||||
@@ -93,21 +93,6 @@ class PI05Config(PreTrainedConfig):
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.01
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
optimizer_foreach: bool | None = False
|
||||
optimizer_fused: bool | None = True
|
||||
|
||||
# LM-head LR multiplier. The PaliGemma `lm_head` projection (and its
|
||||
# tied `embed_tokens`) is the surface the LM head's first-token
|
||||
# distribution depends on. With ``knowledge_insulation`` blocking
|
||||
# action→VLM gradients, the LM head only sees gradients on text-CE
|
||||
# samples — which can be a small fraction of the mix (e.g. ~45% in
|
||||
# ``subtask_mem.yaml``). Under aggressive cosine LR decay the head's
|
||||
# first-token distribution can drift back toward PaliGemma's
|
||||
# pretrained ``<loc>`` detection prior, despite teacher-forced CE
|
||||
# staying near zero. Boosting just the LM-head LR (e.g. 5x) keeps
|
||||
# the head pinned to fine-tuning targets without perturbing the
|
||||
# backbone / vision tower / action expert. Default 1.0 = no change.
|
||||
lm_head_lr_scale: float = 1.0
|
||||
|
||||
# Scheduler settings: see openpi `CosineDecaySchedule`
|
||||
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||
@@ -167,8 +152,6 @@ class PI05Config(PreTrainedConfig):
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
foreach=self.optimizer_foreach,
|
||||
fused=self.optimizer_fused,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
DynamicCache = None
|
||||
modeling_gemma = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
@@ -138,6 +139,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
def clone_past_key_values(past_key_values):
|
||||
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
|
||||
return DynamicCache(
|
||||
tuple(
|
||||
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
@@ -223,53 +233,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
return padded_images
|
||||
|
||||
|
||||
def sdpa_attention_forward(
|
||||
module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
"""Drop-in for ``modeling_gemma.eager_attention_forward`` using
|
||||
``torch.nn.functional.scaled_dot_product_attention``.
|
||||
|
||||
PyTorch SDPA picks the memory-efficient kernel for arbitrary additive
|
||||
bias masks (the FA backend only accepts causal/sliding-window). On
|
||||
H100 that is ~1.3-1.7x faster and uses ~30-40% less attention memory
|
||||
than the eager softmax(QK^T)+matmul path. Mirrors eager's signature
|
||||
and output shape (``(B, Lq, H, D)``) so call sites are unchanged.
|
||||
"""
|
||||
n_rep = module.num_key_value_groups
|
||||
if n_rep > 1:
|
||||
key = key.repeat_interleave(n_rep, dim=1)
|
||||
value = value.repeat_interleave(n_rep, dim=1)
|
||||
if attention_mask is not None and attention_mask.dtype != query.dtype:
|
||||
attention_mask = attention_mask.to(dtype=query.dtype)
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=dropout if module.training else 0.0,
|
||||
is_causal=False,
|
||||
scale=scaling,
|
||||
)
|
||||
return attn_output.transpose(1, 2).contiguous(), None
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
@@ -291,14 +262,16 @@ def compute_layer_complete(
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
cos, sin = rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
att_output, _ = sdpa_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
paligemma_layer = layers[0]
|
||||
scaling = paligemma_layer.self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma_layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -306,13 +279,13 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma_layer.self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
@@ -444,7 +417,6 @@ class PaliGemmaWithExpertModel(
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"lm_head",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
@@ -477,13 +449,13 @@ class PaliGemmaWithExpertModel(
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
features = image_outputs.pooler_output
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -521,8 +493,9 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
paligemma_layers = self.paligemma.model.language_model.layers
|
||||
gemma_expert_layers = self.gemma_expert.model.layers
|
||||
rotary_emb = self.paligemma.model.language_model.rotary_emb
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
@@ -532,36 +505,39 @@ class PaliGemmaWithExpertModel(
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
# final norm
|
||||
final_norms = (
|
||||
self.paligemma.model.language_model.norm,
|
||||
self.gemma_expert.model.norm,
|
||||
)
|
||||
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -653,13 +629,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def _prepare_attention_masks_4d(self, att_2d_masks, dtype=None):
|
||||
def _prepare_attention_masks_4d(self, att_2d_masks):
|
||||
"""Helper method to prepare 4D attention masks for transformer."""
|
||||
att_2d_masks_4d = att_2d_masks[:, None, :, :]
|
||||
result = torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
||||
if dtype is not None:
|
||||
result = result.to(dtype=dtype)
|
||||
return result
|
||||
return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
return torch.normal(
|
||||
@@ -701,8 +674,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Process language tokens
|
||||
def lang_embed_func(tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
return lang_emb * math.sqrt(lang_emb_dim)
|
||||
return lang_emb
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
||||
embs.append(lang_emb)
|
||||
@@ -789,22 +761,21 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
|
||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype)
|
||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
|
||||
|
||||
# Selective AC: rely on the per-layer checkpoint inside
|
||||
# ``PaliGemmaWithExpertModel.forward`` (which wraps each
|
||||
# transformer block individually). The previous outer
|
||||
# ``_apply_checkpoint(forward_func, ...)`` doubled up — it
|
||||
# re-ran the full backbone forward during backward *and* each
|
||||
# block's own checkpoint re-ran during that recompute. Pure
|
||||
# waste with SDPA, which already streams attention activations.
|
||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, suffix_embs],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, suffix_embs],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
return suffix_out
|
||||
|
||||
suffix_out = self._apply_checkpoint(
|
||||
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
||||
)
|
||||
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
@@ -848,9 +819,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(
|
||||
prefix_att_2d_masks, dtype=prefix_embs.dtype
|
||||
)
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
@@ -920,12 +889,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
||||
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(
|
||||
full_att_2d_masks, dtype=suffix_embs.dtype
|
||||
)
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
past_key_values = clone_past_key_values(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
@@ -1060,16 +1027,6 @@ class PI05Policy(PreTrainedPolicy):
|
||||
if remap_count > 0:
|
||||
print(f"Remapped {remap_count} state dict keys")
|
||||
|
||||
lm_head_key = "model.paligemma_with_expert.paligemma.lm_head.weight"
|
||||
embed_tokens_key = (
|
||||
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
)
|
||||
if lm_head_key not in remapped_state_dict and embed_tokens_key in remapped_state_dict:
|
||||
remapped_state_dict[lm_head_key] = remapped_state_dict[embed_tokens_key].clone().float()
|
||||
print("Initialized PaliGemma lm_head from language token embeddings")
|
||||
elif lm_head_key in remapped_state_dict:
|
||||
remapped_state_dict[lm_head_key] = remapped_state_dict[lm_head_key].float()
|
||||
|
||||
# Load the remapped state dict into the model
|
||||
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
|
||||
|
||||
@@ -1163,62 +1120,8 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
return fixed_state_dict
|
||||
|
||||
def get_optim_params(self):
|
||||
"""Return policy parameters, optionally split into LR-scaled groups.
|
||||
|
||||
When ``config.lm_head_lr_scale != 1.0``, the PaliGemma ``lm_head``
|
||||
and its tied ``embed_tokens`` are placed in their own param
|
||||
group with ``lr = base_lr * lm_head_lr_scale``. The cosine
|
||||
scheduler multiplies both groups by the same lambda each step,
|
||||
so the ratio is preserved across decay. Default ``1.0`` =
|
||||
return ``self.parameters()`` (back-compat with existing checkpoints
|
||||
and configs).
|
||||
"""
|
||||
scale = float(getattr(self.config, "lm_head_lr_scale", 1.0))
|
||||
if scale == 1.0:
|
||||
return self.parameters()
|
||||
head_params: list[torch.nn.Parameter] = []
|
||||
other_params: list[torch.nn.Parameter] = []
|
||||
# Both ``lm_head.weight`` and the tied ``embed_tokens.weight`` —
|
||||
# boosting only the projection without the embedding pulls them
|
||||
# apart and breaks the tie that PaliGemma was pre-trained with.
|
||||
head_substrings = (
|
||||
"paligemma_with_expert.paligemma.lm_head.",
|
||||
"paligemma_with_expert.paligemma.model.language_model.embed_tokens.",
|
||||
)
|
||||
for name, p in self.named_parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
if any(s in name for s in head_substrings):
|
||||
head_params.append(p)
|
||||
else:
|
||||
other_params.append(p)
|
||||
base_lr = float(self.config.optimizer_lr)
|
||||
groups: list[dict[str, object]] = []
|
||||
if other_params:
|
||||
groups.append({"params": other_params, "lr": base_lr, "name": "policy"})
|
||||
if head_params:
|
||||
groups.append(
|
||||
{"params": head_params, "lr": base_lr * scale, "name": "lm_head"}
|
||||
)
|
||||
# Sanity: head_substrings must match at least one parameter, otherwise
|
||||
# the scale silently does nothing — surface that fast.
|
||||
if not head_params:
|
||||
raise RuntimeError(
|
||||
"lm_head_lr_scale != 1.0 but no parameters matched the LM-head "
|
||||
"name patterns: "
|
||||
f"{head_substrings!r}. Did the underlying PaliGemma module rename?"
|
||||
)
|
||||
logging.info(
|
||||
"PI05Policy: LM-head LR scale = %.3g (base=%.3g, head=%.3g) over "
|
||||
"%d head params + %d other params",
|
||||
scale,
|
||||
base_lr,
|
||||
base_lr * scale,
|
||||
len(head_params),
|
||||
len(other_params),
|
||||
)
|
||||
return groups
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
def reset(self):
|
||||
"""Reset internal state - called when environment resets."""
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""π0.5 v2 — full reproduction of the π0.5 paper's hierarchical
|
||||
inference recipe on lerobot.
|
||||
|
||||
Extends :class:`lerobot.policies.pi05.PI05Policy` with:
|
||||
|
||||
* recipe-driven training (PR 1's :class:`RenderMessagesStep`),
|
||||
* PaliGemma ``lm_head`` cross-entropy on supervised subtask spans
|
||||
(the "high-level subtask prediction" of the paper, §IV.D),
|
||||
* AR text generation at inference (:meth:`PI052Policy.select_message`),
|
||||
* per-component prompt dropout (Pi 0.7 §V.E) for regularising the
|
||||
text head against missing context at inference.
|
||||
|
||||
See ``src/lerobot/configs/recipes/subtasks_vqa.yaml`` for the
|
||||
canonical training recipe and
|
||||
``examples/training/pi052_hirobot.slurm`` for the launcher.
|
||||
"""
|
||||
|
||||
from .configuration_pi052 import PI052Config
|
||||
from .modeling_pi052 import PI052Policy
|
||||
from .processor_pi052 import make_pi052_pre_post_processors
|
||||
from .text_processor_pi052 import PI052TextTokenizerStep
|
||||
|
||||
__all__ = [
|
||||
"PI052Config",
|
||||
"PI052Policy",
|
||||
"PI052TextTokenizerStep",
|
||||
"make_pi052_pre_post_processors",
|
||||
]
|
||||
@@ -1,216 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""π0.5 v2 (with text head) — reproduction of the π0.5 paper's
|
||||
hierarchical inference recipe.
|
||||
|
||||
Same architecture as the existing ``PI05Policy`` (PaliGemma 2B VLM +
|
||||
~300M Gemma action expert, joint training with FAST tokens during
|
||||
pre-train and flow matching during post-train), but with the
|
||||
PaliGemma ``lm_head`` re-enabled so the same model can be supervised
|
||||
to predict both:
|
||||
|
||||
* **subtask strings** at the high level (cross-entropy on the LM
|
||||
head), and
|
||||
* **action chunks** at the low level (flow matching on the
|
||||
action-expert tokens).
|
||||
|
||||
This is the dual-head co-training pattern from the paper:
|
||||
|
||||
L = H(x, f_θ_text) + α * ‖ω - a - f_θ_action(a_τ, o, ℓ)‖²
|
||||
|
||||
with α = 10.0 per § IV.D of arxiv:2504.16054. The π0.5 model splits
|
||||
inference into a text-prediction step followed by an action-prediction
|
||||
step, which the multi-rate ``PI052Runtime`` (in
|
||||
``lerobot.policies.pi052.inference``) drives at separate rates.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
|
||||
from ..pi05.configuration_pi05 import PI05Config
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi052")
|
||||
@dataclass
|
||||
class PI052Config(PI05Config):
|
||||
"""π0.5 with the PaliGemma LM head re-enabled for subtask prediction.
|
||||
|
||||
Recipe-driven dual-head training: the flow head supervises actions,
|
||||
the LM head supervises subtask / plan / memory / VQA text. The
|
||||
flow:text loss split is the milder 5:1 (see ``flow_loss_weight``).
|
||||
"""
|
||||
|
||||
# Recipe / language stack ---------------------------------------------
|
||||
recipe_path: str | None = "recipes/subtasks_vqa.yaml"
|
||||
"""Path (absolute or relative to ``src/lerobot/configs/``) to a
|
||||
``TrainingRecipe`` YAML. Defaults to the canonical Hi-Robot blend
|
||||
shipped alongside this policy. Set to ``None`` to disable recipe
|
||||
rendering and fall back to π0.5's single-task ``Task: ... Action:``
|
||||
prompt path (unannotated datasets keep working that way)."""
|
||||
|
||||
apply_chat_template: bool = False
|
||||
"""PaliGemma is *not* chat-pretrained — its tokenizer doesn't ship a
|
||||
chat template, so we don't apply one. The recipe renderer's output
|
||||
is concatenated as a plain prefix + assistant suffix instead,
|
||||
mirroring how the π0.5 paper's high-level inference samples text
|
||||
auto-regressively after the prefix."""
|
||||
|
||||
# Loss weights --------------------------------------------------------
|
||||
# Paper §IV.D uses α=10 between the flow and text terms, assuming
|
||||
# text is a rare auxiliary task. With the recipe stack the flow-only
|
||||
# `low_level` branch fires on a large share of samples, so α=10
|
||||
# swamps the LM head and collapses generation into degenerate
|
||||
# repetition. We use the milder 5:1 split here.
|
||||
text_loss_weight: float = 1.0
|
||||
"""Weight on the LM-head cross-entropy term. Set to ``0`` to disable
|
||||
text training entirely (reverts to flow-only / π0.5 behaviour)."""
|
||||
|
||||
flow_loss_weight: float = 5.0
|
||||
"""Weight on the action-expert flow-matching term. ``5.0`` — a milder
|
||||
flow:text split than the paper's α=10, since the flow-only
|
||||
``low_level`` recipe already gives the action expert frequent
|
||||
gradient. Lower it further if the LM head still underfits."""
|
||||
|
||||
# Backbone training ---------------------------------------------------
|
||||
unfreeze_lm_head: bool = True
|
||||
"""Whether to keep the PaliGemma ``lm_head`` unfrozen for fine-tuning.
|
||||
The existing ``PI05Policy`` zeroes / freezes the head on load
|
||||
because it never reads from it. Must be ``True`` for π0.5-style
|
||||
hierarchical inference."""
|
||||
|
||||
# Per-component prompt dropout (Pi0.7 §V.E) ---------------------------
|
||||
# Randomly drop non-target context messages so the LM head learns
|
||||
# to handle missing /
|
||||
# stale plan / memory at inference. Defaults to 0.0 so behaviour
|
||||
# is identical until explicitly enabled.
|
||||
plan_dropout_prob: float = 0.0
|
||||
memory_dropout_prob: float = 0.0
|
||||
subtask_dropout_prob: float = 0.0
|
||||
|
||||
# FAST discrete-action supervision — paper §III.B-C ------------------
|
||||
# When enabled, actions are *also* tokenised via the FAST tokenizer
|
||||
# ("physical-intelligence/fast") and supervised with cross-entropy
|
||||
# on the PaliGemma LM head — exactly as in the paper's pre-training
|
||||
# objective (Eq. 1 mixes FAST CE + flow MSE + subtask CE). The
|
||||
# ActionTokenizerProcessorStep is wired into the preprocessor
|
||||
# pipeline when this flag is set; the loss is computed in
|
||||
# PI052Policy.forward.
|
||||
enable_fast_action_loss: bool = True
|
||||
"""If True, tokenise actions with the FAST tokenizer and add a
|
||||
cross-entropy loss on the LM head. On by default to match the
|
||||
π0.5 paper's three-loss objective (text CE + FAST CE + flow MSE,
|
||||
§III.B-C Eq. 1). Set to False if you only want the
|
||||
post-training-style flow + text recipe."""
|
||||
|
||||
action_tokenizer_name: str = "physical-intelligence/fast"
|
||||
"""HF identifier for the FAST action tokenizer."""
|
||||
|
||||
max_action_tokens: int = 256
|
||||
"""Maximum number of FAST tokens per action chunk."""
|
||||
|
||||
fast_skip_tokens: int = 128
|
||||
"""Number of low-vocab tokens the FAST tokenizer skips to avoid
|
||||
collisions with PaliGemma's text vocabulary."""
|
||||
|
||||
fast_action_loss_weight: float = 1.0
|
||||
"""Weight on the FAST-action-token CE loss. Paper §III.C uses 1.0."""
|
||||
|
||||
auto_fit_fast_tokenizer: bool = False
|
||||
"""If True, the processor factory checks ``fast_tokenizer_cache_dir``
|
||||
for a previously-fitted tokenizer keyed on ``(dataset_repo_id,
|
||||
base_tokenizer_name, fit_samples)``. On cache miss, it loads
|
||||
``action_tokenizer_name`` as a base, samples
|
||||
``fast_tokenizer_fit_samples`` action chunks from the dataset, runs
|
||||
``.fit()``, saves the result, and uses *that* fitted path as the
|
||||
actual tokenizer. Pertsch et al. 2025 (FAST paper [64], π0.5 §III.C)
|
||||
explicitly recommend per-dataset fitting for best compression.
|
||||
|
||||
Off by default because the fit requires a separate pre-training
|
||||
pass over the dataset (~1-2 min on a medium dataset) and depends
|
||||
on the FAST tokenizer snapshot having a ``.fit()`` method. Opt in
|
||||
when you want paper-faithful compression; leave off to fall back
|
||||
on the universal ``physical-intelligence/fast`` codebook."""
|
||||
|
||||
fast_tokenizer_cache_dir: str = "~/.cache/lerobot/fast_tokenizers"
|
||||
"""Where fitted FAST tokenizers are stored. ``~`` expands."""
|
||||
|
||||
fast_tokenizer_fit_samples: int = 1024
|
||||
"""Number of action chunks to sample for the fit. The FAST paper uses
|
||||
a few thousand; 1024 is a reasonable default for medium datasets."""
|
||||
|
||||
# Knowledge insulation — paper §III.B --------------------------------
|
||||
# When enabled, gradients from the action expert's flow loss are
|
||||
# blocked from flowing back into the VLM's K/V projections. This
|
||||
# prevents the action loss from over-fitting the language backbone
|
||||
# to robot-specific features. Implemented in ``modeling_pi052`` as
|
||||
# a per-instance monkey-patch on ``paligemma_with_expert.forward``
|
||||
# that splits queries into VLM and action halves and ``.detach()``-s
|
||||
# the VLM K/V tensors used in the action-half's attention.
|
||||
knowledge_insulation: bool = False
|
||||
"""If True, route every transformer layer through the KI
|
||||
attention path that blocks action→VLM gradient flow on K/V."""
|
||||
|
||||
# Learning-rate defaults --------------------------------------------
|
||||
# pi052 inherits π0.5's openpi-validated optimizer config (peak LR
|
||||
# 2.5e-5, cosine→2.5e-6, 1k warmup, AdamW (0.9, 0.95), wd=0.01,
|
||||
# grad_clip=1.0). The only place pi052 needs to diverge from pi05
|
||||
# is the LM-head LR multiplier: pi05 has no text supervision so the
|
||||
# head doesn't get gradients; pi052 always has text supervision
|
||||
# (subtask / memory / VQA) via the recipe, and under KI the LM head
|
||||
# only sees gradients on ~30–45% of the batch (the text-CE mask
|
||||
# share of the recipe). Under aggressive cosine decay this is too
|
||||
# weak to keep the head pinned, so it drifts back toward PaliGemma's
|
||||
# pretrained ``<loc>`` first-token bias. 5x is the documented fix
|
||||
# (see ``PI05Config.lm_head_lr_scale`` docstring); the wiring is
|
||||
# already in ``PI05Policy.get_optim_params`` — it splits the LM head
|
||||
# + tied ``embed_tokens`` into their own param group while sharing
|
||||
# the same cosine lambda, so the 5x ratio is preserved across decay.
|
||||
lm_head_lr_scale: float = 5.0
|
||||
|
||||
# PaLM-style z-loss on text CE. Penalises the log-partition function
|
||||
# ``z = log Σ exp(logits)`` drifting away from zero — without it, large-
|
||||
# vocab models (PaliGemma is 257k) can let ``logsumexp`` grow unbounded
|
||||
# while CE stays low, because a uniform additive logit bias cancels in
|
||||
# softmax. PaLM appendix B / Chinchilla report z-loss is essential for
|
||||
# stable large-vocab CE; it especially helps under ``lm_head_lr_scale=
|
||||
# 5.0`` which amplifies drift risk on the LM head. ``1e-4`` is the
|
||||
# commonly cited weight; set 0 to disable entirely.
|
||||
text_ce_z_loss_weight: float = 1e-4
|
||||
|
||||
# Liger Triton kernels (rope + geglu + layer_norm) are now patched
|
||||
# unconditionally at model build time — see ``_enable_hf_kernels``
|
||||
# in ``modeling_pi052``. The patch is process-global, idempotent
|
||||
# and degrades gracefully if ``liger-kernel`` is missing. Measured
|
||||
# at -4.5% step time on H100 (bench job 22161421); peak memory
|
||||
# unchanged. ``fused_linear_cross_entropy`` ships separately via
|
||||
# ``_shifted_lin_ce`` / ``_fast_lin_ce``.
|
||||
use_hf_kernels: bool = True
|
||||
"""Deprecated. Liger HF kernels are patched unconditionally by
|
||||
``_enable_hf_kernels`` — this field is retained as a no-op for
|
||||
backward compatibility with checkpoints saved before commit
|
||||
d70c8104 (which still serialize ``use_hf_kernels: true`` into
|
||||
``config.json``). Loading those configs would otherwise raise
|
||||
``DecodingError: The fields use_hf_kernels are not valid for
|
||||
PI052Config`` (job 22164492). Remove in a future major bump."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
# Backbone needs gradients flowing through the text head when
|
||||
# we're training it. Override the π0.5 default
|
||||
# (``train_expert_only=True``) unless the user explicitly opts
|
||||
# out of text training via ``text_loss_weight=0``.
|
||||
if self.text_loss_weight > 0 and self.unfreeze_lm_head:
|
||||
self.train_expert_only = False
|
||||
@@ -1,263 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Dataset-specific FAST action tokenizer fitting.
|
||||
|
||||
The published ``physical-intelligence/fast`` tokenizer is a *universal*
|
||||
codebook fitted on a heterogeneous mix of robot datasets. Per Pertsch
|
||||
et al. 2025 (the FAST paper, [64] in the π0.5 paper) and §III.C of
|
||||
π0.5 itself, the recommended practice is to **finetune the tokenizer on
|
||||
your specific dataset's action distribution** before training the
|
||||
policy — same way one would adapt a language tokenizer to a domain
|
||||
corpus. Without this finetune step, action sequences from your robot
|
||||
may require more tokens per chunk than necessary, lowering effective
|
||||
compression and slowing convergence of the action-CE loss.
|
||||
|
||||
This module provides a single utility, :func:`fit_fast_tokenizer`,
|
||||
that does the finetune. The training entry point invokes it
|
||||
automatically when the policy's ``enable_fast_action_loss`` and
|
||||
``auto_fit_fast_tokenizer`` flags are both ``True`` and no cached
|
||||
fitted tokenizer is found at ``fast_tokenizer_cache_dir``.
|
||||
|
||||
The fitted tokenizer is saved to
|
||||
``{cache_dir}/{dataset_hash}_{base_hash}/`` so successive training
|
||||
runs over the same dataset re-use it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Marker file the cache-hit check looks for. ``ProcessorMixin.save_pretrained``
|
||||
# writes ``processor_config.json`` (NOT ``preprocessor_config.json`` —
|
||||
# that's the image / feature-extractor convention). Centralised here so
|
||||
# the cache-hit check and the rank-N readiness wait agree on the same
|
||||
# sentinel.
|
||||
_CACHE_SENTINEL = "processor_config.json"
|
||||
|
||||
|
||||
def _dataset_signature(
|
||||
dataset_repo_id: str,
|
||||
base_tokenizer_name: str,
|
||||
n_samples: int,
|
||||
chunk_size: int,
|
||||
) -> str:
|
||||
"""Deterministic short hash for naming the cache directory.
|
||||
|
||||
Keys on (dataset, base tokenizer, sample count, chunk size) so any
|
||||
of those changing re-runs the fit. ``chunk_size`` matters because
|
||||
the tokenizer is fit on chunks of that length.
|
||||
"""
|
||||
h = hashlib.sha256()
|
||||
h.update(dataset_repo_id.encode("utf-8"))
|
||||
h.update(b"\0")
|
||||
h.update(base_tokenizer_name.encode("utf-8"))
|
||||
h.update(b"\0")
|
||||
h.update(str(n_samples).encode("utf-8"))
|
||||
h.update(b"\0")
|
||||
h.update(str(chunk_size).encode("utf-8"))
|
||||
return h.hexdigest()[:16]
|
||||
|
||||
|
||||
def fit_fast_tokenizer(
|
||||
*,
|
||||
dataset_repo_id: str,
|
||||
cache_dir: str | Path,
|
||||
base_tokenizer_name: str = "physical-intelligence/fast",
|
||||
n_samples: int = 1024,
|
||||
chunk_size: int = 50,
|
||||
seed: int = 42,
|
||||
) -> str:
|
||||
"""Fit a FAST tokenizer on a LeRobot dataset's action distribution.
|
||||
|
||||
Args:
|
||||
dataset_repo_id: HF Hub repo id of the LeRobotDataset to fit on.
|
||||
cache_dir: Directory under which to save (and look up) fitted
|
||||
tokenizers. The actual save path is
|
||||
``{cache_dir}/{signature}``.
|
||||
base_tokenizer_name: HF identifier for the base FAST tokenizer
|
||||
to finetune from. ``physical-intelligence/fast`` is the
|
||||
universal one.
|
||||
n_samples: Number of action chunks to sample for the fit. The
|
||||
FAST paper uses a few thousand; ``1024`` is a good default
|
||||
for medium datasets.
|
||||
chunk_size: Length of each action chunk (matches
|
||||
``policy.chunk_size``). The FAST tokenizer is fit on
|
||||
sequences of this length.
|
||||
seed: RNG seed for sample selection.
|
||||
|
||||
Returns:
|
||||
The local path to the fitted tokenizer. Passed directly to
|
||||
``--policy.action_tokenizer_name`` for the training run.
|
||||
|
||||
Raises:
|
||||
ImportError: If the ``transformers`` library doesn't expose
|
||||
``AutoProcessor`` or the FAST tokenizer doesn't have a
|
||||
``.fit()`` method (then you're on an older FAST snapshot —
|
||||
update to the current published model).
|
||||
FileNotFoundError: If the dataset can't be loaded.
|
||||
"""
|
||||
cache_dir = Path(cache_dir)
|
||||
sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples, chunk_size)
|
||||
out_dir = cache_dir / sig
|
||||
|
||||
if out_dir.exists() and (out_dir / _CACHE_SENTINEL).exists():
|
||||
logger.info(
|
||||
"FAST tokenizer cache hit: %s — re-using fitted tokenizer for "
|
||||
"dataset=%s base=%s n_samples=%d",
|
||||
out_dir, dataset_repo_id, base_tokenizer_name, n_samples,
|
||||
)
|
||||
return str(out_dir)
|
||||
|
||||
# DDP-safe fit: only the (local) main process actually fits + saves;
|
||||
# other ranks poll the cache sentinel until the leader is done.
|
||||
# Without this guard, all N ranks fit concurrently and race on
|
||||
# ``save_pretrained`` + ``AutoProcessor.from_pretrained`` (the latter
|
||||
# copies ``processing_action_tokenizer.py`` into ``HF_MODULES_CACHE``
|
||||
# and compiles a ``.pyc`` — concurrent writers occasionally produce
|
||||
# a stale / partial ``.pyc`` and the subsequent ``from .. import
|
||||
# UniversalActionProcessor`` raises ``AttributeError``.
|
||||
is_leader = (
|
||||
int(os.environ.get("RANK", "0")) == 0
|
||||
and int(os.environ.get("LOCAL_RANK", "0")) == 0
|
||||
)
|
||||
if not is_leader:
|
||||
timeout_s = 1800.0 # 30 min — covers ~1024-sample fits on cold caches
|
||||
start = time.monotonic()
|
||||
while not (out_dir / _CACHE_SENTINEL).exists():
|
||||
if time.monotonic() - start > timeout_s:
|
||||
raise RuntimeError(
|
||||
f"FAST tokenizer fit: non-leader rank timed out after "
|
||||
f"{timeout_s:.0f}s waiting for {out_dir / _CACHE_SENTINEL}. "
|
||||
"Leader rank likely crashed during the fit."
|
||||
)
|
||||
time.sleep(2.0)
|
||||
logger.info("FAST tokenizer ready (leader populated cache): %s", out_dir)
|
||||
return str(out_dir)
|
||||
|
||||
logger.info(
|
||||
"FAST tokenizer cache miss — fitting on dataset=%s "
|
||||
"base=%s n_samples=%d chunk_size=%d → %s",
|
||||
dataset_repo_id, base_tokenizer_name, n_samples, chunk_size, out_dir,
|
||||
)
|
||||
|
||||
from transformers import AutoProcessor # noqa: PLC0415
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415
|
||||
|
||||
# Stream a single episode's worth of action chunks at a time so
|
||||
# we don't blow memory on huge datasets. Random episode +
|
||||
# random start offset gives a reasonable spread.
|
||||
#
|
||||
# Actions are read straight from the underlying HF dataset's
|
||||
# ``action`` *column* — never via ``ds[i]``. ``ds[i]`` builds a full
|
||||
# training item (delta-timestamp expansion + video decode + image
|
||||
# transforms); a single bad video frame would then throw and, since
|
||||
# the failure was swallowed at debug level, silently starve the fit
|
||||
# of every chunk. The action column carries no video, so reading it
|
||||
# directly is both faster and immune to decode errors.
|
||||
rng = np.random.default_rng(seed)
|
||||
actions_buf: list[np.ndarray] = []
|
||||
|
||||
# Load just the metadata first to know episode boundaries.
|
||||
ds_meta_only = LeRobotDataset(dataset_repo_id, episodes=[0])
|
||||
num_episodes = ds_meta_only.meta.total_episodes
|
||||
if "action" not in ds_meta_only.features:
|
||||
available = ", ".join(sorted(ds_meta_only.features)) or "<none>"
|
||||
raise RuntimeError(
|
||||
f"FAST fit: dataset {dataset_repo_id!r} has no ``action`` feature. "
|
||||
f"Available features: {available}."
|
||||
)
|
||||
del ds_meta_only
|
||||
|
||||
samples_per_episode = max(1, n_samples // max(num_episodes, 1))
|
||||
collected = 0
|
||||
eps_visited = 0
|
||||
short_episodes = 0
|
||||
for ep_idx in rng.permutation(num_episodes):
|
||||
if collected >= n_samples:
|
||||
break
|
||||
ep_idx = int(ep_idx)
|
||||
try:
|
||||
ds = LeRobotDataset(dataset_repo_id, episodes=[ep_idx])
|
||||
ep_actions = np.asarray(ds.hf_dataset["action"], dtype=np.float32)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("FAST fit: skipping episode %d: %s", ep_idx, exc)
|
||||
continue
|
||||
if ep_actions.ndim != 2 or ep_actions.shape[0] < chunk_size:
|
||||
short_episodes += 1
|
||||
continue
|
||||
# Sample ``samples_per_episode`` contiguous chunks uniformly.
|
||||
starts = rng.integers(0, ep_actions.shape[0] - chunk_size + 1, size=samples_per_episode)
|
||||
for s in starts:
|
||||
actions_buf.append(ep_actions[int(s) : int(s) + chunk_size])
|
||||
collected += 1
|
||||
if collected >= n_samples:
|
||||
break
|
||||
eps_visited += 1
|
||||
|
||||
if not actions_buf:
|
||||
raise RuntimeError(
|
||||
f"FAST fit collected zero action chunks from {dataset_repo_id!r}: "
|
||||
f"all {num_episodes} episodes were shorter than chunk_size="
|
||||
f"{chunk_size} ({short_episodes} too short) or had an unreadable "
|
||||
"``action`` column. Lower ``chunk_size`` to match your episode "
|
||||
"lengths."
|
||||
)
|
||||
|
||||
actions = np.stack(actions_buf, axis=0).astype(np.float32) # (N, H, D)
|
||||
logger.info(
|
||||
"FAST fit: collected %d chunks of shape %s from %d episodes",
|
||||
actions.shape[0], actions.shape[1:], eps_visited,
|
||||
)
|
||||
|
||||
# Quantile-normalise per dimension before fitting.
|
||||
#
|
||||
# The FAST tokenizer DCT-transforms actions, scales by ``scale`` and
|
||||
# rounds to integer tokens; the integer *range* must fit the
|
||||
# codebook (vocab_size, default 1024). Raw motor units (e.g. encoder
|
||||
# ticks) blow that range up — hence "Vocab size 1024 is too small".
|
||||
# More importantly, at training time ``ActionTokenizerProcessorStep``
|
||||
# runs *after* the QUANTILES ``NormalizerProcessorStep``, so it
|
||||
# encodes normalised actions. Fitting on raw actions would mismatch
|
||||
# that space. We replicate QUANTILES normalisation here (per-dim
|
||||
# [q01, q99] → [-1, 1], clipped) so the fit and the training-time
|
||||
# encode see the same distribution.
|
||||
flat = actions.reshape(-1, actions.shape[-1])
|
||||
q01 = np.quantile(flat, 0.01, axis=0)
|
||||
q99 = np.quantile(flat, 0.99, axis=0)
|
||||
span = np.where((q99 - q01) > 1e-6, q99 - q01, 1.0)
|
||||
actions = np.clip((actions - q01) / span * 2.0 - 1.0, -1.0, 1.0).astype(np.float32)
|
||||
|
||||
base = AutoProcessor.from_pretrained(base_tokenizer_name, trust_remote_code=True)
|
||||
if not hasattr(base, "fit"):
|
||||
raise ImportError(
|
||||
f"Base FAST tokenizer {base_tokenizer_name!r} has no ``.fit()`` "
|
||||
"method — your transformers / model snapshot is too old. Update "
|
||||
"to the current ``physical-intelligence/fast`` revision."
|
||||
)
|
||||
|
||||
fitted = base.fit(actions)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
fitted.save_pretrained(str(out_dir))
|
||||
logger.info("FAST fit: saved fitted tokenizer to %s", out_dir)
|
||||
return str(out_dir)
|
||||
@@ -1,73 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PI052 inference / runtime orchestration.
|
||||
|
||||
Multi-rate runtime that mirrors the recipe-time training shape:
|
||||
|
||||
low_level_execution → LowLevelForward + DispatchAction (high Hz)
|
||||
high_level_subtask → HighLevelSubtaskFwd (~1 Hz)
|
||||
memory_update → MemoryUpdateFwd (event: subtask_change)
|
||||
user_interjection_response → UserInterjectionFwd (event: stdin)
|
||||
ask_vqa_* → AskVQAFwd (event: stdin question)
|
||||
speech tool calls → DispatchToolCalls (event: tool_call_pending)
|
||||
|
||||
The CLI ``lerobot-pi052-runtime`` builds a ``PI052Runtime`` and calls
|
||||
``run()``.
|
||||
"""
|
||||
|
||||
from .repl import StdinReader
|
||||
from .runtime import PI052Runtime
|
||||
from .runtime_state import initial_runtime_state, push_log, set_if_changed, take_event
|
||||
from .steps import (
|
||||
AskVQAFwd,
|
||||
DispatchAction,
|
||||
DispatchToolCalls,
|
||||
HighLevelSubtaskFwd,
|
||||
InferenceStep,
|
||||
LowLevelForward,
|
||||
MemoryUpdateFwd,
|
||||
UserInterjectionFwd,
|
||||
)
|
||||
from .triggers import EventTrigger, HzTrigger, Tick, TickClock, Trigger
|
||||
from .ui import make_state_panel, print_robot_lines, print_user_line
|
||||
|
||||
__all__ = [
|
||||
# runtime
|
||||
"PI052Runtime",
|
||||
"StdinReader",
|
||||
# state helpers
|
||||
"initial_runtime_state",
|
||||
"push_log",
|
||||
"set_if_changed",
|
||||
"take_event",
|
||||
# triggers
|
||||
"Trigger",
|
||||
"Tick",
|
||||
"TickClock",
|
||||
"HzTrigger",
|
||||
"EventTrigger",
|
||||
# steps
|
||||
"InferenceStep",
|
||||
"LowLevelForward",
|
||||
"DispatchAction",
|
||||
"HighLevelSubtaskFwd",
|
||||
"MemoryUpdateFwd",
|
||||
"UserInterjectionFwd",
|
||||
"AskVQAFwd",
|
||||
"DispatchToolCalls",
|
||||
# UI
|
||||
"make_state_panel",
|
||||
"print_robot_lines",
|
||||
"print_user_line",
|
||||
]
|
||||
@@ -1,105 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Stdin REPL event collector for the PI052 runtime.
|
||||
|
||||
Reads non-blocking stdin lines, classifies each one heuristically:
|
||||
|
||||
"stop" / "quit" / "exit" → state["stop"] = True
|
||||
"/action" / "/pause" → set state["mode"]
|
||||
ends with "?" → user_vqa_query event
|
||||
starts with "task:" or first line → set runtime task
|
||||
anything else → user_interjection event
|
||||
|
||||
Plugged into the runtime via ``event_collector=StdinReader().poll``.
|
||||
|
||||
Note: the shipped CLI (``lerobot-pi052-runtime``) drives stdin
|
||||
directly in its REPL / autonomous loops and does *not* wire this
|
||||
collector; it's kept as the documented embedding hook and for tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import select
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class StdinReader:
|
||||
"""Non-blocking stdin line collector for the runtime loop."""
|
||||
|
||||
prompt: str = "> "
|
||||
_seen_first_line: bool = field(default=False, init=False)
|
||||
_prompted: bool = field(default=False, init=False)
|
||||
|
||||
def poll(self, state: dict[str, Any]) -> None:
|
||||
"""Drain pending stdin lines into runtime events."""
|
||||
# Print the input prompt once on every fresh tick if we don't
|
||||
# already have a pending line; matches the expected REPL feel.
|
||||
if not self._prompted:
|
||||
print(self.prompt, end="", flush=True)
|
||||
self._prompted = True
|
||||
|
||||
# ``select`` with timeout=0 makes this non-blocking. Only works
|
||||
# for actual TTY / pipe stdins; CI / scripted runs hit EOF.
|
||||
try:
|
||||
ready, _, _ = select.select([sys.stdin], [], [], 0)
|
||||
except (ValueError, OSError):
|
||||
return
|
||||
if not ready:
|
||||
return
|
||||
|
||||
line = sys.stdin.readline()
|
||||
if not line: # EOF
|
||||
state["stop"] = True
|
||||
return
|
||||
line = line.strip()
|
||||
self._prompted = False # we'll re-prompt next tick
|
||||
if not line:
|
||||
return
|
||||
|
||||
lower = line.lower()
|
||||
if lower in {"stop", "quit", "exit"}:
|
||||
state["stop"] = True
|
||||
return
|
||||
|
||||
# Slash commands flip the run mode. ``/pause`` stops the action
|
||||
# loop (the action steps gate on ``state["mode"]``); ``/action``
|
||||
# resumes it.
|
||||
if lower.split(" ", 1)[0] in {"/action", "/act", "/run"}:
|
||||
state["mode"] = "action"
|
||||
return
|
||||
if lower in {"/pause", "/p"}:
|
||||
state["mode"] = "paused"
|
||||
queue = state.get("action_queue")
|
||||
if hasattr(queue, "clear"):
|
||||
queue.clear()
|
||||
return
|
||||
|
||||
# First non-control line sets the task if no task is active.
|
||||
if not state.get("task"):
|
||||
task = line[5:].strip() if lower.startswith("task:") else line
|
||||
state["task"] = task
|
||||
print(f"[pi052] Task: {task}", flush=True)
|
||||
self._seen_first_line = True
|
||||
return
|
||||
|
||||
# Question → VQA; statement → interjection.
|
||||
if lower.endswith("?"):
|
||||
state["recent_vqa_query"] = line
|
||||
state.setdefault("events_this_tick", []).append("user_vqa_query")
|
||||
else:
|
||||
state["recent_interjection"] = line
|
||||
state.setdefault("events_this_tick", []).append("user_interjection")
|
||||
@@ -1,205 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PI052 runtime loop.
|
||||
|
||||
Threads the multi-rate inference pipeline together with a stdin REPL
|
||||
event collector, drives ticks through :class:`TickClock`, and prints
|
||||
state-change updates to the user.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable
|
||||
|
||||
from .runtime_state import initial_runtime_state, push_log
|
||||
from .steps import (
|
||||
AskVQAFwd,
|
||||
DispatchAction,
|
||||
DispatchToolCalls,
|
||||
HighLevelSubtaskFwd,
|
||||
InferenceStep,
|
||||
LowLevelForward,
|
||||
MemoryUpdateFwd,
|
||||
)
|
||||
from .triggers import EventTrigger, HzTrigger, TickClock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PI052Runtime:
|
||||
"""Compose the inference pipeline and drive it tick-by-tick."""
|
||||
|
||||
policy: Any
|
||||
tools: dict[str, Any] = field(default_factory=dict)
|
||||
"""Name → tool-instance dict, e.g. ``{"say": SayTool(...)}``. Read
|
||||
from :func:`lerobot.tools.get_tools(meta)` when wiring the
|
||||
runtime."""
|
||||
observation_provider: Callable[[], dict | None] | None = None
|
||||
"""Closure returning the current preprocessed observation batch.
|
||||
``None`` for dry-run / language-only sessions."""
|
||||
robot_executor: Callable[[Any], None] | None = None
|
||||
"""Closure that takes one action chunk and forwards it to the
|
||||
robot. ``None`` for dry-run."""
|
||||
event_collector: Callable[[dict], None] | None = None
|
||||
"""Per-tick hook that polls external sources (stdin, network) and
|
||||
appends event names to ``state["events_this_tick"]``."""
|
||||
chunk_hz: float = 4.0
|
||||
ctrl_hz: float = 50.0
|
||||
high_level_hz: float = 1.0
|
||||
max_rate_hz: float = 50.0
|
||||
|
||||
pipeline: list[InferenceStep] = field(init=False)
|
||||
state: dict[str, Any] = field(init=False)
|
||||
_stop: bool = field(default=False, init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Subtask + memory + VQA configuration. Pipeline:
|
||||
#
|
||||
# HighLevelSubtaskFwd → generate the next subtask via the LM
|
||||
# head at ~``high_level_hz``; writes
|
||||
# ``current_subtask`` and emits
|
||||
# ``subtask_change`` on a transition.
|
||||
# MemoryUpdateFwd → on ``subtask_change``, refresh
|
||||
# ``current_memory`` from the
|
||||
# ``memory_update`` head.
|
||||
# AskVQAFwd → answer camera-grounded stdin questions.
|
||||
# LowLevelForward → action chunk conditioned on the
|
||||
# generated ``current_subtask``.
|
||||
# DispatchAction → drain the chunk to the robot.
|
||||
# DispatchToolCalls → fire any pending tool calls.
|
||||
#
|
||||
# Order matters: ``HighLevelSubtaskFwd`` must run before
|
||||
# ``MemoryUpdateFwd`` so the event is visible the same tick, and
|
||||
# both must run before ``LowLevelForward`` (which is gated on
|
||||
# "action queue empty") so the chunk consumes the freshest
|
||||
# subtask. ``UserInterjectionFwd`` is still importable but
|
||||
# disabled until plan generation is wired in.
|
||||
self.pipeline = [
|
||||
HighLevelSubtaskFwd(
|
||||
trigger=HzTrigger(self.high_level_hz),
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
# Listens for the ``subtask_change`` event raised by
|
||||
# ``HighLevelSubtaskFwd`` and refreshes ``current_memory``.
|
||||
MemoryUpdateFwd(
|
||||
trigger=EventTrigger("subtask_change"),
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
AskVQAFwd(
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
LowLevelForward(
|
||||
trigger=HzTrigger(self.chunk_hz),
|
||||
policy=self.policy,
|
||||
observation_provider=self.observation_provider,
|
||||
),
|
||||
DispatchAction(
|
||||
trigger=HzTrigger(self.ctrl_hz),
|
||||
robot_executor=self.robot_executor,
|
||||
),
|
||||
DispatchToolCalls(tools=self.tools),
|
||||
]
|
||||
self.state = initial_runtime_state()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def set_task(self, task: str) -> None:
|
||||
"""Set or replace the active task. Logged for the REPL."""
|
||||
self.state["task"] = task
|
||||
push_log(self.state, f"Task: {task}")
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop = True
|
||||
|
||||
def run(self, *, max_ticks: int | None = None) -> None:
|
||||
"""Main loop. Returns when ``stop()`` is called or after
|
||||
``max_ticks`` ticks (useful for tests / dry-run)."""
|
||||
clock = TickClock(max_rate_hz=self.max_rate_hz)
|
||||
while not self._stop:
|
||||
tick = clock.advance()
|
||||
self.state["_tick"] = tick
|
||||
self.state["events_this_tick"] = []
|
||||
self.state["log_lines"] = []
|
||||
|
||||
if self.event_collector is not None:
|
||||
self.event_collector(self.state)
|
||||
if self.state.get("stop"):
|
||||
self._stop = True
|
||||
break
|
||||
|
||||
for step in self.pipeline:
|
||||
self.state = step(self.state)
|
||||
|
||||
self._flush_logs()
|
||||
if max_ticks is not None and tick.index >= max_ticks:
|
||||
break
|
||||
|
||||
self._on_shutdown()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# REPL helper: drive one full pipeline pass and return its logs
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def step_once(self) -> list[str]:
|
||||
"""Run one tick of the pipeline and return the log lines.
|
||||
|
||||
Used by the interactive REPL: instead of a background thread,
|
||||
the CLI drives ticks synchronously after each user input. Logs
|
||||
are returned (not printed) so the caller can route them into
|
||||
the rich-Live chat scrollback.
|
||||
"""
|
||||
from .triggers import Tick # noqa: PLC0415
|
||||
|
||||
# Synthesize a tick. We don't need the real wall-clock pacing
|
||||
# here — the REPL drives the runtime, not vice versa — but
|
||||
# ``HzTrigger`` uses ``tick.monotonic_seconds`` to gate, so we
|
||||
# bump it generously so every Hz-triggered step considers
|
||||
# itself due.
|
||||
import time as _time # noqa: PLC0415
|
||||
|
||||
prev_index = self.state.get("_tick").index if isinstance(self.state.get("_tick"), Tick) else 0
|
||||
self.state["_tick"] = Tick(index=prev_index + 1, monotonic_seconds=_time.monotonic())
|
||||
self.state["log_lines"] = []
|
||||
# ``events_this_tick`` is set up by the caller before
|
||||
# ``step_once`` (the REPL pushes user-driven events first).
|
||||
self.state.setdefault("events_this_tick", [])
|
||||
|
||||
for step in self.pipeline:
|
||||
self.state = step(self.state)
|
||||
|
||||
return list(self.state.get("log_lines") or [])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# I/O
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _flush_logs(self) -> None:
|
||||
for line in self.state.get("log_lines") or []:
|
||||
print(f"[pi052] {line}", flush=True)
|
||||
|
||||
def _on_shutdown(self) -> None:
|
||||
# Drain any queued action chunks safely.
|
||||
queue = self.state.get("action_queue")
|
||||
if isinstance(queue, deque):
|
||||
queue.clear()
|
||||
print("[pi052] runtime stopped", flush=True)
|
||||
@@ -1,95 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Runtime state passed between inference steps each tick.
|
||||
|
||||
The runtime threads a single dict through the pipeline; this module
|
||||
documents the shape and provides factories. We use a plain ``dict``
|
||||
rather than a frozen dataclass because steps freely add and remove
|
||||
keys (``events_this_tick``, ``messages_pending``, ``tool_calls_pending``,
|
||||
…) and dataclass field churn would just get in the way.
|
||||
|
||||
Stable keys (read by multiple steps):
|
||||
|
||||
task str the current top-level task
|
||||
current_plan str | None latest plan emitted by the planner
|
||||
current_subtask str | None latest subtask the policy is executing
|
||||
current_memory str | None latest compressed memory
|
||||
recent_interjection str | None most recent user interjection text (consumed)
|
||||
|
||||
action_queue collections.deque[Tensor] pending action chunks
|
||||
tool_calls_pending list[dict] parsed but not-yet-dispatched tool calls
|
||||
|
||||
events_this_tick list[str] triggers consumed this tick
|
||||
_tick Tick current tick (set by the loop)
|
||||
|
||||
mode str "action" (run the robot) | "paused"
|
||||
(action loop stopped — robot holds)
|
||||
|
||||
log_lines list[str] human-readable status lines printed each tick
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
|
||||
def initial_runtime_state(task: str | None = None) -> dict[str, Any]:
|
||||
"""Build a fresh runtime state dict with sensible defaults."""
|
||||
return {
|
||||
"task": task,
|
||||
"current_plan": None,
|
||||
"current_subtask": None,
|
||||
"current_memory": None,
|
||||
"recent_interjection": None,
|
||||
"action_queue": deque(),
|
||||
"tool_calls_pending": [],
|
||||
"events_this_tick": [],
|
||||
"log_lines": [],
|
||||
"mode": "action",
|
||||
"stop": False,
|
||||
}
|
||||
|
||||
|
||||
def take_event(state: dict[str, Any], event_name: str) -> bool:
|
||||
"""Pop ``event_name`` from ``events_this_tick`` if present.
|
||||
|
||||
Steps that consume an event call this so the same event doesn't
|
||||
re-fire on a sibling step within the same tick.
|
||||
"""
|
||||
events: list[str] = state.get("events_this_tick") or []
|
||||
if event_name in events:
|
||||
events.remove(event_name)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def push_log(state: dict[str, Any], line: str) -> None:
|
||||
"""Append ``line`` to the per-tick log buffer; the runtime prints
|
||||
it at the end of the tick."""
|
||||
state.setdefault("log_lines", []).append(line)
|
||||
|
||||
|
||||
def set_if_changed(state: dict[str, Any], key: str, value: Any, label: str | None = None) -> bool:
|
||||
"""Update ``state[key]`` and log a diff line if the value changed.
|
||||
|
||||
Returns ``True`` if the value actually changed.
|
||||
"""
|
||||
prev = state.get(key)
|
||||
if prev == value:
|
||||
return False
|
||||
state[key] = value
|
||||
if label is not None:
|
||||
push_log(state, f" {label}: {value}")
|
||||
return True
|
||||
@@ -1,936 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference steps for the PI052 multi-rate runtime.
|
||||
|
||||
Each step is a tiny class with a ``trigger`` and an ``__call__(state)``;
|
||||
the runtime applies them in order each tick. When a step's trigger
|
||||
doesn't fire, the step is a no-op and the runtime moves on.
|
||||
|
||||
Stream-to-step mapping mirrors the ``subtasks_vqa.yaml`` recipe:
|
||||
|
||||
* ``LowLevelForward`` — calls ``policy.select_action`` for the
|
||||
action chunk; trained by
|
||||
``low_level_execution``
|
||||
* ``EnqueueChunk`` — pushes the chunk to ``action_queue``
|
||||
* ``DispatchAction`` — pops one action per control tick and
|
||||
forwards to the robot
|
||||
* ``HighLevelSubtaskFwd`` — calls ``policy.select_message`` for the
|
||||
next subtask; trained by
|
||||
``high_level_subtask``
|
||||
* ``MemoryUpdateFwd`` — fires on subtask boundary; trained by
|
||||
``memory_update``
|
||||
* ``UserInterjectionFwd`` — fires on stdin interjection; trained by
|
||||
``user_interjection_response``
|
||||
* ``AskVQAFwd`` — fires on stdin question; trained by
|
||||
``ask_vqa_*``
|
||||
* ``DispatchToolCalls`` — pops ``tool_calls_pending`` and calls
|
||||
the matching ``Tool`` instance
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from .runtime_state import push_log, set_if_changed, take_event
|
||||
from .triggers import EventTrigger, HzTrigger, Trigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step base + runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceStep:
|
||||
"""A trigger-gated callable. Subclasses override :meth:`run`."""
|
||||
|
||||
trigger: Trigger
|
||||
|
||||
def __call__(self, state: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self.trigger.should_fire(state["_tick"], state):
|
||||
return state
|
||||
return self.run(state) or state
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None: # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Low-level (action) path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class LowLevelForward(InferenceStep):
|
||||
"""Run the policy's action head and produce one action chunk."""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
"""Callable ``() -> dict``: returns the current observation batch
|
||||
(already preprocessed). Typically wraps the robot's camera /
|
||||
proprio reads. ``None`` in dry-run mode → step skips."""
|
||||
|
||||
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=4.0))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or self.observation_provider is None:
|
||||
return None
|
||||
# ``/vlm`` mode pauses the whole action loop so the robot holds
|
||||
# position while the operator probes the VLM with VQA.
|
||||
if state.get("mode", "action") != "action":
|
||||
return None
|
||||
if not state.get("task"):
|
||||
return None
|
||||
|
||||
# PI052 produces *action chunks* (typically 50 steps via
|
||||
# flow-matching). Every step gets dispatched to the robot;
|
||||
# popping one per dispatch tick is essentially free. Only
|
||||
# generate a new chunk once the previous one has fully
|
||||
# drained — this is the canonical "sense → think → act"
|
||||
# loop. Refreshing while a chunk is still queued causes the
|
||||
# new chunk to "telescope" past the old one (planned from an
|
||||
# observation that's already 25+ steps stale by the time it
|
||||
# starts dispatching).
|
||||
queue = state.setdefault("action_queue", [])
|
||||
if len(queue) > 0:
|
||||
return None
|
||||
|
||||
observation = self.observation_provider()
|
||||
if observation is None:
|
||||
return None
|
||||
|
||||
# The action expert is conditioned on the SUBTASK generated by
|
||||
# the high-level loop (``HighLevelSubtaskFwd`` runs earlier in
|
||||
# the pipeline and writes ``current_subtask``). Matches the
|
||||
# training-time ``low_level_execution`` recipe — ``user(${subtask})``.
|
||||
# Falls back to the task string only on the very first frame,
|
||||
# before the high-level loop has produced a subtask.
|
||||
subtask = state.get("current_subtask") or state.get("task") or ""
|
||||
ctx = [{"role": "user", "content": subtask}]
|
||||
# ``add_generation_prompt=False`` to match the training-time
|
||||
# prefix shape: at training the action expert sees the rendered
|
||||
# user turn ending at ``<|im_end|>`` (no trailing
|
||||
# ``<|im_start|>assistant\n``). Passing True here would append
|
||||
# extra role-marker tokens the action expert never saw during
|
||||
# training.
|
||||
text_batch = _build_text_batch(self.policy, ctx, add_generation_prompt=False)
|
||||
from lerobot.utils.constants import ( # noqa: PLC0415
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
)
|
||||
|
||||
observation = dict(observation)
|
||||
observation[OBS_LANGUAGE_TOKENS] = text_batch["lang_tokens"]
|
||||
observation[OBS_LANGUAGE_ATTENTION_MASK] = text_batch["lang_masks"]
|
||||
|
||||
try:
|
||||
# ``predict_action_chunk`` returns the *full* chunk shape
|
||||
# ``(batch, n_action_steps, action_dim)``. Enqueue every
|
||||
# step so DispatchAction at ctrl_hz can drain them
|
||||
# smoothly until the next refresh.
|
||||
chunk = self.policy.predict_action_chunk(observation)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"predict_action_chunk failed: %s",
|
||||
exc,
|
||||
exc_info=logger.isEnabledFor(logging.DEBUG),
|
||||
)
|
||||
push_log(
|
||||
state,
|
||||
f" [warn] predict_action_chunk failed: "
|
||||
f"{type(exc).__name__}: {exc}",
|
||||
)
|
||||
return None
|
||||
|
||||
# ``chunk`` shape: ``(batch, n_action_steps, action_dim)``. Push
|
||||
# each step as a ``(1, action_dim)`` tensor so the existing
|
||||
# action executor's batch-squeeze logic works unchanged.
|
||||
if chunk.ndim == 3:
|
||||
chunk_iter = chunk[0] # ``(n_action_steps, action_dim)``
|
||||
elif chunk.ndim == 2:
|
||||
chunk_iter = chunk
|
||||
else:
|
||||
chunk_iter = chunk.unsqueeze(0)
|
||||
|
||||
for step in chunk_iter:
|
||||
queue.append(step.unsqueeze(0))
|
||||
state["last_chunk_size"] = int(chunk_iter.shape[0])
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DispatchAction(InferenceStep):
|
||||
"""Pop one action per tick and hand it to the robot.
|
||||
|
||||
In dry-run mode (``robot_executor=None``) the step still pops the
|
||||
queue so it doesn't grow unbounded — the popped tensor is logged
|
||||
instead of executed.
|
||||
|
||||
Wall-clock catch-up: the action queue represents an open-loop
|
||||
trajectory at a fixed step rate (``trigger.hz`` ≈ ``ctrl_hz``).
|
||||
When the main loop stalls — e.g. an LLM call for the high-level
|
||||
subtask blocks for ~2 s on MPS — the dispatch trigger fires only
|
||||
once over that whole interval. Naively popping a single entry per
|
||||
fire makes the robot lag further and further behind the planned
|
||||
timeline, and a 50-step chunk would take ~125 s to drain instead
|
||||
of ~1.7 s. Track real elapsed time between dispatches and pop
|
||||
``round(elapsed * hz)`` entries, sending the most recent one. The
|
||||
skipped intermediate joint targets are stale anyway — the dynamixel
|
||||
will smooth toward the latest goal position.
|
||||
"""
|
||||
|
||||
robot_executor: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=50.0))
|
||||
_last_dispatch_t: float | None = field(default=None, init=False)
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
import time as _time # noqa: PLC0415
|
||||
|
||||
# ``/vlm`` mode pauses dispatch — the robot holds its last
|
||||
# commanded position while the operator runs VQA.
|
||||
if state.get("mode", "action") != "action":
|
||||
self._last_dispatch_t = None
|
||||
return None
|
||||
|
||||
queue = state.get("action_queue")
|
||||
if not queue:
|
||||
# Reset wall-clock anchor when the queue is empty so the
|
||||
# next chunk doesn't see a huge fake "elapsed" window.
|
||||
self._last_dispatch_t = None
|
||||
return None
|
||||
|
||||
now = _time.monotonic()
|
||||
hz = getattr(self.trigger, "hz", 30.0)
|
||||
if self._last_dispatch_t is None or hz <= 0:
|
||||
n_to_pop = 1
|
||||
else:
|
||||
elapsed = now - self._last_dispatch_t
|
||||
# ``max(1, ...)`` so we always pop at least one when the
|
||||
# trigger fires; ``min(len(queue), ...)`` so we don't run
|
||||
# off the end of the chunk.
|
||||
n_to_pop = max(1, min(len(queue), int(round(elapsed * hz))))
|
||||
self._last_dispatch_t = now
|
||||
|
||||
# Drain ``n_to_pop`` stale entries, keep only the latest as the
|
||||
# action actually sent. The intermediate joint targets would
|
||||
# all be ~10–30 ms apart in chunk time — the robot can't track
|
||||
# them individually anyway when the host loop is slow.
|
||||
latest = None
|
||||
for _ in range(n_to_pop):
|
||||
if not queue:
|
||||
break
|
||||
latest = queue.popleft() if hasattr(queue, "popleft") else queue.pop(0)
|
||||
state["actions_dispatched"] = state.get("actions_dispatched", 0) + 1
|
||||
|
||||
if latest is not None and self.robot_executor is not None:
|
||||
self.robot_executor(latest)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# High-level (text) paths — all use policy.select_message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_text_batch(
|
||||
policy: Any,
|
||||
prompt_messages: list[dict[str, Any]],
|
||||
*,
|
||||
add_generation_prompt: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Tokenize chat messages into the batch ``select_message`` expects.
|
||||
|
||||
PI052's backbone (PaliGemma) ships no chat template, so we train on
|
||||
a plain role-prefixed concatenation built by
|
||||
``PI052TextTokenizerStep``. We reuse that exact formatter so the
|
||||
inference prefix matches training; ``add_generation_prompt`` appends
|
||||
the bare ``Assistant: `` header the LM head continues from.
|
||||
"""
|
||||
import torch # noqa: PLC0415
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: PLC0415
|
||||
_flatten_say_tool_calls,
|
||||
_format_messages,
|
||||
_strip_blocks,
|
||||
register_paligemma_loc_tokens,
|
||||
)
|
||||
|
||||
tok_name = (
|
||||
getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224"
|
||||
)
|
||||
# Register PaliGemma's <locDDDD> tokens so inference encoding /
|
||||
# decoding sees them as single vocab ids — must match training.
|
||||
tokenizer = register_paligemma_loc_tokens(AutoTokenizer.from_pretrained(tok_name))
|
||||
|
||||
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in prompt_messages]
|
||||
prompt, _spans = _format_messages(messages)
|
||||
if add_generation_prompt:
|
||||
prompt = prompt + "Assistant: "
|
||||
|
||||
encoded = tokenizer(prompt, return_tensors="pt")
|
||||
ids = encoded["input_ids"]
|
||||
attn = encoded.get("attention_mask")
|
||||
if attn is None and tokenizer.pad_token_id is not None:
|
||||
attn = ids != tokenizer.pad_token_id
|
||||
if attn is not None and hasattr(attn, "dtype") and attn.dtype != torch.bool:
|
||||
attn = attn.bool()
|
||||
|
||||
# Move tokens onto the policy's device — otherwise prefix embedding
|
||||
# raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA
|
||||
# model), which the caller's broad except would swallow silently.
|
||||
device = getattr(getattr(policy, "config", None), "device", None)
|
||||
if device is not None:
|
||||
try:
|
||||
ids = ids.to(device)
|
||||
if attn is not None and hasattr(attn, "to"):
|
||||
attn = attn.to(device)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("could not move pi052 lang tokens to %s: %s", device, exc)
|
||||
return {"lang_tokens": ids, "lang_masks": attn, "tokenizer": tokenizer}
|
||||
|
||||
|
||||
def _strip_recipe_keys(m: dict[str, Any]) -> dict[str, Any]:
|
||||
new = dict(m)
|
||||
new.pop("stream", None)
|
||||
new.pop("target", None)
|
||||
return new
|
||||
|
||||
|
||||
@dataclass
|
||||
class HighLevelSubtaskFwd(InferenceStep):
|
||||
"""At ~1 Hz, ask the policy for the next subtask.
|
||||
|
||||
Mirrors the ``high_level_subtask`` recipe layout exactly:
|
||||
|
||||
user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}"
|
||||
user: "Current subtask: ${subtask}" (if subtask present)
|
||||
↓ generate ↓
|
||||
assistant: <next subtask>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
"""Same shape as ``LowLevelForward.observation_provider``. When
|
||||
set, the resulting observation is merged into ``select_message``'s
|
||||
batch so text generation runs against real video + state."""
|
||||
|
||||
trigger: Trigger = field(default_factory=lambda: HzTrigger(hz=1.0))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or not state.get("task"):
|
||||
return None
|
||||
# ``/vlm`` mode pauses subtask generation along with the rest of
|
||||
# the action loop.
|
||||
if state.get("mode", "action") != "action":
|
||||
return None
|
||||
# Gate to chunk boundaries: only generate a fresh subtask when
|
||||
# the action queue is empty (i.e. right before LowLevelForward
|
||||
# refreshes the chunk). ``select_message`` takes ~2 s on MPS,
|
||||
# and running it every loop iteration starves DispatchAction
|
||||
# at ctrl_hz=30 — the queue drains at ~0.4 actions/sec instead
|
||||
# of 30/sec and the robot barely moves. Tying it to the same
|
||||
# "queue empty" condition as the chunk refresh produces a
|
||||
# clean sense → think → act cycle.
|
||||
#
|
||||
# Rearm the trigger when skipping so a low-hz schedule
|
||||
# (e.g. ``--high_level_hz=0.2`` = once per 5 s) doesn't lose
|
||||
# the slot: the trigger fires once on the timer but the brief
|
||||
# queue-empty window almost never coincides, so without rearm
|
||||
# HL would effectively never run.
|
||||
queue = state.get("action_queue") or []
|
||||
if len(queue) > 0:
|
||||
if hasattr(self.trigger, "rearm"):
|
||||
self.trigger.rearm()
|
||||
return None
|
||||
# Per-chunk-boundary throttle: at each "queue empty" moment we
|
||||
# increment a counter; subtask gen only fires once the counter
|
||||
# reaches ``subtask_chunks_per_gen``. Lets the operator run e.g.
|
||||
# 5 action chunks per subtask-gen so the LM head doesn't churn
|
||||
# every 1.7 s (a fresh subtask while the previous one is still
|
||||
# being executed is wasted compute *and* causes the action
|
||||
# expert's flow trajectory to be re-planned mid-grasp).
|
||||
chunks_per_gen = max(1, int(state.get("subtask_chunks_per_gen", 1) or 1))
|
||||
# Initialise so the first chunk boundary fires immediately
|
||||
# (counter starts at chunks_per_gen, decrements per skip,
|
||||
# generates and resets when it hits 0).
|
||||
if "_hl_chunks_until_gen" not in state:
|
||||
state["_hl_chunks_until_gen"] = 0
|
||||
if state["_hl_chunks_until_gen"] > 0:
|
||||
state["_hl_chunks_until_gen"] -= 1
|
||||
if hasattr(self.trigger, "rearm"):
|
||||
self.trigger.rearm()
|
||||
return None
|
||||
state["_hl_chunks_until_gen"] = chunks_per_gen - 1
|
||||
ctx = _msgs_for_subtask(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
# Default: greedy argmax, no min_new_tokens, no special-token
|
||||
# suppression — matches training. Operator can override via
|
||||
# ``--text_min_new_tokens=N --text_temperature=T --text_top_p=P``
|
||||
# on the CLI; useful for under-trained checkpoints whose LM
|
||||
# head still favours EOS at position 0 (pre-trained chat
|
||||
# backbone's short-turn prior hasn't been fully overridden
|
||||
# by the fine-tuning supervision yet).
|
||||
msg = _generate_with_policy(
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="subtask gen",
|
||||
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
|
||||
temperature=float(state.get("text_gen_temperature") or 0.0),
|
||||
top_p=float(state.get("text_gen_top_p") or 1.0),
|
||||
# Subtasks never legitimately contain PaliGemma ``<loc>``
|
||||
# tokens — suppress them so a checkpoint whose LM head
|
||||
# has drifted toward the pretrained loc-prior falls back
|
||||
# to its (still-correct) text mass.
|
||||
suppress_loc_tokens=True,
|
||||
)
|
||||
# Diagnostics: surface what the model is *actually* producing
|
||||
# at chunk boundaries, even when the output gets rejected or
|
||||
# repeats. Memorisation collapse looks like "same accepted
|
||||
# subtask N times in a row" or "gibberish_count rising while
|
||||
# current_subtask is stuck". The state panel renders these.
|
||||
state["last_subtask_raw"] = msg or ""
|
||||
# Persistent empty completion is its own failure mode (model
|
||||
# immediately EOS-es from the chat-template generation
|
||||
# prompt) — surface it once every N occurrences so the
|
||||
# operator can distinguish "generation failing silently"
|
||||
# from "generating fine but filter rejecting".
|
||||
if not msg:
|
||||
empties = state.get("subtask_empty_count", 0) + 1
|
||||
state["subtask_empty_count"] = empties
|
||||
if empties == 1 or empties % 5 == 0:
|
||||
debug = getattr(self.policy, "_last_select_message_debug", "") or ""
|
||||
if debug:
|
||||
push_log(
|
||||
state,
|
||||
f" [info] subtask gen empty (×{empties}); {debug}",
|
||||
)
|
||||
else:
|
||||
push_log(
|
||||
state,
|
||||
f" [info] subtask gen returned empty (×{empties}) — "
|
||||
"no tokens generated (head EOS-ing before any "
|
||||
"non-special token).",
|
||||
)
|
||||
if msg and _looks_like_gibberish(msg):
|
||||
# Bump a counter so the operator can see the model is
|
||||
# struggling without spamming the log every tick. A first
|
||||
# rejection still logs once so the failure is visible.
|
||||
count = state.get("subtask_gibberish_count", 0) + 1
|
||||
state["subtask_gibberish_count"] = count
|
||||
if count == 1 or count % 30 == 0:
|
||||
push_log(
|
||||
state,
|
||||
f" [info] subtask gen rejected (gibberish ×{count}): {msg[:60]!r}",
|
||||
)
|
||||
return None
|
||||
if msg:
|
||||
prev_subtask = state.get("current_subtask")
|
||||
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
|
||||
if changed:
|
||||
# Stash the just-completed subtask so ``MemoryUpdateFwd``
|
||||
# can drop it into its prompt as ``Completed subtask:``
|
||||
# — the recipe binds ``completed_subtask`` to
|
||||
# ``nth_prev(style=subtask, offset=1)``, i.e. the subtask
|
||||
# that was active *before* the change.
|
||||
if prev_subtask:
|
||||
state["prior_subtask"] = prev_subtask
|
||||
# Subtask change is a downstream trigger.
|
||||
state.setdefault("events_this_tick", []).append("subtask_change")
|
||||
state["subtask_repeat_count"] = 0
|
||||
else:
|
||||
# Same accepted string regenerated — memorisation tell.
|
||||
# Once this counter climbs past a few, you're seeing
|
||||
# the model unable to move past the current subtask
|
||||
# despite the chunk having drained (visual scene may
|
||||
# have changed but the LM is replaying training
|
||||
# tokens).
|
||||
state["subtask_repeat_count"] = (
|
||||
state.get("subtask_repeat_count", 0) + 1
|
||||
)
|
||||
# Silently skip empty completions — common when the model
|
||||
# warms up or generates only EOS; logging it every tick at
|
||||
# ctrl_hz is just noise.
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryUpdateFwd(InferenceStep):
|
||||
"""On subtask boundary, refresh the compressed memory.
|
||||
|
||||
Mirrors the ``memory_update`` recipe layout exactly:
|
||||
|
||||
user: "${task}"
|
||||
assistant: "Previous memory: ${prior_memory}" (if prior memory)
|
||||
user: "Completed subtask: ${completed_subtask}" (if subtask)
|
||||
↓ generate ↓
|
||||
assistant: <new memory>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("subtask_change"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# Don't consume the event — multiple steps may want to react.
|
||||
if self.policy is None:
|
||||
return None
|
||||
ctx = _msgs_for_memory(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
new_memory = _generate_with_policy(
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="memory gen",
|
||||
suppress_loc_tokens=True,
|
||||
)
|
||||
state["last_memory_raw"] = new_memory or ""
|
||||
if new_memory and _looks_like_gibberish(new_memory):
|
||||
count = state.get("memory_gibberish_count", 0) + 1
|
||||
state["memory_gibberish_count"] = count
|
||||
push_log(
|
||||
state,
|
||||
f" [info] memory gen rejected (gibberish ×{count}): {new_memory[:60]!r}",
|
||||
)
|
||||
return None
|
||||
if new_memory:
|
||||
set_if_changed(state, "current_memory", new_memory, label="memory")
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserInterjectionFwd(InferenceStep):
|
||||
"""On stdin interjection, refresh the plan + emit a paired ``say``.
|
||||
|
||||
Mirrors the ``user_interjection_response`` recipe layout exactly:
|
||||
|
||||
user: "${task}"
|
||||
assistant: "Previous plan:\\n${prior_plan}" (if prior plan)
|
||||
user: "${interjection}" (the new utterance)
|
||||
↓ generate ↓
|
||||
assistant: <plan + <say>...</say>>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_interjection"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or not take_event(state, "user_interjection"):
|
||||
return None
|
||||
ctx = _msgs_for_interjection(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
out = _generate_with_policy(
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="plan/say gen",
|
||||
suppress_loc_tokens=True,
|
||||
)
|
||||
if not out:
|
||||
# Don't log every empty completion — happens repeatedly on
|
||||
# MPS during warm-up and floods the panel. The user can
|
||||
# re-trigger by typing again.
|
||||
return None
|
||||
if _looks_like_gibberish(out):
|
||||
count = state.get("plan_gibberish_count", 0) + 1
|
||||
state["plan_gibberish_count"] = count
|
||||
push_log(
|
||||
state,
|
||||
f" [info] plan/say gen rejected (gibberish ×{count}): {out[:60]!r}",
|
||||
)
|
||||
return None
|
||||
# Heuristic split: model is trained to emit one assistant turn
|
||||
# carrying both plan text AND a `say` tool call. Look for a
|
||||
# "<say>...</say>" or "say(...)" marker; fall back to whole
|
||||
# text → plan, no speech.
|
||||
plan_text, speech_text = _split_plan_and_say(out)
|
||||
if plan_text and _looks_like_gibberish(plan_text):
|
||||
plan_text = ""
|
||||
if plan_text:
|
||||
set_if_changed(state, "current_plan", plan_text, label="plan")
|
||||
if speech_text:
|
||||
push_log(state, f" speech: {speech_text}")
|
||||
state.setdefault("tool_calls_pending", []).append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "say", "arguments": {"text": speech_text}},
|
||||
}
|
||||
)
|
||||
state.setdefault("events_this_tick", []).append("tool_call_pending")
|
||||
# Mark interjection consumed.
|
||||
state["recent_interjection"] = None
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AskVQAFwd(InferenceStep):
|
||||
"""On stdin question, answer a frame-grounded VQA.
|
||||
|
||||
Mirrors the ``ask_vqa_*`` recipe layout exactly: a single user
|
||||
turn carrying just the VQA question, plus the camera image block
|
||||
in training (we drop the image at inference because the dataset's
|
||||
image preprocessing doesn't match SmolVLM's vision tower input).
|
||||
|
||||
user: <question>
|
||||
↓ generate ↓
|
||||
assistant: <vqa answer>
|
||||
"""
|
||||
|
||||
policy: Any = None
|
||||
observation_provider: Any = None
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("user_vqa_query"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if self.policy is None or not take_event(state, "user_vqa_query"):
|
||||
return None
|
||||
question = state.get("recent_vqa_query")
|
||||
if not question:
|
||||
return None
|
||||
ctx = _msgs_for_vqa(question)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
answer = _generate_with_policy(
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="vqa gen",
|
||||
)
|
||||
# VQA answers are intentionally JSON-like during training, so
|
||||
# ``_looks_like_gibberish`` would false-positive on them. Keep
|
||||
# the answer as-is — the VQA panel line lets the user judge.
|
||||
if answer:
|
||||
push_log(state, f" vqa: {answer}")
|
||||
state["recent_vqa_query"] = None
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class DispatchToolCalls(InferenceStep):
|
||||
"""Pop ``tool_calls_pending`` and execute them via :data:`TOOL_REGISTRY`."""
|
||||
|
||||
tools: dict[str, Any] = field(default_factory=dict)
|
||||
trigger: Trigger = field(default_factory=lambda: EventTrigger("tool_call_pending"))
|
||||
|
||||
def run(self, state: dict[str, Any]) -> dict[str, Any] | None:
|
||||
take_event(state, "tool_call_pending")
|
||||
pending = state.get("tool_calls_pending") or []
|
||||
for call in pending:
|
||||
try:
|
||||
fn = (call or {}).get("function") or {}
|
||||
name = fn.get("name")
|
||||
args = fn.get("arguments") or {}
|
||||
tool = self.tools.get(name)
|
||||
if tool is None:
|
||||
push_log(state, f" [warn] tool {name!r} not registered — skipping call")
|
||||
continue
|
||||
tool.call(args)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
push_log(state, f" [error] tool dispatch failed: {exc}")
|
||||
state["tool_calls_pending"] = []
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _looks_like_gibberish(text: str) -> bool:
|
||||
"""Heuristically detect generation that's clearly off the rails.
|
||||
|
||||
Memorised models can collapse to dominant-mode outputs when the
|
||||
prompt drifts even slightly from training distribution. Reject:
|
||||
|
||||
* empty / whitespace-only
|
||||
* too few alphabetic characters (mostly punctuation)
|
||||
* a single character repeated past the threshold
|
||||
* starts with ``":"`` and contains no letters
|
||||
* too few unique tokens — e.g. ``"the"``, ``"the the the"``,
|
||||
``"Ass\\n::\\nthe"`` (the collapse seen on real-robot frames
|
||||
where the model emits one or two memorised tokens repeatedly)
|
||||
* chat-template fragment leakage (``Assistant:``, ``User:``,
|
||||
``Ass\\n``)
|
||||
|
||||
Real subtasks look like ``"close the gripper to grasp the blue
|
||||
cube"`` — multiple unique alphabetic tokens, no role-marker
|
||||
fragments. Anything materially shorter than that is rejected.
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return True
|
||||
stripped = text.strip()
|
||||
alpha = sum(1 for c in stripped if c.isalpha())
|
||||
if alpha < max(3, len(stripped) // 8):
|
||||
return True
|
||||
if stripped.startswith('":') and stripped.count('"') > stripped.count(" "):
|
||||
return True
|
||||
# Single repeating char: e.g. ``""""""``.
|
||||
if len(set(stripped)) <= 2 and len(stripped) > 4:
|
||||
return True
|
||||
# Chat-template fragment leakage — the model emits ``Ass``,
|
||||
# ``Assistant:``, ``User:``, often with extra newlines/colons.
|
||||
# Reject if the cleaned text is mostly role-marker shards.
|
||||
cleaned = stripped.replace("\n", " ").replace(":", " ")
|
||||
for marker in ("Assistant", "User", "Ass "):
|
||||
if marker in cleaned and len(cleaned.split()) < 4:
|
||||
return True
|
||||
tokens = [t for t in cleaned.split() if any(c.isalpha() for c in t)]
|
||||
unique_alpha = {t.lower() for t in tokens}
|
||||
# Short degenerate output — model stuck on ``the`` or a couple of
|
||||
# memorised single-token continuations.
|
||||
if len(unique_alpha) < 3 and len(stripped) < 80:
|
||||
return True
|
||||
# Long repetition collapse — the LM head loops an n-gram for the
|
||||
# whole generation budget ("the arm the arm … the the the the").
|
||||
# Length-independent: many tokens but a tiny unique ratio. The
|
||||
# earlier ``< 80`` check missed these because the looped string
|
||||
# blows well past 80 chars.
|
||||
if len(tokens) >= 8 and len(unique_alpha) <= max(3, len(tokens) // 10):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _control_context_messages(
|
||||
state: dict[str, Any],
|
||||
*,
|
||||
include_completed: bool = False,
|
||||
extra_user: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a chat-template-ready prompt from current runtime state.
|
||||
|
||||
Mirrors what ``subtasks_vqa.yaml`` renders into ``${task}\nPlan:
|
||||
${plan}\nMemory: ${memory}`` for the high-level branches.
|
||||
"""
|
||||
# Always emit ``Plan: `` / ``Memory: `` labels — even with empty
|
||||
# values — to mirror the training-time recipe substitution.
|
||||
task = state.get("task") or ""
|
||||
plan = state.get("current_plan") or ""
|
||||
memory = state.get("current_memory") or ""
|
||||
parts = [task, f"Plan: {plan}", f"Memory: {memory}"]
|
||||
if include_completed and state.get("current_subtask"):
|
||||
parts.append(f"Completed subtask: {state['current_subtask']}")
|
||||
head = "\n".join(parts)
|
||||
msgs: list[dict[str, Any]] = [{"role": "user", "content": head}]
|
||||
if extra_user:
|
||||
msgs.append({"role": "user", "content": extra_user})
|
||||
return msgs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-recipe prompt builders. Each one mirrors a single sub-recipe's
|
||||
# message layout in ``subtasks_vqa.yaml`` so the chat-templated
|
||||
# prompt at inference matches what the model saw during training.
|
||||
# Generic ``_control_context_messages`` is kept around as a fallback
|
||||
# for ad-hoc callers but the four high-level steps now use these.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _hirobot_user_head(state: dict[str, Any]) -> str:
|
||||
"""Build the ``task\\nPlan: …\\nMemory: …`` user content string.
|
||||
|
||||
Mirrors what the recipe renders at training time, where
|
||||
``language_render._substitute`` substitutes empty strings for
|
||||
missing ``${plan}`` / ``${memory}`` bindings — i.e. the
|
||||
``Plan: `` / ``Memory: `` prefix labels are *always* in the
|
||||
user turn, even when their values aren't set yet. Skipping them
|
||||
here (the previous behaviour) produced a different prompt shape
|
||||
on early frames before plan / memory are populated and on
|
||||
samples where the dataset has no plan / memory annotation.
|
||||
"""
|
||||
task = state.get("task") or ""
|
||||
plan = state.get("current_plan") or ""
|
||||
memory = state.get("current_memory") or ""
|
||||
return f"{task}\nPlan: {plan}\nMemory: {memory}"
|
||||
|
||||
|
||||
def _msgs_for_subtask(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""``high_level_subtask`` recipe layout — predict the subtask from the
|
||||
task. The v-current recipe's user turn is just ``${task}`` (plan and
|
||||
memory are not trained), so the inference prompt is the bare task —
|
||||
no ``Plan: `` / ``Memory: `` lines.
|
||||
"""
|
||||
return [{"role": "user", "content": state.get("task") or ""}]
|
||||
|
||||
|
||||
def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""Memory-update prompt — mirrors ``memory_update`` recipe layout.
|
||||
|
||||
Recipe layout (``subtask_mem.yaml``):
|
||||
|
||||
user: "${task}"
|
||||
assistant: "Previous memory: ${prior_memory}" (if_present prior)
|
||||
user: "Completed subtask: ${completed}" (if_present completed)
|
||||
assistant: → predicts new memory
|
||||
|
||||
Fired by ``MemoryUpdateFwd`` on a ``subtask_change`` event:
|
||||
``state['current_memory']`` is the memory the policy last emitted
|
||||
(= the ``prior_memory`` binding at training), and
|
||||
``state['prior_subtask']`` is the subtask that just got replaced
|
||||
(= the ``completed_subtask`` binding at training).
|
||||
"""
|
||||
msgs: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": state.get("task") or ""},
|
||||
]
|
||||
prior_memory = state.get("current_memory")
|
||||
if prior_memory:
|
||||
msgs.append(
|
||||
{"role": "assistant", "content": f"Previous memory: {prior_memory}"}
|
||||
)
|
||||
completed_subtask = state.get("prior_subtask")
|
||||
if completed_subtask:
|
||||
msgs.append(
|
||||
{"role": "user", "content": f"Completed subtask: {completed_subtask}"}
|
||||
)
|
||||
return msgs
|
||||
|
||||
|
||||
def _msgs_for_interjection(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""``user_interjection_response`` recipe layout."""
|
||||
msgs: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": state.get("task") or ""}
|
||||
]
|
||||
if state.get("current_plan"):
|
||||
msgs.append(
|
||||
{"role": "assistant", "content": f"Previous plan:\n{state['current_plan']}"}
|
||||
)
|
||||
interjection = state.get("recent_interjection")
|
||||
if interjection:
|
||||
msgs.append({"role": "user", "content": interjection})
|
||||
return msgs
|
||||
|
||||
|
||||
def _msgs_for_plan(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""``plan_generation`` recipe layout — bare task → plan.
|
||||
|
||||
The assistant turn is the generation target, so we only render
|
||||
the user turn at inference; the runtime appends the predicted
|
||||
plan after sampling.
|
||||
"""
|
||||
return [{"role": "user", "content": state.get("task") or ""}]
|
||||
|
||||
|
||||
def _msgs_for_vqa(question: str) -> list[dict[str, Any]]:
|
||||
"""``ask_vqa_*`` recipe layout (text-only at inference)."""
|
||||
return [{"role": "user", "content": question}]
|
||||
|
||||
|
||||
def _maybe_observation(provider: Any) -> dict | None:
|
||||
"""Pull one observation from ``provider`` if it's set, else ``None``.
|
||||
|
||||
Errors from the provider are logged at debug level and swallowed —
|
||||
text generation still runs (in text-only mode) so a flaky frame
|
||||
source doesn't kill the REPL.
|
||||
"""
|
||||
if provider is None:
|
||||
return None
|
||||
try:
|
||||
return provider()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("observation_provider raised %s — falling back to text-only", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _generate_with_policy(
|
||||
policy: Any,
|
||||
messages: list[dict[str, Any]],
|
||||
*,
|
||||
observation: dict | None = None,
|
||||
state: dict[str, Any] | None = None,
|
||||
label: str = "select_message",
|
||||
min_new_tokens: int = 0,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
suppress_loc_tokens: bool = False,
|
||||
) -> str:
|
||||
"""Drive ``policy.select_message`` with a chat batch (and optional obs).
|
||||
|
||||
When ``observation`` carries ``observation.images.*`` and
|
||||
``observation.state``, those are merged into the batch so
|
||||
``select_message`` runs the same VLM prefix the policy was trained
|
||||
on. Without an observation the runtime falls back to a text-only
|
||||
prompt — the text head still runs, but generations may drift from
|
||||
the training distribution.
|
||||
|
||||
Failures are surfaced both to the module logger (``warning``) and,
|
||||
when ``state`` is given, to the runtime's user-visible log via
|
||||
:func:`push_log`, so the REPL no longer "looks dead" when
|
||||
something goes wrong inside generation.
|
||||
"""
|
||||
if not hasattr(policy, "select_message"):
|
||||
if state is not None:
|
||||
push_log(state, f" [warn] policy has no select_message — skipping {label}")
|
||||
return ""
|
||||
text_batch = _build_text_batch(policy, messages)
|
||||
try:
|
||||
from lerobot.utils.constants import ( # noqa: PLC0415
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
)
|
||||
|
||||
batch: dict[str, Any] = {
|
||||
OBS_LANGUAGE_TOKENS: text_batch["lang_tokens"],
|
||||
OBS_LANGUAGE_ATTENTION_MASK: text_batch["lang_masks"],
|
||||
}
|
||||
if observation:
|
||||
for k, v in observation.items():
|
||||
if isinstance(k, str) and k.startswith("observation.") and k not in batch:
|
||||
batch[k] = v
|
||||
kwargs: dict[str, Any] = {
|
||||
"tokenizer": text_batch["tokenizer"],
|
||||
"min_new_tokens": min_new_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
kwargs["suppress_loc_tokens"] = suppress_loc_tokens
|
||||
return policy.select_message(batch, **kwargs)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG))
|
||||
if state is not None:
|
||||
push_log(state, f" [warn] {label} failed: {type(exc).__name__}: {exc}")
|
||||
return ""
|
||||
|
||||
|
||||
_SAY_RE = re.compile(r"<\s*say\s*>(.*?)<\s*/\s*say\s*>", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
|
||||
def _split_plan_and_say(text: str) -> tuple[str, str]:
|
||||
"""Pull a ``<say>...</say>`` snippet out of ``text``; remainder is plan.
|
||||
|
||||
The training-time tool-call serializer wraps ``say(text="…")`` in a
|
||||
deterministic textual marker so prefix-LM-style training learns to
|
||||
emit it. The runtime parses it back here. If no marker is present,
|
||||
the entire text is treated as plan with no speech.
|
||||
"""
|
||||
if not text:
|
||||
return "", ""
|
||||
match = _SAY_RE.search(text)
|
||||
if not match:
|
||||
return text.strip(), ""
|
||||
speech = match.group(1).strip().strip('"').strip("'")
|
||||
plan = (text[: match.start()] + text[match.end() :]).strip()
|
||||
return plan, speech
|
||||
@@ -1,134 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Trigger primitives for PI052's multi-rate inference runtime.
|
||||
|
||||
Mirrors the plan's Section "Runtime orchestration": each
|
||||
``InferenceStep`` is gated by a :class:`Trigger` that decides per tick
|
||||
whether the step fires. Two trigger flavours cover all the cadences
|
||||
the canonical recipe needs:
|
||||
|
||||
* :class:`HzTrigger` for periodic beats (action chunks at ~3-5 Hz,
|
||||
high-level subtask generation at ~1 Hz, action dispatch at ~50 Hz)
|
||||
* :class:`EventTrigger` for one-shot reactions (subtask boundary →
|
||||
memory update; user interjection → plan refresh; user VQA query →
|
||||
vqa answer; pending tool call → dispatcher)
|
||||
|
||||
Triggers are stateless except for ``HzTrigger``'s last-fire timestamp.
|
||||
The runtime stores the :class:`Tick` clock as ``state["_tick"]`` so
|
||||
every step shares a single time source.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tick:
|
||||
"""Single tick from :class:`TickClock`. Carries time references the
|
||||
runtime steps consume to gate themselves."""
|
||||
|
||||
index: int
|
||||
"""Monotonic counter — increments by one per tick."""
|
||||
|
||||
monotonic_seconds: float
|
||||
"""``time.monotonic()`` at the start of this tick."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TickClock:
|
||||
"""Drives the runtime loop at up to ``max_rate_hz``.
|
||||
|
||||
Sleeps just enough between :meth:`advance` calls to enforce the
|
||||
rate. With ``max_rate_hz=50`` the loop wakes ~every 20ms; the
|
||||
higher-level ``HzTrigger`` slices that timeline into sub-cadences.
|
||||
"""
|
||||
|
||||
max_rate_hz: float = 50.0
|
||||
_index: int = field(default=0, init=False)
|
||||
_last_seconds: float | None = field(default=None, init=False)
|
||||
|
||||
def advance(self) -> Tick:
|
||||
period = 1.0 / max(self.max_rate_hz, 0.1)
|
||||
now = time.monotonic()
|
||||
if self._last_seconds is not None:
|
||||
sleep_for = (self._last_seconds + period) - now
|
||||
if sleep_for > 0:
|
||||
time.sleep(sleep_for)
|
||||
now = time.monotonic()
|
||||
self._last_seconds = now
|
||||
self._index += 1
|
||||
return Tick(index=self._index, monotonic_seconds=now)
|
||||
|
||||
|
||||
class Trigger(Protocol):
|
||||
"""Decide whether the next ``InferenceStep`` should fire."""
|
||||
|
||||
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class HzTrigger:
|
||||
"""Fire at most ``hz`` times per second.
|
||||
|
||||
A step that gates further (e.g. ``HighLevelSubtaskFwd`` skipping
|
||||
when the action queue is non-empty) and wants the trigger to
|
||||
retry next tick instead of waiting a full period can call
|
||||
:meth:`rearm` from inside ``run``. Without this, a low-hz trigger
|
||||
(e.g. ``hz=0.2`` = once per 5 s) almost never coincides with the
|
||||
brief queue-empty window and the step never fires at all.
|
||||
"""
|
||||
|
||||
hz: float
|
||||
_last_seconds: float | None = field(default=None, init=False)
|
||||
|
||||
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
|
||||
period = 1.0 / max(self.hz, 1e-6)
|
||||
if self._last_seconds is None or (tick.monotonic_seconds - self._last_seconds) >= period:
|
||||
self._last_seconds = tick.monotonic_seconds
|
||||
return True
|
||||
return False
|
||||
|
||||
def rearm(self) -> None:
|
||||
"""Mark the trigger as not having fired, so the next tick re-evaluates.
|
||||
|
||||
Used by a step that decided to skip after ``should_fire`` already
|
||||
committed the firing — keeps the cadence honest without losing
|
||||
the slot.
|
||||
"""
|
||||
self._last_seconds = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventTrigger:
|
||||
"""Fire when ``event_name`` is in ``state["events_this_tick"]``.
|
||||
|
||||
The runtime fills ``events_this_tick`` once per tick from:
|
||||
|
||||
* stdin / network input (``user_interjection``, ``user_vqa_query``,
|
||||
``stop``)
|
||||
* internal state transitions (``subtask_change``,
|
||||
``tool_call_pending``)
|
||||
|
||||
The list is consumed (cleared at the end of the tick) so events
|
||||
fire at most once.
|
||||
"""
|
||||
|
||||
event_name: str
|
||||
|
||||
def should_fire(self, tick: Tick, state: dict[str, Any]) -> bool:
|
||||
events: list[str] = state.get("events_this_tick") or []
|
||||
return self.event_name in events
|
||||
@@ -1,127 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Rich-based REPL layout for the PI052 runtime.
|
||||
|
||||
Two-zone terminal layout:
|
||||
|
||||
[chat scrollback — user messages / robot responses, scrolls naturally]
|
||||
|
||||
┌── State ──────────────────────────────────────────┐
|
||||
│ task please clean up the kitchen │
|
||||
│ subtask grasp the handle of the sponge │
|
||||
│ plan 1. grasp sponge 2. wipe 3. tidy │
|
||||
│ memory sponge picked up; counter still dirty │
|
||||
└───────────────────────────────────────────────────┘
|
||||
> _
|
||||
|
||||
The state panel re-renders on every state change. Chat lines are
|
||||
``console.print``'d above the live region so they accumulate naturally
|
||||
in scrollback. Implemented with :class:`rich.live.Live` plus
|
||||
:func:`rich.console.Console.input` for the prompt — when an input is
|
||||
pending, ``rich.Live`` auto-suspends so the input doesn't fight the
|
||||
panel for cursor position.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
try: # rich is optional; only required for the interactive REPL.
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
_HAS_RICH = True
|
||||
except ImportError: # pragma: no cover
|
||||
_HAS_RICH = False
|
||||
Console = Any # type: ignore[assignment]
|
||||
Panel = Any # type: ignore[assignment]
|
||||
Table = Any # type: ignore[assignment]
|
||||
Text = Any # type: ignore[assignment]
|
||||
|
||||
|
||||
_STATE_KEYS = (
|
||||
("task", "task"),
|
||||
("current_subtask", "subtask"),
|
||||
("current_plan", "plan"),
|
||||
("current_memory", "memory"),
|
||||
)
|
||||
|
||||
|
||||
def make_state_panel(state: dict[str, Any]) -> Any:
|
||||
"""Render the persistent state panel for the live region.
|
||||
|
||||
Returns a :class:`rich.panel.Panel`. Caller passes it to
|
||||
``Live.update(panel)`` whenever the state changes.
|
||||
"""
|
||||
if not _HAS_RICH:
|
||||
raise RuntimeError(
|
||||
"rich is required for the interactive REPL. "
|
||||
"`pip install rich` (it's a transitive dep of lerobot)."
|
||||
)
|
||||
table = Table.grid(padding=(0, 2), expand=True)
|
||||
table.add_column(justify="right", style="dim", no_wrap=True, width=10)
|
||||
table.add_column(justify="left")
|
||||
for key, label in _STATE_KEYS:
|
||||
value = state.get(key)
|
||||
if value is None:
|
||||
rendered = Text("(not set)", style="dim italic")
|
||||
else:
|
||||
rendered = Text(str(value), style="bold")
|
||||
table.add_row(label, rendered)
|
||||
queue = state.get("action_queue")
|
||||
queue_len = len(queue) if hasattr(queue, "__len__") else 0
|
||||
pending = state.get("tool_calls_pending") or []
|
||||
footer = Text.assemble(
|
||||
("queued actions: ", "dim"),
|
||||
(str(queue_len), "bold cyan"),
|
||||
(" pending tool calls: ", "dim"),
|
||||
(str(len(pending)), "bold magenta"),
|
||||
)
|
||||
table.add_row("", footer)
|
||||
run_mode = state.get("mode", "action")
|
||||
mode_tag = (
|
||||
"[green]action[/]" if run_mode == "action" else "[yellow]paused[/]"
|
||||
)
|
||||
return Panel(
|
||||
table,
|
||||
title=f"[bold]PI052 state[/] · mode: {mode_tag}",
|
||||
border_style="cyan",
|
||||
)
|
||||
|
||||
|
||||
def print_user_line(console: Any, line: str) -> None:
|
||||
"""Append a user-typed line to the chat scrollback."""
|
||||
if not _HAS_RICH:
|
||||
print(f"you: {line}", flush=True)
|
||||
return
|
||||
console.print(f"[bold cyan]you:[/] {line}")
|
||||
|
||||
|
||||
def print_robot_lines(console: Any, lines: list[str]) -> None:
|
||||
"""Append robot/runtime log lines to the chat scrollback."""
|
||||
if not _HAS_RICH:
|
||||
for line in lines:
|
||||
print(f"robot: {line.lstrip()}", flush=True)
|
||||
return
|
||||
for line in lines:
|
||||
# The runtime uses leading whitespace + "label: text"; render
|
||||
# the label in green and the value in default for readability.
|
||||
stripped = line.lstrip()
|
||||
if ":" in stripped:
|
||||
label, _, value = stripped.partition(":")
|
||||
console.print(f"[bold green]robot[/] [dim]({label.strip()})[/] {value.strip()}")
|
||||
else:
|
||||
console.print(f"[bold green]robot:[/] {stripped}")
|
||||
@@ -1,423 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Interactive VQA for the PI052 runtime.
|
||||
|
||||
In ``/vlm`` mode a typed line is treated as a VQA question. This module
|
||||
runs the full interactive flow:
|
||||
|
||||
1. pull the current observation and list available cameras,
|
||||
2. ask the operator which camera to ground the question on,
|
||||
3. generate the answer with the VLM conditioned on that one camera,
|
||||
4. parse the JSON answer; if it carries a bounding box (``bbox``) or a
|
||||
point (``keypoint``), draw the overlay on the camera frame, save a
|
||||
PNG to ``./vqa_overlays/`` and auto-open it.
|
||||
|
||||
VQA answer schemas mirror the annotation pipeline's ``VQA_ANSWER_SHAPES``
|
||||
(see ``lerobot.annotations.steerable_pipeline.validator``):
|
||||
|
||||
* ``bbox`` — ``{"detections": [{"label", "bbox_format": "xyxy",
|
||||
"bbox": [x1, y1, x2, y2]}, ...]}``
|
||||
* ``keypoint`` — ``{"label", "point_format": "xy", "point": [x, y]}``
|
||||
* ``count`` / ``attribute`` / ``spatial`` — text-only, no overlay.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .runtime_state import push_log
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_IMAGE_PREFIX = "observation.images."
|
||||
|
||||
# PaliGemma detection / pointing vocabulary. PI052 trains spatial VQA
|
||||
# answers in this native ``<locNNNN>`` format (index in [0, 1023],
|
||||
# normalized to the image axis) instead of pixel-coordinate JSON, so the
|
||||
# answer string the runtime parses can be e.g.
|
||||
# ``<loc0512><loc0301> blue cube`` (point) or
|
||||
# ``<loc0100><loc0080><loc0400><loc0360> blue cube`` (box).
|
||||
_LOC_RE = re.compile(r"<loc(\d{1,4})>")
|
||||
|
||||
# Iteration order for shape matching — most specific keys first so an
|
||||
# answer is classified deterministically.
|
||||
_SHAPE_ORDER = ("bbox", "keypoint", "count", "attribute", "spatial")
|
||||
|
||||
_BBOX_COLOR = (255, 64, 64)
|
||||
_POINT_COLOR = (64, 220, 64)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Camera selection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def available_cameras(observation: dict | None) -> list[str]:
|
||||
"""Return the sorted ``observation.images.*`` keys present in ``observation``."""
|
||||
if not observation:
|
||||
return []
|
||||
return sorted(k for k in observation if isinstance(k, str) and k.startswith(_IMAGE_PREFIX))
|
||||
|
||||
|
||||
def camera_short_name(camera_key: str) -> str:
|
||||
"""Strip the ``observation.images.`` prefix for display."""
|
||||
return camera_key[len(_IMAGE_PREFIX) :] if camera_key.startswith(_IMAGE_PREFIX) else camera_key
|
||||
|
||||
|
||||
def prompt_camera_choice(
|
||||
cameras: list[str],
|
||||
*,
|
||||
input_fn: Any = input,
|
||||
print_fn: Any = print,
|
||||
) -> str | None:
|
||||
"""Ask the operator which camera frame to draw a VQA overlay on.
|
||||
|
||||
Accepts either the menu number or the (short or full) camera name.
|
||||
A single-camera setup auto-selects without prompting. Returns the
|
||||
chosen ``observation.images.*`` key, or ``None`` if the operator
|
||||
cancels / gives an invalid answer.
|
||||
"""
|
||||
if not cameras:
|
||||
return None
|
||||
if len(cameras) == 1:
|
||||
return cameras[0]
|
||||
print_fn("Draw the result on which camera?")
|
||||
for i, cam in enumerate(cameras, 1):
|
||||
print_fn(f" [{i}] {camera_short_name(cam)}")
|
||||
try:
|
||||
raw = str(input_fn("camera> ")).strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return None
|
||||
if not raw:
|
||||
return cameras[0]
|
||||
if raw.isdigit():
|
||||
idx = int(raw) - 1
|
||||
return cameras[idx] if 0 <= idx < len(cameras) else None
|
||||
for cam in cameras:
|
||||
if raw == cam or raw == camera_short_name(cam):
|
||||
return cam
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Answer parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _loc_to_norm(idx: int) -> float:
|
||||
"""PaliGemma ``<locNNNN>`` index → normalized [0, 1] axis coordinate."""
|
||||
return max(0.0, min(1023.0, float(idx))) / 1023.0
|
||||
|
||||
|
||||
def parse_loc_answer(answer: str) -> dict | None:
|
||||
"""Parse a PaliGemma ``<loc>``-format spatial VQA answer.
|
||||
|
||||
PI052 trains spatial answers in PaliGemma's native detection
|
||||
vocabulary, label-first: a point is ``<label> <locY><locX>``, a box
|
||||
is ``<label> <locY0><locX0><locY1><locX1>``, and multiple boxes are
|
||||
joined by `` ; `` (e.g. ``cube <loc..><loc..><loc..><loc..> ; box
|
||||
<loc..><loc..><loc..><loc..>``). Loc-first formats are also accepted
|
||||
— this parser strips loc tokens and treats the remainder as the
|
||||
label, so order is irrelevant. Coordinates come back *normalized*
|
||||
([0, 1]); the overlay denormalizes them against the chosen camera
|
||||
frame's pixel size.
|
||||
|
||||
Returns ``{"kind", "payload", "normalized": True}`` on success
|
||||
(``payload`` mirrors the JSON shapes so the overlay code is shared),
|
||||
or ``None`` when the answer carries no ``<loc>`` tokens.
|
||||
"""
|
||||
if not answer or "<loc" not in answer:
|
||||
return None
|
||||
segments = [seg for seg in answer.split(";") if "<loc" in seg]
|
||||
points: list[tuple[float, float, str]] = []
|
||||
boxes: list[tuple[float, float, float, float, str]] = []
|
||||
for seg in segments:
|
||||
locs = [int(m) for m in _LOC_RE.findall(seg)]
|
||||
label = _LOC_RE.sub("", seg).strip()
|
||||
if len(locs) == 2:
|
||||
y, x = (_loc_to_norm(v) for v in locs[:2])
|
||||
points.append((x, y, label))
|
||||
elif len(locs) >= 4:
|
||||
y1, x1, y2, x2 = (_loc_to_norm(v) for v in locs[:4])
|
||||
boxes.append((x1, y1, x2, y2, label))
|
||||
if boxes:
|
||||
detections = [
|
||||
{"label": lbl, "bbox_format": "xyxy", "bbox": [x1, y1, x2, y2]}
|
||||
for (x1, y1, x2, y2, lbl) in boxes
|
||||
]
|
||||
return {"kind": "bbox", "payload": {"detections": detections}, "normalized": True}
|
||||
if len(points) == 1:
|
||||
x, y, lbl = points[0]
|
||||
return {
|
||||
"kind": "keypoint",
|
||||
"payload": {"label": lbl, "point_format": "xy", "point": [x, y]},
|
||||
"normalized": True,
|
||||
}
|
||||
if points: # several bare points → treat as detections-as-points
|
||||
detections = [
|
||||
{"label": lbl, "bbox_format": "xyxy", "bbox": [x, y, x, y]} for (x, y, lbl) in points
|
||||
]
|
||||
return {"kind": "bbox", "payload": {"detections": detections}, "normalized": True}
|
||||
return None
|
||||
|
||||
|
||||
def parse_vqa_answer(answer: str) -> dict | None:
|
||||
"""Parse a VQA answer string into ``{"kind", "payload"}``.
|
||||
|
||||
``kind`` is one of the ``VQA_ANSWER_SHAPES`` names (``bbox``,
|
||||
``keypoint``, ``count``, ``attribute``, ``spatial``) or ``"unknown"``
|
||||
when the JSON doesn't match any known shape. PaliGemma ``<loc>``
|
||||
spatial answers are detected first (PI052 trains them in that native
|
||||
format). Returns ``None`` when the answer is neither ``<loc>`` text
|
||||
nor a parseable JSON object.
|
||||
"""
|
||||
if not answer or not answer.strip():
|
||||
return None
|
||||
loc_parsed = parse_loc_answer(answer)
|
||||
if loc_parsed is not None:
|
||||
return loc_parsed
|
||||
try:
|
||||
payload = json.loads(answer)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
|
||||
try:
|
||||
from lerobot.annotations.steerable_pipeline.validator import ( # noqa: PLC0415
|
||||
VQA_ANSWER_SHAPES,
|
||||
)
|
||||
|
||||
shapes = VQA_ANSWER_SHAPES
|
||||
except ImportError: # pragma: no cover - annotation extra not installed
|
||||
shapes = {
|
||||
"bbox": {"detections"},
|
||||
"keypoint": {"label", "point_format", "point"},
|
||||
"count": {"label", "count"},
|
||||
"attribute": {"label", "attribute", "value"},
|
||||
"spatial": {"subject", "relation", "object"},
|
||||
}
|
||||
|
||||
keys = set(payload)
|
||||
for kind in _SHAPE_ORDER:
|
||||
required = shapes.get(kind)
|
||||
if required and required <= keys:
|
||||
return {"kind": kind, "payload": payload}
|
||||
return {"kind": "unknown", "payload": payload}
|
||||
|
||||
|
||||
def answer_has_overlay(parsed: dict | None) -> bool:
|
||||
"""True iff ``parsed`` carries drawable spatial coordinates."""
|
||||
return bool(parsed) and parsed.get("kind") in ("bbox", "keypoint")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Overlay drawing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def observation_image_to_pil(image_tensor: Any) -> Any:
|
||||
"""Convert an ``observation.images.*`` tensor to a PIL RGB image.
|
||||
|
||||
The runtime observation stores images as ``(1, C, H, W)`` (or
|
||||
``(C, H, W)``) float tensors in ``[0, 1]``. Reuses
|
||||
``image_array_to_pil_image`` which handles the CHW→HWC transpose and
|
||||
the float→uint8 scaling.
|
||||
"""
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image # noqa: PLC0415
|
||||
|
||||
arr = image_tensor
|
||||
if hasattr(arr, "detach"):
|
||||
arr = arr.detach().cpu()
|
||||
if hasattr(arr, "numpy"):
|
||||
arr = arr.numpy()
|
||||
while arr.ndim > 3: # drop leading batch dim(s)
|
||||
arr = arr[0]
|
||||
return image_array_to_pil_image(arr).convert("RGB")
|
||||
|
||||
|
||||
def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
|
||||
"""Draw ``bbox`` / ``keypoint`` answers onto a copy of ``image``.
|
||||
|
||||
Non-spatial answers (``count`` / ``attribute`` / ``spatial`` /
|
||||
``unknown``) are returned as an unmodified copy. When ``parsed`` has
|
||||
``normalized=True`` (PaliGemma ``<loc>`` answers) the [0, 1]
|
||||
coordinates are scaled to the image's pixel size.
|
||||
"""
|
||||
from PIL import ImageDraw # noqa: PLC0415
|
||||
|
||||
img = image.convert("RGB").copy()
|
||||
kind = parsed.get("kind")
|
||||
payload = parsed.get("payload") or {}
|
||||
draw = ImageDraw.Draw(img)
|
||||
w, h = img.size
|
||||
sx, sy = (w, h) if parsed.get("normalized") else (1, 1)
|
||||
|
||||
if kind == "bbox":
|
||||
for det in payload.get("detections") or []:
|
||||
if not isinstance(det, dict):
|
||||
continue
|
||||
box = det.get("bbox")
|
||||
if not (isinstance(box, list | tuple) and len(box) == 4):
|
||||
continue
|
||||
try:
|
||||
x1, y1, x2, y2 = (float(v) for v in box)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
x1, x2 = x1 * sx, x2 * sx
|
||||
y1, y2 = y1 * sy, y2 * sy
|
||||
draw.rectangle([x1, y1, x2, y2], outline=_BBOX_COLOR, width=3)
|
||||
label = str(det.get("label", "")).strip()
|
||||
if label:
|
||||
draw.text((x1 + 3, max(0.0, y1 - 12)), label, fill=_BBOX_COLOR)
|
||||
elif kind == "keypoint":
|
||||
point = payload.get("point")
|
||||
if isinstance(point, list | tuple) and len(point) == 2:
|
||||
try:
|
||||
x, y = float(point[0]) * sx, float(point[1]) * sy
|
||||
except (TypeError, ValueError):
|
||||
return img
|
||||
r = 6
|
||||
draw.ellipse([x - r, y - r, x + r, y + r], outline=_POINT_COLOR, width=3)
|
||||
draw.line([x - 2 * r, y, x + 2 * r, y], fill=_POINT_COLOR, width=2)
|
||||
draw.line([x, y - 2 * r, x, y + 2 * r], fill=_POINT_COLOR, width=2)
|
||||
label = str(payload.get("label", "")).strip()
|
||||
if label:
|
||||
draw.text((x + r + 3, y - r), label, fill=_POINT_COLOR)
|
||||
return img
|
||||
|
||||
|
||||
def _open_file(path: Path) -> None:
|
||||
"""Best-effort open ``path`` in the OS default viewer."""
|
||||
try:
|
||||
if sys.platform == "darwin":
|
||||
subprocess.run(["open", str(path)], check=False)
|
||||
elif sys.platform.startswith("linux"):
|
||||
subprocess.run(["xdg-open", str(path)], check=False)
|
||||
elif os.name == "nt":
|
||||
os.startfile(str(path)) # type: ignore[attr-defined] # noqa: S606
|
||||
else: # pragma: no cover - exotic platform
|
||||
webbrowser.open(path.resolve().as_uri())
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("could not auto-open %s: %s", path, exc)
|
||||
|
||||
|
||||
def save_and_open_overlay(image: Any, out_dir: str | Path = "./vqa_overlays") -> Path:
|
||||
"""Save ``image`` as a timestamped PNG under ``out_dir`` and auto-open it."""
|
||||
out = Path(out_dir)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
path = out / f"vqa_{int(time.time() * 1000)}.png"
|
||||
image.save(path)
|
||||
_open_file(path)
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Orchestrator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def handle_vqa_query(
|
||||
*,
|
||||
policy: Any,
|
||||
observation_provider: Any,
|
||||
question: str,
|
||||
state: dict[str, Any],
|
||||
input_fn: Any = input,
|
||||
print_fn: Any = print,
|
||||
) -> None:
|
||||
"""Run one interactive VQA question end to end.
|
||||
|
||||
Called synchronously from the input layer while the runtime is in
|
||||
``/question`` mode (the action loop is gated off, so the policy is
|
||||
not in concurrent use). Progress is reported via both
|
||||
:func:`push_log` (REPL panel scrollback) and ``print_fn`` (direct
|
||||
stdout) — in autonomous question mode the panel redraw is suspended,
|
||||
so the direct print is what the operator actually sees.
|
||||
"""
|
||||
from .steps import _generate_with_policy, _msgs_for_vqa # noqa: PLC0415
|
||||
|
||||
def report(line: str) -> None:
|
||||
"""Surface a line both to the panel scrollback and to stdout."""
|
||||
push_log(state, line)
|
||||
try:
|
||||
print_fn(line)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
if policy is None or not hasattr(policy, "select_message"):
|
||||
report(" [warn] vqa: policy has no select_message — skipping")
|
||||
return
|
||||
|
||||
observation: dict | None = None
|
||||
if observation_provider is not None:
|
||||
try:
|
||||
observation = observation_provider()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.debug("observation_provider raised %s", exc)
|
||||
|
||||
# Feed the FULL observation (every camera + state) to the VLM. The
|
||||
# ``ask_vqa_*`` recipes look single-camera, but the image *block* is
|
||||
# stripped before tokenization — the actual frames reach the model
|
||||
# via PI052's ``OBS_IMAGES_*`` channels, and ``embed_prefix``
|
||||
# consumes *all* ``config.image_features`` regardless of which
|
||||
# camera the sub-recipe was tagged for. So the model always sees
|
||||
# every camera; the operator never has to name one to ask.
|
||||
answer = _generate_with_policy(
|
||||
policy,
|
||||
_msgs_for_vqa(question),
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="vqa gen",
|
||||
)
|
||||
if not answer:
|
||||
report(" [info] vqa gen returned empty")
|
||||
return
|
||||
report(f" vqa: {answer}")
|
||||
|
||||
parsed = parse_vqa_answer(answer)
|
||||
if not answer_has_overlay(parsed):
|
||||
if parsed is None:
|
||||
report(" [info] vqa answer is not JSON — no overlay")
|
||||
return
|
||||
|
||||
# The answer carries a bounding box / point. Its pixel coordinates
|
||||
# are camera-specific and the text answer doesn't say which camera,
|
||||
# so ask the operator *now* — only when there is actually something
|
||||
# to draw — which camera frame to render the overlay on.
|
||||
cameras = available_cameras(observation)
|
||||
if observation is None or not cameras:
|
||||
report(" [info] no camera image — cannot draw overlay")
|
||||
return
|
||||
chosen = prompt_camera_choice(cameras, input_fn=input_fn, print_fn=print_fn)
|
||||
if chosen is None:
|
||||
report(" [info] overlay skipped — no camera selected")
|
||||
return
|
||||
try:
|
||||
pil = observation_image_to_pil(observation[chosen])
|
||||
overlay = draw_vqa_overlay(pil, parsed)
|
||||
path = save_and_open_overlay(overlay)
|
||||
report(f" vqa overlay ({camera_short_name(chosen)}) saved: {path}")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("vqa overlay failed: %s", exc, exc_info=logger.isEnabledFor(logging.DEBUG))
|
||||
report(f" [warn] vqa overlay failed: {type(exc).__name__}: {exc}")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,198 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""π0.5 v2 pre/post-processor factory.
|
||||
|
||||
When ``config.recipe_path`` is set, the pre-processor pipeline becomes:
|
||||
|
||||
rename observations
|
||||
add batch dim
|
||||
relative-action prep (inherited from π0.5)
|
||||
NormalizerProcessorStep
|
||||
RenderMessagesStep — recipe → messages, target_message_indices,
|
||||
message_streams (PR 1 of the steerable
|
||||
stack)
|
||||
PI052TextTokenizerStep — messages → input_ids + label mask +
|
||||
predict_actions
|
||||
DeviceProcessorStep
|
||||
|
||||
When ``recipe_path`` is ``None`` we delegate to the plain π0.5 pipeline
|
||||
so unannotated datasets keep working.
|
||||
|
||||
Post-processor is unchanged from π0.5.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.recipe import TrainingRecipe
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
ActionTokenizerProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RelativeActionsProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
# RenderMessagesStep is intentionally not re-exported from
|
||||
# ``lerobot.processor`` because it pulls in optional language-stack deps;
|
||||
# import it directly.
|
||||
from lerobot.processor.render_messages_processor import RenderMessagesStep
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
from ..pi05.processor_pi05 import make_pi05_pre_post_processors
|
||||
from .configuration_pi052 import PI052Config
|
||||
from .text_processor_pi052 import PI052TextTokenizerStep
|
||||
|
||||
|
||||
def make_pi052_pre_post_processors(
|
||||
config: PI052Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
dataset_repo_id: str | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Build PI0.5-v2's pre/post-processor pipelines.
|
||||
|
||||
Falls through to π0.5's stock pipeline when ``recipe_path`` is unset.
|
||||
"""
|
||||
if not config.recipe_path:
|
||||
return make_pi05_pre_post_processors(config, dataset_stats=dataset_stats)
|
||||
|
||||
recipe = _load_recipe(config.recipe_path)
|
||||
|
||||
relative_step = RelativeActionsProcessorStep(
|
||||
enabled=config.use_relative_actions,
|
||||
exclude_joints=getattr(config, "relative_exclude_joints", []),
|
||||
action_names=getattr(config, "action_feature_names", None),
|
||||
)
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
relative_step,
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
RenderMessagesStep(recipe=recipe),
|
||||
PI052TextTokenizerStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
plan_dropout_prob=getattr(config, "plan_dropout_prob", 0.0),
|
||||
memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0),
|
||||
subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0),
|
||||
),
|
||||
]
|
||||
|
||||
# FAST tokenizer for discrete-action CE supervision (paper §III.C).
|
||||
# Only inserted when explicitly enabled — keeps the post-training-
|
||||
# style recipe (flow + text) as the default. When on, the step
|
||||
# writes ACTION_TOKENS / ACTION_TOKEN_MASK into
|
||||
# ``COMPLEMENTARY_DATA`` and the modeling forward picks them up.
|
||||
if getattr(config, "enable_fast_action_loss", False):
|
||||
# Per Pertsch et al. 2025 (FAST [64], π0.5 §III.C): fit the
|
||||
# tokenizer on this dataset's action distribution rather than
|
||||
# using the universal codebook off the shelf. We do this once
|
||||
# and cache to disk, keyed on (dataset, base, n_samples).
|
||||
action_tokenizer_path = config.action_tokenizer_name
|
||||
if (
|
||||
getattr(config, "auto_fit_fast_tokenizer", False)
|
||||
and dataset_repo_id is not None
|
||||
):
|
||||
from .fit_fast_tokenizer import fit_fast_tokenizer # noqa: PLC0415
|
||||
|
||||
cache_dir = Path(config.fast_tokenizer_cache_dir).expanduser()
|
||||
try:
|
||||
action_tokenizer_path = fit_fast_tokenizer(
|
||||
dataset_repo_id=dataset_repo_id,
|
||||
cache_dir=cache_dir,
|
||||
base_tokenizer_name=config.action_tokenizer_name,
|
||||
n_samples=config.fast_tokenizer_fit_samples,
|
||||
chunk_size=config.chunk_size,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
import logging # noqa: PLC0415
|
||||
|
||||
logging.getLogger(__name__).warning(
|
||||
"FAST tokenizer fit failed (%s) — falling back to "
|
||||
"the universal base tokenizer %r. Train will still "
|
||||
"work but compression will be suboptimal.",
|
||||
exc, config.action_tokenizer_name,
|
||||
)
|
||||
|
||||
input_steps.append(
|
||||
ActionTokenizerProcessorStep(
|
||||
action_tokenizer_name=action_tokenizer_path,
|
||||
max_action_tokens=config.max_action_tokens,
|
||||
fast_skip_tokens=config.fast_skip_tokens,
|
||||
paligemma_tokenizer_name="google/paligemma-3b-pt-224",
|
||||
)
|
||||
)
|
||||
|
||||
input_steps.append(DeviceProcessorStep(device=config.device))
|
||||
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AbsoluteActionsProcessorStep(
|
||||
enabled=config.use_relative_actions,
|
||||
relative_step=relative_step,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _load_recipe(path_str: str) -> TrainingRecipe:
|
||||
"""Resolve ``path_str`` to a ``TrainingRecipe``.
|
||||
|
||||
Accepts an absolute path or a path relative to
|
||||
``src/lerobot/configs/``.
|
||||
"""
|
||||
p = Path(path_str)
|
||||
if not p.is_absolute() and not p.exists():
|
||||
from lerobot.configs import recipe as _recipe_module # noqa: PLC0415
|
||||
|
||||
configs_dir = Path(_recipe_module.__file__).resolve().parent
|
||||
candidate = configs_dir / path_str
|
||||
if candidate.exists():
|
||||
p = candidate
|
||||
return TrainingRecipe.from_yaml(p)
|
||||
@@ -1,598 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""π0.5 v2 text-tokenisation step.
|
||||
|
||||
PaliGemma is *not* chat-pretrained, so we can't lean on
|
||||
``tokenizer.apply_chat_template``. Instead we concatenate the rendered
|
||||
messages as plain text with simple ``User: ... Assistant: ...`` role
|
||||
delimiters — matching the prompt format π0.5 uses in the paper
|
||||
(``Task: ... State: ... Action: ...``).
|
||||
|
||||
Outputs:
|
||||
|
||||
* ``OBS_LANGUAGE_TOKENS`` / ``OBS_LANGUAGE_ATTENTION_MASK`` — the
|
||||
concatenated prompt tokenised by the PaliGemma tokenizer (the same
|
||||
one ``processor_pi05`` already uses).
|
||||
* ``text_labels`` — same shape as token ids, ``-100`` everywhere except
|
||||
positions belonging to messages whose index is in
|
||||
``target_message_indices``. ``modeling_pi052`` runs cross-entropy on
|
||||
those positions via the PaliGemma ``lm_head``.
|
||||
* ``predict_actions`` — bool tensor, ``True`` iff any of the rendered
|
||||
target messages has ``message_streams[i] == "low_level"``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _content_to_text(content: Any) -> str:
|
||||
"""Collapse a message's ``content`` (string or multimodal blocks) to text."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b["text"]
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text" and isinstance(b.get("text"), str)
|
||||
]
|
||||
return "\n".join(parts)
|
||||
return ""
|
||||
|
||||
|
||||
def _flatten_say_tool_calls(message: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Serialize assistant ``say`` tool calls into a ``<say>...</say>`` marker.
|
||||
|
||||
PaliGemma's flat text prompt has no notion of structured tool calls,
|
||||
and ``_format_messages`` only reads ``role`` / ``content`` — so
|
||||
without this a ``say`` tool call is dropped entirely and never
|
||||
supervised. Rewriting it into the content text as a ``<say>...</say>``
|
||||
marker lets the LM head learn to emit it; the runtime parses it back
|
||||
via ``_split_plan_and_say``. Messages without ``say`` tool calls are
|
||||
returned unchanged (the structured calls, if any, are still dropped).
|
||||
"""
|
||||
tool_calls = message.get("tool_calls")
|
||||
if not tool_calls:
|
||||
return message
|
||||
say_texts: list[str] = []
|
||||
for call in tool_calls:
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
fn = call.get("function") or {}
|
||||
if fn.get("name") != "say":
|
||||
continue
|
||||
args = fn.get("arguments")
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
import json # noqa: PLC0415
|
||||
|
||||
args = json.loads(args)
|
||||
except (ValueError, TypeError):
|
||||
args = {}
|
||||
text = args.get("text", "") if isinstance(args, dict) else ""
|
||||
if text:
|
||||
say_texts.append(str(text))
|
||||
new = dict(message)
|
||||
new.pop("tool_calls", None)
|
||||
if not say_texts:
|
||||
return new
|
||||
base = _content_to_text(new.get("content")).strip()
|
||||
marker = "".join(f"<say>{t}</say>" for t in say_texts)
|
||||
new["content"] = f"{base}\n{marker}" if base else marker
|
||||
return new
|
||||
|
||||
|
||||
def _strip_blocks(message: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Normalise a message's content to a plain string.
|
||||
|
||||
The recipe renderer can emit ``content`` as a string OR as a list
|
||||
of HF-style multimodal blocks (``{type: text, text: ...}``,
|
||||
``{type: image, feature: ...}``). PaliGemma's text tokenizer can
|
||||
only consume strings, so we flatten: drop image blocks (cameras
|
||||
flow through ``observation.images.*`` separately) and join text
|
||||
block texts.
|
||||
"""
|
||||
new = dict(message)
|
||||
new.pop("stream", None)
|
||||
new.pop("target", None)
|
||||
content = new.get("content")
|
||||
if content is None:
|
||||
new["content"] = ""
|
||||
elif isinstance(content, str):
|
||||
pass
|
||||
elif isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "text":
|
||||
t = block.get("text", "")
|
||||
if isinstance(t, str):
|
||||
parts.append(t)
|
||||
new["content"] = "\n".join(parts)
|
||||
else:
|
||||
new["content"] = str(content)
|
||||
return new
|
||||
|
||||
|
||||
def _is_batched_messages(messages: Any) -> bool:
|
||||
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
|
||||
|
||||
|
||||
def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
|
||||
if value is None:
|
||||
return [None] * batch_size
|
||||
if isinstance(value, torch.Tensor):
|
||||
if value.numel() == 1:
|
||||
return [int(value.item())] * batch_size
|
||||
values = value.reshape(-1).tolist()
|
||||
return [int(v) for v in values[:batch_size]]
|
||||
if isinstance(value, (list, tuple)):
|
||||
if len(value) == 1:
|
||||
return _sample_indices(value[0], batch_size)
|
||||
return [int(v.item() if hasattr(v, "item") else v) for v in value[:batch_size]]
|
||||
return [int(value)] * batch_size
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VQA spatial answers → PaliGemma <loc> format (PI052 only)
|
||||
#
|
||||
# PaliGemma is pre-trained on detection / pointing with a ``<locNNNN>``
|
||||
# vocabulary (normalized [0, 1023]). The recipe's bbox / keypoint VQA
|
||||
# answers are stored as JSON in Qwen2.5-VL's grounding convention:
|
||||
# **0–1000 normalized coordinates**, NOT pixels. (Verified empirically
|
||||
# on the published datasets: x and y both span 0..1000 with ~30% of
|
||||
# values exceeding the camera's pixel dimensions — they're not pixels.)
|
||||
# Converting to ``<loc>`` is therefore camera-resolution-independent:
|
||||
# ``loc_idx = round(coord / 1000 * 1023)``. We do the conversion here —
|
||||
# not in the dataset — so the dataset keeps the raw JSON and stays
|
||||
# backbone-agnostic.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# The 0–1000 scale Qwen2.5-VL emits for grounding coordinates.
|
||||
_VQA_COORD_SCALE = 1000.0
|
||||
|
||||
|
||||
def register_paligemma_loc_tokens(tokenizer: Any) -> Any:
|
||||
"""Make PaliGemma's ``<locDDDD>`` ids match on raw text — single tokens.
|
||||
|
||||
PaliGemma reserves vocab ids [256000, 257023] for ``<locDDDD>``
|
||||
(detection / pointing) tokens, but the *stock* tokenizer does NOT
|
||||
match them when encoding raw text — it BPE-splits ``<loc0162>`` into
|
||||
7 pieces (``<``, ``loc``, ``0``, ``1``, ``6``, ``2``, ``>``). Training
|
||||
the LM head on a ``<loc>`` target then supervises those 7 generic
|
||||
BPE pieces instead of one detection-vocab id, the LM head learns to
|
||||
emit the *character sequence*, and those pieces' logits dominate
|
||||
other turns (the ``<loc>``-salad on subtasks). Registering the loc
|
||||
tokens once makes them tokenize as their single ids (256000+idx),
|
||||
leveraging PaliGemma's detection prior properly. Idempotent.
|
||||
"""
|
||||
if "<loc0000>" in getattr(tokenizer, "added_tokens_encoder", {}):
|
||||
return tokenizer
|
||||
tokenizer.add_tokens([f"<loc{i:04d}>" for i in range(1024)])
|
||||
return tokenizer
|
||||
|
||||
|
||||
def _loc_token(coord: float, scale: float = _VQA_COORD_SCALE) -> str:
|
||||
"""PaliGemma ``<locNNNN>`` for a coord on a ``[0, scale]`` axis."""
|
||||
idx = round(float(coord) / scale * 1023) if scale > 0 else 0
|
||||
return f"<loc{max(0, min(1023, idx)):04d}>"
|
||||
|
||||
|
||||
def _vqa_answer_to_loc(answer: dict[str, Any]) -> str | None:
|
||||
"""Convert a bbox / keypoint VQA answer dict to PaliGemma ``<loc>`` text.
|
||||
|
||||
Input coordinates are in Qwen2.5-VL's 0–1000 normalized space (see
|
||||
module-level note). y is emitted before x for each coordinate pair
|
||||
(PaliGemma convention), with the integer indices in [0, 1023].
|
||||
|
||||
**Format: label first, locs after.** PaliGemma's pretraining puts
|
||||
locs first (``<loc><loc> label``), but for our small-dataset VQA
|
||||
blend that turns the LM head into a loc-emission attractor at every
|
||||
``Assistant:`` position — VQA targets share their first supervised
|
||||
token with ~25% of all text samples, and the head collapses to
|
||||
emitting ``<loc>`` regardless of the prompt. Putting the label
|
||||
first (``label <locY><locX>``) means every text sample (subtask,
|
||||
memory, VQA, …) starts the supervised target with a real word,
|
||||
breaking the attractor. The model still learns the loc vocabulary
|
||||
for the *spatial* portion of the answer; it just can't fire it as
|
||||
the first generation step from a clean prompt.
|
||||
|
||||
Returns ``None`` for non-spatial answers (count / attribute /
|
||||
spatial-relation) — those keep their JSON form.
|
||||
"""
|
||||
point = answer.get("point")
|
||||
if isinstance(point, list | tuple) and len(point) == 2 and "point_format" in answer:
|
||||
try:
|
||||
x, y = float(point[0]), float(point[1])
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
label = str(answer.get("label", "")).strip()
|
||||
if not label:
|
||||
return None
|
||||
return f"{label} {_loc_token(y)}{_loc_token(x)}"
|
||||
|
||||
detections = answer.get("detections")
|
||||
if isinstance(detections, list) and detections:
|
||||
parts: list[str] = []
|
||||
for det in detections:
|
||||
if not isinstance(det, dict):
|
||||
continue
|
||||
box = det.get("bbox")
|
||||
if not (isinstance(box, list | tuple) and len(box) == 4):
|
||||
continue
|
||||
try:
|
||||
x1, y1, x2, y2 = (float(v) for v in box)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
label = str(det.get("label", "")).strip()
|
||||
if not label:
|
||||
continue
|
||||
toks = (
|
||||
f"{_loc_token(y1)}{_loc_token(x1)}"
|
||||
f"{_loc_token(y2)}{_loc_token(x2)}"
|
||||
)
|
||||
parts.append(f"{label} {toks}")
|
||||
return " ; ".join(parts) if parts else None
|
||||
return None
|
||||
|
||||
|
||||
def _messages_vqa_to_loc(
|
||||
messages: list[dict[str, Any]],
|
||||
target_indices: list[int],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Rewrite bbox / keypoint VQA *target* answers from JSON to ``<loc>`` text.
|
||||
|
||||
Each target turn whose content parses as a spatial VQA answer is
|
||||
converted. Non-spatial answers and subtask / memory targets (plain
|
||||
text → not JSON) are left untouched. Camera-independent: VQA coords
|
||||
are 0–1000 normalized, so no observation lookup is needed.
|
||||
"""
|
||||
if not target_indices:
|
||||
return messages
|
||||
out = list(messages)
|
||||
for idx in target_indices:
|
||||
if not (0 <= idx < len(out)):
|
||||
continue
|
||||
content = out[idx].get("content")
|
||||
if not isinstance(content, str) or not content.strip():
|
||||
continue
|
||||
try:
|
||||
answer = json.loads(content)
|
||||
except (ValueError, TypeError):
|
||||
continue # subtask / memory targets are plain text — skip
|
||||
if not isinstance(answer, dict):
|
||||
continue
|
||||
loc_text = _vqa_answer_to_loc(answer)
|
||||
if loc_text is not None:
|
||||
out[idx] = {**out[idx], "content": loc_text}
|
||||
return out
|
||||
|
||||
|
||||
def _format_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
target_indices: list[int] | None = None,
|
||||
eos_token: str | None = None,
|
||||
) -> tuple[str, list[tuple[int, int]]]:
|
||||
"""Concatenate messages into the π0.5-style flat prompt.
|
||||
|
||||
When both ``target_indices`` and ``eos_token`` are given, the EOS
|
||||
string is appended to each supervised target turn's content and the
|
||||
returned span covers it — so the label builder marks the EOS token
|
||||
as a supervised label. That teaches the LM head where the answer
|
||||
*ends*: without an EOS in the target span the model is never given a
|
||||
stop signal and rambles to ``max_length`` at inference. Inference
|
||||
callers omit both args (no EOS baked into the prompt — the model
|
||||
generates it and ``select_message`` stops on it).
|
||||
|
||||
Returns:
|
||||
prompt: the full text the tokenizer will consume.
|
||||
msg_spans: list of ``(char_start, char_end)`` covering each
|
||||
message's supervised payload (content, plus the
|
||||
appended EOS for target turns) within ``prompt``.
|
||||
"""
|
||||
targets = set(target_indices or [])
|
||||
parts: list[str] = []
|
||||
spans: list[tuple[int, int]] = []
|
||||
cursor = 0
|
||||
for i, m in enumerate(messages):
|
||||
role = m.get("role", "user")
|
||||
content = m.get("content", "") or ""
|
||||
# Role tag + newline. The model has to learn to emit the same
|
||||
# role tokens at generation time, which is fine for greedy
|
||||
# decoding because the chat template is implicit in the
|
||||
# supervised target span.
|
||||
header = f"{role.capitalize()}: "
|
||||
# A supervised target turn ends with EOS so the model learns to
|
||||
# terminate; the span below covers content + EOS. Non-target
|
||||
# turns (and inference) carry no EOS.
|
||||
body = content + eos_token if (eos_token and i in targets) else content
|
||||
# span covers the content (+ EOS) portion only — never the role
|
||||
# tag — so labels are computed over the supervised payload.
|
||||
full = header + body + "\n"
|
||||
start = cursor + len(header)
|
||||
end = start + len(body)
|
||||
parts.append(full)
|
||||
spans.append((start, end))
|
||||
cursor += len(full)
|
||||
return "".join(parts), spans
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="pi052_text_tokenizer")
|
||||
class PI052TextTokenizerStep(ProcessorStep):
|
||||
"""Render messages → token ids + label mask + predict_actions flag.
|
||||
|
||||
No chat template; concatenates messages as
|
||||
``User: ... \\nAssistant: ...`` text.
|
||||
"""
|
||||
|
||||
tokenizer_name: str = "google/paligemma-3b-pt-224"
|
||||
max_length: int = 200
|
||||
padding: str = "max_length"
|
||||
padding_side: str = "right"
|
||||
plan_dropout_prob: float = 0.0
|
||||
memory_dropout_prob: float = 0.0
|
||||
subtask_dropout_prob: float = 0.0
|
||||
interjection_dropout_prob: float = 0.0
|
||||
dropout_seed: int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._tokenizer: Any = None
|
||||
|
||||
def _ensure_tokenizer(self) -> Any:
|
||||
if self._tokenizer is not None:
|
||||
return self._tokenizer
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
self._tokenizer = register_paligemma_loc_tokens(
|
||||
AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||
)
|
||||
return self._tokenizer
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pipeline step
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
|
||||
transition = transition.copy()
|
||||
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
||||
messages = complementary.get("messages") or []
|
||||
|
||||
if not messages:
|
||||
# No recipe was rendered — caller will fall back to the
|
||||
# plain Pi0.5 prompt path. We pass the transition through
|
||||
# unmodified.
|
||||
return transition
|
||||
|
||||
tokenizer = self._ensure_tokenizer()
|
||||
# VQA coords are 0–1000 normalized (Qwen2.5-VL convention) — the
|
||||
# <loc> conversion is camera-resolution-independent and needs no
|
||||
# observation lookup here.
|
||||
if _is_batched_messages(messages):
|
||||
indices_iter = _sample_indices(complementary.get("index"), len(messages))
|
||||
encoded = [
|
||||
self._encode_messages(
|
||||
tokenizer,
|
||||
msg,
|
||||
list(streams),
|
||||
list(tgt_indices),
|
||||
complementary,
|
||||
sample_idx=int(s_idx) if s_idx is not None else None,
|
||||
)
|
||||
for msg, streams, tgt_indices, s_idx in zip(
|
||||
messages,
|
||||
complementary.get("message_streams") or [[] for _ in messages],
|
||||
complementary.get("target_message_indices") or [[] for _ in messages],
|
||||
indices_iter,
|
||||
strict=False,
|
||||
)
|
||||
]
|
||||
else:
|
||||
sample_idx = _sample_indices(complementary.get("index"), 1)[0]
|
||||
encoded = [
|
||||
self._encode_messages(
|
||||
tokenizer,
|
||||
messages,
|
||||
list(complementary.get("message_streams") or []),
|
||||
list(complementary.get("target_message_indices") or []),
|
||||
complementary,
|
||||
sample_idx=sample_idx,
|
||||
)
|
||||
]
|
||||
|
||||
obs = dict(transition.get(TransitionKey.OBSERVATION) or {})
|
||||
obs[OBS_LANGUAGE_TOKENS] = torch.stack([ids for ids, _, _, _, _ in encoded])
|
||||
obs[OBS_LANGUAGE_ATTENTION_MASK] = torch.stack([attn for _, attn, _, _, _ in encoded])
|
||||
transition[TransitionKey.OBSERVATION] = obs
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = {
|
||||
**complementary,
|
||||
"text_labels": torch.stack([labels for _, _, labels, _, _ in encoded]),
|
||||
"predict_actions": torch.stack([pred for _, _, _, pred, _ in encoded]),
|
||||
}
|
||||
return transition
|
||||
|
||||
def _encode_messages(
|
||||
self,
|
||||
tokenizer: Any,
|
||||
messages: list[dict[str, Any]],
|
||||
message_streams: list[str | None],
|
||||
target_indices: list[int],
|
||||
complementary: dict[str, Any],
|
||||
sample_idx: int | None = None,
|
||||
) -> tuple[Tensor, Tensor, Tensor, Tensor, str]:
|
||||
# Optional: drop non-target messages per the dropout config.
|
||||
# Keeps the supervised-target indices stable by re-mapping
|
||||
# after removal.
|
||||
if (
|
||||
self.plan_dropout_prob
|
||||
or self.memory_dropout_prob
|
||||
or self.subtask_dropout_prob
|
||||
or self.interjection_dropout_prob
|
||||
):
|
||||
messages, target_indices = self._apply_prompt_dropout(
|
||||
messages,
|
||||
target_indices,
|
||||
complementary,
|
||||
sample_idx=sample_idx,
|
||||
)
|
||||
|
||||
# Rewrite bbox / keypoint VQA target answers from JSON to
|
||||
# PaliGemma <loc> text. Coords are 0–1000 normalized so this is
|
||||
# camera-independent.
|
||||
messages = _messages_vqa_to_loc(messages, target_indices)
|
||||
|
||||
# Flatten ``say`` tool calls into ``<say>...</say>`` text before
|
||||
# stripping, so the spoken reply is actually tokenized and
|
||||
# supervised (PaliGemma's flat prompt has no structured calls).
|
||||
messages = [_strip_blocks(_flatten_say_tool_calls(m)) for m in messages]
|
||||
# Append EOS to supervised target turns so the LM head learns to
|
||||
# stop (the span covers it → it becomes a supervised label).
|
||||
prompt, spans = _format_messages(
|
||||
messages, target_indices, getattr(tokenizer, "eos_token", None)
|
||||
)
|
||||
|
||||
encoded = tokenizer(
|
||||
prompt,
|
||||
max_length=self.max_length,
|
||||
padding=self.padding,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
return_offsets_mapping=True,
|
||||
padding_side=self.padding_side,
|
||||
)
|
||||
|
||||
input_ids = encoded["input_ids"][0]
|
||||
attention_mask = encoded["attention_mask"][0].bool()
|
||||
offsets = encoded["offset_mapping"][0] # (seq, 2), char (start,end)
|
||||
|
||||
# Build label mask: -100 everywhere except over supervised
|
||||
# target message char ranges.
|
||||
labels = torch.full_like(input_ids, fill_value=-100)
|
||||
for idx in target_indices:
|
||||
if idx >= len(spans):
|
||||
continue
|
||||
char_start, char_end = spans[idx]
|
||||
for token_pos in range(input_ids.shape[0]):
|
||||
if not attention_mask[token_pos]:
|
||||
continue
|
||||
tok_start, tok_end = int(offsets[token_pos, 0]), int(offsets[token_pos, 1])
|
||||
if tok_end <= char_start or tok_start >= char_end:
|
||||
continue
|
||||
labels[token_pos] = input_ids[token_pos]
|
||||
|
||||
# Scan ALL message streams (not just targets): the
|
||||
# ``low_level_execution`` recipe drops ``target: true`` on
|
||||
# the assistant to avoid trivial copy-from-user text-CE; the
|
||||
# flow loss still needs to fire, gated by ``stream: low_level``.
|
||||
predict_actions = torch.tensor(
|
||||
bool(any(s == "low_level" for s in message_streams)),
|
||||
dtype=torch.bool,
|
||||
)
|
||||
return input_ids, attention_mask, labels, predict_actions, prompt
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Per-component prompt dropout (Pi0.7 §V.E)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _apply_prompt_dropout(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
target_indices: list[int],
|
||||
complementary: dict[str, Any],
|
||||
sample_idx: int | None = None,
|
||||
) -> tuple[list[dict[str, Any]], list[int]]:
|
||||
"""Drop messages classified as plan/memory/subtask context.
|
||||
|
||||
Targets are *never* dropped (they're the supervised payload).
|
||||
Re-maps target_indices to the new positions after drops.
|
||||
"""
|
||||
import random # noqa: PLC0415
|
||||
|
||||
seed = self.dropout_seed
|
||||
if seed is None:
|
||||
# Canonical row-index key set by ``BatchProcessor`` /
|
||||
# ``render_messages_processor``. Falling back to other
|
||||
# keys silently gave every sample seed=0 → identical
|
||||
# dropout pattern across the whole epoch.
|
||||
seed_src = sample_idx if sample_idx is not None else complementary.get("index", 0)
|
||||
try:
|
||||
if hasattr(seed_src, "item"):
|
||||
seed_src = seed_src.item()
|
||||
seed = int(seed_src)
|
||||
except (TypeError, ValueError):
|
||||
seed = 0
|
||||
rng = random.Random(seed)
|
||||
|
||||
keep_indices: list[int] = []
|
||||
for idx, msg in enumerate(messages):
|
||||
if idx in target_indices:
|
||||
keep_indices.append(idx)
|
||||
continue
|
||||
kind = _classify_for_dropout(msg)
|
||||
prob = {
|
||||
"plan": self.plan_dropout_prob,
|
||||
"memory": self.memory_dropout_prob,
|
||||
"subtask": self.subtask_dropout_prob,
|
||||
"interjection": self.interjection_dropout_prob,
|
||||
}.get(kind, 0.0)
|
||||
if prob > 0.0 and rng.random() < prob:
|
||||
continue
|
||||
keep_indices.append(idx)
|
||||
|
||||
# Build remap and apply
|
||||
new_messages = [messages[i] for i in keep_indices]
|
||||
old_to_new = {old: new for new, old in enumerate(keep_indices)}
|
||||
new_targets = [old_to_new[t] for t in target_indices if t in old_to_new]
|
||||
return new_messages, new_targets
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
def _classify_for_dropout(message: dict[str, Any]) -> str | None:
|
||||
"""Heuristic content-prefix classifier (plan / memory / subtask)."""
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
text_parts = [b.get("text", "") for b in content if isinstance(b, dict) and b.get("type") == "text"]
|
||||
content = " ".join(text_parts)
|
||||
elif content is None:
|
||||
return None
|
||||
elif not isinstance(content, str):
|
||||
return None
|
||||
s = content.strip()
|
||||
if s.startswith("Plan:") or s.startswith("Previous plan"):
|
||||
return "plan"
|
||||
if s.startswith("Memory:") or s.startswith("Previous memory"):
|
||||
return "memory"
|
||||
if s.startswith("Current subtask") or s.startswith("Completed subtask"):
|
||||
return "subtask"
|
||||
return None
|
||||
@@ -275,8 +275,6 @@ class PiGemmaModel(GemmaModel): # type: ignore[misc]
|
||||
# Convert to bfloat16 if the first layer uses bfloat16
|
||||
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.bfloat16)
|
||||
if causal_mask is not None and torch.is_floating_point(causal_mask):
|
||||
causal_mask = causal_mask.to(dtype=hidden_states.dtype)
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
@@ -175,6 +175,9 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
|
||||
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
|
||||
complementary_data.pop("language_persistent", None)
|
||||
complementary_data.pop("language_events", None)
|
||||
|
||||
if "messages" in complementary_data:
|
||||
messages = complementary_data["messages"]
|
||||
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
|
||||
|
||||
@@ -52,9 +52,6 @@ class RenderMessagesStep(ProcessorStep):
|
||||
if not persistent and not events:
|
||||
return transition
|
||||
|
||||
if _is_batched_language(persistent) or _is_batched_language(events):
|
||||
return self._call_batch(transition, complementary_data, persistent, events)
|
||||
|
||||
timestamp = complementary_data.get("timestamp")
|
||||
if timestamp is None:
|
||||
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
|
||||
@@ -70,131 +67,18 @@ class RenderMessagesStep(ProcessorStep):
|
||||
dataset_ctx=self.dataset_ctx,
|
||||
)
|
||||
if rendered is None:
|
||||
rendered = _fallback_low_level_render(complementary_data.get("task"))
|
||||
if rendered is None:
|
||||
return None
|
||||
return None
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
|
||||
new_complementary_data.pop(LANGUAGE_EVENTS, None)
|
||||
new_complementary_data.update(rendered)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def _call_batch(
|
||||
self,
|
||||
transition: EnvTransition,
|
||||
complementary_data: dict[str, Any],
|
||||
persistent_batch: list,
|
||||
events_batch: list,
|
||||
) -> EnvTransition | None:
|
||||
timestamp = complementary_data.get("timestamp")
|
||||
if timestamp is None:
|
||||
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
|
||||
|
||||
batch_size = max(len(persistent_batch), len(events_batch))
|
||||
messages: list[list[dict[str, Any]]] = []
|
||||
message_streams: list[list[str | None]] = []
|
||||
target_message_indices: list[list[int]] = []
|
||||
keep_indices: list[int] = []
|
||||
|
||||
for i in range(batch_size):
|
||||
rendered = render_sample(
|
||||
recipe=self.recipe,
|
||||
persistent=persistent_batch[i] if i < len(persistent_batch) else [],
|
||||
events=events_batch[i] if i < len(events_batch) else [],
|
||||
t=_batch_value(timestamp, i),
|
||||
sample_idx=int(_batch_value(complementary_data.get("index", 0), i)),
|
||||
task=_batch_value(complementary_data.get("task"), i),
|
||||
dataset_ctx=self.dataset_ctx,
|
||||
)
|
||||
if rendered is None:
|
||||
rendered = _fallback_low_level_render(_batch_value(complementary_data.get("task"), i))
|
||||
if rendered is None:
|
||||
continue
|
||||
keep_indices.append(i)
|
||||
messages.append(rendered["messages"])
|
||||
message_streams.append(rendered["message_streams"])
|
||||
target_message_indices.append(rendered["target_message_indices"])
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
new_transition = (
|
||||
_select_batch_indices(transition, keep_indices)
|
||||
if len(keep_indices) != batch_size
|
||||
else transition.copy()
|
||||
)
|
||||
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
|
||||
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
|
||||
new_complementary_data.pop(LANGUAGE_EVENTS, None)
|
||||
new_complementary_data["messages"] = messages
|
||||
new_complementary_data["message_streams"] = message_streams
|
||||
new_complementary_data["target_message_indices"] = target_message_indices
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Pass features through unchanged; rendering only touches complementary data."""
|
||||
return features
|
||||
|
||||
|
||||
def _scalar(value: Any) -> float | int:
|
||||
"""Unwrap a tensor/array/single-element list into a Python scalar."""
|
||||
if hasattr(value, "item"):
|
||||
return value.item()
|
||||
if isinstance(value, list):
|
||||
if len(value) != 1:
|
||||
raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}")
|
||||
return _scalar(value[0])
|
||||
return value
|
||||
|
||||
|
||||
def _is_batched_language(value: Any) -> bool:
|
||||
return isinstance(value, list) and bool(value) and isinstance(value[0], list)
|
||||
|
||||
|
||||
def _batch_value(value: Any, index: int) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, list):
|
||||
return value[index]
|
||||
if hasattr(value, "ndim") and getattr(value, "ndim") > 0:
|
||||
return _scalar(value[index])
|
||||
return _scalar(value)
|
||||
|
||||
|
||||
def _select_batch_indices(transition: EnvTransition, indices: list[int]) -> EnvTransition:
|
||||
selected = transition.copy()
|
||||
for key in (TransitionKey.OBSERVATION, TransitionKey.COMPLEMENTARY_DATA):
|
||||
data = selected.get(key)
|
||||
if isinstance(data, dict):
|
||||
selected[key] = {k: _select_value(v, indices) for k, v in data.items()}
|
||||
action = selected.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
selected[TransitionKey.ACTION] = _select_value(action, indices)
|
||||
return selected
|
||||
|
||||
|
||||
def _select_value(value: Any, indices: list[int]) -> Any:
|
||||
if isinstance(value, list) and len(value) >= len(indices):
|
||||
return [value[i] for i in indices]
|
||||
if hasattr(value, "index_select") and hasattr(value, "new_tensor") and getattr(value, "ndim", 0) > 0:
|
||||
return value.index_select(0, value.new_tensor(indices).long())
|
||||
return value
|
||||
|
||||
|
||||
def _fallback_low_level_render(task: Any) -> dict[str, Any] | None:
|
||||
"""Keep action-only samples trainable when no recipe branch matches."""
|
||||
if hasattr(task, "item"):
|
||||
task = task.item()
|
||||
if not isinstance(task, str) or not task:
|
||||
return None
|
||||
return {
|
||||
"messages": [{"role": "user", "content": task}],
|
||||
"message_streams": ["low_level"],
|
||||
"target_message_indices": [],
|
||||
}
|
||||
|
||||
@@ -32,7 +32,6 @@ import torch
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.types import EnvTransition, RobotObservation, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
ACTION_CODE_TOKEN_MASK,
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -413,15 +412,14 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
# During inference, no action is available, skip tokenization
|
||||
return new_transition
|
||||
|
||||
# Tokenize and get masks for the full formatted sequence and the discrete action codes.
|
||||
tokens, mask, code_mask = self._tokenize_action(action)
|
||||
# Tokenize and get both tokens and mask
|
||||
tokens, mask = self._tokenize_action(action)
|
||||
|
||||
# Store mask in complementary data
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if complementary_data is None:
|
||||
complementary_data = {}
|
||||
complementary_data[ACTION_TOKEN_MASK] = mask
|
||||
complementary_data[ACTION_CODE_TOKEN_MASK] = code_mask
|
||||
complementary_data[ACTION_TOKENS] = tokens
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
return new_transition
|
||||
@@ -432,7 +430,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
"""
|
||||
return self._paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
|
||||
|
||||
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Tokenizes the action tensor and creates a mask.
|
||||
|
||||
@@ -461,7 +459,6 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
# The fast tokenizer expects action data and returns token IDs
|
||||
tokens_list = []
|
||||
masks_list = []
|
||||
code_masks_list = []
|
||||
|
||||
for i in range(batch_size):
|
||||
# Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy)
|
||||
@@ -479,26 +476,19 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
if tokens.dim() > 1:
|
||||
tokens = tokens.flatten()
|
||||
|
||||
action_code_tokens = self._act_tokens_to_paligemma_tokens(tokens)
|
||||
bos_id = self._paligemma_tokenizer.bos_token_id
|
||||
prompt_tokens = torch.tensor(
|
||||
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
|
||||
device=action.device,
|
||||
)
|
||||
end_tokens = torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device)
|
||||
|
||||
code_start = 1 + len(prompt_tokens)
|
||||
code_end = code_start + len(action_code_tokens)
|
||||
# add bos
|
||||
tokens = torch.cat(
|
||||
[
|
||||
torch.tensor([bos_id], device=action.device),
|
||||
prompt_tokens,
|
||||
action_code_tokens,
|
||||
end_tokens,
|
||||
torch.tensor(
|
||||
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
|
||||
device=action.device,
|
||||
),
|
||||
self._act_tokens_to_paligemma_tokens(tokens),
|
||||
torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),
|
||||
]
|
||||
)
|
||||
code_mask = torch.zeros(len(tokens), dtype=torch.bool, device=action.device)
|
||||
code_mask[code_start:code_end] = True
|
||||
|
||||
# Truncate or pad to max_action_tokens
|
||||
if len(tokens) > self.max_action_tokens:
|
||||
@@ -507,49 +497,44 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
"Consider increasing the `max_action_tokens` in your model config if this happens frequently."
|
||||
)
|
||||
tokens = tokens[: self.max_action_tokens]
|
||||
code_mask = code_mask[: self.max_action_tokens]
|
||||
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
|
||||
else:
|
||||
pad_len = self.max_action_tokens - len(tokens)
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.ones(len(tokens), dtype=torch.bool, device=action.device),
|
||||
torch.zeros(pad_len, dtype=torch.bool, device=action.device),
|
||||
torch.zeros(
|
||||
self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device
|
||||
),
|
||||
]
|
||||
)
|
||||
code_mask = torch.nn.functional.pad(code_mask, (0, pad_len), value=False)
|
||||
# Pad tokens with zeros
|
||||
tokens = torch.nn.functional.pad(tokens, (0, pad_len), value=0)
|
||||
tokens = torch.nn.functional.pad(tokens, (0, self.max_action_tokens - len(tokens)), value=0)
|
||||
|
||||
tokens_list.append(tokens)
|
||||
masks_list.append(mask)
|
||||
code_masks_list.append(code_mask)
|
||||
|
||||
# Stack into batched tensors
|
||||
tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens)
|
||||
masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens)
|
||||
code_masks_batch = torch.stack(code_masks_list, dim=0) # (B, max_action_tokens)
|
||||
|
||||
# Remove batch dimension if input was single sample
|
||||
if single_sample:
|
||||
tokens_batch = tokens_batch.squeeze(0)
|
||||
masks_batch = masks_batch.squeeze(0)
|
||||
code_masks_batch = code_masks_batch.squeeze(0)
|
||||
|
||||
# Move to the same device as the input
|
||||
if device is not None:
|
||||
tokens_batch = tokens_batch.to(device)
|
||||
masks_batch = masks_batch.to(device)
|
||||
code_masks_batch = code_masks_batch.to(device)
|
||||
|
||||
return tokens_batch, masks_batch, code_masks_batch
|
||||
return tokens_batch, masks_batch
|
||||
|
||||
def action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This method is not used since we override __call__.
|
||||
Required by ActionProcessorStep ABC.
|
||||
"""
|
||||
tokens, _, _ = self._tokenize_action(action)
|
||||
tokens, _ = self._tokenize_action(action)
|
||||
return tokens
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
|
||||
@@ -21,8 +21,6 @@ from lerobot.utils.import_utils import make_device_from_device_class
|
||||
from .config import RobotConfig
|
||||
from .robot import Robot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
# TODO(Steven): Consider just using the make_device_from_device_class for all types
|
||||
@@ -120,7 +118,7 @@ def ensure_safe_goal_position(
|
||||
}
|
||||
|
||||
if warnings_dict:
|
||||
logger.warning(
|
||||
logging.warning(
|
||||
"Relative goal position magnitude had to be clamped to be safe.\n"
|
||||
f"{pformat(warnings_dict, indent=4)}"
|
||||
)
|
||||
|
||||
@@ -1,345 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Build a single combined LeRobotDataset from RoboCasa's 16 composite_seen tasks.
|
||||
|
||||
RoboCasa 1.0 already ships in LeRobot format (parquet + mp4), distributed as
|
||||
``lerobot.tar`` archives from Box. This script:
|
||||
|
||||
1. Downloads each composite_seen task's ``target/human`` archive via RoboCasa's
|
||||
official ``download_datasets`` helper (idempotent — skipped if already on
|
||||
disk).
|
||||
2. Opens each extracted directory as a ``LeRobotDataset``.
|
||||
3. Merges all 16 into one unified dataset via ``merge_datasets`` (a thin wrapper
|
||||
over ``aggregate_datasets`` that revalidates fps / robot_type / features,
|
||||
unifies task indices, concatenates videos and parquet, and recomputes stats).
|
||||
4. Optionally pushes the merged dataset to the Hub.
|
||||
|
||||
The result is one ~8,000-trajectory dataset where each episode carries its
|
||||
source task as the ``task`` field — ready for downstream annotation
|
||||
(subtasks / memory / VQA / tool calls) without per-task bookkeeping.
|
||||
|
||||
Usage::
|
||||
|
||||
uv run python -m lerobot.scripts.build_robocasa_composite_seen \\
|
||||
--output-dir=/data/lerobot/robocasa_composite_seen \\
|
||||
--hub-repo-id=${HF_USER}/robocasa_composite_seen \\
|
||||
--push-to-hub
|
||||
|
||||
Prereqs: ``robocasa`` and ``robosuite`` installed (see
|
||||
``docs/source/benchmarks/robocasa.mdx`` for the editable-install dance — they
|
||||
are not on PyPI and RoboCasa's own ``setup.py`` pins an old LeRobot version).
|
||||
|
||||
The 16 composite_seen tasks are the multi-step subset of the official
|
||||
RoboCasa365 target benchmark — exactly the slice used to compute the
|
||||
``Composite-Seen`` column of the leaderboard.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.datasets.dataset_tools import merge_datasets
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Canonical 16 composite_seen tasks (RoboCasa365 target benchmark).
|
||||
# Order matches the leaderboard docs.
|
||||
COMPOSITE_SEEN_TASKS: list[str] = [
|
||||
"DeliverStraw",
|
||||
"GetToastedBread",
|
||||
"KettleBoiling",
|
||||
"LoadDishwasher",
|
||||
"PackIdenticalLunches",
|
||||
"PreSoakPan",
|
||||
"PrepareCoffee",
|
||||
"RinseSinkBasin",
|
||||
"ScrubCuttingBoard",
|
||||
"SearingMeat",
|
||||
"SetUpCuttingStation",
|
||||
"StackBowlsCabinet",
|
||||
"SteamInMicrowave",
|
||||
"StirVegetables",
|
||||
"StoreLeftoversInBowl",
|
||||
"WashLettuce",
|
||||
]
|
||||
|
||||
|
||||
def _require_robocasa() -> None:
|
||||
"""Fail fast with an actionable message if robocasa is missing.
|
||||
|
||||
RoboCasa is not on PyPI and is not a LeRobot extra — see the installation
|
||||
notes in ``docs/source/benchmarks/robocasa.mdx``.
|
||||
"""
|
||||
try:
|
||||
import robocasa # noqa: F401, PLC0415
|
||||
from robocasa.scripts import download_datasets as _dl # noqa: F401, PLC0415
|
||||
from robocasa.utils import dataset_registry as _reg # noqa: F401, PLC0415
|
||||
except ImportError as exc:
|
||||
sys.exit(
|
||||
"[build_robocasa_composite_seen] robocasa is not importable.\n"
|
||||
"Install it (and robosuite) per the LeRobot RoboCasa docs:\n"
|
||||
" git clone https://github.com/robocasa/robocasa.git ~/robocasa\n"
|
||||
" git clone https://github.com/ARISE-Initiative/robosuite.git ~/robosuite\n"
|
||||
" pip install -e ~/robocasa --no-deps\n"
|
||||
" pip install -e ~/robosuite\n"
|
||||
f"(original error: {exc})"
|
||||
)
|
||||
|
||||
|
||||
def _resolve_task_root(task: str) -> Path:
|
||||
"""Resolve the local extracted ``LeRobotDataset`` root for a target/human task.
|
||||
|
||||
Uses RoboCasa's own ``dataset_registry`` so we follow whatever directory
|
||||
layout RoboCasa picks (currently ``v1.0/target/composite/<task>/<date>/``
|
||||
under ``robocasa.macros.DATASET_BASE_DIR``). Falls back to discovering the
|
||||
extracted directory if the helper's signature drifted between releases.
|
||||
"""
|
||||
from robocasa.utils import dataset_registry # noqa: PLC0415
|
||||
|
||||
# ``get_ds_path`` is the canonical helper. RoboCasa 1.0 signature is
|
||||
# ``get_ds_path(task, ds_type, return_info=False)`` with ``ds_type`` like
|
||||
# ``"human_im"`` (image-observation human demos). We try the common
|
||||
# ``split=`` kwarg first (newer registry); if it's rejected, fall back.
|
||||
try:
|
||||
ds_path = dataset_registry.get_ds_path(
|
||||
task=task,
|
||||
ds_type="human_im",
|
||||
return_info=False,
|
||||
split="target",
|
||||
)
|
||||
except TypeError:
|
||||
# Older registry — ds_type alone disambiguates target/human.
|
||||
ds_path = dataset_registry.get_ds_path(
|
||||
task=task,
|
||||
ds_type="human_im",
|
||||
return_info=False,
|
||||
)
|
||||
|
||||
root = Path(ds_path)
|
||||
# ``get_ds_path`` may return either the extracted dir or the .tar; normalize.
|
||||
if root.suffix == ".tar":
|
||||
root = root.parent
|
||||
return root
|
||||
|
||||
|
||||
def _download_task(task: str, *, overwrite: bool = False) -> Path:
|
||||
"""Download (or locate) a single target/human task and return its extracted root."""
|
||||
from robocasa.scripts import download_datasets as dl # noqa: PLC0415
|
||||
|
||||
# Try the documented programmatic API. The CLI is
|
||||
# python -m robocasa.scripts.download_datasets --tasks <T> --source human --split target
|
||||
# which is a thin wrapper over a function of the same name.
|
||||
if hasattr(dl, "download_datasets"):
|
||||
try:
|
||||
dl.download_datasets(
|
||||
tasks=[task],
|
||||
source="human",
|
||||
split="target",
|
||||
overwrite=overwrite,
|
||||
)
|
||||
except TypeError:
|
||||
# Older signature — drop the kwargs RoboCasa didn't have yet.
|
||||
dl.download_datasets(tasks=[task])
|
||||
else:
|
||||
# No public function — shell out to the CLI as a last resort. This
|
||||
# guarantees we use whatever entrypoint RoboCasa's authors maintain.
|
||||
import subprocess # noqa: PLC0415
|
||||
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"robocasa.scripts.download_datasets",
|
||||
"--tasks",
|
||||
task,
|
||||
"--source",
|
||||
"human",
|
||||
"--split",
|
||||
"target",
|
||||
]
|
||||
if overwrite:
|
||||
cmd.append("--overwrite")
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
root = _resolve_task_root(task)
|
||||
if not root.exists():
|
||||
raise RuntimeError(
|
||||
f"Expected {root} after download, but it doesn't exist. "
|
||||
"RoboCasa may have changed its data layout — verify with "
|
||||
"`robocasa.utils.dataset_registry.get_ds_path()`."
|
||||
)
|
||||
return root
|
||||
|
||||
|
||||
def _open_as_lerobot_dataset(task: str, root: Path) -> LeRobotDataset:
|
||||
"""Open an extracted RoboCasa target/human task as a ``LeRobotDataset``.
|
||||
|
||||
The placeholder ``repo_id`` (``robocasa/<task>_target_human``) is only used
|
||||
by the aggregator for logging and for the unified task table — the actual
|
||||
data is loaded from ``root``.
|
||||
"""
|
||||
repo_id = f"robocasa/{task}_target_human"
|
||||
return LeRobotDataset(repo_id=repo_id, root=root)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Aggregate the 16 RoboCasa composite_seen target tasks into one LeRobotDataset.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Local directory for the merged dataset (will be created).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub-repo-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Hub repo_id for the merged dataset (e.g. ``yourname/"
|
||||
"robocasa_composite_seen``). Required for ``--push-to-hub``; also "
|
||||
"becomes the merged dataset's canonical ``repo_id``."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Push the merged dataset to the Hub after building. Requires "
|
||||
"``--hub-repo-id`` and a prior ``huggingface-cli login``.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--private",
|
||||
action="store_true",
|
||||
help="When pushing, create the Hub repo as private.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tasks",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Comma-separated task names to override the default 16 "
|
||||
"composite_seen list (useful for smoke-testing with 1–2 tasks).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-download",
|
||||
action="store_true",
|
||||
help="Skip the download step entirely; assume each task is already "
|
||||
"extracted on disk at the path ``dataset_registry.get_ds_path`` "
|
||||
"returns.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite-download",
|
||||
action="store_true",
|
||||
help="Force re-download even when a complete local extraction exists.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
default="INFO",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, args.log_level),
|
||||
format="[%(levelname)s] %(message)s",
|
||||
)
|
||||
|
||||
tasks = (
|
||||
[t.strip() for t in args.tasks.split(",") if t.strip()]
|
||||
if args.tasks
|
||||
else list(COMPOSITE_SEEN_TASKS)
|
||||
)
|
||||
if not tasks:
|
||||
sys.exit("No tasks selected.")
|
||||
|
||||
if args.push_to_hub and not args.hub_repo_id:
|
||||
sys.exit("--push-to-hub requires --hub-repo-id.")
|
||||
|
||||
output_repo_id = args.hub_repo_id or "local/robocasa_composite_seen"
|
||||
logger.info(
|
||||
"Building merged RoboCasa dataset: %d tasks → %s (output dir: %s)",
|
||||
len(tasks),
|
||||
output_repo_id,
|
||||
args.output_dir,
|
||||
)
|
||||
|
||||
_require_robocasa()
|
||||
|
||||
# 1. Download (or locate) each task's extracted directory.
|
||||
task_roots: list[tuple[str, Path]] = []
|
||||
for i, task in enumerate(tasks, 1):
|
||||
logger.info("[%d/%d] %s", i, len(tasks), task)
|
||||
if args.skip_download:
|
||||
root = _resolve_task_root(task)
|
||||
if not root.exists():
|
||||
sys.exit(
|
||||
f"--skip-download set but extracted directory does not "
|
||||
f"exist for {task}: {root}"
|
||||
)
|
||||
else:
|
||||
root = _download_task(task, overwrite=args.overwrite_download)
|
||||
logger.info(" extracted at: %s", root)
|
||||
task_roots.append((task, root))
|
||||
|
||||
# 2. Open each as a LeRobotDataset (validation happens inside aggregator).
|
||||
datasets: list[LeRobotDataset] = []
|
||||
for task, root in task_roots:
|
||||
logger.info("Opening %s", task)
|
||||
ds = _open_as_lerobot_dataset(task, root)
|
||||
logger.info(
|
||||
" %s: %d episodes, %d frames, %d FPS",
|
||||
task,
|
||||
ds.num_episodes,
|
||||
ds.num_frames,
|
||||
ds.fps,
|
||||
)
|
||||
datasets.append(ds)
|
||||
|
||||
# 3. Merge — re-validates features/fps/robot_type, unifies tasks, concats
|
||||
# videos + parquet, recomputes stats.
|
||||
logger.info("Merging %d datasets into %s", len(datasets), output_repo_id)
|
||||
merged = merge_datasets(
|
||||
datasets=datasets,
|
||||
output_repo_id=output_repo_id,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
logger.info(
|
||||
"Merged: %d episodes, %d frames across %d unique task strings",
|
||||
merged.num_episodes,
|
||||
merged.num_frames,
|
||||
len(merged.meta.tasks) if merged.meta.tasks is not None else 0,
|
||||
)
|
||||
|
||||
# 4. Push to Hub.
|
||||
if args.push_to_hub:
|
||||
logger.info("Pushing %s to the Hub (private=%s)", args.hub_repo_id, args.private)
|
||||
# ``upload_large_folder=True`` is the right mode for tens-of-GB
|
||||
# datasets — uses multipart uploads + resumable transfers.
|
||||
merged.push_to_hub(
|
||||
private=args.private,
|
||||
upload_large_folder=True,
|
||||
tags=["lerobot", "robocasa", "composite_seen", "manipulation"],
|
||||
)
|
||||
logger.info(
|
||||
"Push complete: https://huggingface.co/datasets/%s",
|
||||
args.hub_repo_id,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Skipping Hub push (no --push-to-hub). Merged dataset is at %s.",
|
||||
args.output_dir,
|
||||
)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
File diff suppressed because it is too large
Load Diff
@@ -20,7 +20,6 @@ Requires: pip install 'lerobot[training]' (includes dataset + accelerate + wand
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pprint import pformat
|
||||
@@ -44,7 +43,7 @@ from lerobot.common.train_utils import (
|
||||
from lerobot.common.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets import EpisodeAwareSampler, WeightedEpisodeAwareSampler, make_dataset
|
||||
from lerobot.datasets import EpisodeAwareSampler, make_dataset
|
||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
@@ -162,196 +161,6 @@ def update_policy(
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
def _print_debug_text_predictions(
|
||||
policy: Any, batch: dict[str, Any], step: int, n_samples: int = 5
|
||||
) -> None:
|
||||
"""Forward the current batch and print head-argmax vs label per supervised position.
|
||||
|
||||
Opt-in via ``LEROBOT_DEBUG_PREDS_EVERY=<step_interval>``. Only the
|
||||
policy types that expose ``debug_text_predictions`` participate
|
||||
(currently PI052); others are silently skipped. Pretty-prints up to
|
||||
``n_samples`` samples from the current batch, showing the prompt,
|
||||
every supervised position's (label, prediction, ✓/✗), and a
|
||||
per-sample token-accuracy summary — the cheapest "is text training
|
||||
actually learning anything" signal.
|
||||
"""
|
||||
# Accelerator/DDP wraps the policy in a ``module`` attribute and
|
||||
# doesn't proxy custom methods through, so a naive
|
||||
# ``hasattr(policy, "debug_text_predictions")`` returns False on the
|
||||
# wrapper — and the helper would silently no-op. Walk through any
|
||||
# ``.module`` indirection (DDP, FSDP, ``accelerator.prepare`` wrappers)
|
||||
# to reach the raw policy that actually defines the method.
|
||||
inner = policy
|
||||
while hasattr(inner, "module") and not hasattr(inner, "debug_text_predictions"):
|
||||
inner = inner.module
|
||||
if not hasattr(inner, "debug_text_predictions"):
|
||||
logging.warning(
|
||||
"LEROBOT_DEBUG_PREDS_EVERY set but policy %s has no "
|
||||
"debug_text_predictions method — skipping dump.",
|
||||
type(inner).__name__,
|
||||
)
|
||||
return
|
||||
try:
|
||||
debug = inner.debug_text_predictions(batch, max_samples=n_samples)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning("debug_text_predictions failed: %s", exc, exc_info=True)
|
||||
return
|
||||
if not debug:
|
||||
logging.warning(
|
||||
"debug_text_predictions returned no supervised samples — "
|
||||
"current batch has no text labels."
|
||||
)
|
||||
return
|
||||
policy = inner # used below for select_message-style decoding parity
|
||||
|
||||
# Build a tokenizer for decoding — match training side exactly.
|
||||
try:
|
||||
from transformers import AutoTokenizer # noqa: PLC0415
|
||||
|
||||
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: PLC0415
|
||||
register_paligemma_loc_tokens,
|
||||
)
|
||||
|
||||
tok_name = (
|
||||
getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224"
|
||||
)
|
||||
tokenizer = register_paligemma_loc_tokens(AutoTokenizer.from_pretrained(tok_name))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning("debug preds: tokenizer load failed: %s", exc)
|
||||
return
|
||||
|
||||
ids = debug["input_ids"]
|
||||
labels = debug["labels"]
|
||||
preds = debug["predictions"]
|
||||
attn = debug["attention_mask"]
|
||||
inference = debug.get("inference") or []
|
||||
|
||||
n = ids.shape[0]
|
||||
print(
|
||||
f"\n========== STEP {step} DEBUG PREDICTIONS ({n} samples) ==========",
|
||||
flush=True,
|
||||
)
|
||||
for s in range(n):
|
||||
a = attn[s].tolist()
|
||||
real = sum(a)
|
||||
sid = ids[s].tolist()
|
||||
sl = labels[s].tolist()
|
||||
sp = preds[s].tolist()
|
||||
prompt = tokenizer.decode(sid[:real], skip_special_tokens=False)
|
||||
print(f"\n --- sample {s + 1}/{n} ---", flush=True)
|
||||
print(f" prompt: {prompt!r}", flush=True)
|
||||
|
||||
# Ground-truth target (the contiguous supervised label span).
|
||||
sup_ids = [int(sid[i]) for i in range(real) if sl[i] != -100]
|
||||
if sup_ids:
|
||||
print(
|
||||
f" target (ground truth) : {tokenizer.decode(sup_ids, skip_special_tokens=False)!r}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Training-side teacher-forced argmax on the same prompt+target.
|
||||
n_sup = n_ok = 0
|
||||
first_sup_pred: int | None = None
|
||||
teacher_chars: list[int] = []
|
||||
for i in range(1, real):
|
||||
label = sl[i]
|
||||
if label == -100:
|
||||
continue
|
||||
n_sup += 1
|
||||
pred = int(sp[i - 1])
|
||||
if first_sup_pred is None:
|
||||
first_sup_pred = pred
|
||||
teacher_chars.append(pred)
|
||||
if label == pred:
|
||||
n_ok += 1
|
||||
teacher_text = (
|
||||
tokenizer.decode(teacher_chars, skip_special_tokens=False) if teacher_chars else ""
|
||||
)
|
||||
acc = n_ok / max(n_sup, 1)
|
||||
print(
|
||||
f" training argmax (teacher-fed) : {teacher_text!r} acc={n_ok}/{n_sup}={acc:.1%}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Inference-side autoregressive output from the same prompt prefix.
|
||||
inf_entry = inference[s] if s < len(inference) else None
|
||||
if inf_entry:
|
||||
inf_decoded = inf_entry.get("decoded", "")
|
||||
print(f" inference (autoregressive) : {inf_decoded!r}", flush=True)
|
||||
# First-token parity: training-side argmax at the prompt-end
|
||||
# position MUST equal inference's first generated token —
|
||||
# both compute argmax(lm_head(h_last_prompt)) on identical
|
||||
# context. Any divergence signals a training↔inference bug.
|
||||
if first_sup_pred is not None and inf_decoded and not inf_decoded.startswith("<inference"):
|
||||
inf_ids = tokenizer(inf_decoded, add_special_tokens=False)["input_ids"]
|
||||
if inf_ids:
|
||||
inf_first = int(inf_ids[0])
|
||||
match = inf_first == first_sup_pred
|
||||
print(
|
||||
f" first-token parity : "
|
||||
f"train={first_sup_pred} ({tokenizer.decode([first_sup_pred])!r}) "
|
||||
f"vs infer={inf_first} ({tokenizer.decode([inf_first])!r}) "
|
||||
f"{'✓ MATCH' if match else '✗ DIVERGED — training/inference mismatch'}",
|
||||
flush=True,
|
||||
)
|
||||
print("=" * 60 + "\n", flush=True)
|
||||
|
||||
|
||||
def _build_vqa_oversample_weights(dataset: Any, target_fraction: float) -> "torch.Tensor | None":
|
||||
"""Build per-frame sampling weights that oversample VQA-annotated frames.
|
||||
|
||||
Scans the dataset's ``language_events`` column for frames carrying a
|
||||
``vqa``-style annotation and returns a weight tensor (length == total
|
||||
dataset frames) such that, under multinomial sampling, VQA frames make up
|
||||
roughly ``target_fraction`` of the training stream.
|
||||
|
||||
Returns ``None`` (⇒ fall back to uniform episode-aware sampling) when VQA
|
||||
frames cannot be detected or there are none.
|
||||
"""
|
||||
if not 0.0 < target_fraction < 1.0:
|
||||
logging.warning(
|
||||
"vqa_target_fraction must be in (0, 1); got %s — VQA oversampling disabled.",
|
||||
target_fraction,
|
||||
)
|
||||
return None
|
||||
hf = getattr(dataset, "hf_dataset", None)
|
||||
if hf is None or "language_events" not in getattr(hf, "column_names", []):
|
||||
logging.warning(
|
||||
"Dataset has no `language_events` column — VQA oversampling disabled."
|
||||
)
|
||||
return None
|
||||
|
||||
events_col = hf["language_events"]
|
||||
n_frames = len(events_col)
|
||||
is_vqa = torch.zeros(n_frames, dtype=torch.bool)
|
||||
for i, rows in enumerate(events_col):
|
||||
if rows and any((row or {}).get("style") == "vqa" for row in rows):
|
||||
is_vqa[i] = True
|
||||
|
||||
n_vqa = int(is_vqa.sum())
|
||||
if n_vqa == 0:
|
||||
logging.warning("No `vqa` annotations found in the dataset — VQA oversampling disabled.")
|
||||
return None
|
||||
n_other = n_frames - n_vqa
|
||||
|
||||
# Solve target = (n_vqa·w) / (n_vqa·w + n_other) for the VQA weight w.
|
||||
# Clamp to ≥ 1 so VQA frames are never *down*-weighted below uniform.
|
||||
weight = (target_fraction * n_other) / ((1.0 - target_fraction) * max(n_vqa, 1))
|
||||
weight = max(weight, 1.0)
|
||||
weights = torch.ones(n_frames, dtype=torch.double)
|
||||
weights[is_vqa] = weight
|
||||
logging.info(
|
||||
"VQA oversampling: %d/%d frames carry a `vqa` annotation (%.2f%%); "
|
||||
"weighting them x%.2f to target ~%.0f%% of the training stream.",
|
||||
n_vqa,
|
||||
n_frames,
|
||||
100.0 * n_vqa / n_frames,
|
||||
weight,
|
||||
100.0 * target_fraction,
|
||||
)
|
||||
return weights
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
"""
|
||||
@@ -483,17 +292,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
active_cfg = cfg.trainable_config
|
||||
processor_pretrained_path = active_cfg.pretrained_path
|
||||
# pi052: even when loading pretrained weights, build the processors
|
||||
# from the current pi052 config so the recipe text-label and FAST
|
||||
# action-label steps are generated and not silently swapped for the
|
||||
# checkpoint's older processor stack.
|
||||
if cfg.policy.type == "pi052" and processor_pretrained_path is not None and not cfg.resume:
|
||||
logging.warning(
|
||||
"pi052 is loading pretrained weights from %s, but building processors from the current "
|
||||
"pi052 config so recipe text labels and FAST action labels are generated.",
|
||||
processor_pretrained_path,
|
||||
)
|
||||
processor_pretrained_path = None
|
||||
if (
|
||||
getattr(active_cfg, "use_relative_actions", False)
|
||||
and processor_pretrained_path is not None
|
||||
@@ -513,14 +311,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if cfg.is_reward_model_training:
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
# For pi052 (and any future policy that auto-fits part of its
|
||||
# preprocessing per-dataset), pass the dataset repo id so the
|
||||
# processor factory can locate/refresh dataset-specific artifacts
|
||||
# (e.g. fitted FAST tokenizers per Pertsch et al. 2025 [64],
|
||||
# π0.5 §III.C).
|
||||
if cfg.policy.type == "pi052":
|
||||
processor_kwargs["dataset_repo_id"] = cfg.dataset.repo_id
|
||||
|
||||
if not cfg.is_reward_model_training and processor_pretrained_path is not None:
|
||||
processor_kwargs["preprocessor_overrides"] = {
|
||||
"device_processor": {"device": device.type},
|
||||
@@ -601,29 +391,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# create dataloader for offline training
|
||||
if hasattr(active_cfg, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
from_indices = dataset.meta.episodes["dataset_from_index"]
|
||||
to_indices = dataset.meta.episodes["dataset_to_index"]
|
||||
# When `vqa_target_fraction` is set, oversample VQA-annotated
|
||||
# frames via a weighted sampler; otherwise plain episode-aware.
|
||||
vqa_weights = None
|
||||
if cfg.vqa_target_fraction is not None and not cfg.dataset.streaming:
|
||||
vqa_weights = _build_vqa_oversample_weights(dataset, cfg.vqa_target_fraction)
|
||||
if vqa_weights is not None:
|
||||
sampler = WeightedEpisodeAwareSampler(
|
||||
from_indices,
|
||||
to_indices,
|
||||
vqa_weights,
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
)
|
||||
else:
|
||||
sampler = EpisodeAwareSampler(
|
||||
from_indices,
|
||||
to_indices,
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
@@ -654,54 +428,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
policy.train()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# EMA setup
|
||||
# ------------------------------------------------------------------
|
||||
# Shadow copy of the trainable params for late-training averaging
|
||||
# (Chi et al. 2023 Diffusion Policy §V.D; openpi JAX trainer ships
|
||||
# this with decay=0.999 for pi05_libero; openpi PyTorch port and
|
||||
# LeRobot main both skip it). Off by default; opt in with
|
||||
# ``--ema.enable=true``. Implemented via ema-pytorch
|
||||
# (https://github.com/lucidrains/ema-pytorch) — the standard PyTorch
|
||||
# EMA library, also used by lucidrains' diffusion repos.
|
||||
ema = None
|
||||
if cfg.ema.enable and is_main_process:
|
||||
from ema_pytorch import EMA # noqa: PLC0415
|
||||
|
||||
ema = EMA(
|
||||
accelerator.unwrap_model(policy),
|
||||
beta=cfg.ema.decay,
|
||||
update_after_step=cfg.ema.warmup_steps,
|
||||
update_every=1, # update on every ema.update() call
|
||||
# Don't register the live model as an ema submodule — accelerator
|
||||
# already owns its lifecycle, and double-registration would
|
||||
# double-count its params in ``ema.state_dict()``.
|
||||
include_online_model=False,
|
||||
)
|
||||
ema.to(accelerator.device)
|
||||
logging.info(
|
||||
"EMA enabled (ema-pytorch): beta=%g, update_after_step=%d, "
|
||||
"use_for_eval=%s, use_for_wandb_examples=%s",
|
||||
cfg.ema.decay,
|
||||
cfg.ema.warmup_steps,
|
||||
cfg.ema.use_for_eval,
|
||||
cfg.ema.use_for_wandb_examples,
|
||||
)
|
||||
|
||||
# Resume the EMA shadow if a previous run wrote one.
|
||||
if cfg.checkpoint_path is not None:
|
||||
ema_path = cfg.checkpoint_path / "training_state" / "ema_state.pt"
|
||||
if ema_path.exists():
|
||||
logging.info("Resuming EMA shadow from %s", ema_path)
|
||||
try:
|
||||
ema.load_state_dict(torch.load(ema_path, map_location=accelerator.device))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning(
|
||||
"Failed to load EMA shadow (%s) — restarting EMA from "
|
||||
"current live weights",
|
||||
exc,
|
||||
)
|
||||
|
||||
train_metrics = {
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||
@@ -754,14 +480,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
sample_weighter=sample_weighter,
|
||||
)
|
||||
|
||||
# EMA update: pull one step of the live weights into the shadow.
|
||||
# Runs only on the main process (the shadow lives there); other
|
||||
# ranks rely on the live model staying in sync via accelerator.
|
||||
# ``ema-pytorch`` holds an internal reference to the online model
|
||||
# (set at construction), so ``ema.update()`` takes no args.
|
||||
if ema is not None:
|
||||
ema.update()
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
step += 1
|
||||
@@ -772,27 +490,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
||||
|
||||
# Optional periodic head-prediction dump for the LM head:
|
||||
# ``LEROBOT_DEBUG_PREDS_EVERY=1000`` prints 5 samples + per-token
|
||||
# (label, argmax, ✓/✗) every 1000 steps. Cheap diagnostic to see
|
||||
# whether the text head is actually learning what we expect, vs
|
||||
# collapsing to a fixed token. Refilling the recipe-sample dump
|
||||
# budget at the same cadence also redumps the raw input shapes.
|
||||
_debug_preds_every = int(os.environ.get("LEROBOT_DEBUG_PREDS_EVERY", "0"))
|
||||
if (
|
||||
_debug_preds_every > 0
|
||||
and step % _debug_preds_every == 0
|
||||
and is_main_process
|
||||
):
|
||||
try:
|
||||
from lerobot.policies.pi052 import text_processor_pi052 as _tp # noqa: PLC0415
|
||||
|
||||
_tp._DUMPED_SO_FAR = 0
|
||||
_tp._DUMP_BUDGET = max(_tp._DUMP_BUDGET, 5)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
_print_debug_text_predictions(policy, batch, step, n_samples=5)
|
||||
|
||||
if is_log_step:
|
||||
logging.info(train_tracker)
|
||||
if wandb_logger:
|
||||
@@ -803,49 +500,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if sample_weighter is not None:
|
||||
weighter_stats = sample_weighter.get_stats()
|
||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||
# EMA observability: ``ema.step`` is the count of
|
||||
# ``ema.update()`` calls (= optimizer steps once EMA is
|
||||
# enabled); ``ema.initted`` flips to True once we've
|
||||
# crossed ``update_after_step``.
|
||||
if ema is not None:
|
||||
wandb_log_dict["ema/step"] = int(ema.step.item())
|
||||
wandb_log_dict["ema/initted"] = float(ema.initted.item())
|
||||
wandb_log_dict["ema/beta"] = float(cfg.ema.decay)
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
# Periodic training-example dump to wandb (camera images + text
|
||||
# fields + action endpoints). Opt-in via ``--wandb.log_examples_freq``;
|
||||
# independent of ``--log_freq`` so you can keep scalar logs frequent
|
||||
# and the heavier visual dump rare (e.g. every 5000 steps).
|
||||
if (
|
||||
wandb_logger is not None
|
||||
and cfg.wandb.log_examples_freq > 0
|
||||
and step % cfg.wandb.log_examples_freq == 0
|
||||
and is_main_process
|
||||
):
|
||||
try:
|
||||
# Optionally use the EMA shadow model directly for the
|
||||
# predicted-action columns (matches what eval / deployment
|
||||
# would see). ``ema-pytorch`` exposes the shadow as a
|
||||
# full ``nn.Module`` at ``ema.ema_model``, so we just
|
||||
# pass that instead of swap-and-restore.
|
||||
target_policy = (
|
||||
ema.ema_model
|
||||
if (ema is not None and cfg.ema.use_for_wandb_examples)
|
||||
else accelerator.unwrap_model(policy)
|
||||
)
|
||||
wandb_logger.log_training_examples(
|
||||
batch=batch,
|
||||
step=step,
|
||||
camera_keys=list(dataset.meta.camera_keys),
|
||||
n_samples=cfg.wandb.log_examples_n,
|
||||
policy=target_policy,
|
||||
predict_actions=cfg.wandb.log_examples_predict_actions,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning("wandb log_training_examples failed: %s", exc)
|
||||
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
if is_main_process:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
@@ -861,18 +518,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
# Save the EMA shadow alongside the training state so a
|
||||
# resumed run picks up exactly where the live EMA left off.
|
||||
# ``ema-pytorch.state_dict()`` returns the full shadow
|
||||
# nn.Module's state dict + step/initted buffers; saved as
|
||||
# .pt (the rest of training_state mixes formats already).
|
||||
if ema is not None:
|
||||
try:
|
||||
ema_path = checkpoint_dir / "training_state" / "ema_state.pt"
|
||||
ema_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(ema.state_dict(), ema_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning("Failed to save EMA shadow: %s", exc)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
|
||||
@@ -882,20 +527,10 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if is_main_process:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
# Use the EMA shadow model for eval when enabled —
|
||||
# standard practice for diffusion-style policies (~1–3%
|
||||
# lift on closed-loop success). ``ema.ema_model`` is a
|
||||
# full nn.Module clone, so we just pass it through; no
|
||||
# swap/restore on the live policy needed.
|
||||
eval_target_policy = (
|
||||
ema.ema_model
|
||||
if (ema is not None and cfg.ema.use_for_eval)
|
||||
else accelerator.unwrap_model(policy)
|
||||
)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
eval_info = eval_policy_all(
|
||||
envs=eval_env, # dict[suite][task_id] -> vec_env
|
||||
policy=eval_target_policy,
|
||||
policy=accelerator.unwrap_model(policy),
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""LeRobot tool implementations.
|
||||
|
||||
Storage of the tool catalog (``meta/info.json["tools"]``) and the
|
||||
``SAY_TOOL_SCHEMA`` constant live in PR 1
|
||||
(``lerobot.datasets.language``). This package holds the *runnable*
|
||||
implementations one file per tool, plus the registry that maps tool
|
||||
names to classes.
|
||||
|
||||
See ``docs/source/tools.mdx`` for the authoring guide.
|
||||
"""
|
||||
|
||||
from .base import Tool
|
||||
from .registry import TOOL_REGISTRY, get_tools
|
||||
from .say import SayTool
|
||||
|
||||
__all__ = ["Tool", "TOOL_REGISTRY", "get_tools", "SayTool"]
|
||||
@@ -1,58 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tool protocol — the contract every runnable tool implementation honors.
|
||||
|
||||
Tools are the executable side of the OpenAI-style function-calling
|
||||
abstraction the v3.1 language schema (PR 1) carries on assistant
|
||||
messages: the schema describes *what can be called*, the tool
|
||||
implementation describes *how to call it*.
|
||||
|
||||
Implementations live one-per-file under :mod:`lerobot.tools` (e.g.
|
||||
``say.py`` for ``SayTool``) and are registered in
|
||||
:mod:`lerobot.tools.registry`. The runtime instantiates them lazily so
|
||||
heavy dependencies (torch models, audio backends, network clients,
|
||||
hardware drivers) only load when the dataset actually declares the tool.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Tool(Protocol):
|
||||
"""Minimum surface every tool must expose."""
|
||||
|
||||
#: Name matching ``schema["function"]["name"]``. The runtime dispatcher
|
||||
#: routes incoming ``tool_calls`` to the implementation by this key.
|
||||
name: str
|
||||
|
||||
#: OpenAI-style function-call schema. Same dict the dataset stores in
|
||||
#: ``meta/info.json["tools"]`` and the chat template renders into the
|
||||
#: prompt.
|
||||
schema: dict[str, Any]
|
||||
|
||||
def call(self, arguments: dict[str, Any]) -> Any:
|
||||
"""Execute the tool with the model-provided arguments.
|
||||
|
||||
``arguments`` is the parsed dict from
|
||||
``tool_calls[i]["function"]["arguments"]`` (already JSON-decoded
|
||||
when the model emits a JSON-string by the chat-template
|
||||
convention). Implementations validate the dict against their own
|
||||
schema; the runtime only routes by name.
|
||||
|
||||
Return value is implementation-defined — typically a tensor
|
||||
(TTS audio), a Path (saved file), a dict (structured result), or
|
||||
``None`` (side-effect-only call).
|
||||
"""
|
||||
@@ -1,70 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tool registry — name → implementation class.
|
||||
|
||||
Adding a new tool:
|
||||
|
||||
1. Drop a file under ``src/lerobot/tools/`` that defines a class
|
||||
conforming to :class:`lerobot.tools.base.Tool` (must expose ``name``,
|
||||
``schema``, ``call(arguments)``).
|
||||
2. Register the class here under :data:`TOOL_REGISTRY`.
|
||||
3. (Optional) Pre-populate ``meta/info.json["tools"]`` on your dataset
|
||||
to advertise the schema to the chat-template + policy. The PR 2
|
||||
annotation pipeline preserves anything you put there.
|
||||
|
||||
See ``docs/source/tools.mdx`` for the full authoring guide.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import Tool
|
||||
from .say import SayTool
|
||||
|
||||
#: Map from ``function.name`` to a class implementing :class:`Tool`.
|
||||
#: The runtime instantiates entries lazily — registering a tool here is
|
||||
#: essentially free (no model load happens until ``call`` runs).
|
||||
TOOL_REGISTRY: dict[str, type] = {
|
||||
"say": SayTool,
|
||||
}
|
||||
|
||||
|
||||
def get_tools(meta: Any, **kwargs: Any) -> dict[str, Tool]:
|
||||
"""Build name → tool-instance dict from a dataset's declared catalog.
|
||||
|
||||
``meta`` is anything with a ``.tools`` attribute returning the
|
||||
OpenAI-style schema list — typically a
|
||||
:class:`lerobot.datasets.dataset_metadata.LeRobotDatasetMetadata`.
|
||||
Each entry whose ``function.name`` is registered here is
|
||||
instantiated with the schema dict; tools whose name is unknown to
|
||||
the registry are skipped (the schema still rides through the chat
|
||||
template, the model just can't actually invoke that tool at
|
||||
inference).
|
||||
|
||||
Extra keyword arguments are forwarded to every constructor — useful
|
||||
for runtime defaults like ``output_dir=Path("./tts_log")``.
|
||||
"""
|
||||
declared = list(meta.tools)
|
||||
instances: dict[str, Tool] = {}
|
||||
for schema in declared:
|
||||
try:
|
||||
name = schema["function"]["name"]
|
||||
except (KeyError, TypeError):
|
||||
continue
|
||||
cls = TOOL_REGISTRY.get(name)
|
||||
if cls is None:
|
||||
continue
|
||||
instances[name] = cls(schema=schema, **kwargs)
|
||||
return instances
|
||||
@@ -1,169 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``SayTool`` — text-to-speech tool wrapping Kyutai's pocket-tts.
|
||||
|
||||
The first concrete tool implementation. PI052 and downstream runtime
|
||||
dispatchers consume this when the model emits an assistant message
|
||||
with ``tool_calls=[{function: {name: "say", arguments: {text: ...}}}]``.
|
||||
|
||||
Why pocket-tts:
|
||||
|
||||
- runs on CPU (no GPU dependency); ~6× real-time on a MacBook Air M4
|
||||
- ~100M parameters, ~200ms first-chunk latency
|
||||
- streamable, voice-cloneable
|
||||
- pip-installable, MIT-style permissive license
|
||||
|
||||
The pocket-tts model is loaded **lazily** the first time ``call(...)``
|
||||
runs (or eagerly via ``preload()``). Loading takes a few seconds and
|
||||
several hundred MB of RAM, so we don't pay the cost when the tool is
|
||||
merely *registered* — only when it's *invoked*.
|
||||
|
||||
Optional dependency. Install with::
|
||||
|
||||
pip install lerobot[tools]
|
||||
# or directly:
|
||||
pip install pocket-tts
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lerobot.datasets.language import SAY_TOOL_SCHEMA
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SayTool:
|
||||
"""Speak a short utterance via Kyutai's pocket-tts.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
schema:
|
||||
Optional schema override; defaults to the canonical
|
||||
``SAY_TOOL_SCHEMA`` from PR 1. Custom voices or extended
|
||||
argument shapes can pass in a modified schema, but the
|
||||
implementation only reads ``arguments["text"]``.
|
||||
voice:
|
||||
One of the pocket-tts catalog voices (``alba``, ``marius``,
|
||||
``javert``, ``jean``, ``fantine``, ``cosette``, ``eponine``,
|
||||
``azelma``) or a path to a ``.wav`` / ``.safetensors`` voice
|
||||
file for cloning. See the pocket-tts model card for licensing.
|
||||
output_dir:
|
||||
If set, every ``call(...)`` writes a ``<timestamp>.wav`` audio
|
||||
file there in addition to returning the PCM tensor.
|
||||
``None`` (default) skips disk writes — useful for live
|
||||
playback paths that hand the tensor directly to a sounddevice
|
||||
/ WebAudio sink.
|
||||
"""
|
||||
|
||||
schema: dict[str, Any] = field(default_factory=lambda: dict(SAY_TOOL_SCHEMA))
|
||||
voice: str = "alba"
|
||||
output_dir: Path | None = None
|
||||
|
||||
name: str = field(init=False, default="say")
|
||||
_model: Any = field(init=False, default=None, repr=False)
|
||||
_voice_state: Any = field(init=False, default=None, repr=False)
|
||||
_sample_rate: int = field(init=False, default=24000, repr=False)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lazy model load
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def preload(self) -> None:
|
||||
"""Load the pocket-tts model + voice state into memory.
|
||||
|
||||
Optional — ``call(...)`` triggers this automatically on first
|
||||
invocation. Useful when you want the multi-second load to
|
||||
happen at startup rather than on the first ``say`` the policy
|
||||
emits.
|
||||
"""
|
||||
if self._model is not None and self._voice_state is not None:
|
||||
return
|
||||
try:
|
||||
from pocket_tts import TTSModel # noqa: PLC0415 (optional dep)
|
||||
except ImportError as exc: # pragma: no cover (env-dependent)
|
||||
raise ImportError(
|
||||
"SayTool requires pocket-tts. Install with `pip install "
|
||||
"lerobot[tools]` or `pip install pocket-tts`."
|
||||
) from exc
|
||||
logger.info("SayTool: loading pocket-tts model + voice=%r", self.voice)
|
||||
self._model = TTSModel.load_model()
|
||||
self._voice_state = self._model.get_state_for_audio_prompt(self.voice)
|
||||
self._sample_rate = int(getattr(self._model, "sample_rate", 24000))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool protocol
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def call(self, arguments: dict[str, Any]) -> Any:
|
||||
"""Speak ``arguments["text"]`` and return the PCM tensor.
|
||||
|
||||
Optionally also writes ``<output_dir>/<timestamp>.wav`` when
|
||||
``self.output_dir`` is set. The returned tensor is a 1-D
|
||||
``torch.Tensor`` of float32 PCM samples at
|
||||
``self.sample_rate`` Hz — directly playable by
|
||||
``sounddevice.play(audio.numpy(), self.sample_rate)`` or
|
||||
encodable by ``scipy.io.wavfile.write``.
|
||||
"""
|
||||
text = arguments.get("text")
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
raise ValueError(
|
||||
f"SayTool.call expects arguments={{'text': str}}, got {arguments!r}"
|
||||
)
|
||||
self.preload()
|
||||
|
||||
audio = self._model.generate_audio(self._voice_state, text)
|
||||
|
||||
if self.output_dir is not None:
|
||||
self._write_wav(audio, text)
|
||||
|
||||
return audio
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
"""PCM sample rate of the returned tensor (Hz)."""
|
||||
return self._sample_rate
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _write_wav(self, audio: Any, text: str) -> Path:
|
||||
"""Write a ``.wav`` next to ``output_dir`` for offline inspection."""
|
||||
import time as _time # noqa: PLC0415
|
||||
|
||||
try:
|
||||
import scipy.io.wavfile # noqa: PLC0415
|
||||
except ImportError as exc: # pragma: no cover
|
||||
raise ImportError(
|
||||
"SayTool.output_dir requires scipy. `pip install scipy`."
|
||||
) from exc
|
||||
|
||||
out_dir = Path(self.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
# One file per call; suffix with a millisecond timestamp + a
|
||||
# short text snippet so a directory listing is informative.
|
||||
snippet = "".join(c if c.isalnum() else "_" for c in text[:32]).strip("_")
|
||||
ts_ms = int(_time.time() * 1000)
|
||||
path = out_dir / f"say_{ts_ms}_{snippet}.wav"
|
||||
|
||||
# ``audio`` is a torch tensor; pocket-tts uses CPU, so a plain
|
||||
# ``.numpy()`` is safe.
|
||||
scipy.io.wavfile.write(path, self.sample_rate, audio.numpy())
|
||||
return path
|
||||
@@ -22,7 +22,7 @@ from torch.utils.data._utils.collate import default_collate
|
||||
|
||||
from lerobot.datasets.language import LANGUAGE_COLUMNS
|
||||
|
||||
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices", *LANGUAGE_COLUMNS}
|
||||
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices"}
|
||||
|
||||
|
||||
def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None:
|
||||
|
||||
@@ -34,7 +34,6 @@ ACTION = "action"
|
||||
ACTION_PREFIX = ACTION + "."
|
||||
ACTION_TOKENS = ACTION + ".tokens"
|
||||
ACTION_TOKEN_MASK = ACTION + ".token_mask"
|
||||
ACTION_CODE_TOKEN_MASK = ACTION + ".code_token_mask"
|
||||
REWARD = "next.reward"
|
||||
TRUNCATED = "next.truncated"
|
||||
DONE = "next.done"
|
||||
|
||||
@@ -40,6 +40,7 @@ from lerobot.datasets.language_render import render_sample
|
||||
|
||||
from ._helpers import make_canned_responder
|
||||
|
||||
|
||||
def _build_pr1_style_blend_recipe() -> TrainingRecipe:
|
||||
"""Inline blend recipe that consumes every style this pipeline produces.
|
||||
|
||||
@@ -144,29 +145,22 @@ def test_pr1_canonical_recipe_renders_nonempty_from_pipeline_output(
|
||||
recipe = _build_pr1_style_blend_recipe()
|
||||
|
||||
rendered_any = False
|
||||
for sample_idx, (ts, persistent, events) in enumerate(
|
||||
zip(timestamps, persistent_lists, events_lists, strict=True)
|
||||
):
|
||||
for ts, persistent, events in zip(timestamps, persistent_lists, events_lists, strict=True):
|
||||
result = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=persistent,
|
||||
events=events,
|
||||
t=float(ts),
|
||||
sample_idx=sample_idx,
|
||||
sample_idx=0,
|
||||
dataset_ctx={"task": "Pour water from the bottle into the cup."},
|
||||
)
|
||||
if result is None:
|
||||
continue
|
||||
if result["messages"]:
|
||||
rendered_any = True
|
||||
# A valid render supervises something: a text-CE target turn
|
||||
# OR a flow-only ``low_level``-stream turn (action loss).
|
||||
assert (
|
||||
result["target_message_indices"]
|
||||
or "low_level" in result["message_streams"]
|
||||
)
|
||||
assert result["target_message_indices"]
|
||||
break
|
||||
assert rendered_any, "recipe rendered no messages from pipeline output"
|
||||
assert rendered_any, "PR 1 recipe rendered no messages from pipeline output"
|
||||
|
||||
# Sanity: speech atom appears in events column intact
|
||||
flat_events = [r for ev in events_lists for r in ev]
|
||||
|
||||
@@ -29,15 +29,6 @@ def test_message_recipe_validates_unknown_binding():
|
||||
)
|
||||
|
||||
|
||||
def test_canonical_recipe_loads():
|
||||
"""The canonical PI052 blend YAML loads + validates."""
|
||||
recipe = TrainingRecipe.from_yaml(
|
||||
Path("src/lerobot/configs/recipes/subtask_mem_vqa_speech.yaml")
|
||||
)
|
||||
assert recipe.blend is not None
|
||||
assert sum(c.weight for c in recipe.blend.values()) == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_message_turn_requires_a_stream():
|
||||
"""Every turn must declare a stream — None is rejected at construction.
|
||||
|
||||
|
||||
@@ -343,84 +343,6 @@ def test_resolve_task_explicit_override_beats_rephrasings():
|
||||
assert rendered["messages"][0]["content"] == "explicit override wins"
|
||||
|
||||
|
||||
def test_flow_only_low_level_recipe_renders_without_target():
|
||||
"""Regression: a flow-only ``low_level`` recipe has no ``target`` turn —
|
||||
its supervision is the action-expert flow loss, not text-CE. It must
|
||||
still render (not ``None``), otherwise every blend draw of it is dropped
|
||||
and the action expert never receives a flow loss."""
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content="${subtask}",
|
||||
stream="low_level",
|
||||
if_present="subtask",
|
||||
),
|
||||
],
|
||||
bindings={"subtask": "active_at(t, style=subtask)"},
|
||||
)
|
||||
|
||||
rendered = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=PERSISTENT,
|
||||
events=[],
|
||||
t=0.5,
|
||||
sample_idx=0,
|
||||
task="clean kitchen",
|
||||
)
|
||||
|
||||
assert rendered is not None
|
||||
assert rendered["messages"] == [{"role": "user", "content": "subtask 0"}]
|
||||
assert rendered["message_streams"] == ["low_level"]
|
||||
assert rendered["target_message_indices"] == []
|
||||
|
||||
|
||||
def test_vqa_frame_is_consumed_over_the_weighted_blend():
|
||||
"""A frame carrying a VQA annotation renders the ``ask_vqa*`` sub-recipe
|
||||
even when its blend weight is tiny — VQA annotations are sparse and must
|
||||
never be wasted on a subtask/action draw."""
|
||||
recipe = TrainingRecipe(
|
||||
blend={
|
||||
"high_level_subtask": TrainingRecipe(
|
||||
weight=0.99,
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="a subtask", stream="high_level", target=True),
|
||||
],
|
||||
),
|
||||
"ask_vqa_top": TrainingRecipe(
|
||||
weight=0.01,
|
||||
bindings={
|
||||
"vqa_query": "emitted_at(t, style=vqa, role=user, camera=observation.images.top)",
|
||||
"vqa": "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)",
|
||||
},
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query"
|
||||
),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${vqa}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="vqa",
|
||||
),
|
||||
],
|
||||
),
|
||||
}
|
||||
)
|
||||
# A frame WITH a vqa event renders VQA on every sample_idx, despite the
|
||||
# ask_vqa weight being only 0.01.
|
||||
for sample_idx in range(20):
|
||||
rendered = render_sample(
|
||||
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_1, t=1.0, sample_idx=sample_idx, task="x"
|
||||
)
|
||||
assert rendered["messages"][-1]["content"] == '{"count": 2}', sample_idx
|
||||
# A frame WITHOUT a vqa event falls back to the normal weighted blend.
|
||||
rendered = render_sample(recipe=recipe, persistent=PERSISTENT, events=[], t=1.0, sample_idx=0, task="x")
|
||||
assert rendered["messages"][-1]["content"] == "a subtask"
|
||||
|
||||
|
||||
def test_emitted_at_persistent_tolerates_small_timestamp_drift():
|
||||
"""Persistent ``emitted_at`` should match within EMITTED_AT_TOLERANCE_S
|
||||
so callers that derive ``t`` arithmetically (``frame_idx / fps``) still
|
||||
|
||||
@@ -25,7 +25,7 @@ from datasets import Dataset # noqa: E402
|
||||
from lerobot.datasets.io_utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler, WeightedEpisodeAwareSampler
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
|
||||
|
||||
def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]:
|
||||
@@ -137,49 +137,3 @@ def test_partial_episode_drop_warns(caplog):
|
||||
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
|
||||
assert sampler.indices == [2, 3, 4, 5]
|
||||
assert "Episode 0" in caplog.text
|
||||
|
||||
|
||||
# --- WeightedEpisodeAwareSampler --------------------------------------------
|
||||
|
||||
|
||||
def test_weighted_sampler_respects_episode_drop_and_length():
|
||||
"""The episode-boundary frame filtering is applied before weighting,
|
||||
and one epoch still yields ``len(indices)`` samples."""
|
||||
# One episode, 10 frames; drop the last 2.
|
||||
sampler = WeightedEpisodeAwareSampler([0], [10], frame_weights=torch.ones(10), drop_n_last_frames=2)
|
||||
assert sampler.indices == list(range(8))
|
||||
assert len(sampler) == 8
|
||||
draws = list(sampler)
|
||||
assert len(draws) == 8
|
||||
# Dropped frames 8 and 9 must never be sampled.
|
||||
assert all(d in set(range(8)) for d in draws)
|
||||
|
||||
|
||||
def test_weighted_sampler_oversamples_high_weight_frames():
|
||||
"""A heavily-weighted frame dominates the draws."""
|
||||
torch.manual_seed(0)
|
||||
# 100 frames, frame 7 is weighted 1000x.
|
||||
weights = torch.ones(100)
|
||||
weights[7] = 1000.0
|
||||
sampler = WeightedEpisodeAwareSampler([0], [100], frame_weights=weights)
|
||||
counts = {}
|
||||
for _ in range(20): # 20 epochs
|
||||
for d in sampler:
|
||||
counts[d] = counts.get(d, 0) + 1
|
||||
total = sum(counts.values())
|
||||
# Frame 7 should be the overwhelming majority of the 2000 draws.
|
||||
assert counts.get(7, 0) / total > 0.9
|
||||
|
||||
|
||||
def test_weighted_sampler_zero_weights_fall_back_to_uniform():
|
||||
"""If every surviving frame has zero weight, sampling is uniform
|
||||
rather than crashing."""
|
||||
sampler = WeightedEpisodeAwareSampler([0], [6], frame_weights=torch.zeros(6))
|
||||
draws = set(sampler)
|
||||
assert draws.issubset(set(range(6)))
|
||||
assert len(list(sampler)) == 6
|
||||
|
||||
|
||||
def test_weighted_sampler_rejects_short_weight_vector():
|
||||
with pytest.raises(ValueError, match="frame_weights"):
|
||||
WeightedEpisodeAwareSampler([0], [10], frame_weights=torch.ones(5))
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Attention-masking tests for the PI052 (π0.5 v2) text head.
|
||||
|
||||
Regression coverage for the text-CE collapse bug: PaliGemma's
|
||||
``embed_prefix`` flags every language token ``att=0``, which
|
||||
``make_att_2d_masks`` turns into one fully *bidirectional* block. Under
|
||||
that mask the text cross-entropy degenerates into a copy task — a
|
||||
supervised target token attends to the tokens it is trained to predict —
|
||||
and the LM head never learns causal generation, so ``select_message``
|
||||
collapses at inference.
|
||||
|
||||
``_mark_target_span_causal`` sets ``att=1`` on the supervised target
|
||||
language positions so each target token attends causally among the
|
||||
targets while staying bidirectional to images + the user prompt. These
|
||||
tests pin that behaviour for the PaliGemma prefix layout.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# modeling_pi052 / modeling_pi05 import transformers transitively.
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.pi05.modeling_pi05 import make_att_2d_masks # noqa: E402
|
||||
from lerobot.policies.pi052.modeling_pi052 import _mark_target_span_causal, _shifted_ce # noqa: E402
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# A synthetic PI052 prefix layout: [images, prompt-lang, target-lang]
|
||||
#
|
||||
# indices 0-1 : 2 image tokens (att = 0)
|
||||
# indices 2-4 : 3 user-prompt lang (att = 0)
|
||||
# indices 5-8 : 4 supervised target lang(att = 0 from embed_prefix)
|
||||
#
|
||||
# ``text_labels`` covers the 7 language tokens; -100 on the prompt span,
|
||||
# real ids on the 4-token target span. PaliGemma's prefix has no state
|
||||
# token (unlike SmolVLA), so the lang span ends at the prefix end.
|
||||
# ---------------------------------------------------------------------------
|
||||
N_IMAGE = 2
|
||||
N_PROMPT = 3
|
||||
N_TARGET = 4
|
||||
LANG_START = N_IMAGE
|
||||
LANG_END = N_IMAGE + N_PROMPT + N_TARGET # = prefix length
|
||||
PREFIX_LEN = LANG_END
|
||||
|
||||
|
||||
def _embed_prefix_att_masks() -> torch.Tensor:
|
||||
"""Mimic PaliGemma ``embed_prefix``: images + lang all att=0."""
|
||||
return torch.zeros(1, PREFIX_LEN, dtype=torch.bool)
|
||||
|
||||
|
||||
def _text_labels() -> torch.Tensor:
|
||||
"""-100 over the prompt span, real ids over the target span."""
|
||||
labels = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long)
|
||||
labels[0, N_PROMPT:] = torch.arange(10, 10 + N_TARGET)
|
||||
return labels
|
||||
|
||||
|
||||
def _attends(prefix_att_masks: torch.Tensor) -> torch.Tensor:
|
||||
"""2D boolean attendance matrix; ``[i, j]`` True ⇒ i attends to j."""
|
||||
pad = torch.ones(1, PREFIX_LEN, dtype=torch.bool)
|
||||
return make_att_2d_masks(pad, prefix_att_masks)[0]
|
||||
|
||||
|
||||
def test_mark_sets_att_on_targets_only():
|
||||
"""Only the supervised target language positions flip to att=1."""
|
||||
marked = _mark_target_span_causal(
|
||||
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
|
||||
)
|
||||
expected = [False] * PREFIX_LEN
|
||||
for i in range(LANG_START + N_PROMPT, LANG_END): # target span
|
||||
expected[i] = True
|
||||
assert marked[0].tolist() == expected
|
||||
|
||||
|
||||
def test_target_tokens_attend_causally_among_themselves():
|
||||
"""A target token must NOT attend to later targets, but must attend
|
||||
to earlier ones — genuine causal next-token prediction."""
|
||||
marked = _mark_target_span_causal(
|
||||
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
|
||||
)
|
||||
attends = _attends(marked)
|
||||
tgt = range(LANG_START + N_PROMPT, LANG_END)
|
||||
for i in tgt:
|
||||
for j in tgt:
|
||||
if j > i:
|
||||
assert not attends[i, j], f"target {i} must not see future target {j}"
|
||||
else:
|
||||
assert attends[i, j], f"target {i} must see earlier/self target {j}"
|
||||
|
||||
|
||||
def test_target_tokens_attend_prompt_and_images_bidirectionally():
|
||||
"""Targets keep full visibility of images + the user prompt."""
|
||||
marked = _mark_target_span_causal(
|
||||
_embed_prefix_att_masks(), _text_labels(), LANG_START, LANG_END
|
||||
)
|
||||
attends = _attends(marked)
|
||||
context = list(range(0, LANG_START + N_PROMPT)) # images + prompt
|
||||
for i in range(LANG_START + N_PROMPT, LANG_END):
|
||||
for j in context:
|
||||
assert attends[i, j], f"target {i} must attend context {j}"
|
||||
|
||||
|
||||
def test_non_target_subtask_stays_bidirectional():
|
||||
"""A flow-only / non-target language span (all -100 labels) leaves the
|
||||
mask untouched — the action expert reads it bidirectionally."""
|
||||
all_ignored = torch.full((1, N_PROMPT + N_TARGET), -100, dtype=torch.long)
|
||||
marked = _mark_target_span_causal(
|
||||
_embed_prefix_att_masks(), all_ignored, LANG_START, LANG_END
|
||||
)
|
||||
assert torch.equal(marked, _embed_prefix_att_masks())
|
||||
|
||||
|
||||
def test_unmarked_mask_is_bidirectional_the_bug():
|
||||
"""Documents the bug the fix prevents: without ``_mark_target_span_causal``
|
||||
a target token attends *bidirectionally* to later targets — the
|
||||
text-CE can copy the answer it is trained to predict."""
|
||||
attends = _attends(_embed_prefix_att_masks())
|
||||
first_tgt = LANG_START + N_PROMPT
|
||||
last_tgt = LANG_END - 1
|
||||
assert attends[first_tgt, last_tgt], (
|
||||
"raw embed_prefix mask is bidirectional over language — the first "
|
||||
"target token can see the last, which is the collapse bug"
|
||||
)
|
||||
|
||||
|
||||
def test_shifted_ce_returns_zero_when_no_text_positions_are_supervised():
|
||||
logits = torch.randn(2, 4, 8, requires_grad=True)
|
||||
labels = torch.full((2, 4), -100, dtype=torch.long)
|
||||
|
||||
loss = _shifted_ce(logits, labels)
|
||||
|
||||
assert loss.item() == 0
|
||||
loss.backward()
|
||||
assert logits.grad is not None
|
||||
@@ -1,87 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Regression tests for PI052 FAST action-code supervision."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.pi052.modeling_pi052 import _fast_ce # noqa: E402
|
||||
|
||||
|
||||
def test_fast_ce_supervises_only_discrete_action_codes():
|
||||
"""Wrapper tokens can be wrong without affecting the FAST action-code loss."""
|
||||
vocab_size = 8
|
||||
action_tokens = torch.tensor([[1, 2, 3, 4, 5, 0]])
|
||||
action_code_mask = torch.tensor([[False, False, True, True, False, False]])
|
||||
|
||||
logits = torch.zeros(1, action_tokens.shape[1], vocab_size)
|
||||
# Deliberately bad wrapper-token predictions. These should be ignored.
|
||||
logits[0, 0, 7] = 10.0 # target would be token 2
|
||||
logits[0, 3, 7] = 10.0 # target would be delimiter token 5
|
||||
# Correct action-code predictions: hidden t predicts target t + 1.
|
||||
logits[0, 1, 3] = 10.0
|
||||
logits[0, 2, 4] = 10.0
|
||||
|
||||
loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions_t=None)
|
||||
expected = F.cross_entropy(
|
||||
torch.stack([logits[0, 1], logits[0, 2]]),
|
||||
torch.tensor([3, 4]),
|
||||
reduction="mean",
|
||||
)
|
||||
|
||||
assert torch.allclose(loss, expected)
|
||||
|
||||
|
||||
def test_fast_ce_masks_non_action_samples():
|
||||
"""Recipe samples with predict_actions=False do not contribute FAST loss."""
|
||||
vocab_size = 8
|
||||
action_tokens = torch.tensor([[1, 2, 3, 4], [1, 2, 5, 6]])
|
||||
action_code_mask = torch.tensor(
|
||||
[[False, False, True, True], [False, False, True, True]]
|
||||
)
|
||||
predict_actions = torch.tensor([True, False])
|
||||
|
||||
logits = torch.zeros(2, action_tokens.shape[1], vocab_size)
|
||||
logits[0, 1, 3] = 10.0
|
||||
logits[0, 2, 4] = 10.0
|
||||
# Bad predictions in the masked sample should not matter.
|
||||
logits[1, 1, 7] = 10.0
|
||||
logits[1, 2, 7] = 10.0
|
||||
|
||||
loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions)
|
||||
expected = F.cross_entropy(
|
||||
torch.stack([logits[0, 1], logits[0, 2]]),
|
||||
torch.tensor([3, 4]),
|
||||
reduction="mean",
|
||||
)
|
||||
|
||||
assert torch.allclose(loss, expected)
|
||||
|
||||
|
||||
def test_fast_ce_returns_zero_when_no_action_code_positions_are_valid():
|
||||
logits = torch.randn(2, 4, 8, requires_grad=True)
|
||||
action_tokens = torch.tensor([[1, 2, 3, 4], [1, 2, 5, 6]])
|
||||
action_code_mask = torch.zeros_like(action_tokens, dtype=torch.bool)
|
||||
|
||||
loss = _fast_ce(logits, action_tokens, action_code_mask, predict_actions_t=None)
|
||||
|
||||
assert loss.item() == 0
|
||||
loss.backward()
|
||||
assert logits.grad is not None
|
||||
@@ -1,155 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Numerical-parity tests for the SDPA attention port.
|
||||
|
||||
``pi05`` / ``pi052`` replaced the per-layer call from
|
||||
``modeling_gemma.eager_attention_forward`` with
|
||||
``sdpa_attention_forward`` (PyTorch SDPA + GQA repeat). The forward
|
||||
output must be bit-equivalent (within bf16 tolerance) on the masks
|
||||
this model actually uses — block-bidirectional with an arbitrary
|
||||
additive bias — otherwise we silently change training behaviour.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from transformers.models.gemma import modeling_gemma # noqa: E402
|
||||
|
||||
from lerobot.policies.pi05.modeling_pi05 import ( # noqa: E402
|
||||
make_att_2d_masks,
|
||||
sdpa_attention_forward,
|
||||
)
|
||||
from lerobot.utils.constants import OPENPI_ATTENTION_MASK_VALUE # noqa: E402
|
||||
|
||||
|
||||
def _mock_self_attn(num_kv_groups: int, training: bool = False):
|
||||
"""Bare module surface that both forwards read."""
|
||||
return SimpleNamespace(
|
||||
num_key_value_groups=num_kv_groups,
|
||||
training=training,
|
||||
)
|
||||
|
||||
|
||||
def _build_inputs(
|
||||
bsize: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
seq_len: int,
|
||||
head_dim: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int = 0,
|
||||
):
|
||||
g = torch.Generator(device="cpu").manual_seed(seed)
|
||||
q = torch.randn(bsize, num_heads, seq_len, head_dim, dtype=dtype, generator=g)
|
||||
k = torch.randn(bsize, num_kv_heads, seq_len, head_dim, dtype=dtype, generator=g)
|
||||
v = torch.randn(bsize, num_kv_heads, seq_len, head_dim, dtype=dtype, generator=g)
|
||||
return q, k, v
|
||||
|
||||
|
||||
def _block_bidirectional_mask(
|
||||
bsize: int, seq_len: int, block_sizes: list[int], dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
"""Mimic ``_prepare_attention_masks_4d`` on a block layout that
|
||||
matches ``[images, language, suffix]`` from ``embed_prefix`` +
|
||||
``embed_suffix``: every block bidirectional internally, later
|
||||
blocks visible to earlier ones via the cumulative-block rule.
|
||||
"""
|
||||
assert sum(block_sizes) == seq_len
|
||||
att_marks = []
|
||||
for i, n in enumerate(block_sizes):
|
||||
att_marks += [1 if i > 0 else 0] + [0] * (n - 1)
|
||||
pad = torch.ones(bsize, seq_len, dtype=torch.bool)
|
||||
att = torch.tensor(att_marks, dtype=torch.bool)[None].expand(bsize, seq_len)
|
||||
att_2d = make_att_2d_masks(pad, att)
|
||||
bias = torch.where(
|
||||
att_2d[:, None, :, :],
|
||||
torch.zeros((), dtype=dtype),
|
||||
torch.tensor(OPENPI_ATTENTION_MASK_VALUE, dtype=dtype),
|
||||
)
|
||||
return bias
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_heads,num_kv_heads,head_dim",
|
||||
[
|
||||
(8, 1, 256), # gemma_2b / paligemma config
|
||||
(8, 8, 64), # MHA control (no GQA repeat)
|
||||
],
|
||||
)
|
||||
def test_sdpa_parity_with_eager_block_bidirectional(num_heads, num_kv_heads, head_dim):
|
||||
"""SDPA forward output matches the eager softmax(QK^T)@V on the
|
||||
block-bidirectional mask layout pi05 actually uses."""
|
||||
bsize, seq_len = 2, 13
|
||||
block_sizes = [4, 5, 4] # images, language, suffix-style blocks
|
||||
dtype = torch.float32 # cpu math kernel — keep fp32 for tight tol
|
||||
scaling = head_dim ** -0.5
|
||||
|
||||
q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, dtype)
|
||||
mask = _block_bidirectional_mask(bsize, seq_len, block_sizes, dtype)
|
||||
|
||||
module = _mock_self_attn(num_heads // num_kv_heads)
|
||||
|
||||
out_eager, _ = modeling_gemma.eager_attention_forward(
|
||||
module, q, k, v, mask, scaling
|
||||
)
|
||||
out_sdpa, _ = sdpa_attention_forward(
|
||||
module, q, k, v, mask, scaling
|
||||
)
|
||||
assert out_eager.shape == out_sdpa.shape
|
||||
torch.testing.assert_close(out_sdpa, out_eager, atol=1e-5, rtol=1e-4)
|
||||
|
||||
|
||||
def test_sdpa_parity_bf16():
|
||||
"""bf16 path — looser tolerance, must still match eager."""
|
||||
bsize, num_heads, num_kv_heads, seq_len, head_dim = 2, 8, 1, 17, 256
|
||||
scaling = head_dim ** -0.5
|
||||
q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, torch.bfloat16)
|
||||
mask = _block_bidirectional_mask(bsize, seq_len, [5, 6, 6], torch.bfloat16)
|
||||
module = _mock_self_attn(num_heads // num_kv_heads)
|
||||
|
||||
out_eager, _ = modeling_gemma.eager_attention_forward(
|
||||
module, q, k, v, mask, scaling
|
||||
)
|
||||
out_sdpa, _ = sdpa_attention_forward(
|
||||
module, q, k, v, mask, scaling
|
||||
)
|
||||
torch.testing.assert_close(out_sdpa, out_eager, atol=2e-2, rtol=2e-2)
|
||||
|
||||
|
||||
def test_sdpa_parity_backward():
|
||||
"""Gradients flow through SDPA and match the eager path within
|
||||
bf16 tolerance — critical for any training-side parity claim."""
|
||||
bsize, num_heads, num_kv_heads, seq_len, head_dim = 1, 4, 2, 9, 32
|
||||
scaling = head_dim ** -0.5
|
||||
q, k, v = _build_inputs(bsize, num_heads, num_kv_heads, seq_len, head_dim, torch.float32)
|
||||
q.requires_grad_(True); k.requires_grad_(True); v.requires_grad_(True)
|
||||
mask = _block_bidirectional_mask(bsize, seq_len, [3, 3, 3], torch.float32)
|
||||
module = _mock_self_attn(num_heads // num_kv_heads)
|
||||
|
||||
out_e, _ = modeling_gemma.eager_attention_forward(module, q, k, v, mask, scaling)
|
||||
g_q_e, g_k_e, g_v_e = torch.autograd.grad(out_e.sum(), [q, k, v])
|
||||
|
||||
out_s, _ = sdpa_attention_forward(module, q, k, v, mask, scaling)
|
||||
g_q_s, g_k_s, g_v_s = torch.autograd.grad(out_s.sum(), [q, k, v])
|
||||
|
||||
torch.testing.assert_close(g_q_s, g_q_e, atol=1e-5, rtol=1e-4)
|
||||
torch.testing.assert_close(g_k_s, g_k_e, atol=1e-5, rtol=1e-4)
|
||||
torch.testing.assert_close(g_v_s, g_v_e, atol=1e-5, rtol=1e-4)
|
||||
@@ -1,196 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for PI052's text tokenizer.
|
||||
|
||||
Covers ``say`` tool-call flattening (PaliGemma's flat prompt has no
|
||||
structured tool calls, so a ``say`` call must be serialized into a
|
||||
``<say>...</say>`` text marker) and EOS-termination supervision (the
|
||||
supervised target span must end with an EOS token so the LM head learns
|
||||
to stop instead of rambling to ``max_length`` at inference).
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.pi052.text_processor_pi052 import (
|
||||
PI052TextTokenizerStep,
|
||||
_flatten_say_tool_calls,
|
||||
_format_messages,
|
||||
)
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
|
||||
def _say_call(text):
|
||||
return {"type": "function", "function": {"name": "say", "arguments": {"text": text}}}
|
||||
|
||||
|
||||
def test_flatten_appends_say_marker_and_drops_tool_calls():
|
||||
msg = {"role": "assistant", "content": "Heading to the cube.", "tool_calls": [_say_call("On it!")]}
|
||||
out = _flatten_say_tool_calls(msg)
|
||||
assert "tool_calls" not in out
|
||||
assert out["content"] == "Heading to the cube.\n<say>On it!</say>"
|
||||
|
||||
|
||||
def test_flatten_marker_only_when_content_empty_or_none():
|
||||
out = _flatten_say_tool_calls({"role": "assistant", "tool_calls": [_say_call("hi")]})
|
||||
assert out["content"] == "<say>hi</say>"
|
||||
|
||||
|
||||
def test_flatten_accepts_json_string_arguments():
|
||||
call = {"type": "function", "function": {"name": "say", "arguments": '{"text": "hello there"}'}}
|
||||
out = _flatten_say_tool_calls({"role": "assistant", "content": "p", "tool_calls": [call]})
|
||||
assert out["content"] == "p\n<say>hello there</say>"
|
||||
|
||||
|
||||
def test_flatten_leaves_messages_without_tool_calls_untouched():
|
||||
msg = {"role": "assistant", "content": "just a plan"}
|
||||
assert _flatten_say_tool_calls(msg) == msg
|
||||
|
||||
|
||||
def test_flatten_drops_non_say_tool_calls_but_keeps_content():
|
||||
weather = {"type": "function", "function": {"name": "check_weather", "arguments": {}}}
|
||||
out = _flatten_say_tool_calls(
|
||||
{"role": "assistant", "content": "plan only", "tool_calls": [weather]}
|
||||
)
|
||||
assert out["content"] == "plan only"
|
||||
assert "tool_calls" not in out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EOS-termination supervision
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_format_messages_appends_eos_to_target_turns_only():
|
||||
msgs = [
|
||||
{"role": "user", "content": "pick cube"},
|
||||
{"role": "assistant", "content": "move to cube"},
|
||||
]
|
||||
prompt, spans = _format_messages(msgs, target_indices=[1], eos_token="<eos>")
|
||||
# EOS is appended to the supervised target (assistant) turn only.
|
||||
assert prompt == "User: pick cube\nAssistant: move to cube<eos>\n"
|
||||
# The user span is unchanged; the target span covers content + EOS.
|
||||
assert prompt[spans[0][0] : spans[0][1]] == "pick cube"
|
||||
assert prompt[spans[1][0] : spans[1][1]] == "move to cube<eos>"
|
||||
|
||||
|
||||
def test_format_messages_without_eos_args_is_unchanged():
|
||||
"""Inference callers omit target_indices / eos_token — no EOS baked in."""
|
||||
prompt, spans = _format_messages([{"role": "user", "content": "hi"}])
|
||||
assert prompt == "User: hi\n"
|
||||
assert prompt[spans[0][0] : spans[0][1]] == "hi"
|
||||
|
||||
|
||||
def _eos_char_id() -> int:
|
||||
"""Token id _CharTokenizer assigns to its 1-char EOS."""
|
||||
return ord("\x1f") % 251 + 1
|
||||
|
||||
|
||||
def test_pi052_text_tokenizer_supervises_eos_at_target_end():
|
||||
"""The appended EOS is the last supervised label on a target turn —
|
||||
that's the signal that teaches the LM head to stop. The trailing
|
||||
newline right after it stays unsupervised (-100)."""
|
||||
step = PI052TextTokenizerStep(max_length=64)
|
||||
step._tokenizer = _CharTokenizer()
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"messages": [
|
||||
{"role": "user", "content": "pick cube"},
|
||||
{"role": "assistant", "content": "move to cube"},
|
||||
],
|
||||
"target_message_indices": [1],
|
||||
"message_streams": ["high_level", "high_level"],
|
||||
"index": torch.tensor(10),
|
||||
},
|
||||
}
|
||||
out = step(transition)
|
||||
ids = out[TransitionKey.OBSERVATION][OBS_LANGUAGE_TOKENS][0]
|
||||
labels = out[TransitionKey.COMPLEMENTARY_DATA]["text_labels"][0]
|
||||
|
||||
supervised = (labels != -100).nonzero().flatten().tolist()
|
||||
assert supervised, "target turn produced no supervised labels"
|
||||
last = supervised[-1]
|
||||
# The last supervised token is the appended EOS.
|
||||
assert int(ids[last]) == _eos_char_id()
|
||||
assert int(labels[last]) == _eos_char_id()
|
||||
# The token right after the EOS (the trailing newline) is NOT supervised.
|
||||
assert int(labels[last + 1]) == -100
|
||||
|
||||
|
||||
class _CharTokenizer:
|
||||
pad_token_id = 0
|
||||
eos_token = "\x1f" # unit separator — a 1-char "EOS" for testing
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text,
|
||||
max_length,
|
||||
padding,
|
||||
truncation,
|
||||
return_tensors,
|
||||
return_offsets_mapping,
|
||||
padding_side,
|
||||
):
|
||||
ids = [ord(c) % 251 + 1 for c in text[:max_length]]
|
||||
offsets = [(i, i + 1) for i in range(len(ids))]
|
||||
attention = [1] * len(ids)
|
||||
if padding == "max_length" and len(ids) < max_length:
|
||||
pad = max_length - len(ids)
|
||||
ids += [self.pad_token_id] * pad
|
||||
offsets += [(0, 0)] * pad
|
||||
attention += [0] * pad
|
||||
return {
|
||||
"input_ids": torch.tensor([ids], dtype=torch.long),
|
||||
"attention_mask": torch.tensor([attention], dtype=torch.long),
|
||||
"offset_mapping": torch.tensor([offsets], dtype=torch.long),
|
||||
}
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=False):
|
||||
return "".join(chr(max(int(i) - 1, 0)) for i in token_ids if int(i) != self.pad_token_id)
|
||||
|
||||
|
||||
def test_pi052_text_tokenizer_handles_batched_rendered_messages():
|
||||
step = PI052TextTokenizerStep(max_length=64)
|
||||
step._tokenizer = _CharTokenizer()
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {
|
||||
"messages": [
|
||||
[
|
||||
{"role": "user", "content": "pick cube"},
|
||||
{"role": "assistant", "content": "move to cube"},
|
||||
],
|
||||
[{"role": "user", "content": "open drawer"}],
|
||||
],
|
||||
"target_message_indices": [[1], []],
|
||||
"message_streams": [["high_level", "high_level"], ["low_level"]],
|
||||
"index": torch.tensor([10, 11]),
|
||||
},
|
||||
}
|
||||
|
||||
out = step(transition)
|
||||
obs = out[TransitionKey.OBSERVATION]
|
||||
comp = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert obs[OBS_LANGUAGE_TOKENS].shape == (2, 64)
|
||||
assert obs[OBS_LANGUAGE_ATTENTION_MASK].shape == (2, 64)
|
||||
assert comp["text_labels"].shape == (2, 64)
|
||||
assert comp["predict_actions"].tolist() == [False, True]
|
||||
assert (comp["text_labels"][0] != -100).any()
|
||||
assert not (comp["text_labels"][1] != -100).any()
|
||||
@@ -1,187 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Training-side conversion of VQA answers to PaliGemma ``<loc>`` text.
|
||||
|
||||
PI052 trains spatial VQA answers (``bbox`` / ``keypoint``) in
|
||||
PaliGemma's native ``<locNNNN>`` detection vocabulary so the LM head
|
||||
reuses the detection prior instead of fighting it (the ``<loc>``-salad
|
||||
bug). The dataset stores Qwen2.5-VL's grounding output — **0–1000
|
||||
normalized** coordinates, *not* pixels. (Verified empirically on the
|
||||
published datasets: x and y both span 0..1000 with ~30% of values
|
||||
exceeding the camera's pixel dimensions.) The conversion is therefore
|
||||
camera-resolution-independent. The dataset stays backbone-agnostic
|
||||
JSON; the conversion lives in PI052's tokenizer. These tests pin the
|
||||
JSON → ``<loc>`` rewrite.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: E402
|
||||
_loc_token,
|
||||
_messages_vqa_to_loc,
|
||||
_vqa_answer_to_loc,
|
||||
register_paligemma_loc_tokens,
|
||||
)
|
||||
|
||||
|
||||
class _FakeTokenizer:
|
||||
"""Tracks ``add_tokens`` calls; mimics the bits ``register_paligemma_loc_tokens`` reads."""
|
||||
|
||||
def __init__(self, prepopulate: bool = False):
|
||||
self.added_tokens_encoder: dict[str, int] = {}
|
||||
self.calls: list[list[str]] = []
|
||||
if prepopulate:
|
||||
self.added_tokens_encoder["<loc0000>"] = 256000
|
||||
|
||||
def add_tokens(self, tokens: list[str]) -> int:
|
||||
self.calls.append(list(tokens))
|
||||
for t in tokens:
|
||||
self.added_tokens_encoder.setdefault(t, len(self.added_tokens_encoder) + 256000)
|
||||
return len(tokens)
|
||||
|
||||
|
||||
def test_register_loc_tokens_adds_full_1024_range():
|
||||
tok = _FakeTokenizer()
|
||||
out = register_paligemma_loc_tokens(tok)
|
||||
assert out is tok # returns same instance
|
||||
assert len(tok.calls) == 1
|
||||
added = tok.calls[0]
|
||||
assert len(added) == 1024
|
||||
assert added[0] == "<loc0000>"
|
||||
assert added[-1] == "<loc1023>"
|
||||
# Spot check a few in the middle.
|
||||
assert added[162] == "<loc0162>"
|
||||
assert added[759] == "<loc0759>"
|
||||
|
||||
|
||||
def test_register_loc_tokens_is_idempotent():
|
||||
"""If the loc tokens are already present we skip re-adding them."""
|
||||
tok = _FakeTokenizer(prepopulate=True)
|
||||
register_paligemma_loc_tokens(tok)
|
||||
register_paligemma_loc_tokens(tok)
|
||||
assert tok.calls == [] # never called add_tokens
|
||||
|
||||
|
||||
def test_loc_token_normalizes_and_clamps():
|
||||
# Default scale is the 0–1000 Qwen convention.
|
||||
assert _loc_token(0) == "<loc0000>"
|
||||
assert _loc_token(1000) == "<loc1023>"
|
||||
assert _loc_token(500) == f"<loc{round(500 / 1000 * 1023):04d}>"
|
||||
# out-of-range coordinates clamp into [0, 1023]
|
||||
assert _loc_token(9999) == "<loc1023>"
|
||||
assert _loc_token(-5) == "<loc0000>"
|
||||
|
||||
|
||||
def test_vqa_answer_to_loc_keypoint_normalized():
|
||||
# Label-first: avoids the "Assistant: → <loc>" attractor at training.
|
||||
answer = {"label": "blue cube", "point_format": "xy", "point": [500, 500]}
|
||||
assert _vqa_answer_to_loc(answer) == "blue cube <loc0512><loc0512>"
|
||||
|
||||
|
||||
def test_vqa_answer_to_loc_bbox_normalized():
|
||||
answer = {
|
||||
"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [0, 0, 1000, 1000]}]
|
||||
}
|
||||
assert _vqa_answer_to_loc(answer) == "cube <loc0000><loc0000><loc1023><loc1023>"
|
||||
|
||||
|
||||
def test_vqa_answer_to_loc_multiple_detections_separator():
|
||||
answer = {
|
||||
"detections": [
|
||||
{"label": "blue", "bbox_format": "xyxy", "bbox": [0, 0, 500, 500]},
|
||||
{"label": "yellow", "bbox_format": "xyxy", "bbox": [500, 500, 1000, 1000]},
|
||||
]
|
||||
}
|
||||
out = _vqa_answer_to_loc(answer)
|
||||
# Each segment is "label <locs>", joined by " ; "
|
||||
assert out == (
|
||||
"blue <loc0000><loc0000><loc0512><loc0512> ; "
|
||||
"yellow <loc0512><loc0512><loc1023><loc1023>"
|
||||
)
|
||||
|
||||
|
||||
def test_vqa_answer_to_loc_returns_none_for_non_spatial():
|
||||
assert _vqa_answer_to_loc({"label": "cubes", "count": 2}) is None
|
||||
assert _vqa_answer_to_loc({"weird": "payload"}) is None
|
||||
|
||||
|
||||
def test_messages_vqa_to_loc_rewrites_target_turn():
|
||||
messages = [
|
||||
{"role": "user", "content": [{"type": "text", "text": "where is the cube?"}]},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": '{"label": "cube", "point_format": "xy", "point": [500, 500]}',
|
||||
},
|
||||
]
|
||||
out = _messages_vqa_to_loc(messages, target_indices=[1])
|
||||
assert out[1]["content"] == "cube <loc0512><loc0512>"
|
||||
# input messages are not mutated
|
||||
assert messages[1]["content"].startswith("{")
|
||||
|
||||
|
||||
def test_messages_vqa_to_loc_leaves_plain_text_targets_untouched():
|
||||
messages = [
|
||||
{"role": "user", "content": "pick the cube"},
|
||||
{"role": "assistant", "content": "pick up the cube"},
|
||||
]
|
||||
out = _messages_vqa_to_loc(messages, target_indices=[1])
|
||||
assert out[1]["content"] == "pick up the cube"
|
||||
|
||||
|
||||
def test_messages_vqa_to_loc_noop_without_target_indices():
|
||||
messages = [
|
||||
{"role": "assistant", "content": '{"label": "c", "point_format": "xy", "point": [1, 2]}'}
|
||||
]
|
||||
assert _messages_vqa_to_loc(messages, []) is messages
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Round-trip: training-side JSON -> <loc> -> runtime-side parse back
|
||||
#
|
||||
# Pins that the conversion preserves coordinate *order* (JSON is x-first,
|
||||
# PaliGemma <loc> is y-first) and the 0–1000 → [0, 1023] scaling. The
|
||||
# only loss is quantization to the 1024-bucket <loc> grid, so a coord
|
||||
# survives within half a bucket (~1000/2046 ≈ 0.49 on the 0–1000 scale).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_loc_round_trip_keypoint_preserves_normalized_coords():
|
||||
from lerobot.policies.pi052.inference.vqa import parse_vqa_answer
|
||||
|
||||
answer = {"label": "blue cube", "point_format": "xy", "point": [640, 480]}
|
||||
loc = _vqa_answer_to_loc(answer)
|
||||
parsed = parse_vqa_answer(loc)
|
||||
nx, ny = parsed["payload"]["point"]
|
||||
# parse_vqa_answer returns [0, 1] normalized; rescale back to 0–1000.
|
||||
assert abs(nx * 1000.0 - 640) <= 1000.0 / 2046 + 1e-6
|
||||
assert abs(ny * 1000.0 - 480) <= 1000.0 / 2046 + 1e-6
|
||||
assert parsed["payload"]["label"] == "blue cube"
|
||||
|
||||
|
||||
def test_loc_round_trip_bbox_preserves_order_and_scale():
|
||||
from lerobot.policies.pi052.inference.vqa import parse_vqa_answer
|
||||
|
||||
answer = {
|
||||
"detections": [{"label": "cube", "bbox_format": "xyxy", "bbox": [100, 200, 800, 900]}]
|
||||
}
|
||||
loc = _vqa_answer_to_loc(answer)
|
||||
parsed = parse_vqa_answer(loc)
|
||||
x1, y1, x2, y2 = parsed["payload"]["detections"][0]["bbox"]
|
||||
for got, want in ((x1, 100), (y1, 200), (x2, 800), (y2, 900)):
|
||||
assert abs(got * 1000.0 - want) <= 1000.0 / 2046 + 1e-6
|
||||
@@ -58,70 +58,3 @@ def test_render_messages_step_renders_and_drops_raw_language():
|
||||
assert data["messages"][-1]["content"] == "reach carefully"
|
||||
assert data["message_streams"] == ["high_level", "low_level"]
|
||||
assert data["target_message_indices"] == [1]
|
||||
|
||||
|
||||
def test_render_messages_step_falls_back_to_low_level_task_when_recipe_misses():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${subtask}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="subtask",
|
||||
),
|
||||
]
|
||||
)
|
||||
transition = create_transition(
|
||||
complementary_data={
|
||||
"task": "pick the cube",
|
||||
"timestamp": torch.tensor(0.0),
|
||||
"index": torch.tensor(7),
|
||||
"language_persistent": [],
|
||||
"language_events": [{"style": "unmatched", "timestamp": 0.0}],
|
||||
}
|
||||
)
|
||||
|
||||
out = RenderMessagesStep(recipe)(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert data["messages"] == [{"role": "user", "content": "pick the cube"}]
|
||||
assert data["message_streams"] == ["low_level"]
|
||||
assert data["target_message_indices"] == []
|
||||
|
||||
|
||||
def test_render_messages_step_falls_back_per_sample_in_batched_language():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${subtask}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="subtask",
|
||||
),
|
||||
]
|
||||
)
|
||||
transition = create_transition(
|
||||
action=torch.arange(4).reshape(2, 2),
|
||||
complementary_data={
|
||||
"task": ["pick the cube", "open the drawer"],
|
||||
"timestamp": torch.tensor([0.0, 1.0]),
|
||||
"index": torch.tensor([7, 8]),
|
||||
"language_persistent": [[], []],
|
||||
"language_events": [
|
||||
[{"style": "unmatched", "timestamp": 0.0}],
|
||||
[{"style": "unmatched", "timestamp": 1.0}],
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
out = RenderMessagesStep(recipe)(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert data["messages"] == [
|
||||
[{"role": "user", "content": "pick the cube"}],
|
||||
[{"role": "user", "content": "open the drawer"}],
|
||||
]
|
||||
assert data["message_streams"] == [["low_level"], ["low_level"]]
|
||||
assert data["target_message_indices"] == [[], []]
|
||||
|
||||
81
uv.lock
generated
81
uv.lock
generated
@@ -416,15 +416,6 @@ dependencies = [
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5c/37/0211f82891a9f14efcfd2b2096f8d9e4351398ad637fdd1ee59cfc580b0e/bddl-1.0.1.tar.gz", hash = "sha256:1fa4e6e5050b93888ff6fd8455c39bfb29d3864ce06b4c37c0f781f513a2ae26", size = 164809, upload-time = "2022-03-08T01:48:23.564Z" }
|
||||
|
||||
[[package]]
|
||||
name = "beartype"
|
||||
version = "0.22.9"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c7/94/1009e248bbfbab11397abca7193bea6626806be9a327d399810d523a07cb/beartype-0.22.9.tar.gz", hash = "sha256:8f82b54aa723a2848a56008d18875f91c1db02c32ef6a62319a002e3e25a975f", size = 1608866, upload-time = "2025-12-13T06:50:30.72Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/71/cc/18245721fa7747065ab478316c7fea7c74777d07f37ae60db2e84f8172e8/beartype-0.22.9-py3-none-any.whl", hash = "sha256:d16c9bbc61ea14637596c5f6fbff2ee99cbe3573e46a716401734ef50c3060c2", size = 1333658, upload-time = "2025-12-13T06:50:28.266Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "beautifulsoup4"
|
||||
version = "4.14.3"
|
||||
@@ -3203,7 +3194,6 @@ all = [
|
||||
{ name = "ruff" },
|
||||
{ name = "scikit-image" },
|
||||
{ name = "scipy" },
|
||||
{ name = "sentencepiece" },
|
||||
{ name = "teleop" },
|
||||
{ name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" },
|
||||
{ name = "torchdiffeq" },
|
||||
@@ -3425,7 +3415,6 @@ phone = [
|
||||
]
|
||||
pi = [
|
||||
{ name = "scipy" },
|
||||
{ name = "sentencepiece" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
placo-dep = [
|
||||
@@ -3477,9 +3466,6 @@ sarm = [
|
||||
scipy-dep = [
|
||||
{ name = "scipy" },
|
||||
]
|
||||
sentencepiece-dep = [
|
||||
{ name = "sentencepiece" },
|
||||
]
|
||||
smolvla = [
|
||||
{ name = "accelerate" },
|
||||
{ name = "num2words" },
|
||||
@@ -3491,10 +3477,6 @@ test = [
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "pytest-timeout" },
|
||||
]
|
||||
tools = [
|
||||
{ name = "pocket-tts" },
|
||||
{ name = "scipy" },
|
||||
]
|
||||
training = [
|
||||
{ name = "accelerate" },
|
||||
{ name = "av" },
|
||||
@@ -3652,7 +3634,6 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'phone'" },
|
||||
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'pi'" },
|
||||
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'wallx'" },
|
||||
{ name = "lerobot", extras = ["sentencepiece-dep"], marker = "extra == 'pi'" },
|
||||
{ name = "lerobot", extras = ["smolvla"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["test"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["training"], marker = "extra == 'all'" },
|
||||
@@ -3694,7 +3675,6 @@ requires-dist = [
|
||||
{ name = "peft", marker = "extra == 'peft-dep'", specifier = ">=0.18.0,<1.0.0" },
|
||||
{ name = "pillow", specifier = ">=10.0.0,<13.0.0" },
|
||||
{ name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.16" },
|
||||
{ name = "pocket-tts", marker = "extra == 'tools'", specifier = ">=1.0.0,<3.0.0" },
|
||||
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.7.0,<5.0.0" },
|
||||
{ name = "protobuf", marker = "extra == 'grpcio-dep'", specifier = ">=6.31.1,<6.32.0" },
|
||||
{ name = "pyarrow", marker = "extra == 'dataset'", specifier = ">=21.0.0,<30.0.0" },
|
||||
@@ -3719,8 +3699,6 @@ requires-dist = [
|
||||
{ name = "scikit-image", marker = "extra == 'video-benchmark'", specifier = ">=0.23.2,<0.26.0" },
|
||||
{ name = "scipy", marker = "extra == 'all'", specifier = ">=1.14.0,<2.0.0" },
|
||||
{ name = "scipy", marker = "extra == 'scipy-dep'", specifier = ">=1.14.0,<2.0.0" },
|
||||
{ name = "scipy", marker = "extra == 'tools'", specifier = ">=1.11.0,<2.0.0" },
|
||||
{ name = "sentencepiece", marker = "extra == 'sentencepiece-dep'", specifier = ">=0.2.0,<0.3.0" },
|
||||
{ name = "setuptools", specifier = ">=71.0.0,<81.0.0" },
|
||||
{ name = "teleop", marker = "extra == 'phone'", specifier = ">=0.1.0,<0.2.0" },
|
||||
{ name = "termcolor", specifier = ">=2.4.0,<4.0.0" },
|
||||
@@ -3738,7 +3716,7 @@ requires-dist = [
|
||||
{ name = "vllm", marker = "sys_platform == 'linux' and extra == 'annotations'", specifier = ">=0.6.0,<1.0.0" },
|
||||
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
|
||||
]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "sentencepiece-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "eo1", "hilserl", "async", "peft", "annotations", "tools", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "eo1", "hilserl", "async", "peft", "annotations", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
|
||||
[[package]]
|
||||
name = "librt"
|
||||
@@ -5256,33 +5234,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pocket-tts"
|
||||
version = "2.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "beartype" },
|
||||
{ name = "einops" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "huggingface-hub" },
|
||||
{ name = "numpy" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "python-multipart" },
|
||||
{ name = "requests" },
|
||||
{ name = "safetensors" },
|
||||
{ name = "scipy" },
|
||||
{ name = "sentencepiece" },
|
||||
{ name = "torch", version = "2.8.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux'" },
|
||||
{ name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux'" },
|
||||
{ name = "typer" },
|
||||
{ name = "typing-extensions" },
|
||||
{ name = "uvicorn" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f9/2c/7445f57163bb40e2b2fab4df70d18a4216c4965cdf74196344d95859fc07/pocket_tts-2.1.0.tar.gz", hash = "sha256:6f244f445413400f686506f5ccfb75048547caab7b455b927f4a854c551c60a8", size = 642108, upload-time = "2026-05-04T14:00:29.207Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cf/63/d16958d388efee3f0fc7287e1418ed652ddbc2b61ff4f581f0ad0abce625/pocket_tts-2.1.0-py3-none-any.whl", hash = "sha256:7b8f01d3e52aa7df84887b711994586bdc875e024a8b40a15f757feeeb29f752", size = 68096, upload-time = "2026-05-04T14:00:27.547Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pre-commit"
|
||||
version = "4.6.0"
|
||||
@@ -6868,46 +6819,16 @@ version = "0.2.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/15/15/2e7a025fc62d764b151ae6d0f2a92f8081755ebe8d4a64099accc6f77ba6/sentencepiece-0.2.1.tar.gz", hash = "sha256:8138cec27c2f2282f4a34d9a016e3374cd40e5c6e9cb335063db66a0a3b71fad", size = 3228515, upload-time = "2025-08-12T07:00:51.718Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4a/be/32ce495aa1d0e0c323dcb1ba87096037358edee539cac5baf8755a6bd396/sentencepiece-0.2.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:57cae326c8727de58c85977b175af132a7138d84c764635d7e71bbee7e774133", size = 1943152, upload-time = "2025-08-12T06:59:40.048Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/88/7e/ff23008899a58678e98c6ff592bf4d368eee5a71af96d0df6b38a039dd4f/sentencepiece-0.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:56dd39a3c4d6493db3cdca7e8cc68c6b633f0d4195495cbadfcf5af8a22d05a6", size = 1325651, upload-time = "2025-08-12T06:59:41.536Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/84/42eb3ce4796777a1b5d3699dfd4dca85113e68b637f194a6c8d786f16a04/sentencepiece-0.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d9381351182ff9888cc80e41c632e7e274b106f450de33d67a9e8f6043da6f76", size = 1253645, upload-time = "2025-08-12T06:59:42.903Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/89/fa/d3d5ebcba3cb9e6d3775a096251860c41a6bc53a1b9461151df83fe93255/sentencepiece-0.2.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:99f955df238021bf11f0fc37cdb54fd5e5b5f7fd30ecc3d93fb48b6815437167", size = 1316273, upload-time = "2025-08-12T06:59:44.476Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/88/14f2f4a2b922d8b39be45bf63d79e6cd3a9b2f248b2fcb98a69b12af12f5/sentencepiece-0.2.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0cdfecef430d985f1c2bcbfff3defd1d95dae876fbd0173376012d2d7d24044b", size = 1387881, upload-time = "2025-08-12T06:59:46.09Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/b8/903e5ccb77b4ef140605d5d71b4f9e0ad95d456d6184688073ed11712809/sentencepiece-0.2.1-cp312-cp312-win32.whl", hash = "sha256:a483fd29a34c3e34c39ac5556b0a90942bec253d260235729e50976f5dba1068", size = 999540, upload-time = "2025-08-12T06:59:48.023Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2d/81/92df5673c067148c2545b1bfe49adfd775bcc3a169a047f5a0e6575ddaca/sentencepiece-0.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:4cdc7c36234fda305e85c32949c5211faaf8dd886096c7cea289ddc12a2d02de", size = 1054671, upload-time = "2025-08-12T06:59:49.895Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fe/02/c5e3bc518655d714622bec87d83db9cdba1cd0619a4a04e2109751c4f47f/sentencepiece-0.2.1-cp312-cp312-win_arm64.whl", hash = "sha256:daeb5e9e9fcad012324807856113708614d534f596d5008638eb9b40112cd9e4", size = 1033923, upload-time = "2025-08-12T06:59:51.952Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ba/4a/85fbe1706d4d04a7e826b53f327c4b80f849cf1c7b7c5e31a20a97d8f28b/sentencepiece-0.2.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dcd8161eee7b41aae57ded06272905dbd680a0a04b91edd0f64790c796b2f706", size = 1943150, upload-time = "2025-08-12T06:59:53.588Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/83/4cfb393e287509fc2155480b9d184706ef8d9fa8cbf5505d02a5792bf220/sentencepiece-0.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c6c8f42949f419ff8c7e9960dbadcfbc982d7b5efc2f6748210d3dd53a7de062", size = 1325651, upload-time = "2025-08-12T06:59:55.073Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8d/de/5a007fb53b1ab0aafc69d11a5a3dd72a289d5a3e78dcf2c3a3d9b14ffe93/sentencepiece-0.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:097f3394e99456e9e4efba1737c3749d7e23563dd1588ce71a3d007f25475fff", size = 1253641, upload-time = "2025-08-12T06:59:56.562Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/d2/f552be5928105588f4f4d66ee37dd4c61460d8097e62d0e2e0eec41bc61d/sentencepiece-0.2.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d7b670879c370d350557edabadbad1f6561a9e6968126e6debca4029e5547820", size = 1316271, upload-time = "2025-08-12T06:59:58.109Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/96/df/0cfe748ace5485be740fed9476dee7877f109da32ed0d280312c94ec259f/sentencepiece-0.2.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c7f0fd2f2693309e6628aeeb2e2faf6edd221134dfccac3308ca0de01f8dab47", size = 1387882, upload-time = "2025-08-12T07:00:00.701Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/dd/f7774d42a881ced8e1739f393ab1e82ece39fc9abd4779e28050c2e975b5/sentencepiece-0.2.1-cp313-cp313-win32.whl", hash = "sha256:92b3816aa2339355fda2c8c4e021a5de92180b00aaccaf5e2808972e77a4b22f", size = 999541, upload-time = "2025-08-12T07:00:02.709Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dd/e9/932b9eae6fd7019548321eee1ab8d5e3b3d1294df9d9a0c9ac517c7b636d/sentencepiece-0.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:10ed3dab2044c47f7a2e7b4969b0c430420cdd45735d78c8f853191fa0e3148b", size = 1054669, upload-time = "2025-08-12T07:00:04.915Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/3a/76488a00ea7d6931689cda28726a1447d66bf1a4837943489314593d5596/sentencepiece-0.2.1-cp313-cp313-win_arm64.whl", hash = "sha256:ac650534e2251083c5f75dde4ff28896ce7c8904133dc8fef42780f4d5588fcd", size = 1033922, upload-time = "2025-08-12T07:00:06.496Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4a/b6/08fe2ce819e02ccb0296f4843e3f195764ce9829cbda61b7513f29b95718/sentencepiece-0.2.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:8dd4b477a7b069648d19363aad0cab9bad2f4e83b2d179be668efa672500dc94", size = 1946052, upload-time = "2025-08-12T07:00:08.136Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/d9/1ea0e740591ff4c6fc2b6eb1d7510d02f3fb885093f19b2f3abd1363b402/sentencepiece-0.2.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0c0f672da370cc490e4c59d89e12289778310a0e71d176c541e4834759e1ae07", size = 1327408, upload-time = "2025-08-12T07:00:09.572Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/99/7e/1fb26e8a21613f6200e1ab88824d5d203714162cf2883248b517deb500b7/sentencepiece-0.2.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:ad8493bea8432dae8d6830365352350f3b4144415a1d09c4c8cb8d30cf3b6c3c", size = 1254857, upload-time = "2025-08-12T07:00:11.021Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/85/c72fd1f3c7a6010544d6ae07f8ddb38b5e2a7e33bd4318f87266c0bbafbf/sentencepiece-0.2.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b81a24733726e3678d2db63619acc5a8dccd074f7aa7a54ecd5ca33ca6d2d596", size = 1315722, upload-time = "2025-08-12T07:00:12.989Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4a/e8/661e5bd82a8aa641fd6c1020bd0e890ef73230a2b7215ddf9c8cd8e941c2/sentencepiece-0.2.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0a81799d0a68d618e89063fb423c3001a034c893069135ffe51fee439ae474d6", size = 1387452, upload-time = "2025-08-12T07:00:15.088Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/99/5e/ae66c361023a470afcbc1fbb8da722c72ea678a2fcd9a18f1a12598c7501/sentencepiece-0.2.1-cp313-cp313t-win32.whl", hash = "sha256:89a3ea015517c42c0341d0d962f3e6aaf2cf10d71b1932d475c44ba48d00aa2b", size = 1002501, upload-time = "2025-08-12T07:00:16.966Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/03/d332828c4ff764e16c1b56c2c8f9a33488bbe796b53fb6b9c4205ddbf167/sentencepiece-0.2.1-cp313-cp313t-win_amd64.whl", hash = "sha256:33f068c9382dc2e7c228eedfd8163b52baa86bb92f50d0488bf2b7da7032e484", size = 1057555, upload-time = "2025-08-12T07:00:18.573Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/88/14/5aee0bf0864df9bd82bd59e7711362908e4935e3f9cdc1f57246b5d5c9b9/sentencepiece-0.2.1-cp313-cp313t-win_arm64.whl", hash = "sha256:b3616ad246f360e52c85781e47682d31abfb6554c779e42b65333d4b5f44ecc0", size = 1036042, upload-time = "2025-08-12T07:00:20.209Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/24/9c/89eb8b2052f720a612478baf11c8227dcf1dc28cd4ea4c0c19506b5af2a2/sentencepiece-0.2.1-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:5d0350b686c320068702116276cfb26c066dc7e65cfef173980b11bb4d606719", size = 1943147, upload-time = "2025-08-12T07:00:21.809Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/0b/a1432bc87f97c2ace36386ca23e8bd3b91fb40581b5e6148d24b24186419/sentencepiece-0.2.1-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:c7f54a31cde6fa5cb030370566f68152a742f433f8d2be458463d06c208aef33", size = 1325624, upload-time = "2025-08-12T07:00:23.289Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ea/99/bbe054ebb5a5039457c590e0a4156ed073fb0fe9ce4f7523404dd5b37463/sentencepiece-0.2.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:c83b85ab2d6576607f31df77ff86f28182be4a8de6d175d2c33ca609925f5da1", size = 1253670, upload-time = "2025-08-12T07:00:24.69Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/ad/d5c7075f701bd97971d7c2ac2904f227566f51ef0838dfbdfdccb58cd212/sentencepiece-0.2.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1855f57db07b51fb51ed6c9c452f570624d2b169b36f0f79ef71a6e6c618cd8b", size = 1316247, upload-time = "2025-08-12T07:00:26.435Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/03/35fbe5f3d9a7435eebd0b473e09584bd3cc354ce118b960445b060d33781/sentencepiece-0.2.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:01e6912125cb45d3792f530a4d38f8e21bf884d6b4d4ade1b2de5cf7a8d2a52b", size = 1387894, upload-time = "2025-08-12T07:00:28.339Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dc/aa/956ef729aafb6c8f9c443104c9636489093bb5c61d6b90fc27aa1a865574/sentencepiece-0.2.1-cp314-cp314-win32.whl", hash = "sha256:c415c9de1447e0a74ae3fdb2e52f967cb544113a3a5ce3a194df185cbc1f962f", size = 1096698, upload-time = "2025-08-12T07:00:29.764Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/cb/fe400d8836952cc535c81a0ce47dc6875160e5fedb71d2d9ff0e9894c2a6/sentencepiece-0.2.1-cp314-cp314-win_amd64.whl", hash = "sha256:881b2e44b14fc19feade3cbed314be37de639fc415375cefaa5bc81a4be137fd", size = 1155115, upload-time = "2025-08-12T07:00:32.865Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/32/89/047921cf70f36c7b6b6390876b2399b3633ab73b8d0cb857e5a964238941/sentencepiece-0.2.1-cp314-cp314-win_arm64.whl", hash = "sha256:2005242a16d2dc3ac5fe18aa7667549134d37854823df4c4db244752453b78a8", size = 1133890, upload-time = "2025-08-12T07:00:34.763Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a1/11/5b414b9fae6255b5fb1e22e2ed3dc3a72d3a694e5703910e640ac78346bb/sentencepiece-0.2.1-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:a19adcec27c524cb7069a1c741060add95f942d1cbf7ad0d104dffa0a7d28a2b", size = 1946081, upload-time = "2025-08-12T07:00:36.97Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/77/eb/7a5682bb25824db8545f8e5662e7f3e32d72a508fdce086029d89695106b/sentencepiece-0.2.1-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:e37e4b4c4a11662b5db521def4e44d4d30ae69a1743241412a93ae40fdcab4bb", size = 1327406, upload-time = "2025-08-12T07:00:38.669Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/03/b0/811dae8fb9f2784e138785d481469788f2e0d0c109c5737372454415f55f/sentencepiece-0.2.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:477c81505db072b3ab627e7eab972ea1025331bd3a92bacbf798df2b75ea86ec", size = 1254846, upload-time = "2025-08-12T07:00:40.611Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ef/23/195b2e7ec85ebb6a547969f60b723c7aca5a75800ece6cc3f41da872d14e/sentencepiece-0.2.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:010f025a544ef770bb395091d57cb94deb9652d8972e0d09f71d85d5a0816c8c", size = 1315721, upload-time = "2025-08-12T07:00:42.914Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7e/aa/553dbe4178b5f23eb28e59393dddd64186178b56b81d9b8d5c3ff1c28395/sentencepiece-0.2.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:733e59ff1794d26db706cd41fc2d7ca5f6c64a820709cb801dc0ea31780d64ab", size = 1387458, upload-time = "2025-08-12T07:00:44.56Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/66/7c/08ff0012507297a4dd74a5420fdc0eb9e3e80f4e88cab1538d7f28db303d/sentencepiece-0.2.1-cp314-cp314t-win32.whl", hash = "sha256:d3233770f78e637dc8b1fda2cd7c3b99ec77e7505041934188a4e7fe751de3b0", size = 1099765, upload-time = "2025-08-12T07:00:46.058Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/91/d5/2a69e1ce15881beb9ddfc7e3f998322f5cedcd5e4d244cb74dade9441663/sentencepiece-0.2.1-cp314-cp314t-win_amd64.whl", hash = "sha256:5e4366c97b68218fd30ea72d70c525e6e78a6c0a88650f57ac4c43c63b234a9d", size = 1157807, upload-time = "2025-08-12T07:00:47.673Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/16/54f611fcfc2d1c46cbe3ec4169780b2cfa7cf63708ef2b71611136db7513/sentencepiece-0.2.1-cp314-cp314t-win_arm64.whl", hash = "sha256:105e36e75cbac1292642045458e8da677b2342dcd33df503e640f0b457cb6751", size = 1136264, upload-time = "2025-08-12T07:00:49.485Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user