pi052: suppress FAST action tokens in select_message text generation

The FAST action tokenizer maps action codes to the top of the PaliGemma
vocab (id = vocab_size-1-fast_skip_tokens-t). The lower part of that band
sits just below the reserved <loc> block, so it escaped the existing
suppress_loc_tokens mask and leaked into generated subtask/VQA/memory text
as high-codepoint gibberish. Mask the FAST band on every select_message
call so the high-level head emits clean language.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn223
2026-06-02 13:06:51 +02:00
parent d04ea0ea8a
commit ff1d58a46f

View File

@@ -55,6 +55,17 @@ from .configuration_pi052 import PI052Config
logger = logging.getLogger(__name__)
# FAST action-token vocab size (``lerobot/fast-action-tokenizer``). The
# tokenizer maps a FAST BPE id ``t`` to the PaliGemma vocab id
# ``vocab_size - 1 - fast_skip_tokens - t`` (see ``TokenizerProcessorStep``),
# so action tokens occupy the top ``_FAST_ACTION_VOCAB_SIZE`` ids below the
# ``fast_skip_tokens`` margin. The upper part collides with the reserved
# ``<loc>`` block; the lower part sits just under it and otherwise leaks into
# generated text as high-codepoint gibberish (the action-trained LM head puts
# heavy mass on these ids), so ``select_message`` masks it.
_FAST_ACTION_VOCAB_SIZE = 2048
_HF_KERNELS_ENABLED = False
@@ -1166,6 +1177,15 @@ class PI052Policy(PI05Policy):
if special_ids and len(generated) < min_new_tokens:
for sid in special_ids:
logits_step[..., sid] = float("-inf")
# Mask FAST action tokens that fall *below* the ``<loc>`` block.
# They are never valid text, but the action-trained head leaks
# them as gibberish; unlike the loc/seg block this region is never
# legitimately emitted (even by VQA), so suppress it on every call.
vocab_size = logits_step.shape[-1]
fast_skip = int(getattr(self.config, "fast_skip_tokens", 128))
fast_lo = vocab_size - 1 - fast_skip - (_FAST_ACTION_VOCAB_SIZE - 1)
if 0 < fast_lo < 256000:
logits_step[..., fast_lo:256000] = float("-inf")
if suppress_loc_tokens:
logits_step[..., 256000:257024] = float("-inf")
next_ids = self._sample_next_token(logits_step, temperature, top_p)