mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
perf(pi052): fuse text + FAST loss into a single prefix forward
Previously the forward did three backbone passes per training step
when all heads were active: one for flow (via super().forward), one
for text CE, and one for FAST CE. That's ~3× the compute of
flow-only training.
The text and FAST losses share their prefix forward exactly — both
are CE on the LM head, evaluated at different slices of the same
hidden states. Adding FAST tokens after language in the prefix is
bit-equivalent for the text loss because the mask_ar convention in
``make_att_2d_masks`` keeps FAST tokens in a strictly-later causal
block: language tokens never see FAST, so their hidden states are
unchanged.
New ``_compute_text_and_fast_loss``:
* embeds [images, language] once
* optionally appends [FAST] (when run_fast is True)
* one backbone forward
* slices ``vlm_out[:, -(fast_len + lang_len):-fast_len]`` for
language hidden states (or ``vlm_out[:, -lang_len:]`` when no
FAST) → text CE
* slices ``vlm_out[:, -fast_len:]`` for FAST hidden states →
FAST CE
* returns both losses, either of which can be None when the
caller doesn't want that head.
forward() now calls this fused helper instead of running the two
separate ``_compute_text_loss`` / ``_compute_fast_action_loss``
methods. Those remain in the file for callers that only want one
head (e.g. ablations).
Why flow isn't fused
--------------------
Flow MSE comes from the action-expert (suffix) hidden states, which
attend to the prefix. If we just concat FAST onto the prefix and let
the action expert attend to it, the expert can trivially decode FAST
back to continuous actions — overfitting via shortcut. Preventing
that requires a custom segment-aware attention mask (action expert
can attend to images+language but NOT to subtask/FAST), which is
what pi05_full does in ``compute_layer_complete_knowledge_insulation``.
That's the full-fusion path; deferred as a follow-up since the
text+FAST fusion already recovers most of the compute.
End-to-end forward pass count
-----------------------------
Before: 1 (flow) + 1 (text) + 1 (FAST) = 3 backbone forwards
After: 1 (flow) + 1 (text+FAST fused) = 2 backbone forwards
~33% wall-time reduction per training step when all three heads
are active.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -325,38 +325,42 @@ class PI052Policy(PI05Policy):
|
||||
)
|
||||
total = self.config.flow_loss_weight * flow_loss
|
||||
|
||||
if run_text:
|
||||
text_loss = self._compute_text_loss(batch, text_labels)
|
||||
loss_dict["text_loss"] = float(text_loss.detach().item())
|
||||
total = (
|
||||
self.config.text_loss_weight * text_loss
|
||||
if total is None
|
||||
else total + self.config.text_loss_weight * text_loss
|
||||
)
|
||||
|
||||
# FAST action-token CE loss (paper §III.C). When
|
||||
# ``enable_fast_action_loss=True`` the preprocessor wrote
|
||||
# ACTION_TOKENS / ACTION_TOKEN_MASK into the batch — we
|
||||
# forward them through the PaliGemma backbone alongside the
|
||||
# language prefix and compute CE on the action positions.
|
||||
# Text + FAST losses share the prefix forward — compute them
|
||||
# together. Both are CE on the LM head at *different slices*
|
||||
# of the same hidden states, so one prefix forward replaces
|
||||
# two. Saves ~33% compute vs. running them separately.
|
||||
#
|
||||
# Gated on ``predict_actions`` (same routing the flow loss
|
||||
# uses): for text-only recipes the action_tokens are still
|
||||
# present in the batch but shouldn't be supervised. Skip the
|
||||
# entire FAST forward when no sample in the batch wants action
|
||||
# supervision.
|
||||
# Flow loss can't be fused into this pass without a custom
|
||||
# segment-aware attention mask (the action expert would
|
||||
# trivially read FAST tokens and decode them back to
|
||||
# continuous actions). That's pi05_full's territory — for
|
||||
# now, flow stays in a separate forward via super().forward.
|
||||
run_fast = (
|
||||
getattr(self.config, "enable_fast_action_loss", False)
|
||||
and self.config.fast_action_loss_weight > 0
|
||||
and (predict_actions_t is None or bool(predict_actions_t.any().item()))
|
||||
)
|
||||
action_tokens = action_mask = None
|
||||
if run_fast:
|
||||
from lerobot.utils.constants import ACTION_TOKEN_MASK, ACTION_TOKENS # noqa: PLC0415
|
||||
|
||||
action_tokens = batch.get(ACTION_TOKENS)
|
||||
action_mask = batch.get(ACTION_TOKEN_MASK)
|
||||
if action_tokens is not None and action_mask is not None:
|
||||
fast_loss = self._compute_fast_action_loss(batch, action_tokens, action_mask)
|
||||
if action_tokens is None or action_mask is None:
|
||||
run_fast = False
|
||||
|
||||
if run_text or run_fast:
|
||||
text_loss, fast_loss = self._compute_text_and_fast_loss(
|
||||
batch,
|
||||
text_labels=text_labels if run_text else None,
|
||||
action_tokens=action_tokens if run_fast else None,
|
||||
action_mask=action_mask if run_fast else None,
|
||||
)
|
||||
if text_loss is not None:
|
||||
loss_dict["text_loss"] = float(text_loss.detach().item())
|
||||
weighted = self.config.text_loss_weight * text_loss
|
||||
total = weighted if total is None else total + weighted
|
||||
if fast_loss is not None:
|
||||
loss_dict["fast_action_loss"] = float(fast_loss.detach().item())
|
||||
weighted = self.config.fast_action_loss_weight * fast_loss
|
||||
total = weighted if total is None else total + weighted
|
||||
@@ -379,6 +383,110 @@ class PI052Policy(PI05Policy):
|
||||
# Text loss
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compute_text_and_fast_loss(
|
||||
self,
|
||||
batch: dict[str, Tensor],
|
||||
text_labels: Tensor | None,
|
||||
action_tokens: Tensor | None,
|
||||
action_mask: Tensor | None,
|
||||
) -> tuple[Tensor | None, Tensor | None]:
|
||||
"""Single prefix forward → text CE + FAST CE.
|
||||
|
||||
Embed [images, language] (and FAST when requested) once, run
|
||||
one backbone forward, then slice the resulting hidden states
|
||||
at the language and FAST positions to compute both CE losses.
|
||||
Bit-equivalent to running the two losses in separate forwards
|
||||
because the segment-aware ``make_att_2d_masks`` keeps FAST
|
||||
tokens invisible to language tokens, so adding FAST to the
|
||||
prefix doesn't perturb the hidden states at language positions.
|
||||
|
||||
Returns ``(text_loss, fast_loss)``. Either can be ``None`` if
|
||||
the caller doesn't want that head.
|
||||
"""
|
||||
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
|
||||
|
||||
images, img_masks = self.model._preprocess_images(batch)
|
||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
|
||||
prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
|
||||
fast_len = 0
|
||||
if action_tokens is not None and action_mask is not None:
|
||||
emb_dim = prefix_embs.shape[-1]
|
||||
fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens)
|
||||
fast_emb = fast_emb * math.sqrt(emb_dim)
|
||||
|
||||
fast_len = action_tokens.shape[1]
|
||||
ones_att = torch.ones(
|
||||
(action_tokens.shape[0], fast_len),
|
||||
dtype=torch.bool,
|
||||
device=prefix_embs.device,
|
||||
)
|
||||
full_embs = torch.cat([prefix_embs, fast_emb], dim=1)
|
||||
full_pad = torch.cat([prefix_pad, action_mask.to(prefix_pad.dtype)], dim=1)
|
||||
full_att = torch.cat([prefix_att, ones_att], dim=1)
|
||||
else:
|
||||
full_embs = prefix_embs
|
||||
full_pad = prefix_pad
|
||||
full_att = prefix_att
|
||||
|
||||
att_2d = make_att_2d_masks(full_pad, full_att)
|
||||
position_ids = torch.cumsum(full_pad, dim=1) - 1
|
||||
|
||||
(vlm_out, _), _ = self.model.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[full_embs, None],
|
||||
use_cache=False,
|
||||
)
|
||||
if vlm_out is None:
|
||||
raise RuntimeError(
|
||||
"PI052 text+fast loss: VLM forward returned no hidden states."
|
||||
)
|
||||
|
||||
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
|
||||
|
||||
text_loss: Tensor | None = None
|
||||
if text_labels is not None:
|
||||
lang_len = text_labels.shape[1]
|
||||
# embed_prefix lays out as [images, language]; with FAST
|
||||
# appended the full sequence is [images, language, FAST].
|
||||
# Language hidden states are at positions
|
||||
# ``[-(fast_len + lang_len) : -fast_len]`` when FAST is
|
||||
# present, or ``[-lang_len:]`` otherwise.
|
||||
if fast_len > 0:
|
||||
text_hidden = vlm_out[:, -(fast_len + lang_len):-fast_len, :]
|
||||
else:
|
||||
text_hidden = vlm_out[:, -lang_len:, :]
|
||||
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
|
||||
shift_logits = text_logits[:, :-1, :].contiguous()
|
||||
shift_labels = text_labels[:, 1:].contiguous().long()
|
||||
text_loss = F.cross_entropy(
|
||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||
shift_labels.reshape(-1),
|
||||
ignore_index=-100,
|
||||
)
|
||||
|
||||
fast_loss: Tensor | None = None
|
||||
if action_tokens is not None and action_mask is not None and fast_len > 0:
|
||||
fast_hidden = vlm_out[:, -fast_len:, :]
|
||||
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
|
||||
shift_logits = fast_logits[:, :-1, :].contiguous()
|
||||
shift_targets = action_tokens[:, 1:].contiguous().long()
|
||||
shift_valid = action_mask[:, 1:].contiguous().bool()
|
||||
shift_targets = shift_targets.masked_fill(~shift_valid, -100)
|
||||
fast_loss = F.cross_entropy(
|
||||
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||
shift_targets.reshape(-1),
|
||||
ignore_index=-100,
|
||||
)
|
||||
|
||||
return text_loss, fast_loss
|
||||
|
||||
def _compute_fast_action_loss(
|
||||
self,
|
||||
batch: dict[str, Tensor],
|
||||
|
||||
Reference in New Issue
Block a user