mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
fix(pi052): four real bugs in the modeling code + flip defaults
Defaults -------- * enable_fast_action_loss: False -> True (match paper §III.B-C Eq.1) * auto_fit_fast_tokenizer: True -> False (opt-in; needs base.fit()) Bug fixes --------- 1. Wrong attribute path on PaliGemma. The KI port copied pi05_full's ``paligemma.language_model.layers[...]`` literally, but the production pi05 wrapper exposes the text model at ``paligemma.model.language_model``. With KI enabled, every layer would have raised AttributeError on first forward. Fixed all references in _compute_layer_ki + _paligemma_forward_ki. 2. ``fill_kv_cache=True`` passed to PaliGemmaWithExpertModel.forward. That kwarg is a SmolVLA-only concept; pi05's signature has no such argument, so every forward call from pi052 (text loss, FAST loss, select_message) would have crashed with TypeError. Dropped from all four call sites — pi05's forward already handles the cache via past_key_values, and re-forwarding the cumulative sequence each step in select_message is fine for our short subtask completions. 3. Text-loss shape mismatch. _compute_text_loss applied lm_head to the *full* vlm_out (image tokens + language tokens), then tried to cross-entropy that against text_labels which only covers the language portion — the .view(-1) calls would produce two tensors of different lengths and CE would fail. Now slices vlm_out to the last text_labels.shape[1] positions before running lm_head, matching the [images, language] order embed_prefix produces. 4. Dead-code conditional in _paligemma_forward_ki's single-expert fallback. The ``if hasattr(...) else self._pi052_orig_forward`` ternary always took the wrong branch because the attribute is always set (we save it in PI052Policy.__init__). Simplified to just call self._pi052_orig_forward directly. After this commit, pi052 should be runnable end-to-end for the first time with all three loss heads + KI active. Still worth a 100-step smoke test before kicking off a long run. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -107,12 +107,12 @@ class PI052Config(PI05Config):
|
||||
# ActionTokenizerProcessorStep is wired into the preprocessor
|
||||
# pipeline when this flag is set; the loss is computed in
|
||||
# PI052Policy.forward.
|
||||
enable_fast_action_loss: bool = False
|
||||
enable_fast_action_loss: bool = True
|
||||
"""If True, tokenise actions with the FAST tokenizer and add a
|
||||
cross-entropy loss on the LM head. Off by default because most
|
||||
fine-tuning runs only need the flow head + text supervision; the
|
||||
FAST CE term is most useful when training from a base PaliGemma
|
||||
rather than an existing π0.5 checkpoint."""
|
||||
cross-entropy loss on the LM head. On by default to match the
|
||||
π0.5 paper's three-loss objective (text CE + FAST CE + flow MSE,
|
||||
§III.B-C Eq. 1). Set to False if you only want the
|
||||
post-training-style flow + text recipe."""
|
||||
|
||||
action_tokenizer_name: str = "physical-intelligence/fast"
|
||||
"""HF identifier for the FAST action tokenizer."""
|
||||
@@ -127,15 +127,21 @@ class PI052Config(PI05Config):
|
||||
fast_action_loss_weight: float = 1.0
|
||||
"""Weight on the FAST-action-token CE loss. Paper §III.C uses 1.0."""
|
||||
|
||||
auto_fit_fast_tokenizer: bool = True
|
||||
"""If True (default), the processor factory checks
|
||||
``fast_tokenizer_cache_dir`` for a previously-fitted tokenizer keyed
|
||||
on ``(dataset_repo_id, base_tokenizer_name, fit_samples)``. On cache
|
||||
miss, it loads ``action_tokenizer_name`` as a base, samples
|
||||
auto_fit_fast_tokenizer: bool = False
|
||||
"""If True, the processor factory checks ``fast_tokenizer_cache_dir``
|
||||
for a previously-fitted tokenizer keyed on ``(dataset_repo_id,
|
||||
base_tokenizer_name, fit_samples)``. On cache miss, it loads
|
||||
``action_tokenizer_name`` as a base, samples
|
||||
``fast_tokenizer_fit_samples`` action chunks from the dataset, runs
|
||||
``.fit()``, saves the result, and uses *that* fitted path as the
|
||||
actual tokenizer. Pertsch et al. 2025 (FAST paper [64], π0.5 §III.C)
|
||||
explicitly recommend per-dataset fitting for best compression."""
|
||||
explicitly recommend per-dataset fitting for best compression.
|
||||
|
||||
Off by default because the fit requires a separate pre-training
|
||||
pass over the dataset (~1-2 min on a medium dataset) and depends
|
||||
on the FAST tokenizer snapshot having a ``.fit()`` method. Opt in
|
||||
when you want paper-faithful compression; leave off to fall back
|
||||
on the universal ``physical-intelligence/fast`` codebook."""
|
||||
|
||||
fast_tokenizer_cache_dir: str = "~/.cache/lerobot/fast_tokenizers"
|
||||
"""Where fitted FAST tokenizers are stored. ``~`` expands."""
|
||||
|
||||
@@ -76,7 +76,7 @@ def _compute_layer_ki(
|
||||
):
|
||||
from transformers.models.gemma import modeling_gemma # noqa: PLC0415
|
||||
|
||||
models = [paligemma.language_model, gemma_expert.model]
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
query_states, key_states, value_states, gates = [], [], [], []
|
||||
|
||||
vlm_len = inputs_embeds[0].shape[1]
|
||||
@@ -111,7 +111,7 @@ def _compute_layer_ki(
|
||||
)
|
||||
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
|
||||
# Split queries / K / V at the VLM-vs-action boundary.
|
||||
Q_vlm = query_states[:, :, :vlm_len, :]
|
||||
@@ -133,16 +133,16 @@ def _compute_layer_ki(
|
||||
mask_for_action = attention_mask[:, :, vlm_len:, :]
|
||||
|
||||
att_vlm, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.language_model.layers[layer_idx].self_attn,
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
Q_vlm, K_for_vlm, V_for_vlm, mask_for_vlm, scaling,
|
||||
)
|
||||
att_action, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.language_model.layers[layer_idx].self_attn,
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
Q_action, K_for_action, V_for_action, mask_for_action, scaling,
|
||||
)
|
||||
att = torch.cat([att_vlm, att_action], dim=1)
|
||||
|
||||
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att = att.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
|
||||
outputs_embeds = []
|
||||
@@ -173,33 +173,24 @@ def _paligemma_forward_ki(
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
adarms_cond=None,
|
||||
fill_kv_cache=None,
|
||||
):
|
||||
"""Replacement ``PaliGemmaWithExpertModel.forward`` that routes the
|
||||
dual-expert layer pass through :func:`_compute_layer_ki`.
|
||||
|
||||
Bound onto the model instance when ``config.knowledge_insulation``
|
||||
is True (see ``PI052Policy.__init__``). Single-expert branches
|
||||
(VLM-only or action-only) reuse the parent's implementation
|
||||
because there's no KI signal to add — KI only matters when
|
||||
actions and VLM tokens are forwarded together.
|
||||
(VLM-only or action-only) defer back to the original forward —
|
||||
KI only matters when actions and VLM tokens are forwarded together.
|
||||
"""
|
||||
from ..pi05.modeling_pi05 import layernorm_forward # noqa: PLC0415
|
||||
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
|
||||
# Single-expert paths: defer to the bound class method via super().
|
||||
# Single-expert paths: defer to the original forward saved in
|
||||
# PI052Policy.__init__.
|
||||
if inputs_embeds[0] is None or inputs_embeds[1] is None:
|
||||
return type(self).__bases__[0].forward(
|
||||
self,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
adarms_cond=adarms_cond,
|
||||
) if hasattr(self, "_pi052_orig_forward") else self._pi052_orig_forward(
|
||||
return self._pi052_orig_forward(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
@@ -429,7 +420,6 @@ class PI052Policy(PI05Policy):
|
||||
past_key_values=None,
|
||||
inputs_embeds=[full_embs, None],
|
||||
use_cache=False,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
if vlm_out is None:
|
||||
raise RuntimeError("PI052 FAST loss: VLM forward returned no hidden states.")
|
||||
@@ -456,12 +446,15 @@ class PI052Policy(PI05Policy):
|
||||
def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor:
|
||||
"""Cross-entropy on PaliGemma's LM head over the supervised span.
|
||||
|
||||
Re-uses the same prefix-embedding path the flow head does:
|
||||
embed images + state + language tokens, run a forward pass,
|
||||
slice out the per-token logits at the supervised positions,
|
||||
compute CE.
|
||||
Embeds images + language, runs the VLM-only forward (the
|
||||
action expert is skipped via ``inputs_embeds=[..., None]``),
|
||||
slices the hidden states to the *language* portion so they
|
||||
align with ``text_labels`` (which covers only the language
|
||||
tokens, not the image patch tokens), then computes shifted
|
||||
next-token CE with ``-100`` ignoring padding/non-target
|
||||
positions.
|
||||
"""
|
||||
from torch.nn import functional as F # noqa: PLC0415
|
||||
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
|
||||
|
||||
images, img_masks = self.model._preprocess_images(batch)
|
||||
tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
@@ -470,12 +463,6 @@ class PI052Policy(PI05Policy):
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix(
|
||||
images, img_masks, tokens, masks
|
||||
)
|
||||
# PaliGemma's text path: forward the prefix through the
|
||||
# backbone *without* the action expert. We piggy-back on the
|
||||
# existing PaliGemmaWithExpertModel.forward — it accepts a
|
||||
# list of expert inputs and returns parallel outputs.
|
||||
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
|
||||
|
||||
att_2d = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
@@ -485,18 +472,24 @@ class PI052Policy(PI05Policy):
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=False,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
if vlm_out is None:
|
||||
raise RuntimeError("PI052 text loss: VLM forward returned no hidden states.")
|
||||
|
||||
# Logits over the vocab via the PaliGemma lm_head.
|
||||
# Slice the hidden states to the language portion. embed_prefix
|
||||
# concatenates [images, language] in that order, so the trailing
|
||||
# ``text_labels.shape[1]`` positions are the language tokens.
|
||||
# Without this slice, applying lm_head to the full vlm_out and
|
||||
# shifting against text_labels[..., 1:] produces a shape
|
||||
# mismatch in cross_entropy.
|
||||
lang_len = text_labels.shape[1]
|
||||
text_hidden = vlm_out[:, -lang_len:, :]
|
||||
|
||||
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
|
||||
logits = lm_head(vlm_out.to(lm_head.weight.dtype))
|
||||
logits = lm_head(text_hidden.to(lm_head.weight.dtype))
|
||||
|
||||
# Shift for next-token prediction: predict token[i+1] from
|
||||
# hidden[i]. Both ``logits`` and ``text_labels`` are over the
|
||||
# same sequence length, so shift logits[:-1] vs labels[1:].
|
||||
# hidden[i] within the language span.
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = text_labels[..., 1:].contiguous()
|
||||
loss = F.cross_entropy(
|
||||
@@ -579,7 +572,6 @@ class PI052Policy(PI05Policy):
|
||||
past_key_values=None,
|
||||
inputs_embeds=[current_embs, None],
|
||||
use_cache=False,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
if vlm_out is None:
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user