diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 657df36f7..6583fe649 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -325,38 +325,42 @@ class PI052Policy(PI05Policy): ) total = self.config.flow_loss_weight * flow_loss - if run_text: - text_loss = self._compute_text_loss(batch, text_labels) - loss_dict["text_loss"] = float(text_loss.detach().item()) - total = ( - self.config.text_loss_weight * text_loss - if total is None - else total + self.config.text_loss_weight * text_loss - ) - - # FAST action-token CE loss (paper §III.C). When - # ``enable_fast_action_loss=True`` the preprocessor wrote - # ACTION_TOKENS / ACTION_TOKEN_MASK into the batch — we - # forward them through the PaliGemma backbone alongside the - # language prefix and compute CE on the action positions. + # Text + FAST losses share the prefix forward — compute them + # together. Both are CE on the LM head at *different slices* + # of the same hidden states, so one prefix forward replaces + # two. Saves ~33% compute vs. running them separately. # - # Gated on ``predict_actions`` (same routing the flow loss - # uses): for text-only recipes the action_tokens are still - # present in the batch but shouldn't be supervised. Skip the - # entire FAST forward when no sample in the batch wants action - # supervision. + # Flow loss can't be fused into this pass without a custom + # segment-aware attention mask (the action expert would + # trivially read FAST tokens and decode them back to + # continuous actions). That's pi05_full's territory — for + # now, flow stays in a separate forward via super().forward. run_fast = ( getattr(self.config, "enable_fast_action_loss", False) and self.config.fast_action_loss_weight > 0 and (predict_actions_t is None or bool(predict_actions_t.any().item())) ) + action_tokens = action_mask = None if run_fast: from lerobot.utils.constants import ACTION_TOKEN_MASK, ACTION_TOKENS # noqa: PLC0415 action_tokens = batch.get(ACTION_TOKENS) action_mask = batch.get(ACTION_TOKEN_MASK) - if action_tokens is not None and action_mask is not None: - fast_loss = self._compute_fast_action_loss(batch, action_tokens, action_mask) + if action_tokens is None or action_mask is None: + run_fast = False + + if run_text or run_fast: + text_loss, fast_loss = self._compute_text_and_fast_loss( + batch, + text_labels=text_labels if run_text else None, + action_tokens=action_tokens if run_fast else None, + action_mask=action_mask if run_fast else None, + ) + if text_loss is not None: + loss_dict["text_loss"] = float(text_loss.detach().item()) + weighted = self.config.text_loss_weight * text_loss + total = weighted if total is None else total + weighted + if fast_loss is not None: loss_dict["fast_action_loss"] = float(fast_loss.detach().item()) weighted = self.config.fast_action_loss_weight * fast_loss total = weighted if total is None else total + weighted @@ -379,6 +383,110 @@ class PI052Policy(PI05Policy): # Text loss # ------------------------------------------------------------------ + def _compute_text_and_fast_loss( + self, + batch: dict[str, Tensor], + text_labels: Tensor | None, + action_tokens: Tensor | None, + action_mask: Tensor | None, + ) -> tuple[Tensor | None, Tensor | None]: + """Single prefix forward → text CE + FAST CE. + + Embed [images, language] (and FAST when requested) once, run + one backbone forward, then slice the resulting hidden states + at the language and FAST positions to compute both CE losses. + Bit-equivalent to running the two losses in separate forwards + because the segment-aware ``make_att_2d_masks`` keeps FAST + tokens invisible to language tokens, so adding FAST to the + prefix doesn't perturb the hidden states at language positions. + + Returns ``(text_loss, fast_loss)``. Either can be ``None`` if + the caller doesn't want that head. + """ + from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415 + + images, img_masks = self.model._preprocess_images(batch) + lang_tokens = batch[OBS_LANGUAGE_TOKENS] + lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK] + + prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix( + images, img_masks, lang_tokens, lang_masks + ) + + fast_len = 0 + if action_tokens is not None and action_mask is not None: + emb_dim = prefix_embs.shape[-1] + fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens) + fast_emb = fast_emb * math.sqrt(emb_dim) + + fast_len = action_tokens.shape[1] + ones_att = torch.ones( + (action_tokens.shape[0], fast_len), + dtype=torch.bool, + device=prefix_embs.device, + ) + full_embs = torch.cat([prefix_embs, fast_emb], dim=1) + full_pad = torch.cat([prefix_pad, action_mask.to(prefix_pad.dtype)], dim=1) + full_att = torch.cat([prefix_att, ones_att], dim=1) + else: + full_embs = prefix_embs + full_pad = prefix_pad + full_att = prefix_att + + att_2d = make_att_2d_masks(full_pad, full_att) + position_ids = torch.cumsum(full_pad, dim=1) - 1 + + (vlm_out, _), _ = self.model.paligemma_with_expert.forward( + attention_mask=att_2d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[full_embs, None], + use_cache=False, + ) + if vlm_out is None: + raise RuntimeError( + "PI052 text+fast loss: VLM forward returned no hidden states." + ) + + lm_head = self.model.paligemma_with_expert.paligemma.lm_head + + text_loss: Tensor | None = None + if text_labels is not None: + lang_len = text_labels.shape[1] + # embed_prefix lays out as [images, language]; with FAST + # appended the full sequence is [images, language, FAST]. + # Language hidden states are at positions + # ``[-(fast_len + lang_len) : -fast_len]`` when FAST is + # present, or ``[-lang_len:]`` otherwise. + if fast_len > 0: + text_hidden = vlm_out[:, -(fast_len + lang_len):-fast_len, :] + else: + text_hidden = vlm_out[:, -lang_len:, :] + text_logits = lm_head(text_hidden.to(lm_head.weight.dtype)) + shift_logits = text_logits[:, :-1, :].contiguous() + shift_labels = text_labels[:, 1:].contiguous().long() + text_loss = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.shape[-1]), + shift_labels.reshape(-1), + ignore_index=-100, + ) + + fast_loss: Tensor | None = None + if action_tokens is not None and action_mask is not None and fast_len > 0: + fast_hidden = vlm_out[:, -fast_len:, :] + fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype)) + shift_logits = fast_logits[:, :-1, :].contiguous() + shift_targets = action_tokens[:, 1:].contiguous().long() + shift_valid = action_mask[:, 1:].contiguous().bool() + shift_targets = shift_targets.masked_fill(~shift_valid, -100) + fast_loss = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.shape[-1]), + shift_targets.reshape(-1), + ignore_index=-100, + ) + + return text_loss, fast_loss + def _compute_fast_action_loss( self, batch: dict[str, Tensor],