#!/usr/bin/env python # Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import builtins import copy import logging import math from collections import deque from pathlib import Path from typing import TYPE_CHECKING, Literal, TypedDict, Unpack import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn from lerobot.utils.import_utils import _transformers_available # Conditional import for type checking and lazy loading if TYPE_CHECKING or _transformers_available: from transformers.models.auto import CONFIG_MAPPING from transformers.models.gemma import modeling_gemma from lerobot.policies.pi_gemma import ( PaliGemmaForConditionalGenerationWithPiGemma, PiGemmaForCausalLM, _gated_residual, layernorm_forward, ) else: CONFIG_MAPPING = None modeling_gemma = None PiGemmaForCausalLM = None _gated_residual = None layernorm_forward = None PaliGemmaForConditionalGenerationWithPiGemma = None from lerobot.configs.policies import PreTrainedConfig from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.rtc.modeling_rtc import RTCProcessor from lerobot.utils.constants import ( ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE, OPENPI_ATTENTION_MASK_VALUE, ) class ActionSelectKwargs(TypedDict, total=False): inference_delay: int | None prev_chunk_left_over: Tensor | None execution_horizon: int | None def get_safe_dtype(target_dtype, device_type): """Get a safe dtype for the given device type.""" if device_type == "mps" and target_dtype == torch.float64: return torch.float32 if device_type == "cpu": # CPU doesn't support bfloat16, use float32 instead if target_dtype == torch.bfloat16: return torch.float32 if target_dtype == torch.float64: return torch.float64 return target_dtype def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy) time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu" ) -> Tensor: """Computes sine-cosine positional embedding vectors for scalar positions.""" if dimension % 2 != 0: raise ValueError(f"dimension ({dimension}) must be divisible by 2") if time.ndim != 1: raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") dtype = get_safe_dtype(torch.float64, device.type) fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) period = min_period * (max_period / min_period) ** fraction # Compute the outer product scaling_factor = 1.0 / period * 2 * math.pi sin_input = scaling_factor[None, :] * time[:, None] return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy) # Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU alpha_t = torch.tensor(alpha, dtype=torch.float32) beta_t = torch.tensor(beta, dtype=torch.float32) dist = torch.distributions.Beta(alpha_t, beta_t) return dist.sample((bsize,)).to(device) def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy) """Copied from big_vision. Tokens can attend to valid inputs tokens which have a cumulative mask_ar smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to setup several types of attention, for example: [[1 1 1 1 1 1]]: pure causal attention. [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between themselves and the last 3 tokens have a causal attention. The first entry could also be a 1 without changing behaviour. [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a block can attend all previous blocks and all tokens on the same block. Args: input_mask: bool[B, N] true if its part of the input, false if padding. mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on it and 0 where it shares the same attention mask as the previous token. """ if att_masks.ndim != 2: raise ValueError(att_masks.ndim) if pad_masks.ndim != 2: raise ValueError(pad_masks.ndim) cumsum = torch.cumsum(att_masks, dim=1) att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] return att_2d_masks & pad_2d_masks def pad_vector(vector, new_dim): """Pad the last dimension of a vector to new_dim with zeros. Can be (batch_size x sequence_length x features_dimension) or (batch_size x features_dimension) """ if vector.shape[-1] >= new_dim: return vector return F.pad(vector, (0, new_dim - vector.shape[-1])) def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) images: torch.Tensor, height: int, width: int, mode: str = "bilinear", ) -> torch.Tensor: """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion by padding with black. If the image is float32, it must be in the range [-1, 1]. Args: images: Tensor of shape [*b, h, w, c] or [*b, c, h, w] height: Target height width: Target width mode: Interpolation mode ('bilinear', 'nearest', etc.) Returns: Resized and padded tensor with same shape format as input """ # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] if images.shape[-1] <= 4: # Assume channels-last format channels_last = True if images.dim() == 3: images = images.unsqueeze(0) # Add batch dimension images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] else: channels_last = False if images.dim() == 3: images = images.unsqueeze(0) # Add batch dimension batch_size, channels, cur_height, cur_width = images.shape # Calculate resize ratio ratio = max(cur_width / width, cur_height / height) resized_height = int(cur_height / ratio) resized_width = int(cur_width / ratio) # Resize resized_images = F.interpolate( images, size=(resized_height, resized_width), mode=mode, align_corners=False if mode == "bilinear" else None, ) # Handle dtype-specific clipping if images.dtype == torch.uint8: resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) elif images.dtype == torch.float32: resized_images = resized_images.clamp(0.0, 1.0) else: raise ValueError(f"Unsupported image dtype: {images.dtype}") # Calculate padding pad_h0, remainder_h = divmod(height - resized_height, 2) pad_h1 = pad_h0 + remainder_h pad_w0, remainder_w = divmod(width - resized_width, 2) pad_w1 = pad_w0 + remainder_w # Pad constant_value = 0 if images.dtype == torch.uint8 else 0.0 padded_images = F.pad( resized_images, (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom mode="constant", value=constant_value, ) # Convert back to original format if needed if channels_last: padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] return padded_images # Define the complete layer computation function for gradient checkpointing def compute_layer_complete( layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert ): models = [paligemma.model.language_model, gemma_expert.model] query_states = [] key_states = [] value_states = [] gates = [] for i, hidden_states in enumerate(inputs_embeds): layer = models[i].layers[layer_idx] hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i]) gates.append(gate) input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) query_states.append(query_state) key_states.append(key_state) value_states.append(value_state) # Concatenate and process attention query_states = torch.cat(query_states, dim=2) key_states = torch.cat(key_states, dim=2) value_states = torch.cat(value_states, dim=2) dummy_tensor = torch.zeros( query_states.shape[0], query_states.shape[2], query_states.shape[-1], device=query_states.device, dtype=query_states.dtype, ) cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) query_states, key_states = modeling_gemma.apply_rotary_pos_emb( query_states, key_states, cos, sin, unsqueeze_dim=1 ) batch_size = query_states.shape[0] scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling # Attention computation att_output, _ = modeling_gemma.eager_attention_forward( paligemma.model.language_model.layers[layer_idx].self_attn, query_states, key_states, value_states, attention_mask, scaling, ) # Get head_dim from the current layer, not from the model head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) # Process layer outputs outputs_embeds = [] start_pos = 0 for i, hidden_states in enumerate(inputs_embeds): layer = models[i].layers[layer_idx] end_pos = start_pos + hidden_states.shape[1] if att_output.dtype != layer.self_attn.o_proj.weight.dtype: att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) # first residual out_emb = _gated_residual(hidden_states, out_emb, gates[i]) after_first_residual = out_emb.clone() out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i]) # Convert to bfloat16 if the next layer (mlp) uses bfloat16 if layer.mlp.up_proj.weight.dtype == torch.bfloat16: out_emb = out_emb.to(dtype=torch.bfloat16) out_emb = layer.mlp(out_emb) # second residual out_emb = _gated_residual(after_first_residual, out_emb, gate) outputs_embeds.append(out_emb) start_pos = end_pos return outputs_embeds class GemmaConfig: # see openpi `gemma.py: Config` """Configuration for Gemma model variants.""" def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim): self.width = width self.depth = depth self.mlp_dim = mlp_dim self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim def get_gemma_config(variant: str) -> GemmaConfig: # see openpi `gemma.py: get_config` """Returns config for specified gemma variant.""" if variant == "gemma_300m": return GemmaConfig( width=1024, depth=18, mlp_dim=4096, num_heads=8, num_kv_heads=1, head_dim=256, ) elif variant == "gemma_2b": return GemmaConfig( width=2048, depth=18, mlp_dim=16_384, num_heads=8, num_kv_heads=1, head_dim=256, ) else: raise ValueError(f"Unknown variant: {variant}") class PaliGemmaWithExpertModel( nn.Module ): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi """PaliGemma model with action expert for PI0.""" def __init__( self, vlm_config, action_expert_config, use_adarms=None, precision: Literal["bfloat16", "float32"] = "bfloat16", image_size: int = DEFAULT_IMAGE_SIZE, freeze_vision_encoder: bool = False, train_expert_only: bool = False, ): if use_adarms is None: use_adarms = [False, False] super().__init__() self.freeze_vision_encoder = freeze_vision_encoder self.train_expert_only = train_expert_only vlm_config_hf = CONFIG_MAPPING["paligemma"]() vlm_config_hf._vocab_size = 257152 # noqa: SLF001 vlm_config_hf.image_token_index = 257152 vlm_config_hf.text_config.hidden_size = vlm_config.width vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads vlm_config_hf.text_config.head_dim = vlm_config.head_dim vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" vlm_config_hf.text_config.dtype = "float32" vlm_config_hf.text_config.vocab_size = 257152 vlm_config_hf.text_config.use_adarms = use_adarms[0] vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None vlm_config_hf.vision_config.image_size = image_size vlm_config_hf.vision_config.intermediate_size = 4304 vlm_config_hf.vision_config.projection_dim = 2048 vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" vlm_config_hf.vision_config.dtype = "float32" action_expert_config_hf = CONFIG_MAPPING["gemma"]( head_dim=action_expert_config.head_dim, hidden_size=action_expert_config.width, intermediate_size=action_expert_config.mlp_dim, num_attention_heads=action_expert_config.num_heads, num_hidden_layers=action_expert_config.depth, num_key_value_heads=action_expert_config.num_kv_heads, vocab_size=257152, hidden_activation="gelu_pytorch_tanh", dtype="float32", use_adarms=use_adarms[1], adarms_cond_dim=action_expert_config.width if use_adarms[1] else None, ) self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf) self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf) self.gemma_expert.model.embed_tokens = None self.to_bfloat16_for_selected_params(precision) self._set_requires_grad() def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): if precision == "bfloat16": self.to(dtype=torch.bfloat16) elif precision == "float32": self.to(dtype=torch.float32) return else: raise ValueError(f"Invalid precision: {precision}") # Keep full vision path in float32 so we never toggle (toggle causes optimizer # "same dtype" error). Align with PI05. params_to_keep_float32 = [ "vision_tower", "multi_modal_projector", "input_layernorm", "post_attention_layernorm", "model.norm", ] for name, param in self.named_parameters(): if any(selector in name for selector in params_to_keep_float32): param.data = param.data.to(dtype=torch.float32) def _set_requires_grad(self): if self.freeze_vision_encoder: self.paligemma.model.vision_tower.eval() for param in self.paligemma.model.vision_tower.parameters(): param.requires_grad = False if self.train_expert_only: self.paligemma.eval() for param in self.paligemma.parameters(): param.requires_grad = False def train(self, mode: bool = True): super().train(mode) if self.freeze_vision_encoder: self.paligemma.model.vision_tower.eval() if self.train_expert_only: self.paligemma.eval() def embed_image(self, image: torch.Tensor): # Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05. out_dtype = image.dtype if image.dtype != torch.float32: image = image.to(torch.float32) image_outputs = self.paligemma.model.get_image_features(image) features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5 if features.dtype != out_dtype: features = features.to(out_dtype) return features def embed_language_tokens(self, tokens: torch.Tensor): return self.paligemma.model.language_model.embed_tokens(tokens) def forward( self, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: list[torch.FloatTensor] | None = None, inputs_embeds: list[torch.FloatTensor] | None = None, use_cache: bool | None = None, adarms_cond: list[torch.Tensor] | None = None, ): if adarms_cond is None: adarms_cond = [None, None] if inputs_embeds[1] is None: prefix_output = self.paligemma.model.language_model.forward( inputs_embeds=inputs_embeds[0], attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, adarms_cond=adarms_cond[0] if adarms_cond is not None else None, ) prefix_past_key_values = prefix_output.past_key_values prefix_output = prefix_output.last_hidden_state suffix_output = None elif inputs_embeds[0] is None: suffix_output = self.gemma_expert.model.forward( inputs_embeds=inputs_embeds[1], attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, adarms_cond=adarms_cond[1] if adarms_cond is not None else None, ) suffix_output = suffix_output.last_hidden_state prefix_output = None prefix_past_key_values = None else: models = [self.paligemma.model.language_model, self.gemma_expert.model] num_layers = self.paligemma.config.text_config.num_hidden_layers # Check if gradient checkpointing is enabled for any of the models use_gradient_checkpointing = ( hasattr(self.gemma_expert.model, "gradient_checkpointing") and self.gemma_expert.model.gradient_checkpointing and self.training ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training) # Process all layers with gradient checkpointing if enabled for layer_idx in range(num_layers): if use_gradient_checkpointing: inputs_embeds = torch.utils.checkpoint.checkpoint( compute_layer_complete, layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, use_reentrant=False, preserve_rng_state=False, paligemma=self.paligemma, gemma_expert=self.gemma_expert, ) else: inputs_embeds = compute_layer_complete( layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma=self.paligemma, gemma_expert=self.gemma_expert, ) # final norm def compute_final_norms(inputs_embeds, adarms_cond): outputs_embeds = [] for i, hidden_states in enumerate(inputs_embeds): out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i]) outputs_embeds.append(out_emb) return outputs_embeds # Apply gradient checkpointing to final norm if enabled if use_gradient_checkpointing: outputs_embeds = torch.utils.checkpoint.checkpoint( compute_final_norms, inputs_embeds, adarms_cond, use_reentrant=False, preserve_rng_state=False, ) else: outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond) prefix_output = outputs_embeds[0] suffix_output = outputs_embeds[1] prefix_past_key_values = None return [prefix_output, suffix_output], prefix_past_key_values class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` """Core PI0 PyTorch model.""" def __init__(self, config: PI0Config, rtc_processor: RTCProcessor | None = None): super().__init__() self.config = config self.rtc_processor = rtc_processor paligemma_config = get_gemma_config(config.paligemma_variant) action_expert_config = get_gemma_config(config.action_expert_variant) if config.image_resolution[0] != config.image_resolution[1]: raise ValueError( f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}" ) self.paligemma_with_expert = PaliGemmaWithExpertModel( paligemma_config, action_expert_config, use_adarms=[False, False], precision=config.dtype, image_size=config.image_resolution[0], freeze_vision_encoder=config.freeze_vision_encoder, train_expert_only=config.train_expert_only, ) self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width) self.action_out_proj = nn.Linear(action_expert_config.width, config.max_action_dim) self.state_proj = nn.Linear(config.max_state_dim, action_expert_config.width) self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width) self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) # Initialize gradient checkpointing flag self.gradient_checkpointing_enabled = False # Compile model if requested if config.compile_model: torch.set_float32_matmul_precision("high") self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode) # Also compile the main forward pass used during training self.forward = torch.compile(self.forward, mode=config.compile_mode) def gradient_checkpointing_enable(self): """Enable gradient checkpointing for memory optimization.""" self.gradient_checkpointing_enabled = True self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True logging.info("Enabled gradient checkpointing for PI0Pytorch model") def gradient_checkpointing_disable(self): """Disable gradient checkpointing.""" self.gradient_checkpointing_enabled = False self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False logging.info("Disabled gradient checkpointing for PI0Pytorch model") def _rtc_enabled(self): return self.config.rtc_config is not None and self.config.rtc_config.enabled def _apply_checkpoint(self, func, *args, **kwargs): """Helper method to apply gradient checkpointing if enabled.""" if self.gradient_checkpointing_enabled and self.training: return torch.utils.checkpoint.checkpoint( func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs ) return func(*args, **kwargs) def _prepare_attention_masks_4d(self, att_2d_masks): """Helper method to prepare 4D attention masks for transformer.""" att_2d_masks_4d = att_2d_masks[:, None, :, :] return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE) def sample_noise(self, shape, device): return torch.normal( mean=0.0, std=1.0, size=shape, dtype=torch.float32, device=device, ) def sample_time(self, bsize, device): time_beta = sample_beta( self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device ) time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset return time.to(dtype=torch.float32, device=device) def embed_prefix( self, images, img_masks, lang_tokens, lang_masks ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Embed images with SigLIP and language tokens with embedding layer.""" embs = [] pad_masks = [] att_masks = [] # Process images for img, img_mask in zip(images, img_masks, strict=True): def image_embed_func(img): return self.paligemma_with_expert.embed_image(img) img_emb = self._apply_checkpoint(image_embed_func, img) bsize, num_img_embs = img_emb.shape[:2] embs.append(img_emb) pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) att_masks += [0] * num_img_embs # Process language tokens def lang_embed_func(lang_tokens): lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) lang_emb_dim = lang_emb.shape[-1] return lang_emb * math.sqrt(lang_emb_dim) lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) embs.append(lang_emb) pad_masks.append(lang_masks) num_lang_embs = lang_emb.shape[1] att_masks += [0] * num_lang_embs embs = torch.cat(embs, dim=1) pad_masks = torch.cat(pad_masks, dim=1) att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) bsize = pad_masks.shape[0] att_masks = att_masks[None, :].expand(bsize, len(att_masks)) return embs, pad_masks, att_masks def embed_suffix(self, state, noisy_actions, timestep): """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" embs = [] pad_masks = [] att_masks = [] if self.state_proj.weight.dtype == torch.float32: state = state.to(torch.float32) def state_proj_func(state): return self.state_proj(state) state_emb = self._apply_checkpoint(state_proj_func, state) embs.append(state_emb[:, None, :]) bsize = state_emb.shape[0] device = state_emb.device state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device) pad_masks.append(state_mask) att_masks += [1] # Embed timestep using sine-cosine positional encoding time_emb = create_sinusoidal_pos_embedding( timestep, self.action_in_proj.out_features, min_period=self.config.min_period, max_period=self.config.max_period, device=timestep.device, ) time_emb = time_emb.type(dtype=timestep.dtype) # Fuse timestep + action information using an MLP def action_proj_func(noisy_actions): return self.action_in_proj(noisy_actions) action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) time_emb = time_emb[:, None, :].expand_as(action_emb) action_time_emb = torch.cat([action_emb, time_emb], dim=2) def mlp_func(action_time_emb): x = self.action_time_mlp_in(action_time_emb) x = F.silu(x) return self.action_time_mlp_out(x) action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb) adarms_cond = None embs.append(action_time_emb) bsize, action_time_dim = action_time_emb.shape[:2] action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device) pad_masks.append(action_time_mask) # Set attention masks so that image, language and state inputs do not attend to action tokens att_masks += [1] + ([0] * (self.config.chunk_size - 1)) embs = torch.cat(embs, dim=1) pad_masks = torch.cat(pad_masks, dim=1) att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) att_masks = att_masks[None, :].expand(bsize, len(att_masks)) return embs, pad_masks, att_masks, adarms_cond def forward( self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None ) -> Tensor: """Do a full training forward pass and compute the loss.""" if noise is None: noise = self.sample_noise(actions.shape, actions.device) if time is None: time = self.sample_time(actions.shape[0], actions.device) time_expanded = time[:, None, None] x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( images, img_masks, lang_tokens, lang_masks ) suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time) if ( self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 ): suffix_embs = suffix_embs.to(dtype=torch.bfloat16) prefix_embs = prefix_embs.to(dtype=torch.bfloat16) pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) att_2d_masks = make_att_2d_masks(pad_masks, att_masks) position_ids = torch.cumsum(pad_masks, dim=1) - 1 att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): (_, suffix_out), _ = self.paligemma_with_expert.forward( attention_mask=att_2d_masks_4d, position_ids=position_ids, past_key_values=None, inputs_embeds=[prefix_embs, suffix_embs], use_cache=False, adarms_cond=[None, adarms_cond], ) return suffix_out suffix_out = self._apply_checkpoint( forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond ) suffix_out = suffix_out[:, -self.config.chunk_size :] suffix_out = suffix_out.to(dtype=torch.float32) def action_out_proj_func(suffix_out): return self.action_out_proj(suffix_out) v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) return F.mse_loss(u_t, v_t, reduction="none") @torch.no_grad() # see openpi `sample_actions` (slightly adapted) def sample_actions( self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None, **kwargs: Unpack[ActionSelectKwargs], ) -> Tensor: """Do a full inference forward and compute the action.""" if num_steps is None: num_steps = self.config.num_inference_steps bsize = state.shape[0] device = state.device if noise is None: # Sample noise with padded dimension as expected by action_in_proj actions_shape = ( bsize, self.config.chunk_size, self.config.max_action_dim, ) # Use config max_action_dim for internal processing noise = self.sample_noise(actions_shape, device) prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( images, img_masks, lang_tokens, lang_masks ) prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001 _, past_key_values = self.paligemma_with_expert.forward( attention_mask=prefix_att_2d_masks_4d, position_ids=prefix_position_ids, past_key_values=None, inputs_embeds=[prefix_embs, None], use_cache=True, ) dt = -1.0 / num_steps x_t = noise for step in range(num_steps): time = 1.0 + step * dt time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): return self.denoise_step( state=state, prefix_pad_masks=prefix_pad_masks, past_key_values=past_key_values, x_t=input_x_t, timestep=current_timestep, ) if self._rtc_enabled(): inference_delay = kwargs.get("inference_delay") prev_chunk_left_over = kwargs.get("prev_chunk_left_over") execution_horizon = kwargs.get("execution_horizon") v_t = self.rtc_processor.denoise_step( x_t=x_t, prev_chunk_left_over=prev_chunk_left_over, inference_delay=inference_delay, time=time, original_denoise_step_partial=denoise_step_partial_call, execution_horizon=execution_horizon, ) else: v_t = denoise_step_partial_call(x_t) x_t = x_t + dt * v_t if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled(): self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t) return x_t def denoise_step( self, state, prefix_pad_masks, past_key_values, x_t, timestep, ): """Apply one denoising step of the noise `x_t` at a given timestep.""" suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep) suffix_len = suffix_pad_masks.shape[1] batch_size = prefix_pad_masks.shape[0] prefix_len = prefix_pad_masks.shape[1] prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 past_key_values = copy.deepcopy(past_key_values) outputs_embeds, _ = self.paligemma_with_expert.forward( attention_mask=full_att_2d_masks_4d, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=[None, suffix_embs], use_cache=False, adarms_cond=[None, adarms_cond], ) suffix_out = outputs_embeds[1] suffix_out = suffix_out[:, -self.config.chunk_size :] suffix_out = suffix_out.to(dtype=torch.float32) return self.action_out_proj(suffix_out) class PI0Policy(PreTrainedPolicy): """PI0 OpenPI Policy for LeRobot.""" config_class = PI0Config name = "pi0" def __init__( self, config: PI0Config, **kwargs, ): """ Args: config: Policy configuration class instance. """ super().__init__(config) config.validate_features() self.config = config # Initialize the core PI0 model self.init_rtc_processor() self.model = PI0Pytorch(config, rtc_processor=self.rtc_processor) # Enable gradient checkpointing if requested if config.gradient_checkpointing: self.model.gradient_checkpointing_enable() self.model.to(config.device) self.reset() @classmethod def from_pretrained( cls: builtins.type[T], pretrained_name_or_path: str | Path, *, config: PreTrainedConfig | None = None, force_download: bool = False, resume_download: bool | None = None, proxies: dict | None = None, token: str | bool | None = None, cache_dir: str | Path | None = None, local_files_only: bool = False, revision: str | None = None, strict: bool = True, **kwargs, ) -> T: """Override the from_pretrained method to handle key remapping and display important disclaimer.""" print( "The PI0 model is a direct port of the OpenPI implementation. \n" "This implementation follows the original OpenPI structure for compatibility. \n" "Original implementation: https://github.com/Physical-Intelligence/openpi" ) if pretrained_name_or_path is None: raise ValueError("pretrained_name_or_path is required") # Use provided config if available, otherwise create default config if config is None: config = PreTrainedConfig.from_pretrained( pretrained_name_or_path=pretrained_name_or_path, force_download=force_download, resume_download=resume_download, proxies=proxies, token=token, cache_dir=cache_dir, local_files_only=local_files_only, revision=revision, **kwargs, ) # Initialize model without loading weights # Check if dataset_stats were provided in kwargs model = cls(config, **kwargs) # Load state dict (expects keys with "model." prefix) try: print(f"Loading model from: {pretrained_name_or_path}") try: from transformers.utils import cached_file resolved_file = cached_file( pretrained_name_or_path, "model.safetensors", cache_dir=kwargs.get("cache_dir"), force_download=kwargs.get("force_download", False), resume_download=kwargs.get("resume_download"), proxies=kwargs.get("proxies"), token=kwargs.get("token"), revision=kwargs.get("revision"), local_files_only=kwargs.get("local_files_only", False), ) from safetensors.torch import load_file original_state_dict = load_file(resolved_file) print("✓ Loaded state dict from model.safetensors") except Exception as e: print(f"Could not load state dict from remote files: {e}") print("Returning model without loading pretrained weights") return model # First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys) fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) # Then add "model." prefix for all keys that don't already have it remapped_state_dict = {} remap_count = 0 for key, value in fixed_state_dict.items(): if not key.startswith("model."): new_key = f"model.{key}" remapped_state_dict[new_key] = value remap_count += 1 else: remapped_state_dict[key] = value if remap_count > 0: print(f"Remapped {remap_count} state dict keys") # Load the remapped state dict into the model missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) if missing_keys: print(f"Missing keys when loading state dict: {len(missing_keys)} keys") if len(missing_keys) <= 5: for key in missing_keys: print(f" - {key}") else: for key in missing_keys[:5]: print(f" - {key}") print(f" ... and {len(missing_keys) - 5} more") if unexpected_keys: print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys") if len(unexpected_keys) <= 5: for key in unexpected_keys: print(f" - {key}") else: for key in unexpected_keys[:5]: print(f" - {key}") print(f" ... and {len(unexpected_keys) - 5} more") if not missing_keys and not unexpected_keys: print("All keys loaded successfully!") except Exception as e: print(f"Warning: Could not load state dict: {e}") return model def _fix_pytorch_state_dict_keys( self, state_dict, model_config ): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys` """Fix state dict keys to match current model architecture.""" import re fixed_state_dict = {} for key, value in state_dict.items(): new_key = key # Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias # For gemma expert layers if re.match( r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight", key, ): # Check if the model actually has adaRMS enabled for the expert expert_uses_adarms = getattr( self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False ) if expert_uses_adarms: logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}") continue if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key): # Check if the model actually has adaRMS enabled for the expert expert_uses_adarms = getattr( self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False ) if expert_uses_adarms: logging.warning(f"Skipping norm key (adaRMS mismatch): {key}") continue # Handle MLP naming changes for pi0 # non-pi05 model expects action_time_mlp_*, but checkpoint might have time_mlp_* if key.startswith("time_mlp_in."): new_key = key.replace("time_mlp_in.", "action_time_mlp_in.") elif key.startswith("time_mlp_out."): new_key = key.replace("time_mlp_out.", "action_time_mlp_out.") # Handle vision tower embedding layer potential differences if "patch_embedding" in key: # Some checkpoints might have this, but current model expects different structure logging.warning(f"Vision embedding key might need handling: {key}") if ( key == "model.paligemma_with_expert.paligemma.lm_head.weight" or key == "paligemma_with_expert.paligemma.lm_head.weight" ): fixed_state_dict[ "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" ] = value.clone() fixed_state_dict[new_key] = value return fixed_state_dict def get_optim_params(self) -> dict: return self.parameters() def reset(self): """Reset internal state - called when environment resets.""" self._action_queue = deque(maxlen=self.config.n_action_steps) self._queues = { ACTION: deque(maxlen=self.config.n_action_steps), } def init_rtc_processor(self): """Initialize RTC processor if RTC is enabled in config.""" self.rtc_processor = None # Create processor if config provided # If RTC is not enabled - we can still track the denoising data if self.config.rtc_config is not None: self.rtc_processor = RTCProcessor(self.config.rtc_config) model_value = getattr(self, "model", None) if model_value is not None: model_value.rtc_processor = self.rtc_processor def _rtc_enabled(self) -> bool: return self.config.rtc_config is not None and self.config.rtc_config.enabled def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: """Preprocess images for the model. Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1]. PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1]. """ images = [] img_masks = [] # Get device from model parameters device = next(self.parameters()).device present_img_keys = [key for key in self.config.image_features if key in batch] missing_img_keys = [key for key in self.config.image_features if key not in batch] if len(present_img_keys) == 0: raise ValueError( f"All image features are missing from the batch. At least one expected. " f"(batch: {batch.keys()}) (image_features: {self.config.image_features})" ) for key in present_img_keys: img = batch[key] # Ensure tensor is on the same device as the model if img.device != device: img = img.to(device) # Ensure float32 dtype for consistency if img.dtype != torch.float32: img = img.to(torch.float32) # from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1 if is_channels_first: # Convert [B, C, H, W] to [B, H, W, C] for processing img = img.permute(0, 2, 3, 1) # from openpi preprocess_observation_pytorch: Resize with padding if needed if img.shape[1:3] != self.config.image_resolution: img = resize_with_pad_torch(img, *self.config.image_resolution) # Normalize from [0,1] to [-1,1] as expected by siglip img = img * 2.0 - 1.0 # from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first if is_channels_first: img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] images.append(img) # Create mask (all ones for real images) bsize = img.shape[0] mask = torch.ones(bsize, dtype=torch.bool, device=device) img_masks.append(mask) # Create image features not present in the batch as fully 0 padded images for _num_empty_cameras in range(len(missing_img_keys)): img = torch.ones_like(img) * -1 # padded with -1 for SigLIP mask = torch.zeros_like(mask) # mask is zero for empty cameras images.append(img) img_masks.append(mask) return images, img_masks def prepare_state(self, batch): """Pad state""" state = pad_vector(batch[OBS_STATE], self.config.max_state_dim) return state def prepare_action(self, batch): """Pad action""" actions = pad_vector(batch[ACTION], self.config.max_action_dim) return actions @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations.""" assert not self._rtc_enabled(), ( "RTC is not supported for select_action, use it with predict_action_chunk" ) self.eval() # Action queue logic for n_action_steps > 1 if len(self._action_queue) == 0: actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] # Transpose to get shape (n_action_steps, batch_size, action_dim) self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor: """Predict a chunk of actions given environment observations.""" self.eval() # Prepare inputs images, img_masks = self._preprocess_images(batch) lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] state = self.prepare_state(batch) # Sample actions using the model (pass through RTC kwargs) actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, **kwargs) # Unpad actions to actual action dimension original_action_dim = self.config.output_features[ACTION].shape[0] actions = actions[:, :, :original_action_dim] return actions def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]: """Run the batch through the model and compute the loss for training. Args: batch: Training batch containing observations and actions. reduction: How to reduce the loss. Options: - "mean": Return scalar mean loss (default, backward compatible) - "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting """ # Prepare inputs images, img_masks = self._preprocess_images(batch) lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] state = self.prepare_state(batch) actions = self.prepare_action(batch) # Compute loss losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions) # Truncate losses to actual action dimensions original_action_dim = self.config.output_features[ACTION].shape[0] losses = losses[:, :, :original_action_dim] loss_dict = { "loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(), } if reduction == "none": # Return per-sample losses (B,) by averaging over time and action dims per_sample_loss = losses.mean(dim=(1, 2)) loss_dict["loss"] = per_sample_loss.mean().item() return per_sample_loss, loss_dict else: # Default: return scalar mean loss loss = losses.mean() loss_dict["loss"] = loss.item() return loss, loss_dict def _get_default_peft_targets(self) -> dict[str, any]: """Return default PEFT target modules for PI0 fine-tuning.""" common_projections = ( "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out" ) target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))" return { "target_modules": target_modules, "modules_to_save": [], }