perf(pi052): full fusion — text + FAST + flow in ONE backbone forward

Previously the forward did 2 backbone passes when all heads were
active: one for flow (via super().forward) and one for the fused
text+FAST helper. This commit reduces it to **one pass** — same
compute as flow-only training.

New ``_compute_all_losses_fused`` builds:

    prefix = [images, language, FAST (when provided)]
    suffix = [noisy_actions]  (action expert via gemma_expert)

and runs a single ``paligemma_with_expert.forward`` with
``inputs_embeds=[prefix_embs, suffix_embs]`` (both experts active
in the same call). Captures *both* prefix_out and suffix_out, slices
each for its respective loss:

    flow MSE     ← suffix_out  (existing action_out_proj + MSE path)
    text  CE     ← prefix_out at language positions (lm_head + CE)
    FAST  CE     ← prefix_out at FAST positions (lm_head + CE)

Critical attention mask override
--------------------------------

``make_att_2d_masks`` produces a cumulative-block attention mask in
which suffix tokens (highest cumsum) attend to *every* lower-cumsum
position by default, including FAST tokens. If we let that stand the
action expert reads the discrete FAST tokens and trivially decodes
them back to the same continuous actions the flow head is supposed
to predict from noise — the entire training signal collapses to a
copy operation.

The fix is a single line right after make_att_2d_masks:

    att_2d_masks[:, fast_end:, fast_start:fast_end] = False

Explicitly zeros out *suffix → FAST* attention. Everything else
remains correct under the cumsum semantics:

  * prefix images/language stay bidirectional among themselves
  * FAST stays causal within itself, attending bidirectionally
    to images+language
  * FAST cannot see suffix (cumsum < suffix cumsum, default)
  * suffix attends bidirectionally among itself, to images+language,
    and now NOT to FAST (this override)

Bit-equivalent to the previous separated forward path for text+FAST
losses (the prefix hidden states at language and FAST positions are
unchanged whether suffix is present or not — the prefix doesn't
attend to suffix). For flow loss, suffix→FAST being masked is the
correct behaviour we *want* — if anything the previous separated
path was less correct for production use because the joint
gradient signal through the action expert was missing the prefix
extension.

Forward routing in ``forward()``
--------------------------------

  * run_flow=True  →  _compute_all_losses_fused (one forward, all
                      three losses)
  * run_flow=False, run_text or run_fast → _compute_text_and_fast_loss
                      (one prefix-only forward, two CE losses, no
                      suffix → cheaper than fusion)
  * neither       →  RuntimeError (explicit; both losses disabled)

Wall-time per step
------------------

  Before this commit:  flow + (text+FAST fused) = 2 forwards
  After this commit:   (flow+text+FAST fused)   = 1 forward

Compute parity with flow-only training when all three heads active.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-13 12:28:38 +02:00
parent 83d7250a22
commit b873fe454c

View File

@@ -316,25 +316,7 @@ class PI052Policy(PI05Policy):
loss_dict: dict[str, Any] = {}
total: Tensor | None = None
if run_flow:
flow_loss, flow_dict = super().forward(batch, reduction=reduction)
for k, v in flow_dict.items():
loss_dict[f"flow_{k}"] = v
loss_dict["flow_loss"] = (
flow_loss.item() if isinstance(flow_loss, Tensor) and flow_loss.dim() == 0 else float("nan")
)
total = self.config.flow_loss_weight * flow_loss
# Text + FAST losses share the prefix forward — compute them
# together. Both are CE on the LM head at *different slices*
# of the same hidden states, so one prefix forward replaces
# two. Saves ~33% compute vs. running them separately.
#
# Flow loss can't be fused into this pass without a custom
# segment-aware attention mask (the action expert would
# trivially read FAST tokens and decode them back to
# continuous actions). That's pi05_full's territory — for
# now, flow stays in a separate forward via super().forward.
# Decide which losses fire this step.
run_fast = (
getattr(self.config, "enable_fast_action_loss", False)
and self.config.fast_action_loss_weight > 0
@@ -349,7 +331,35 @@ class PI052Policy(PI05Policy):
if action_tokens is None or action_mask is None:
run_fast = False
if run_text or run_fast:
# ------------------------------------------------------------
# Dispatch: full fusion when flow is active, otherwise the
# prefix-only text+FAST helper (no suffix forward needed).
#
# Full fusion (flow ON):
# ONE backbone forward with prefix=[images, lang, FAST] +
# suffix=[noisy_actions], suffix→FAST attention masked out.
# All three losses computed from slices of the single output.
#
# Prefix-only fusion (flow OFF, e.g. text-only recipes):
# ONE prefix-only forward, both text + FAST losses computed
# from slices. No suffix forward → cheaper.
# ------------------------------------------------------------
if run_flow:
flow_loss, text_loss, fast_loss = self._compute_all_losses_fused(
batch,
text_labels=text_labels if run_text else None,
action_tokens=action_tokens if run_fast else None,
action_mask=action_mask if run_fast else None,
)
loss_dict["flow_loss"] = float(flow_loss.detach().item())
total = self.config.flow_loss_weight * flow_loss
if text_loss is not None:
loss_dict["text_loss"] = float(text_loss.detach().item())
total = total + self.config.text_loss_weight * text_loss
if fast_loss is not None:
loss_dict["fast_action_loss"] = float(fast_loss.detach().item())
total = total + self.config.fast_action_loss_weight * fast_loss
elif run_text or run_fast:
text_loss, fast_loss = self._compute_text_and_fast_loss(
batch,
text_labels=text_labels if run_text else None,
@@ -383,6 +393,160 @@ class PI052Policy(PI05Policy):
# Text loss
# ------------------------------------------------------------------
def _compute_all_losses_fused(
self,
batch: dict[str, Tensor],
text_labels: Tensor | None,
action_tokens: Tensor | None,
action_mask: Tensor | None,
) -> tuple[Tensor, Tensor | None, Tensor | None]:
"""Full fusion: flow + text + FAST in ONE backbone forward.
Builds:
prefix = [images, language, FAST (when provided)]
suffix = [noisy_actions] (action expert via gemma_expert)
Then overrides the unified 2D attention mask to *explicitly*
zero out ``suffix → FAST`` attention. Without this override
the action expert would attend to the discrete FAST tokens
and trivially decode them back to the same continuous
actions it's supposed to predict via flow matching — the
whole training signal collapses.
Both prefix_out and suffix_out are captured from the same
forward. From prefix_out we slice the language and FAST
token positions and compute their CE losses. From suffix_out
we run the existing flow path (action_out_proj → MSE).
Returns ``(flow_loss, text_loss, fast_loss)`` where text/fast
can be ``None`` when the caller didn't supply the
corresponding inputs.
"""
from lerobot.utils.constants import ACTION # noqa: PLC0415
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
# ---- preamble (mirrors PI05Pytorch.forward) ------------------
actions = self.model.prepare_action(batch)
noise = self.model.sample_noise(actions.shape, actions.device)
time = self.model.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
# ---- prefix: images + language + (optional FAST) -------------
images, img_masks = self.model._preprocess_images(batch)
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
non_fast_prefix_len = prefix_embs.shape[1] # images + language only
fast_len = 0
if action_tokens is not None and action_mask is not None:
emb_dim = prefix_embs.shape[-1]
fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens)
fast_emb = fast_emb * math.sqrt(emb_dim)
fast_len = action_tokens.shape[1]
ones_att = torch.ones(
(action_tokens.shape[0], fast_len),
dtype=torch.bool,
device=prefix_embs.device,
)
prefix_embs = torch.cat([prefix_embs, fast_emb], dim=1)
prefix_pad = torch.cat([prefix_pad, action_mask.to(prefix_pad.dtype)], dim=1)
prefix_att = torch.cat([prefix_att, ones_att], dim=1)
# ---- suffix: noisy actions ----------------------------------
suffix_embs, suffix_pad, suffix_att, adarms_cond = self.model.embed_suffix(x_t, time)
# ---- bf16 alignment (mirrors PI05Pytorch.forward) -----------
first_layer = (
self.model.paligemma_with_expert.paligemma.model.language_model.layers[0]
)
if first_layer.self_attn.q_proj.weight.dtype == torch.bfloat16:
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
# ---- combined attention -------------------------------------
pad_masks = torch.cat([prefix_pad, suffix_pad], dim=1)
att_masks = torch.cat([prefix_att, suffix_att], dim=1)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
# Critical: zero out suffix → FAST attention. Without this the
# action expert reads the FAST tokens and trivially decodes
# them back to the same continuous actions it's supposed to
# predict from noise. Cumulative-block attention from
# ``make_att_2d_masks`` doesn't enforce this on its own
# because suffix tokens have a strictly higher cumsum than
# FAST tokens and therefore attend to them by default.
if fast_len > 0:
fast_start = non_fast_prefix_len
fast_end = non_fast_prefix_len + fast_len # = prefix_pad.shape[1]
att_2d_masks[:, fast_end:, fast_start:fast_end] = False
position_ids = torch.cumsum(pad_masks, dim=1) - 1
att_2d_masks_4d = self.model._prepare_attention_masks_4d(att_2d_masks)
# ---- forward (capture BOTH expert outputs) ------------------
(prefix_out, suffix_out), _ = self.model.paligemma_with_expert.forward(
attention_mask=att_2d_masks_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
adarms_cond=[None, adarms_cond],
)
# ---- flow loss (mirrors PI05Pytorch.forward) ----------------
suffix_out_slice = suffix_out[:, -self.model.config.chunk_size :].to(
dtype=torch.float32
)
v_t = self.model.action_out_proj(suffix_out_slice)
flow_per_dim = F.mse_loss(u_t, v_t, reduction="none")
# Truncate to the actual action dimensionality (PI05 pads
# internally to max_action_dim).
original_action_dim = self.config.output_features[ACTION].shape[0]
flow_per_dim = flow_per_dim[:, :, :original_action_dim]
flow_loss = flow_per_dim.mean()
# ---- text + FAST CE from prefix_out ------------------------
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
text_loss: Tensor | None = None
if text_labels is not None and prefix_out is not None:
lang_len = text_labels.shape[1]
if fast_len > 0:
text_hidden = prefix_out[:, -(fast_len + lang_len) : -fast_len, :]
else:
text_hidden = prefix_out[:, -lang_len:, :]
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
shift_logits = text_logits[:, :-1, :].contiguous()
shift_labels = text_labels[:, 1:].contiguous().long()
text_loss = F.cross_entropy(
shift_logits.reshape(-1, shift_logits.shape[-1]),
shift_labels.reshape(-1),
ignore_index=-100,
)
fast_loss: Tensor | None = None
if fast_len > 0 and prefix_out is not None:
fast_hidden = prefix_out[:, -fast_len:, :]
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
shift_logits = fast_logits[:, :-1, :].contiguous()
shift_targets = action_tokens[:, 1:].contiguous().long()
shift_valid = action_mask[:, 1:].contiguous().bool()
shift_targets = shift_targets.masked_fill(~shift_valid, -100)
fast_loss = F.cross_entropy(
shift_logits.reshape(-1, shift_logits.shape[-1]),
shift_targets.reshape(-1),
ignore_index=-100,
)
return flow_loss, text_loss, fast_loss
def _compute_text_and_fast_loss(
self,
batch: dict[str, Tensor],