fix(pi052): four real bugs in the modeling code + flip defaults

Defaults
--------
* enable_fast_action_loss: False -> True   (match paper §III.B-C Eq.1)
* auto_fit_fast_tokenizer: True -> False   (opt-in; needs base.fit())

Bug fixes
---------

1. Wrong attribute path on PaliGemma. The KI port copied
   pi05_full's ``paligemma.language_model.layers[...]`` literally,
   but the production pi05 wrapper exposes the text model at
   ``paligemma.model.language_model``. With KI enabled, every layer
   would have raised AttributeError on first forward. Fixed all
   references in _compute_layer_ki + _paligemma_forward_ki.

2. ``fill_kv_cache=True`` passed to PaliGemmaWithExpertModel.forward.
   That kwarg is a SmolVLA-only concept; pi05's signature has no
   such argument, so every forward call from pi052 (text loss, FAST
   loss, select_message) would have crashed with TypeError. Dropped
   from all four call sites — pi05's forward already handles the
   cache via past_key_values, and re-forwarding the cumulative
   sequence each step in select_message is fine for our short
   subtask completions.

3. Text-loss shape mismatch. _compute_text_loss applied lm_head to
   the *full* vlm_out (image tokens + language tokens), then tried
   to cross-entropy that against text_labels which only covers the
   language portion — the .view(-1) calls would produce two
   tensors of different lengths and CE would fail. Now slices
   vlm_out to the last text_labels.shape[1] positions before
   running lm_head, matching the [images, language] order
   embed_prefix produces.

4. Dead-code conditional in _paligemma_forward_ki's single-expert
   fallback. The ``if hasattr(...) else self._pi052_orig_forward``
   ternary always took the wrong branch because the attribute is
   always set (we save it in PI052Policy.__init__). Simplified to
   just call self._pi052_orig_forward directly.

After this commit, pi052 should be runnable end-to-end for the
first time with all three loss heads + KI active. Still worth a
100-step smoke test before kicking off a long run.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-13 11:58:40 +02:00
parent 0f4faddc01
commit c8763e0ad5
2 changed files with 46 additions and 48 deletions

View File

@@ -107,12 +107,12 @@ class PI052Config(PI05Config):
# ActionTokenizerProcessorStep is wired into the preprocessor
# pipeline when this flag is set; the loss is computed in
# PI052Policy.forward.
enable_fast_action_loss: bool = False
enable_fast_action_loss: bool = True
"""If True, tokenise actions with the FAST tokenizer and add a
cross-entropy loss on the LM head. Off by default because most
fine-tuning runs only need the flow head + text supervision; the
FAST CE term is most useful when training from a base PaliGemma
rather than an existing π0.5 checkpoint."""
cross-entropy loss on the LM head. On by default to match the
π0.5 paper's three-loss objective (text CE + FAST CE + flow MSE,
§III.B-C Eq. 1). Set to False if you only want the
post-training-style flow + text recipe."""
action_tokenizer_name: str = "physical-intelligence/fast"
"""HF identifier for the FAST action tokenizer."""
@@ -127,15 +127,21 @@ class PI052Config(PI05Config):
fast_action_loss_weight: float = 1.0
"""Weight on the FAST-action-token CE loss. Paper §III.C uses 1.0."""
auto_fit_fast_tokenizer: bool = True
"""If True (default), the processor factory checks
``fast_tokenizer_cache_dir`` for a previously-fitted tokenizer keyed
on ``(dataset_repo_id, base_tokenizer_name, fit_samples)``. On cache
miss, it loads ``action_tokenizer_name`` as a base, samples
auto_fit_fast_tokenizer: bool = False
"""If True, the processor factory checks ``fast_tokenizer_cache_dir``
for a previously-fitted tokenizer keyed on ``(dataset_repo_id,
base_tokenizer_name, fit_samples)``. On cache miss, it loads
``action_tokenizer_name`` as a base, samples
``fast_tokenizer_fit_samples`` action chunks from the dataset, runs
``.fit()``, saves the result, and uses *that* fitted path as the
actual tokenizer. Pertsch et al. 2025 (FAST paper [64], π0.5 §III.C)
explicitly recommend per-dataset fitting for best compression."""
explicitly recommend per-dataset fitting for best compression.
Off by default because the fit requires a separate pre-training
pass over the dataset (~1-2 min on a medium dataset) and depends
on the FAST tokenizer snapshot having a ``.fit()`` method. Opt in
when you want paper-faithful compression; leave off to fall back
on the universal ``physical-intelligence/fast`` codebook."""
fast_tokenizer_cache_dir: str = "~/.cache/lerobot/fast_tokenizers"
"""Where fitted FAST tokenizers are stored. ``~`` expands."""

View File

@@ -76,7 +76,7 @@ def _compute_layer_ki(
):
from transformers.models.gemma import modeling_gemma # noqa: PLC0415
models = [paligemma.language_model, gemma_expert.model]
models = [paligemma.model.language_model, gemma_expert.model]
query_states, key_states, value_states, gates = [], [], [], []
vlm_len = inputs_embeds[0].shape[1]
@@ -111,7 +111,7 @@ def _compute_layer_ki(
)
batch_size = query_states.shape[0]
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
# Split queries / K / V at the VLM-vs-action boundary.
Q_vlm = query_states[:, :, :vlm_len, :]
@@ -133,16 +133,16 @@ def _compute_layer_ki(
mask_for_action = attention_mask[:, :, vlm_len:, :]
att_vlm, _ = modeling_gemma.eager_attention_forward(
paligemma.language_model.layers[layer_idx].self_attn,
paligemma.model.language_model.layers[layer_idx].self_attn,
Q_vlm, K_for_vlm, V_for_vlm, mask_for_vlm, scaling,
)
att_action, _ = modeling_gemma.eager_attention_forward(
paligemma.language_model.layers[layer_idx].self_attn,
paligemma.model.language_model.layers[layer_idx].self_attn,
Q_action, K_for_action, V_for_action, mask_for_action, scaling,
)
att = torch.cat([att_vlm, att_action], dim=1)
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
att = att.reshape(batch_size, -1, 1 * 8 * head_dim)
outputs_embeds = []
@@ -173,33 +173,24 @@ def _paligemma_forward_ki(
inputs_embeds=None,
use_cache=None,
adarms_cond=None,
fill_kv_cache=None,
):
"""Replacement ``PaliGemmaWithExpertModel.forward`` that routes the
dual-expert layer pass through :func:`_compute_layer_ki`.
Bound onto the model instance when ``config.knowledge_insulation``
is True (see ``PI052Policy.__init__``). Single-expert branches
(VLM-only or action-only) reuse the parent's implementation
because there's no KI signal to add — KI only matters when
actions and VLM tokens are forwarded together.
(VLM-only or action-only) defer back to the original forward —
KI only matters when actions and VLM tokens are forwarded together.
"""
from ..pi05.modeling_pi05 import layernorm_forward # noqa: PLC0415
if adarms_cond is None:
adarms_cond = [None, None]
# Single-expert paths: defer to the bound class method via super().
# Single-expert paths: defer to the original forward saved in
# PI052Policy.__init__.
if inputs_embeds[0] is None or inputs_embeds[1] is None:
return type(self).__bases__[0].forward(
self,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
adarms_cond=adarms_cond,
) if hasattr(self, "_pi052_orig_forward") else self._pi052_orig_forward(
return self._pi052_orig_forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@@ -429,7 +420,6 @@ class PI052Policy(PI05Policy):
past_key_values=None,
inputs_embeds=[full_embs, None],
use_cache=False,
fill_kv_cache=True,
)
if vlm_out is None:
raise RuntimeError("PI052 FAST loss: VLM forward returned no hidden states.")
@@ -456,12 +446,15 @@ class PI052Policy(PI05Policy):
def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor:
"""Cross-entropy on PaliGemma's LM head over the supervised span.
Re-uses the same prefix-embedding path the flow head does:
embed images + state + language tokens, run a forward pass,
slice out the per-token logits at the supervised positions,
compute CE.
Embeds images + language, runs the VLM-only forward (the
action expert is skipped via ``inputs_embeds=[..., None]``),
slices the hidden states to the *language* portion so they
align with ``text_labels`` (which covers only the language
tokens, not the image patch tokens), then computes shifted
next-token CE with ``-100`` ignoring padding/non-target
positions.
"""
from torch.nn import functional as F # noqa: PLC0415
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
images, img_masks = self.model._preprocess_images(batch)
tokens = batch[OBS_LANGUAGE_TOKENS]
@@ -470,12 +463,6 @@ class PI052Policy(PI05Policy):
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
images, img_masks, tokens, masks
)
# PaliGemma's text path: forward the prefix through the
# backbone *without* the action expert. We piggy-back on the
# existing PaliGemmaWithExpertModel.forward — it accepts a
# list of expert inputs and returns parallel outputs.
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
att_2d = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
@@ -485,18 +472,24 @@ class PI052Policy(PI05Policy):
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=False,
fill_kv_cache=True,
)
if vlm_out is None:
raise RuntimeError("PI052 text loss: VLM forward returned no hidden states.")
# Logits over the vocab via the PaliGemma lm_head.
# Slice the hidden states to the language portion. embed_prefix
# concatenates [images, language] in that order, so the trailing
# ``text_labels.shape[1]`` positions are the language tokens.
# Without this slice, applying lm_head to the full vlm_out and
# shifting against text_labels[..., 1:] produces a shape
# mismatch in cross_entropy.
lang_len = text_labels.shape[1]
text_hidden = vlm_out[:, -lang_len:, :]
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
logits = lm_head(vlm_out.to(lm_head.weight.dtype))
logits = lm_head(text_hidden.to(lm_head.weight.dtype))
# Shift for next-token prediction: predict token[i+1] from
# hidden[i]. Both ``logits`` and ``text_labels`` are over the
# same sequence length, so shift logits[:-1] vs labels[1:].
# hidden[i] within the language span.
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = text_labels[..., 1:].contiguous()
loss = F.cross_entropy(
@@ -579,7 +572,6 @@ class PI052Policy(PI05Policy):
past_key_values=None,
inputs_embeds=[current_embs, None],
use_cache=False,
fill_kv_cache=True,
)
if vlm_out is None:
break