From be2267974a3846e7f9904ea2e5212848bb73def4 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 26 Jan 2026 11:16:43 +0100 Subject: [PATCH] fix(videovla): improve PerceiverResampler and address code review issues Key fixes for the PI05Video policy implementation: PerceiverResampler improvements: - Add residual connection (latents + attn_out) for better gradient flow - Add output LayerNorm after residual connection - Initialize latents with smaller variance (*0.02) for stability Bug fixes: - Replace expand() with repeat() in _preprocess_video to create copies instead of memory views, preventing potential in-place modification bugs - Fix dtype consistency in embed_video: use PaliGemma's dtype instead of input dtype for consistent processing throughout the pipeline - Add bfloat16/float16 support to resize_with_pad_torch PEFT improvements: - Remove state_proj from target modules (PI0-only, not in PI05) - Add video_proj and video_resampler to PEFT targets for fine-tuning Other improvements: - Add warning when use_video_encoder=True but no image features found - Add gradient checkpointing support for video encoder - Remove duplicate tokenizer_max_length definition in config - Add validation for video_num_latents and video_resampler_num_heads --- .../policies/videovla/configuration_pi05.py | 8 ++- .../policies/videovla/modeling_pi05.py | 70 +++++++++++++++---- 2 files changed, 61 insertions(+), 17 deletions(-) diff --git a/src/lerobot/policies/videovla/configuration_pi05.py b/src/lerobot/policies/videovla/configuration_pi05.py index fa969e08c..9f25d809a 100644 --- a/src/lerobot/policies/videovla/configuration_pi05.py +++ b/src/lerobot/policies/videovla/configuration_pi05.py @@ -108,8 +108,6 @@ class PI05VideoConfig(PreTrainedConfig): scheduler_decay_steps: int = 30_000 scheduler_decay_lr: float = 2.5e-6 - tokenizer_max_length: int = 200 # see openpi `__post_init__` - def __post_init__(self): super().__post_init__() @@ -138,6 +136,12 @@ class PI05VideoConfig(PreTrainedConfig): raise ValueError( f"video_padding_mode must be 'repeat' or 'zero', got {self.video_padding_mode}" ) + if self.video_num_latents < 1: + raise ValueError(f"video_num_latents must be >= 1, got {self.video_num_latents}") + if self.video_resampler_num_heads < 1: + raise ValueError( + f"video_resampler_num_heads must be >= 1, got {self.video_resampler_num_heads}" + ) def validate_features(self) -> None: """Validate and set up input/output features.""" diff --git a/src/lerobot/policies/videovla/modeling_pi05.py b/src/lerobot/policies/videovla/modeling_pi05.py index d3d992411..493bd8e62 100644 --- a/src/lerobot/policies/videovla/modeling_pi05.py +++ b/src/lerobot/policies/videovla/modeling_pi05.py @@ -197,7 +197,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) # Handle dtype-specific clipping if images.dtype == torch.uint8: resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) - elif images.dtype == torch.float32: + elif images.dtype in (torch.float32, torch.float16, torch.bfloat16): resized_images = resized_images.clamp(-1.0, 1.0) else: raise ValueError(f"Unsupported image dtype: {images.dtype}") @@ -315,8 +315,8 @@ class PerceiverResampler(nn.Module): self.num_latents = num_latents self.dim = dim - # Learnable query tokens - self.latents = nn.Parameter(torch.randn(num_latents, dim)) + # Learnable query tokens (initialized with small values for stability) + self.latents = nn.Parameter(torch.randn(num_latents, dim) * 0.02) # Cross-attention layer self.attn = nn.MultiheadAttention( @@ -328,6 +328,8 @@ class PerceiverResampler(nn.Module): # Layer norms for queries and key-values self.ln_q = nn.LayerNorm(dim) self.ln_kv = nn.LayerNorm(dim) + # Output layer norm (applied after residual connection) + self.ln_out = nn.LayerNorm(dim) def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -347,7 +349,10 @@ class PerceiverResampler(nn.Module): kv = self.ln_kv(x) # Cross-attention: queries attend to video tokens - out, _ = self.attn(q, kv, kv, need_weights=False) # (B, num_latents, D) + attn_out, _ = self.attn(q, kv, kv, need_weights=False) # (B, num_latents, D) + + # Residual connection + output layer norm + out = self.ln_out(latents + attn_out) return out @@ -693,6 +698,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True + + # Enable gradient checkpointing for video encoder if available and not frozen + if self.video_encoder is not None and not self.config.freeze_video_encoder: + if hasattr(self.video_encoder, "gradient_checkpointing_enable"): + self.video_encoder.gradient_checkpointing_enable() + logging.info("Enabled gradient checkpointing for video encoder") + logging.info("Enabled gradient checkpointing for PI05Pytorch model") def gradient_checkpointing_disable(self): @@ -701,6 +713,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False + + # Disable gradient checkpointing for video encoder if available + if self.video_encoder is not None: + if hasattr(self.video_encoder, "gradient_checkpointing_disable"): + self.video_encoder.gradient_checkpointing_disable() + logging.info("Disabled gradient checkpointing for video encoder") + logging.info("Disabled gradient checkpointing for PI05Pytorch model") def _rtc_enabled(self): @@ -807,13 +826,15 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` Returns: Video embeddings of shape [B, num_video_tokens, hidden_dim] projected to - PaliGemma's hidden dimension. + PaliGemma's hidden dimension, in the same dtype as PaliGemma's language model. """ if self.video_encoder is None: raise RuntimeError("Video encoder is not initialized. Set use_video_encoder=True in config.") device = video_frames.device - dtype = video_frames.dtype + + # Determine target dtype: match PaliGemma's language model dtype for consistency + paligemma_dtype = self.paligemma_with_expert.paligemma.language_model.embed_tokens.weight.dtype # Move video encoder to the same device if needed if next(self.video_encoder.parameters()).device != device: @@ -848,8 +869,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` # Shape: [B, num_patches * num_frames, hidden_size] (e.g., [B, 4096, 768]) video_embeddings = video_outputs.last_hidden_state - # Convert to working dtype - video_embeddings = video_embeddings.to(dtype=dtype) + # Convert to PaliGemma's dtype for consistency with the rest of the model + video_embeddings = video_embeddings.to(dtype=paligemma_dtype) # Apply Perceiver Resampler to reduce tokens (e.g., 4096 -> 128) # This uses cross-attention from learnable queries to the video tokens @@ -1369,6 +1390,12 @@ class PI05VideoPolicy(PreTrainedPolicy): if self.config.image_features: return next(iter(self.config.image_features.keys())) + # Warn if video encoder is enabled but no image features found + logging.warning( + "use_video_encoder=True but no image features found in config. " + "Video encoding will be skipped. Either set video_encoder_camera_key " + "or ensure image_features contains at least one camera." + ) return None def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: @@ -1496,7 +1523,9 @@ class PI05VideoPolicy(PreTrainedPolicy): # Single frame [B, C, H, W] - expand to video by repeating B, C, H, W = img.shape if self.config.video_padding_mode == "repeat": - video_frames = img.unsqueeze(1).expand(B, self.config.video_num_frames, C, H, W) + # Use repeat() instead of expand() to create actual copies, not views + # This prevents potential issues if downstream operations modify tensors in-place + video_frames = img.unsqueeze(1).repeat(1, self.config.video_num_frames, 1, 1, 1) else: # zero padding video_frames = torch.zeros( B, self.config.video_num_frames, C, H, W, dtype=img.dtype, device=img.device @@ -1513,8 +1542,9 @@ class PI05VideoPolicy(PreTrainedPolicy): if self.config.video_padding_mode == "repeat": # Repeat the first frame to fill missing frames at the beginning + # Use repeat() instead of expand() to create actual copies, not views first_frame = video_frames[:, 0:1] # [B, 1, C, H, W] - padding = first_frame.expand(B, num_missing, C, H, W) + padding = first_frame.repeat(1, num_missing, 1, 1, 1) video_frames = torch.cat([padding, video_frames], dim=1) else: # zero padding # Zero-pad at the beginning @@ -1625,11 +1655,21 @@ class PI05VideoPolicy(PreTrainedPolicy): return loss, loss_dict def _get_default_peft_targets(self) -> dict[str, any]: - """Return default PEFT target modules for PI0.5 fine-tuning.""" - common_projections = ( - "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out" - ) - target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))" + """Return default PEFT target modules for PI0.5 fine-tuning. + + Note: PI05 does NOT have state_proj (that's PI0 only). PI05 tokenizes state + into the language prompt instead. + """ + # Core PI05 projections (no state_proj in PI05) + core_projections = "action_in_proj|action_out_proj|time_mlp_in|time_mlp_out" + + # Video-related modules (only present if use_video_encoder=True) + video_projections = "video_proj|video_resampler\\..*" + + # Combine all projections + all_projections = f"{core_projections}|{video_projections}" + + target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({all_projections}))" return { "target_modules": target_modules, "modules_to_save": [],