diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index ce8c3abc6..73799cbc9 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -55,6 +55,17 @@ from .configuration_pi052 import PI052Config logger = logging.getLogger(__name__) +# FAST action-token vocab size (``lerobot/fast-action-tokenizer``). The +# tokenizer maps a FAST BPE id ``t`` to the PaliGemma vocab id +# ``vocab_size - 1 - fast_skip_tokens - t`` (see ``TokenizerProcessorStep``), +# so action tokens occupy the top ``_FAST_ACTION_VOCAB_SIZE`` ids below the +# ``fast_skip_tokens`` margin. The upper part collides with the reserved +# ```` block; the lower part sits just under it and otherwise leaks into +# generated text as high-codepoint gibberish (the action-trained LM head puts +# heavy mass on these ids), so ``select_message`` masks it. +_FAST_ACTION_VOCAB_SIZE = 2048 + + _HF_KERNELS_ENABLED = False @@ -1166,6 +1177,15 @@ class PI052Policy(PI05Policy): if special_ids and len(generated) < min_new_tokens: for sid in special_ids: logits_step[..., sid] = float("-inf") + # Mask FAST action tokens that fall *below* the ```` block. + # They are never valid text, but the action-trained head leaks + # them as gibberish; unlike the loc/seg block this region is never + # legitimately emitted (even by VQA), so suppress it on every call. + vocab_size = logits_step.shape[-1] + fast_skip = int(getattr(self.config, "fast_skip_tokens", 128)) + fast_lo = vocab_size - 1 - fast_skip - (_FAST_ACTION_VOCAB_SIZE - 1) + if 0 < fast_lo < 256000: + logits_step[..., fast_lo:256000] = float("-inf") if suppress_loc_tokens: logits_step[..., 256000:257024] = float("-inf") next_ids = self._sample_next_token(logits_step, temperature, top_p)