From 02dbaf22eea97d0d86b9ef28eb0583c62367f399 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 26 Jan 2026 17:00:04 +0000 Subject: [PATCH] fix(videovla): use video_resampler dtype instead of embed_tokens for consistency The dtype inference in embed_video() was using embed_tokens.weight.dtype, but embed_tokens can be missing/tied in some checkpoints (e.g., when loading pi05-video models). This caused a RuntimeError: "expected scalar type BFloat16 but found Float" because: - embed_tokens was freshly initialized as float32 (missing from checkpoint) - video_resampler layers were loaded as bfloat16 (from checkpoint) - video_embeddings were cast to float32, then passed to bfloat16 layers Fix: Use video_resampler.ln_kv.weight.dtype as the target dtype source, since this is the exact layer that requires dtype consistency and is always present when use_video_encoder=True. --- src/lerobot/policies/videovla/modeling_pi05.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/lerobot/policies/videovla/modeling_pi05.py b/src/lerobot/policies/videovla/modeling_pi05.py index 493bd8e62..29ad1dc1a 100644 --- a/src/lerobot/policies/videovla/modeling_pi05.py +++ b/src/lerobot/policies/videovla/modeling_pi05.py @@ -826,15 +826,17 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` Returns: Video embeddings of shape [B, num_video_tokens, hidden_dim] projected to - PaliGemma's hidden dimension, in the same dtype as PaliGemma's language model. + PaliGemma's hidden dimension, in the same dtype as the video_resampler. """ if self.video_encoder is None: raise RuntimeError("Video encoder is not initialized. Set use_video_encoder=True in config.") device = video_frames.device - # Determine target dtype: match PaliGemma's language model dtype for consistency - paligemma_dtype = self.paligemma_with_expert.paligemma.language_model.embed_tokens.weight.dtype + # Determine target dtype: use the video_resampler's dtype for consistency + # Note: We use video_resampler (not embed_tokens) because embed_tokens may be + # tied/missing in some checkpoints, but video_resampler is always initialized + target_dtype = self.video_resampler.ln_kv.weight.dtype # Move video encoder to the same device if needed if next(self.video_encoder.parameters()).device != device: @@ -869,8 +871,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` # Shape: [B, num_patches * num_frames, hidden_size] (e.g., [B, 4096, 768]) video_embeddings = video_outputs.last_hidden_state - # Convert to PaliGemma's dtype for consistency with the rest of the model - video_embeddings = video_embeddings.to(dtype=paligemma_dtype) + # Convert to target dtype for consistency with the video_resampler + video_embeddings = video_embeddings.to(dtype=target_dtype) # Apply Perceiver Resampler to reduce tokens (e.g., 4096 -> 128) # This uses cross-attention from learnable queries to the video tokens