mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
perf(pi052): full fusion — text + FAST + flow in ONE backbone forward
Previously the forward did 2 backbone passes when all heads were
active: one for flow (via super().forward) and one for the fused
text+FAST helper. This commit reduces it to **one pass** — same
compute as flow-only training.
New ``_compute_all_losses_fused`` builds:
prefix = [images, language, FAST (when provided)]
suffix = [noisy_actions] (action expert via gemma_expert)
and runs a single ``paligemma_with_expert.forward`` with
``inputs_embeds=[prefix_embs, suffix_embs]`` (both experts active
in the same call). Captures *both* prefix_out and suffix_out, slices
each for its respective loss:
flow MSE ← suffix_out (existing action_out_proj + MSE path)
text CE ← prefix_out at language positions (lm_head + CE)
FAST CE ← prefix_out at FAST positions (lm_head + CE)
Critical attention mask override
--------------------------------
``make_att_2d_masks`` produces a cumulative-block attention mask in
which suffix tokens (highest cumsum) attend to *every* lower-cumsum
position by default, including FAST tokens. If we let that stand the
action expert reads the discrete FAST tokens and trivially decodes
them back to the same continuous actions the flow head is supposed
to predict from noise — the entire training signal collapses to a
copy operation.
The fix is a single line right after make_att_2d_masks:
att_2d_masks[:, fast_end:, fast_start:fast_end] = False
Explicitly zeros out *suffix → FAST* attention. Everything else
remains correct under the cumsum semantics:
* prefix images/language stay bidirectional among themselves
* FAST stays causal within itself, attending bidirectionally
to images+language
* FAST cannot see suffix (cumsum < suffix cumsum, default)
* suffix attends bidirectionally among itself, to images+language,
and now NOT to FAST (this override)
Bit-equivalent to the previous separated forward path for text+FAST
losses (the prefix hidden states at language and FAST positions are
unchanged whether suffix is present or not — the prefix doesn't
attend to suffix). For flow loss, suffix→FAST being masked is the
correct behaviour we *want* — if anything the previous separated
path was less correct for production use because the joint
gradient signal through the action expert was missing the prefix
extension.
Forward routing in ``forward()``
--------------------------------
* run_flow=True → _compute_all_losses_fused (one forward, all
three losses)
* run_flow=False, run_text or run_fast → _compute_text_and_fast_loss
(one prefix-only forward, two CE losses, no
suffix → cheaper than fusion)
* neither → RuntimeError (explicit; both losses disabled)
Wall-time per step
------------------
Before this commit: flow + (text+FAST fused) = 2 forwards
After this commit: (flow+text+FAST fused) = 1 forward
Compute parity with flow-only training when all three heads active.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -316,25 +316,7 @@ class PI052Policy(PI05Policy):
|
||||
loss_dict: dict[str, Any] = {}
|
||||
total: Tensor | None = None
|
||||
|
||||
if run_flow:
|
||||
flow_loss, flow_dict = super().forward(batch, reduction=reduction)
|
||||
for k, v in flow_dict.items():
|
||||
loss_dict[f"flow_{k}"] = v
|
||||
loss_dict["flow_loss"] = (
|
||||
flow_loss.item() if isinstance(flow_loss, Tensor) and flow_loss.dim() == 0 else float("nan")
|
||||
)
|
||||
total = self.config.flow_loss_weight * flow_loss
|
||||
|
||||
# Text + FAST losses share the prefix forward — compute them
|
||||
# together. Both are CE on the LM head at *different slices*
|
||||
# of the same hidden states, so one prefix forward replaces
|
||||
# two. Saves ~33% compute vs. running them separately.
|
||||
#
|
||||
# Flow loss can't be fused into this pass without a custom
|
||||
# segment-aware attention mask (the action expert would
|
||||
# trivially read FAST tokens and decode them back to
|
||||
# continuous actions). That's pi05_full's territory — for
|
||||
# now, flow stays in a separate forward via super().forward.
|
||||
# Decide which losses fire this step.
|
||||
run_fast = (
|
||||
getattr(self.config, "enable_fast_action_loss", False)
|
||||
and self.config.fast_action_loss_weight > 0
|
||||
@@ -349,7 +331,35 @@ class PI052Policy(PI05Policy):
|
||||
if action_tokens is None or action_mask is None:
|
||||
run_fast = False
|
||||
|
||||
if run_text or run_fast:
|
||||
# ------------------------------------------------------------
|
||||
# Dispatch: full fusion when flow is active, otherwise the
|
||||
# prefix-only text+FAST helper (no suffix forward needed).
|
||||
#
|
||||
# Full fusion (flow ON):
|
||||
# ONE backbone forward with prefix=[images, lang, FAST] +
|
||||
# suffix=[noisy_actions], suffix→FAST attention masked out.
|
||||
# All three losses computed from slices of the single output.
|
||||
#
|
||||
# Prefix-only fusion (flow OFF, e.g. text-only recipes):
|
||||
# ONE prefix-only forward, both text + FAST losses computed
|
||||
# from slices. No suffix forward → cheaper.
|
||||
# ------------------------------------------------------------
|
||||
if run_flow:
|
||||
flow_loss, text_loss, fast_loss = self._compute_all_losses_fused(
|
||||
batch,
|
||||
text_labels=text_labels if run_text else None,
|
||||
action_tokens=action_tokens if run_fast else None,
|
||||
action_mask=action_mask if run_fast else None,
|
||||
)
|
||||
loss_dict["flow_loss"] = float(flow_loss.detach().item())
|
||||
total = self.config.flow_loss_weight * flow_loss
|
||||
if text_loss is not None:
|
||||
loss_dict["text_loss"] = float(text_loss.detach().item())
|
||||
total = total + self.config.text_loss_weight * text_loss
|
||||
if fast_loss is not None:
|
||||
loss_dict["fast_action_loss"] = float(fast_loss.detach().item())
|
||||
total = total + self.config.fast_action_loss_weight * fast_loss
|
||||
elif run_text or run_fast:
|
||||
text_loss, fast_loss = self._compute_text_and_fast_loss(
|
||||
batch,
|
||||
text_labels=text_labels if run_text else None,
|
||||
@@ -383,6 +393,160 @@ class PI052Policy(PI05Policy):
|
||||
# Text loss
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compute_all_losses_fused(
|
||||
self,
|
||||
batch: dict[str, Tensor],
|
||||
text_labels: Tensor | None,
|
||||
action_tokens: Tensor | None,
|
||||
action_mask: Tensor | None,
|
||||
) -> tuple[Tensor, Tensor | None, Tensor | None]:
|
||||
"""Full fusion: flow + text + FAST in ONE backbone forward.
|
||||
|
||||
Builds:
|
||||
|
||||
prefix = [images, language, FAST (when provided)]
|
||||
suffix = [noisy_actions] (action expert via gemma_expert)
|
||||
|
||||
Then overrides the unified 2D attention mask to *explicitly*
|
||||
zero out ``suffix → FAST`` attention. Without this override
|
||||
the action expert would attend to the discrete FAST tokens
|
||||
and trivially decode them back to the same continuous
|
||||
actions it's supposed to predict via flow matching — the
|
||||
whole training signal collapses.
|
||||
|
||||
Both prefix_out and suffix_out are captured from the same
|
||||
forward. From prefix_out we slice the language and FAST
|
||||
token positions and compute their CE losses. From suffix_out
|
||||
we run the existing flow path (action_out_proj → MSE).
|
||||
|
||||
Returns ``(flow_loss, text_loss, fast_loss)`` where text/fast
|
||||
can be ``None`` when the caller didn't supply the
|
||||
corresponding inputs.
|
||||
"""
|
||||
from lerobot.utils.constants import ACTION # noqa: PLC0415
|
||||
|
||||
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
|
||||
|
||||
# ---- preamble (mirrors PI05Pytorch.forward) ------------------
|
||||
actions = self.model.prepare_action(batch)
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
time = self.model.sample_time(actions.shape[0], actions.device)
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
# ---- prefix: images + language + (optional FAST) -------------
|
||||
images, img_masks = self.model._preprocess_images(batch)
|
||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
non_fast_prefix_len = prefix_embs.shape[1] # images + language only
|
||||
|
||||
fast_len = 0
|
||||
if action_tokens is not None and action_mask is not None:
|
||||
emb_dim = prefix_embs.shape[-1]
|
||||
fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens)
|
||||
fast_emb = fast_emb * math.sqrt(emb_dim)
|
||||
fast_len = action_tokens.shape[1]
|
||||
ones_att = torch.ones(
|
||||
(action_tokens.shape[0], fast_len),
|
||||
dtype=torch.bool,
|
||||
device=prefix_embs.device,
|
||||
)
|
||||
prefix_embs = torch.cat([prefix_embs, fast_emb], dim=1)
|
||||
prefix_pad = torch.cat([prefix_pad, action_mask.to(prefix_pad.dtype)], dim=1)
|
||||
prefix_att = torch.cat([prefix_att, ones_att], dim=1)
|
||||
|
||||
# ---- suffix: noisy actions ----------------------------------
|
||||
suffix_embs, suffix_pad, suffix_att, adarms_cond = self.model.embed_suffix(x_t, time)
|
||||
|
||||
# ---- bf16 alignment (mirrors PI05Pytorch.forward) -----------
|
||||
first_layer = (
|
||||
self.model.paligemma_with_expert.paligemma.model.language_model.layers[0]
|
||||
)
|
||||
if first_layer.self_attn.q_proj.weight.dtype == torch.bfloat16:
|
||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
|
||||
# ---- combined attention -------------------------------------
|
||||
pad_masks = torch.cat([prefix_pad, suffix_pad], dim=1)
|
||||
att_masks = torch.cat([prefix_att, suffix_att], dim=1)
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
|
||||
# Critical: zero out suffix → FAST attention. Without this the
|
||||
# action expert reads the FAST tokens and trivially decodes
|
||||
# them back to the same continuous actions it's supposed to
|
||||
# predict from noise. Cumulative-block attention from
|
||||
# ``make_att_2d_masks`` doesn't enforce this on its own
|
||||
# because suffix tokens have a strictly higher cumsum than
|
||||
# FAST tokens and therefore attend to them by default.
|
||||
if fast_len > 0:
|
||||
fast_start = non_fast_prefix_len
|
||||
fast_end = non_fast_prefix_len + fast_len # = prefix_pad.shape[1]
|
||||
att_2d_masks[:, fast_end:, fast_start:fast_end] = False
|
||||
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
att_2d_masks_4d = self.model._prepare_attention_masks_4d(att_2d_masks)
|
||||
|
||||
# ---- forward (capture BOTH expert outputs) ------------------
|
||||
(prefix_out, suffix_out), _ = self.model.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, suffix_embs],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
|
||||
# ---- flow loss (mirrors PI05Pytorch.forward) ----------------
|
||||
suffix_out_slice = suffix_out[:, -self.model.config.chunk_size :].to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
v_t = self.model.action_out_proj(suffix_out_slice)
|
||||
flow_per_dim = F.mse_loss(u_t, v_t, reduction="none")
|
||||
# Truncate to the actual action dimensionality (PI05 pads
|
||||
# internally to max_action_dim).
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
flow_per_dim = flow_per_dim[:, :, :original_action_dim]
|
||||
flow_loss = flow_per_dim.mean()
|
||||
|
||||
# ---- text + FAST CE from prefix_out ------------------------
|
||||
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
|
||||
|
||||
text_loss: Tensor | None = None
|
||||
if text_labels is not None and prefix_out is not None:
|
||||
lang_len = text_labels.shape[1]
|
||||
if fast_len > 0:
|
||||
text_hidden = prefix_out[:, -(fast_len + lang_len) : -fast_len, :]
|
||||
else:
|
||||
text_hidden = prefix_out[:, -lang_len:, :]
|
||||
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
|
||||
shift_logits = text_logits[:, :-1, :].contiguous()
|
||||
shift_labels = text_labels[:, 1:].contiguous().long()
|
||||
text_loss = F.cross_entropy(
|
||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||
shift_labels.reshape(-1),
|
||||
ignore_index=-100,
|
||||
)
|
||||
|
||||
fast_loss: Tensor | None = None
|
||||
if fast_len > 0 and prefix_out is not None:
|
||||
fast_hidden = prefix_out[:, -fast_len:, :]
|
||||
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
|
||||
shift_logits = fast_logits[:, :-1, :].contiguous()
|
||||
shift_targets = action_tokens[:, 1:].contiguous().long()
|
||||
shift_valid = action_mask[:, 1:].contiguous().bool()
|
||||
shift_targets = shift_targets.masked_fill(~shift_valid, -100)
|
||||
fast_loss = F.cross_entropy(
|
||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||
shift_targets.reshape(-1),
|
||||
ignore_index=-100,
|
||||
)
|
||||
|
||||
return flow_loss, text_loss, fast_loss
|
||||
|
||||
def _compute_text_and_fast_loss(
|
||||
self,
|
||||
batch: dict[str, Tensor],
|
||||
|
||||
Reference in New Issue
Block a user