fix(pi05): don't scale image features by sqrt(hidden_size)

lerobot/pi05_base was trained in the OpenPI/big_vision regime where image
(soft) tokens are NOT multiplied by the Gemma embedder normalizer
(sqrt(hidden_size)) — only text tokens are. Scaling image features here
over-scaled them ~45x, breaking the pretrained vision-language alignment
and yielding ~0% closed-loop success on RoboCasa across all pi05 runs.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn223
2026-06-04 17:20:34 +02:00
parent 9596e3d53f
commit a48d4e32a1

View File

@@ -477,7 +477,11 @@ class PaliGemmaWithExpertModel(
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
# OpenPI / big_vision convention: image (soft) tokens are NOT scaled by the
# Gemma embedder normalizer (sqrt(hidden_size)) — only text tokens are. lerobot/pi05_base
# was trained in this regime, so scaling image features here over-scales them ~45x and
# breaks the pretrained vision-language alignment. Keep image features un-normalized.
features = image_outputs.pooler_output
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features