perf(pi052): fuse text + FAST loss into a single prefix forward

Previously the forward did three backbone passes per training step
when all heads were active: one for flow (via super().forward), one
for text CE, and one for FAST CE. That's ~3× the compute of
flow-only training.

The text and FAST losses share their prefix forward exactly — both
are CE on the LM head, evaluated at different slices of the same
hidden states. Adding FAST tokens after language in the prefix is
bit-equivalent for the text loss because the mask_ar convention in
``make_att_2d_masks`` keeps FAST tokens in a strictly-later causal
block: language tokens never see FAST, so their hidden states are
unchanged.

New ``_compute_text_and_fast_loss``:

  * embeds [images, language] once
  * optionally appends [FAST] (when run_fast is True)
  * one backbone forward
  * slices ``vlm_out[:, -(fast_len + lang_len):-fast_len]`` for
    language hidden states (or ``vlm_out[:, -lang_len:]`` when no
    FAST) → text CE
  * slices ``vlm_out[:, -fast_len:]`` for FAST hidden states →
    FAST CE
  * returns both losses, either of which can be None when the
    caller doesn't want that head.

forward() now calls this fused helper instead of running the two
separate ``_compute_text_loss`` / ``_compute_fast_action_loss``
methods. Those remain in the file for callers that only want one
head (e.g. ablations).

Why flow isn't fused
--------------------

Flow MSE comes from the action-expert (suffix) hidden states, which
attend to the prefix. If we just concat FAST onto the prefix and let
the action expert attend to it, the expert can trivially decode FAST
back to continuous actions — overfitting via shortcut. Preventing
that requires a custom segment-aware attention mask (action expert
can attend to images+language but NOT to subtask/FAST), which is
what pi05_full does in ``compute_layer_complete_knowledge_insulation``.
That's the full-fusion path; deferred as a follow-up since the
text+FAST fusion already recovers most of the compute.

End-to-end forward pass count
-----------------------------

Before: 1 (flow) + 1 (text) + 1 (FAST) = 3 backbone forwards
After:  1 (flow) + 1 (text+FAST fused) = 2 backbone forwards

~33% wall-time reduction per training step when all three heads
are active.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-13 12:08:34 +02:00
parent 17c0800461
commit 35f9063a6c

View File

@@ -325,38 +325,42 @@ class PI052Policy(PI05Policy):
)
total = self.config.flow_loss_weight * flow_loss
if run_text:
text_loss = self._compute_text_loss(batch, text_labels)
loss_dict["text_loss"] = float(text_loss.detach().item())
total = (
self.config.text_loss_weight * text_loss
if total is None
else total + self.config.text_loss_weight * text_loss
)
# FAST action-token CE loss (paper §III.C). When
# ``enable_fast_action_loss=True`` the preprocessor wrote
# ACTION_TOKENS / ACTION_TOKEN_MASK into the batch — we
# forward them through the PaliGemma backbone alongside the
# language prefix and compute CE on the action positions.
# Text + FAST losses share the prefix forward — compute them
# together. Both are CE on the LM head at *different slices*
# of the same hidden states, so one prefix forward replaces
# two. Saves ~33% compute vs. running them separately.
#
# Gated on ``predict_actions`` (same routing the flow loss
# uses): for text-only recipes the action_tokens are still
# present in the batch but shouldn't be supervised. Skip the
# entire FAST forward when no sample in the batch wants action
# supervision.
# Flow loss can't be fused into this pass without a custom
# segment-aware attention mask (the action expert would
# trivially read FAST tokens and decode them back to
# continuous actions). That's pi05_full's territory — for
# now, flow stays in a separate forward via super().forward.
run_fast = (
getattr(self.config, "enable_fast_action_loss", False)
and self.config.fast_action_loss_weight > 0
and (predict_actions_t is None or bool(predict_actions_t.any().item()))
)
action_tokens = action_mask = None
if run_fast:
from lerobot.utils.constants import ACTION_TOKEN_MASK, ACTION_TOKENS # noqa: PLC0415
action_tokens = batch.get(ACTION_TOKENS)
action_mask = batch.get(ACTION_TOKEN_MASK)
if action_tokens is not None and action_mask is not None:
fast_loss = self._compute_fast_action_loss(batch, action_tokens, action_mask)
if action_tokens is None or action_mask is None:
run_fast = False
if run_text or run_fast:
text_loss, fast_loss = self._compute_text_and_fast_loss(
batch,
text_labels=text_labels if run_text else None,
action_tokens=action_tokens if run_fast else None,
action_mask=action_mask if run_fast else None,
)
if text_loss is not None:
loss_dict["text_loss"] = float(text_loss.detach().item())
weighted = self.config.text_loss_weight * text_loss
total = weighted if total is None else total + weighted
if fast_loss is not None:
loss_dict["fast_action_loss"] = float(fast_loss.detach().item())
weighted = self.config.fast_action_loss_weight * fast_loss
total = weighted if total is None else total + weighted
@@ -379,6 +383,110 @@ class PI052Policy(PI05Policy):
# Text loss
# ------------------------------------------------------------------
def _compute_text_and_fast_loss(
self,
batch: dict[str, Tensor],
text_labels: Tensor | None,
action_tokens: Tensor | None,
action_mask: Tensor | None,
) -> tuple[Tensor | None, Tensor | None]:
"""Single prefix forward → text CE + FAST CE.
Embed [images, language] (and FAST when requested) once, run
one backbone forward, then slice the resulting hidden states
at the language and FAST positions to compute both CE losses.
Bit-equivalent to running the two losses in separate forwards
because the segment-aware ``make_att_2d_masks`` keeps FAST
tokens invisible to language tokens, so adding FAST to the
prefix doesn't perturb the hidden states at language positions.
Returns ``(text_loss, fast_loss)``. Either can be ``None`` if
the caller doesn't want that head.
"""
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
images, img_masks = self.model._preprocess_images(batch)
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
fast_len = 0
if action_tokens is not None and action_mask is not None:
emb_dim = prefix_embs.shape[-1]
fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens)
fast_emb = fast_emb * math.sqrt(emb_dim)
fast_len = action_tokens.shape[1]
ones_att = torch.ones(
(action_tokens.shape[0], fast_len),
dtype=torch.bool,
device=prefix_embs.device,
)
full_embs = torch.cat([prefix_embs, fast_emb], dim=1)
full_pad = torch.cat([prefix_pad, action_mask.to(prefix_pad.dtype)], dim=1)
full_att = torch.cat([prefix_att, ones_att], dim=1)
else:
full_embs = prefix_embs
full_pad = prefix_pad
full_att = prefix_att
att_2d = make_att_2d_masks(full_pad, full_att)
position_ids = torch.cumsum(full_pad, dim=1) - 1
(vlm_out, _), _ = self.model.paligemma_with_expert.forward(
attention_mask=att_2d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[full_embs, None],
use_cache=False,
)
if vlm_out is None:
raise RuntimeError(
"PI052 text+fast loss: VLM forward returned no hidden states."
)
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
text_loss: Tensor | None = None
if text_labels is not None:
lang_len = text_labels.shape[1]
# embed_prefix lays out as [images, language]; with FAST
# appended the full sequence is [images, language, FAST].
# Language hidden states are at positions
# ``[-(fast_len + lang_len) : -fast_len]`` when FAST is
# present, or ``[-lang_len:]`` otherwise.
if fast_len > 0:
text_hidden = vlm_out[:, -(fast_len + lang_len):-fast_len, :]
else:
text_hidden = vlm_out[:, -lang_len:, :]
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
shift_logits = text_logits[:, :-1, :].contiguous()
shift_labels = text_labels[:, 1:].contiguous().long()
text_loss = F.cross_entropy(
shift_logits.reshape(-1, shift_logits.shape[-1]),
shift_labels.reshape(-1),
ignore_index=-100,
)
fast_loss: Tensor | None = None
if action_tokens is not None and action_mask is not None and fast_len > 0:
fast_hidden = vlm_out[:, -fast_len:, :]
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
shift_logits = fast_logits[:, :-1, :].contiguous()
shift_targets = action_tokens[:, 1:].contiguous().long()
shift_valid = action_mask[:, 1:].contiguous().bool()
shift_targets = shift_targets.masked_fill(~shift_valid, -100)
fast_loss = F.cross_entropy(
shift_logits.reshape(-1, shift_logits.shape[-1]),
shift_targets.reshape(-1),
ignore_index=-100,
)
return text_loss, fast_loss
def _compute_fast_action_loss(
self,
batch: dict[str, Tensor],