From 35f9063a6ca65c8a2b545869a93ab6c60d89a02e Mon Sep 17 00:00:00 2001 From: Pepijn Date: Wed, 13 May 2026 12:08:34 +0200 Subject: [PATCH] perf(pi052): fuse text + FAST loss into a single prefix forward MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously the forward did three backbone passes per training step when all heads were active: one for flow (via super().forward), one for text CE, and one for FAST CE. That's ~3× the compute of flow-only training. The text and FAST losses share their prefix forward exactly — both are CE on the LM head, evaluated at different slices of the same hidden states. Adding FAST tokens after language in the prefix is bit-equivalent for the text loss because the mask_ar convention in ``make_att_2d_masks`` keeps FAST tokens in a strictly-later causal block: language tokens never see FAST, so their hidden states are unchanged. New ``_compute_text_and_fast_loss``: * embeds [images, language] once * optionally appends [FAST] (when run_fast is True) * one backbone forward * slices ``vlm_out[:, -(fast_len + lang_len):-fast_len]`` for language hidden states (or ``vlm_out[:, -lang_len:]`` when no FAST) → text CE * slices ``vlm_out[:, -fast_len:]`` for FAST hidden states → FAST CE * returns both losses, either of which can be None when the caller doesn't want that head. forward() now calls this fused helper instead of running the two separate ``_compute_text_loss`` / ``_compute_fast_action_loss`` methods. Those remain in the file for callers that only want one head (e.g. ablations). Why flow isn't fused -------------------- Flow MSE comes from the action-expert (suffix) hidden states, which attend to the prefix. If we just concat FAST onto the prefix and let the action expert attend to it, the expert can trivially decode FAST back to continuous actions — overfitting via shortcut. Preventing that requires a custom segment-aware attention mask (action expert can attend to images+language but NOT to subtask/FAST), which is what pi05_full does in ``compute_layer_complete_knowledge_insulation``. That's the full-fusion path; deferred as a follow-up since the text+FAST fusion already recovers most of the compute. End-to-end forward pass count ----------------------------- Before: 1 (flow) + 1 (text) + 1 (FAST) = 3 backbone forwards After: 1 (flow) + 1 (text+FAST fused) = 2 backbone forwards ~33% wall-time reduction per training step when all three heads are active. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lerobot/policies/pi052/modeling_pi052.py | 150 ++++++++++++++++--- 1 file changed, 129 insertions(+), 21 deletions(-) diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 657df36f7..6583fe649 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -325,38 +325,42 @@ class PI052Policy(PI05Policy): ) total = self.config.flow_loss_weight * flow_loss - if run_text: - text_loss = self._compute_text_loss(batch, text_labels) - loss_dict["text_loss"] = float(text_loss.detach().item()) - total = ( - self.config.text_loss_weight * text_loss - if total is None - else total + self.config.text_loss_weight * text_loss - ) - - # FAST action-token CE loss (paper §III.C). When - # ``enable_fast_action_loss=True`` the preprocessor wrote - # ACTION_TOKENS / ACTION_TOKEN_MASK into the batch — we - # forward them through the PaliGemma backbone alongside the - # language prefix and compute CE on the action positions. + # Text + FAST losses share the prefix forward — compute them + # together. Both are CE on the LM head at *different slices* + # of the same hidden states, so one prefix forward replaces + # two. Saves ~33% compute vs. running them separately. # - # Gated on ``predict_actions`` (same routing the flow loss - # uses): for text-only recipes the action_tokens are still - # present in the batch but shouldn't be supervised. Skip the - # entire FAST forward when no sample in the batch wants action - # supervision. + # Flow loss can't be fused into this pass without a custom + # segment-aware attention mask (the action expert would + # trivially read FAST tokens and decode them back to + # continuous actions). That's pi05_full's territory — for + # now, flow stays in a separate forward via super().forward. run_fast = ( getattr(self.config, "enable_fast_action_loss", False) and self.config.fast_action_loss_weight > 0 and (predict_actions_t is None or bool(predict_actions_t.any().item())) ) + action_tokens = action_mask = None if run_fast: from lerobot.utils.constants import ACTION_TOKEN_MASK, ACTION_TOKENS # noqa: PLC0415 action_tokens = batch.get(ACTION_TOKENS) action_mask = batch.get(ACTION_TOKEN_MASK) - if action_tokens is not None and action_mask is not None: - fast_loss = self._compute_fast_action_loss(batch, action_tokens, action_mask) + if action_tokens is None or action_mask is None: + run_fast = False + + if run_text or run_fast: + text_loss, fast_loss = self._compute_text_and_fast_loss( + batch, + text_labels=text_labels if run_text else None, + action_tokens=action_tokens if run_fast else None, + action_mask=action_mask if run_fast else None, + ) + if text_loss is not None: + loss_dict["text_loss"] = float(text_loss.detach().item()) + weighted = self.config.text_loss_weight * text_loss + total = weighted if total is None else total + weighted + if fast_loss is not None: loss_dict["fast_action_loss"] = float(fast_loss.detach().item()) weighted = self.config.fast_action_loss_weight * fast_loss total = weighted if total is None else total + weighted @@ -379,6 +383,110 @@ class PI052Policy(PI05Policy): # Text loss # ------------------------------------------------------------------ + def _compute_text_and_fast_loss( + self, + batch: dict[str, Tensor], + text_labels: Tensor | None, + action_tokens: Tensor | None, + action_mask: Tensor | None, + ) -> tuple[Tensor | None, Tensor | None]: + """Single prefix forward → text CE + FAST CE. + + Embed [images, language] (and FAST when requested) once, run + one backbone forward, then slice the resulting hidden states + at the language and FAST positions to compute both CE losses. + Bit-equivalent to running the two losses in separate forwards + because the segment-aware ``make_att_2d_masks`` keeps FAST + tokens invisible to language tokens, so adding FAST to the + prefix doesn't perturb the hidden states at language positions. + + Returns ``(text_loss, fast_loss)``. Either can be ``None`` if + the caller doesn't want that head. + """ + from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415 + + images, img_masks = self.model._preprocess_images(batch) + lang_tokens = batch[OBS_LANGUAGE_TOKENS] + lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK] + + prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix( + images, img_masks, lang_tokens, lang_masks + ) + + fast_len = 0 + if action_tokens is not None and action_mask is not None: + emb_dim = prefix_embs.shape[-1] + fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens) + fast_emb = fast_emb * math.sqrt(emb_dim) + + fast_len = action_tokens.shape[1] + ones_att = torch.ones( + (action_tokens.shape[0], fast_len), + dtype=torch.bool, + device=prefix_embs.device, + ) + full_embs = torch.cat([prefix_embs, fast_emb], dim=1) + full_pad = torch.cat([prefix_pad, action_mask.to(prefix_pad.dtype)], dim=1) + full_att = torch.cat([prefix_att, ones_att], dim=1) + else: + full_embs = prefix_embs + full_pad = prefix_pad + full_att = prefix_att + + att_2d = make_att_2d_masks(full_pad, full_att) + position_ids = torch.cumsum(full_pad, dim=1) - 1 + + (vlm_out, _), _ = self.model.paligemma_with_expert.forward( + attention_mask=att_2d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[full_embs, None], + use_cache=False, + ) + if vlm_out is None: + raise RuntimeError( + "PI052 text+fast loss: VLM forward returned no hidden states." + ) + + lm_head = self.model.paligemma_with_expert.paligemma.lm_head + + text_loss: Tensor | None = None + if text_labels is not None: + lang_len = text_labels.shape[1] + # embed_prefix lays out as [images, language]; with FAST + # appended the full sequence is [images, language, FAST]. + # Language hidden states are at positions + # ``[-(fast_len + lang_len) : -fast_len]`` when FAST is + # present, or ``[-lang_len:]`` otherwise. + if fast_len > 0: + text_hidden = vlm_out[:, -(fast_len + lang_len):-fast_len, :] + else: + text_hidden = vlm_out[:, -lang_len:, :] + text_logits = lm_head(text_hidden.to(lm_head.weight.dtype)) + shift_logits = text_logits[:, :-1, :].contiguous() + shift_labels = text_labels[:, 1:].contiguous().long() + text_loss = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.shape[-1]), + shift_labels.reshape(-1), + ignore_index=-100, + ) + + fast_loss: Tensor | None = None + if action_tokens is not None and action_mask is not None and fast_len > 0: + fast_hidden = vlm_out[:, -fast_len:, :] + fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype)) + shift_logits = fast_logits[:, :-1, :].contiguous() + shift_targets = action_tokens[:, 1:].contiguous().long() + shift_valid = action_mask[:, 1:].contiguous().bool() + shift_targets = shift_targets.masked_fill(~shift_valid, -100) + fast_loss = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.shape[-1]), + shift_targets.reshape(-1), + ignore_index=-100, + ) + + return text_loss, fast_loss + def _compute_fast_action_loss( self, batch: dict[str, Tensor],