diff --git a/src/lerobot/policies/pi052/configuration_pi052.py b/src/lerobot/policies/pi052/configuration_pi052.py index 67dfd446a..df9325c01 100644 --- a/src/lerobot/policies/pi052/configuration_pi052.py +++ b/src/lerobot/policies/pi052/configuration_pi052.py @@ -107,12 +107,12 @@ class PI052Config(PI05Config): # ActionTokenizerProcessorStep is wired into the preprocessor # pipeline when this flag is set; the loss is computed in # PI052Policy.forward. - enable_fast_action_loss: bool = False + enable_fast_action_loss: bool = True """If True, tokenise actions with the FAST tokenizer and add a - cross-entropy loss on the LM head. Off by default because most - fine-tuning runs only need the flow head + text supervision; the - FAST CE term is most useful when training from a base PaliGemma - rather than an existing π0.5 checkpoint.""" + cross-entropy loss on the LM head. On by default to match the + π0.5 paper's three-loss objective (text CE + FAST CE + flow MSE, + §III.B-C Eq. 1). Set to False if you only want the + post-training-style flow + text recipe.""" action_tokenizer_name: str = "physical-intelligence/fast" """HF identifier for the FAST action tokenizer.""" @@ -127,15 +127,21 @@ class PI052Config(PI05Config): fast_action_loss_weight: float = 1.0 """Weight on the FAST-action-token CE loss. Paper §III.C uses 1.0.""" - auto_fit_fast_tokenizer: bool = True - """If True (default), the processor factory checks - ``fast_tokenizer_cache_dir`` for a previously-fitted tokenizer keyed - on ``(dataset_repo_id, base_tokenizer_name, fit_samples)``. On cache - miss, it loads ``action_tokenizer_name`` as a base, samples + auto_fit_fast_tokenizer: bool = False + """If True, the processor factory checks ``fast_tokenizer_cache_dir`` + for a previously-fitted tokenizer keyed on ``(dataset_repo_id, + base_tokenizer_name, fit_samples)``. On cache miss, it loads + ``action_tokenizer_name`` as a base, samples ``fast_tokenizer_fit_samples`` action chunks from the dataset, runs ``.fit()``, saves the result, and uses *that* fitted path as the actual tokenizer. Pertsch et al. 2025 (FAST paper [64], π0.5 §III.C) - explicitly recommend per-dataset fitting for best compression.""" + explicitly recommend per-dataset fitting for best compression. + + Off by default because the fit requires a separate pre-training + pass over the dataset (~1-2 min on a medium dataset) and depends + on the FAST tokenizer snapshot having a ``.fit()`` method. Opt in + when you want paper-faithful compression; leave off to fall back + on the universal ``physical-intelligence/fast`` codebook.""" fast_tokenizer_cache_dir: str = "~/.cache/lerobot/fast_tokenizers" """Where fitted FAST tokenizers are stored. ``~`` expands.""" diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index b5e4e6054..52349c614 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -76,7 +76,7 @@ def _compute_layer_ki( ): from transformers.models.gemma import modeling_gemma # noqa: PLC0415 - models = [paligemma.language_model, gemma_expert.model] + models = [paligemma.model.language_model, gemma_expert.model] query_states, key_states, value_states, gates = [], [], [], [] vlm_len = inputs_embeds[0].shape[1] @@ -111,7 +111,7 @@ def _compute_layer_ki( ) batch_size = query_states.shape[0] - scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling + scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling # Split queries / K / V at the VLM-vs-action boundary. Q_vlm = query_states[:, :, :vlm_len, :] @@ -133,16 +133,16 @@ def _compute_layer_ki( mask_for_action = attention_mask[:, :, vlm_len:, :] att_vlm, _ = modeling_gemma.eager_attention_forward( - paligemma.language_model.layers[layer_idx].self_attn, + paligemma.model.language_model.layers[layer_idx].self_attn, Q_vlm, K_for_vlm, V_for_vlm, mask_for_vlm, scaling, ) att_action, _ = modeling_gemma.eager_attention_forward( - paligemma.language_model.layers[layer_idx].self_attn, + paligemma.model.language_model.layers[layer_idx].self_attn, Q_action, K_for_action, V_for_action, mask_for_action, scaling, ) att = torch.cat([att_vlm, att_action], dim=1) - head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim + head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim att = att.reshape(batch_size, -1, 1 * 8 * head_dim) outputs_embeds = [] @@ -173,33 +173,24 @@ def _paligemma_forward_ki( inputs_embeds=None, use_cache=None, adarms_cond=None, - fill_kv_cache=None, ): """Replacement ``PaliGemmaWithExpertModel.forward`` that routes the dual-expert layer pass through :func:`_compute_layer_ki`. Bound onto the model instance when ``config.knowledge_insulation`` is True (see ``PI052Policy.__init__``). Single-expert branches - (VLM-only or action-only) reuse the parent's implementation - because there's no KI signal to add — KI only matters when - actions and VLM tokens are forwarded together. + (VLM-only or action-only) defer back to the original forward — + KI only matters when actions and VLM tokens are forwarded together. """ from ..pi05.modeling_pi05 import layernorm_forward # noqa: PLC0415 if adarms_cond is None: adarms_cond = [None, None] - # Single-expert paths: defer to the bound class method via super(). + # Single-expert paths: defer to the original forward saved in + # PI052Policy.__init__. if inputs_embeds[0] is None or inputs_embeds[1] is None: - return type(self).__bases__[0].forward( - self, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - adarms_cond=adarms_cond, - ) if hasattr(self, "_pi052_orig_forward") else self._pi052_orig_forward( + return self._pi052_orig_forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -429,7 +420,6 @@ class PI052Policy(PI05Policy): past_key_values=None, inputs_embeds=[full_embs, None], use_cache=False, - fill_kv_cache=True, ) if vlm_out is None: raise RuntimeError("PI052 FAST loss: VLM forward returned no hidden states.") @@ -456,12 +446,15 @@ class PI052Policy(PI05Policy): def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor: """Cross-entropy on PaliGemma's LM head over the supervised span. - Re-uses the same prefix-embedding path the flow head does: - embed images + state + language tokens, run a forward pass, - slice out the per-token logits at the supervised positions, - compute CE. + Embeds images + language, runs the VLM-only forward (the + action expert is skipped via ``inputs_embeds=[..., None]``), + slices the hidden states to the *language* portion so they + align with ``text_labels`` (which covers only the language + tokens, not the image patch tokens), then computes shifted + next-token CE with ``-100`` ignoring padding/non-target + positions. """ - from torch.nn import functional as F # noqa: PLC0415 + from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415 images, img_masks = self.model._preprocess_images(batch) tokens = batch[OBS_LANGUAGE_TOKENS] @@ -470,12 +463,6 @@ class PI052Policy(PI05Policy): prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix( images, img_masks, tokens, masks ) - # PaliGemma's text path: forward the prefix through the - # backbone *without* the action expert. We piggy-back on the - # existing PaliGemmaWithExpertModel.forward — it accepts a - # list of expert inputs and returns parallel outputs. - from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415 - att_2d = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 @@ -485,18 +472,24 @@ class PI052Policy(PI05Policy): past_key_values=None, inputs_embeds=[prefix_embs, None], use_cache=False, - fill_kv_cache=True, ) if vlm_out is None: raise RuntimeError("PI052 text loss: VLM forward returned no hidden states.") - # Logits over the vocab via the PaliGemma lm_head. + # Slice the hidden states to the language portion. embed_prefix + # concatenates [images, language] in that order, so the trailing + # ``text_labels.shape[1]`` positions are the language tokens. + # Without this slice, applying lm_head to the full vlm_out and + # shifting against text_labels[..., 1:] produces a shape + # mismatch in cross_entropy. + lang_len = text_labels.shape[1] + text_hidden = vlm_out[:, -lang_len:, :] + lm_head = self.model.paligemma_with_expert.paligemma.lm_head - logits = lm_head(vlm_out.to(lm_head.weight.dtype)) + logits = lm_head(text_hidden.to(lm_head.weight.dtype)) # Shift for next-token prediction: predict token[i+1] from - # hidden[i]. Both ``logits`` and ``text_labels`` are over the - # same sequence length, so shift logits[:-1] vs labels[1:]. + # hidden[i] within the language span. shift_logits = logits[..., :-1, :].contiguous() shift_labels = text_labels[..., 1:].contiguous() loss = F.cross_entropy( @@ -579,7 +572,6 @@ class PI052Policy(PI05Policy): past_key_values=None, inputs_embeds=[current_embs, None], use_cache=False, - fill_kv_cache=True, ) if vlm_out is None: break