mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 11:51:25 +00:00
pi052: suppress FAST action tokens in select_message text generation
The FAST action tokenizer maps action codes to the top of the PaliGemma vocab (id = vocab_size-1-fast_skip_tokens-t). The lower part of that band sits just below the reserved <loc> block, so it escaped the existing suppress_loc_tokens mask and leaked into generated subtask/VQA/memory text as high-codepoint gibberish. Mask the FAST band on every select_message call so the high-level head emits clean language. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -55,6 +55,17 @@ from .configuration_pi052 import PI052Config
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# FAST action-token vocab size (``lerobot/fast-action-tokenizer``). The
|
||||
# tokenizer maps a FAST BPE id ``t`` to the PaliGemma vocab id
|
||||
# ``vocab_size - 1 - fast_skip_tokens - t`` (see ``TokenizerProcessorStep``),
|
||||
# so action tokens occupy the top ``_FAST_ACTION_VOCAB_SIZE`` ids below the
|
||||
# ``fast_skip_tokens`` margin. The upper part collides with the reserved
|
||||
# ``<loc>`` block; the lower part sits just under it and otherwise leaks into
|
||||
# generated text as high-codepoint gibberish (the action-trained LM head puts
|
||||
# heavy mass on these ids), so ``select_message`` masks it.
|
||||
_FAST_ACTION_VOCAB_SIZE = 2048
|
||||
|
||||
|
||||
_HF_KERNELS_ENABLED = False
|
||||
|
||||
|
||||
@@ -1166,6 +1177,15 @@ class PI052Policy(PI05Policy):
|
||||
if special_ids and len(generated) < min_new_tokens:
|
||||
for sid in special_ids:
|
||||
logits_step[..., sid] = float("-inf")
|
||||
# Mask FAST action tokens that fall *below* the ``<loc>`` block.
|
||||
# They are never valid text, but the action-trained head leaks
|
||||
# them as gibberish; unlike the loc/seg block this region is never
|
||||
# legitimately emitted (even by VQA), so suppress it on every call.
|
||||
vocab_size = logits_step.shape[-1]
|
||||
fast_skip = int(getattr(self.config, "fast_skip_tokens", 128))
|
||||
fast_lo = vocab_size - 1 - fast_skip - (_FAST_ACTION_VOCAB_SIZE - 1)
|
||||
if 0 < fast_lo < 256000:
|
||||
logits_step[..., fast_lo:256000] = float("-inf")
|
||||
if suppress_loc_tokens:
|
||||
logits_step[..., 256000:257024] = float("-inf")
|
||||
next_ids = self._sample_next_token(logits_step, temperature, top_p)
|
||||
|
||||
Reference in New Issue
Block a user