From ff1d58a46f1df7da754a16dcd45cd6b6bcbbf2c0 Mon Sep 17 00:00:00 2001 From: pepijn223 Date: Tue, 2 Jun 2026 13:06:51 +0200 Subject: [PATCH] pi052: suppress FAST action tokens in select_message text generation The FAST action tokenizer maps action codes to the top of the PaliGemma vocab (id = vocab_size-1-fast_skip_tokens-t). The lower part of that band sits just below the reserved block, so it escaped the existing suppress_loc_tokens mask and leaked into generated subtask/VQA/memory text as high-codepoint gibberish. Mask the FAST band on every select_message call so the high-level head emits clean language. Co-authored-by: Cursor --- src/lerobot/policies/pi052/modeling_pi052.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index ce8c3abc6..73799cbc9 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -55,6 +55,17 @@ from .configuration_pi052 import PI052Config logger = logging.getLogger(__name__) +# FAST action-token vocab size (``lerobot/fast-action-tokenizer``). The +# tokenizer maps a FAST BPE id ``t`` to the PaliGemma vocab id +# ``vocab_size - 1 - fast_skip_tokens - t`` (see ``TokenizerProcessorStep``), +# so action tokens occupy the top ``_FAST_ACTION_VOCAB_SIZE`` ids below the +# ``fast_skip_tokens`` margin. The upper part collides with the reserved +# ```` block; the lower part sits just under it and otherwise leaks into +# generated text as high-codepoint gibberish (the action-trained LM head puts +# heavy mass on these ids), so ``select_message`` masks it. +_FAST_ACTION_VOCAB_SIZE = 2048 + + _HF_KERNELS_ENABLED = False @@ -1166,6 +1177,15 @@ class PI052Policy(PI05Policy): if special_ids and len(generated) < min_new_tokens: for sid in special_ids: logits_step[..., sid] = float("-inf") + # Mask FAST action tokens that fall *below* the ```` block. + # They are never valid text, but the action-trained head leaks + # them as gibberish; unlike the loc/seg block this region is never + # legitimately emitted (even by VQA), so suppress it on every call. + vocab_size = logits_step.shape[-1] + fast_skip = int(getattr(self.config, "fast_skip_tokens", 128)) + fast_lo = vocab_size - 1 - fast_skip - (_FAST_ACTION_VOCAB_SIZE - 1) + if 0 < fast_lo < 256000: + logits_step[..., fast_lo:256000] = float("-inf") if suppress_loc_tokens: logits_step[..., 256000:257024] = float("-inf") next_ids = self._sample_next_token(logits_step, temperature, top_p)