diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 6583fe649..b0e7269a2 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -316,25 +316,7 @@ class PI052Policy(PI05Policy): loss_dict: dict[str, Any] = {} total: Tensor | None = None - if run_flow: - flow_loss, flow_dict = super().forward(batch, reduction=reduction) - for k, v in flow_dict.items(): - loss_dict[f"flow_{k}"] = v - loss_dict["flow_loss"] = ( - flow_loss.item() if isinstance(flow_loss, Tensor) and flow_loss.dim() == 0 else float("nan") - ) - total = self.config.flow_loss_weight * flow_loss - - # Text + FAST losses share the prefix forward — compute them - # together. Both are CE on the LM head at *different slices* - # of the same hidden states, so one prefix forward replaces - # two. Saves ~33% compute vs. running them separately. - # - # Flow loss can't be fused into this pass without a custom - # segment-aware attention mask (the action expert would - # trivially read FAST tokens and decode them back to - # continuous actions). That's pi05_full's territory — for - # now, flow stays in a separate forward via super().forward. + # Decide which losses fire this step. run_fast = ( getattr(self.config, "enable_fast_action_loss", False) and self.config.fast_action_loss_weight > 0 @@ -349,7 +331,35 @@ class PI052Policy(PI05Policy): if action_tokens is None or action_mask is None: run_fast = False - if run_text or run_fast: + # ------------------------------------------------------------ + # Dispatch: full fusion when flow is active, otherwise the + # prefix-only text+FAST helper (no suffix forward needed). + # + # Full fusion (flow ON): + # ONE backbone forward with prefix=[images, lang, FAST] + + # suffix=[noisy_actions], suffix→FAST attention masked out. + # All three losses computed from slices of the single output. + # + # Prefix-only fusion (flow OFF, e.g. text-only recipes): + # ONE prefix-only forward, both text + FAST losses computed + # from slices. No suffix forward → cheaper. + # ------------------------------------------------------------ + if run_flow: + flow_loss, text_loss, fast_loss = self._compute_all_losses_fused( + batch, + text_labels=text_labels if run_text else None, + action_tokens=action_tokens if run_fast else None, + action_mask=action_mask if run_fast else None, + ) + loss_dict["flow_loss"] = float(flow_loss.detach().item()) + total = self.config.flow_loss_weight * flow_loss + if text_loss is not None: + loss_dict["text_loss"] = float(text_loss.detach().item()) + total = total + self.config.text_loss_weight * text_loss + if fast_loss is not None: + loss_dict["fast_action_loss"] = float(fast_loss.detach().item()) + total = total + self.config.fast_action_loss_weight * fast_loss + elif run_text or run_fast: text_loss, fast_loss = self._compute_text_and_fast_loss( batch, text_labels=text_labels if run_text else None, @@ -383,6 +393,160 @@ class PI052Policy(PI05Policy): # Text loss # ------------------------------------------------------------------ + def _compute_all_losses_fused( + self, + batch: dict[str, Tensor], + text_labels: Tensor | None, + action_tokens: Tensor | None, + action_mask: Tensor | None, + ) -> tuple[Tensor, Tensor | None, Tensor | None]: + """Full fusion: flow + text + FAST in ONE backbone forward. + + Builds: + + prefix = [images, language, FAST (when provided)] + suffix = [noisy_actions] (action expert via gemma_expert) + + Then overrides the unified 2D attention mask to *explicitly* + zero out ``suffix → FAST`` attention. Without this override + the action expert would attend to the discrete FAST tokens + and trivially decode them back to the same continuous + actions it's supposed to predict via flow matching — the + whole training signal collapses. + + Both prefix_out and suffix_out are captured from the same + forward. From prefix_out we slice the language and FAST + token positions and compute their CE losses. From suffix_out + we run the existing flow path (action_out_proj → MSE). + + Returns ``(flow_loss, text_loss, fast_loss)`` where text/fast + can be ``None`` when the caller didn't supply the + corresponding inputs. + """ + from lerobot.utils.constants import ACTION # noqa: PLC0415 + + from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415 + + # ---- preamble (mirrors PI05Pytorch.forward) ------------------ + actions = self.model.prepare_action(batch) + noise = self.model.sample_noise(actions.shape, actions.device) + time = self.model.sample_time(actions.shape[0], actions.device) + time_expanded = time[:, None, None] + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t = noise - actions + + # ---- prefix: images + language + (optional FAST) ------------- + images, img_masks = self.model._preprocess_images(batch) + lang_tokens = batch[OBS_LANGUAGE_TOKENS] + lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK] + prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix( + images, img_masks, lang_tokens, lang_masks + ) + non_fast_prefix_len = prefix_embs.shape[1] # images + language only + + fast_len = 0 + if action_tokens is not None and action_mask is not None: + emb_dim = prefix_embs.shape[-1] + fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens) + fast_emb = fast_emb * math.sqrt(emb_dim) + fast_len = action_tokens.shape[1] + ones_att = torch.ones( + (action_tokens.shape[0], fast_len), + dtype=torch.bool, + device=prefix_embs.device, + ) + prefix_embs = torch.cat([prefix_embs, fast_emb], dim=1) + prefix_pad = torch.cat([prefix_pad, action_mask.to(prefix_pad.dtype)], dim=1) + prefix_att = torch.cat([prefix_att, ones_att], dim=1) + + # ---- suffix: noisy actions ---------------------------------- + suffix_embs, suffix_pad, suffix_att, adarms_cond = self.model.embed_suffix(x_t, time) + + # ---- bf16 alignment (mirrors PI05Pytorch.forward) ----------- + first_layer = ( + self.model.paligemma_with_expert.paligemma.model.language_model.layers[0] + ) + if first_layer.self_attn.q_proj.weight.dtype == torch.bfloat16: + suffix_embs = suffix_embs.to(dtype=torch.bfloat16) + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + + # ---- combined attention ------------------------------------- + pad_masks = torch.cat([prefix_pad, suffix_pad], dim=1) + att_masks = torch.cat([prefix_att, suffix_att], dim=1) + att_2d_masks = make_att_2d_masks(pad_masks, att_masks) + + # Critical: zero out suffix → FAST attention. Without this the + # action expert reads the FAST tokens and trivially decodes + # them back to the same continuous actions it's supposed to + # predict from noise. Cumulative-block attention from + # ``make_att_2d_masks`` doesn't enforce this on its own + # because suffix tokens have a strictly higher cumsum than + # FAST tokens and therefore attend to them by default. + if fast_len > 0: + fast_start = non_fast_prefix_len + fast_end = non_fast_prefix_len + fast_len # = prefix_pad.shape[1] + att_2d_masks[:, fast_end:, fast_start:fast_end] = False + + position_ids = torch.cumsum(pad_masks, dim=1) - 1 + att_2d_masks_4d = self.model._prepare_attention_masks_4d(att_2d_masks) + + # ---- forward (capture BOTH expert outputs) ------------------ + (prefix_out, suffix_out), _ = self.model.paligemma_with_expert.forward( + attention_mask=att_2d_masks_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + + # ---- flow loss (mirrors PI05Pytorch.forward) ---------------- + suffix_out_slice = suffix_out[:, -self.model.config.chunk_size :].to( + dtype=torch.float32 + ) + v_t = self.model.action_out_proj(suffix_out_slice) + flow_per_dim = F.mse_loss(u_t, v_t, reduction="none") + # Truncate to the actual action dimensionality (PI05 pads + # internally to max_action_dim). + original_action_dim = self.config.output_features[ACTION].shape[0] + flow_per_dim = flow_per_dim[:, :, :original_action_dim] + flow_loss = flow_per_dim.mean() + + # ---- text + FAST CE from prefix_out ------------------------ + lm_head = self.model.paligemma_with_expert.paligemma.lm_head + + text_loss: Tensor | None = None + if text_labels is not None and prefix_out is not None: + lang_len = text_labels.shape[1] + if fast_len > 0: + text_hidden = prefix_out[:, -(fast_len + lang_len) : -fast_len, :] + else: + text_hidden = prefix_out[:, -lang_len:, :] + text_logits = lm_head(text_hidden.to(lm_head.weight.dtype)) + shift_logits = text_logits[:, :-1, :].contiguous() + shift_labels = text_labels[:, 1:].contiguous().long() + text_loss = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.shape[-1]), + shift_labels.reshape(-1), + ignore_index=-100, + ) + + fast_loss: Tensor | None = None + if fast_len > 0 and prefix_out is not None: + fast_hidden = prefix_out[:, -fast_len:, :] + fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype)) + shift_logits = fast_logits[:, :-1, :].contiguous() + shift_targets = action_tokens[:, 1:].contiguous().long() + shift_valid = action_mask[:, 1:].contiguous().bool() + shift_targets = shift_targets.masked_fill(~shift_valid, -100) + fast_loss = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.shape[-1]), + shift_targets.reshape(-1), + ignore_index=-100, + ) + + return flow_loss, text_loss, fast_loss + def _compute_text_and_fast_loss( self, batch: dict[str, Tensor],