diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 104ec63bf..08bfbb98c 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -460,8 +460,8 @@ class PaliGemmaWithExpertModel( inputs_embeds=inputs_embeds[1], attention_mask=attention_mask, position_ids=position_ids, - past_key_values=past_key_values, - use_cache=use_cache, + use_cache=False, + past_key_values=None, #jadechoghari adarms_cond=adarms_cond[1] if adarms_cond is not None else None, ) suffix_output = suffix_output.last_hidden_state @@ -575,13 +575,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" - try: - from transformers.models.siglip import check + # try: + # from transformers.models.siglip import check - if not check.check_whether_transformers_replace_is_installed_correctly(): - raise ValueError(msg) - except ImportError: - raise ValueError(msg) from None + # if not check.check_whether_transformers_replace_is_installed_correctly(): + # raise ValueError(msg) + # except ImportError: + # raise ValueError(msg) from None def gradient_checkpointing_enable(self): """Enable gradient checkpointing for memory optimization.""" diff --git a/src/lerobot/policies/videovla/README.md b/src/lerobot/policies/videovla/README.md new file mode 100644 index 000000000..2ae69d978 --- /dev/null +++ b/src/lerobot/policies/videovla/README.md @@ -0,0 +1,49 @@ +# π₀.₅ (pi05) + +This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence. +It is designed as a **Vision-Language-Action model with open-world generalization**. + +--- + +## Model Overview + +| Feature | π₀ | π₀.₅ | +| -------------------- | ------------------------------------------------------ | ----------------------------------------- | +| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning | +| AdaRMS | Not used | Used in action expert | +| Tokenizer Length | 48 tokens | 200 tokens | +| Discrete State Input | False (Uses `state_proj` layer) | True | +| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) | + +--- + +## Citation + +If you use this work, please cite both **OpenPI** and the π₀.₅ paper: + +```bibtex +@misc{openpi2024, + author = {Physical Intelligence Lab}, + title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies}, + year = {2024}, + publisher = {GitHub}, + howpublished = {\url{https://github.com/Physical-Intelligence/openpi}}, + license = {Apache-2.0} +} + +@misc{intelligence2025pi05visionlanguageactionmodelopenworld, + title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization}, + author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky}, + year = {2025}, + eprint = {2504.16054}, + archivePrefix= {arXiv}, + primaryClass = {cs.LG}, + url = {https://arxiv.org/abs/2504.16054}, +} +``` + +--- + +## License + +This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi). diff --git a/src/lerobot/policies/videovla/__init__.py b/src/lerobot/policies/videovla/__init__.py new file mode 100644 index 000000000..a8580913c --- /dev/null +++ b/src/lerobot/policies/videovla/__init__.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lazy imports to avoid conflicts with lerobot.policies.pi05.PI05Config +# when only importing subpackages like videoprism +def __getattr__(name): + if name == "PI05Config": + from .configuration_pi05 import PI05Config + return PI05Config + elif name == "PI05Policy": + from .modeling_pi05 import PI05Policy + return PI05Policy + elif name == "make_pi05_pre_post_processors": + from .processor_pi05 import make_pi05_pre_post_processors + return make_pi05_pre_post_processors + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + +__all__ = ["PI05Config", "PI05Policy", "make_pi05_pre_post_processors"] diff --git a/src/lerobot/policies/videovla/configuration_pi05.py b/src/lerobot/policies/videovla/configuration_pi05.py new file mode 100644 index 000000000..b96e6d196 --- /dev/null +++ b/src/lerobot/policies/videovla/configuration_pi05.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig +from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE + +DEFAULT_IMAGE_SIZE = 224 + + +@PreTrainedConfig.register_subclass("pi05") +@dataclass +class PI05Config(PreTrainedConfig): + paligemma_variant: str = "gemma_2b" + action_expert_variant: str = "gemma_300m" + dtype: str = "float32" # Options: "bfloat16", "float32" + + n_obs_steps: int = 1 + chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon" + n_action_steps: int = 50 # Number of action steps to execute + + # Shorter state and action vectors will be padded to these dimensions + max_state_dim: int = 32 + max_action_dim: int = 32 + + # Flow matching parameters: see openpi `PI0Pytorch` + num_inference_steps: int = 10 + time_sampling_beta_alpha: float = 1.5 + time_sampling_beta_beta: float = 1.0 + time_sampling_scale: float = 0.999 + time_sampling_offset: float = 0.001 + min_period: float = 4e-3 + max_period: float = 4.0 + + # Real-Time Chunking (RTC) configuration + rtc_config: RTCConfig | None = None + + image_resolution: tuple[int, int] = ( + DEFAULT_IMAGE_SIZE, + DEFAULT_IMAGE_SIZE, + ) # see openpi `preprocessing_pytorch.py` + + # Add empty images. Used to add empty cameras when no image features are present. + empty_cameras: int = 0 + + tokenizer_max_length: int = 200 # see openpi `__post_init__` + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state + "ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action + } + ) + + # Training settings + gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization + compile_model: bool = False # Whether to use torch.compile for model optimization + compile_mode: str = "max-autotune" # Torch compile mode + device: str | None = None # Device to use for the model (None = auto-detect) + + # Finetuning settings + freeze_vision_encoder: bool = False # Freeze only the vision encoder + train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections + + # Optimizer settings: see openpi `AdamW` + optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr` + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 0.01 + optimizer_grad_clip_norm: float = 1.0 + + # Scheduler settings: see openpi `CosineDecaySchedule` + # Note: These will auto-scale if --steps < scheduler_decay_steps + # For example, --steps=3000 will scale warmup to 100 and decay to 3000 + scheduler_warmup_steps: int = 1_000 + scheduler_decay_steps: int = 30_000 + scheduler_decay_lr: float = 2.5e-6 + + tokenizer_max_length: int = 200 # see openpi `__post_init__` + + def __post_init__(self): + super().__post_init__() + + # Validate configuration + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})" + ) + + if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]: + raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}") + + if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]: + raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}") + + if self.dtype not in ["bfloat16", "float32"]: + raise ValueError(f"Invalid dtype: {self.dtype}") + + def validate_features(self) -> None: + """Validate and set up input/output features.""" + for i in range(self.empty_cameras): + key = OBS_IMAGES + f".empty_camera_{i}" + empty_camera = PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, *self.image_resolution), # Use configured image resolution + ) + self.input_features[key] = empty_camera + + if OBS_STATE not in self.input_features: + state_feature = PolicyFeature( + type=FeatureType.STATE, + shape=(self.max_state_dim,), # Padded to max_state_dim + ) + self.input_features[OBS_STATE] = state_feature + + if ACTION not in self.output_features: + action_feature = PolicyFeature( + type=FeatureType.ACTION, + shape=(self.max_action_dim,), # Padded to max_action_dim + ) + self.output_features[ACTION] = action_feature + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, + ) + + def get_scheduler_preset(self): + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/src/lerobot/policies/videovla/modeling_pi05.py b/src/lerobot/policies/videovla/modeling_pi05.py new file mode 100644 index 000000000..08bfbb98c --- /dev/null +++ b/src/lerobot/policies/videovla/modeling_pi05.py @@ -0,0 +1,1283 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import builtins +import logging +import math +from collections import deque +from pathlib import Path +from typing import TYPE_CHECKING, Literal, TypedDict + +import torch +import torch.nn.functional as F # noqa: N812 +from torch import Tensor, nn +from typing_extensions import Unpack + +from lerobot.utils.import_utils import _transformers_available + +# Conditional import for type checking and lazy loading +if TYPE_CHECKING or _transformers_available: + from transformers.models.auto import CONFIG_MAPPING + from transformers.models.gemma import modeling_gemma + from transformers.models.gemma.modeling_gemma import GemmaForCausalLM + from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration +else: + CONFIG_MAPPING = None + modeling_gemma = None + GemmaForCausalLM = None + PaliGemmaForConditionalGeneration = None + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config +from lerobot.policies.pretrained import PreTrainedPolicy, T +from lerobot.policies.rtc.modeling_rtc import RTCProcessor +from lerobot.utils.constants import ( + ACTION, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + OPENPI_ATTENTION_MASK_VALUE, +) + + +class ActionSelectKwargs(TypedDict, total=False): + inference_delay: int | None + prev_chunk_left_over: Tensor | None + execution_horizon: int | None + + +def get_safe_dtype(target_dtype, device_type): + """Get a safe dtype for the given device type.""" + if device_type == "mps" and target_dtype == torch.float64: + return torch.float32 + if device_type == "cpu": + # CPU doesn't support bfloat16, use float32 instead + if target_dtype == torch.bfloat16: + return torch.float32 + if target_dtype == torch.float64: + return torch.float64 + return target_dtype + + +def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy) + time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu" +) -> Tensor: + """Computes sine-cosine positional embedding vectors for scalar positions.""" + if dimension % 2 != 0: + raise ValueError(f"dimension ({dimension}) must be divisible by 2") + + if time.ndim != 1: + raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") + + dtype = get_safe_dtype(torch.float64, device.type) + fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) + period = min_period * (max_period / min_period) ** fraction + + # Compute the outer product + scaling_factor = 1.0 / period * 2 * math.pi + sin_input = scaling_factor[None, :] * time[:, None] + return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + + +def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy) + alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) + beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) + dist = torch.distributions.Beta(alpha_t, beta_t) + return dist.sample((bsize,)) + + +def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy) + """Copied from big_vision. + + Tokens can attend to valid inputs tokens which have a cumulative mask_ar + smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to + setup several types of attention, for example: + + [[1 1 1 1 1 1]]: pure causal attention. + + [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between + themselves and the last 3 tokens have a causal attention. The first + entry could also be a 1 without changing behaviour. + + [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a + block can attend all previous blocks and all tokens on the same block. + + Args: + input_mask: bool[B, N] true if its part of the input, false if padding. + mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on + it and 0 where it shares the same attention mask as the previous token. + """ + if att_masks.ndim != 2: + raise ValueError(att_masks.ndim) + if pad_masks.ndim != 2: + raise ValueError(pad_masks.ndim) + + cumsum = torch.cumsum(att_masks, dim=1) + att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + return att_2d_masks & pad_2d_masks + + +def pad_vector(vector, new_dim): + """Pad the last dimension of a vector to new_dim with zeros. + + Can be (batch_size x sequence_length x features_dimension) + or (batch_size x features_dimension) + """ + if vector.shape[-1] >= new_dim: + return vector + return F.pad(vector, (0, new_dim - vector.shape[-1])) + + +def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) + images: torch.Tensor, + height: int, + width: int, + mode: str = "bilinear", +) -> torch.Tensor: + """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion + by padding with black. If the image is float32, it must be in the range [-1, 1]. + + Args: + images: Tensor of shape [*b, h, w, c] or [*b, c, h, w] + height: Target height + width: Target width + mode: Interpolation mode ('bilinear', 'nearest', etc.) + + Returns: + Resized and padded tensor with same shape format as input + """ + # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] + if images.shape[-1] <= 4: # Assume channels-last format + channels_last = True + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] + else: + channels_last = False + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + + batch_size, channels, cur_height, cur_width = images.shape + + # Calculate resize ratio + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + + # Resize + resized_images = F.interpolate( + images, + size=(resized_height, resized_width), + mode=mode, + align_corners=False if mode == "bilinear" else None, + ) + + # Handle dtype-specific clipping + if images.dtype == torch.uint8: + resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) + elif images.dtype == torch.float32: + resized_images = resized_images.clamp(-1.0, 1.0) + else: + raise ValueError(f"Unsupported image dtype: {images.dtype}") + + # Calculate padding + pad_h0, remainder_h = divmod(height - resized_height, 2) + pad_h1 = pad_h0 + remainder_h + pad_w0, remainder_w = divmod(width - resized_width, 2) + pad_w1 = pad_w0 + remainder_w + + # Pad + constant_value = 0 if images.dtype == torch.uint8 else -1.0 + padded_images = F.pad( + resized_images, + (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom + mode="constant", + value=constant_value, + ) + + # Convert back to original format if needed + if channels_last: + padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + return padded_images + + +# Define the complete layer computation function for gradient checkpointing +def compute_layer_complete( + layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert +): + models = [paligemma.language_model, gemma_expert.model] + query_states = [] + key_states = [] + value_states = [] + gates = [] + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 + gates.append(gate) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + # Concatenate and process attention + query_states = torch.cat(query_states, dim=2) + key_states = torch.cat(key_states, dim=2) + value_states = torch.cat(value_states, dim=2) + dummy_tensor = torch.zeros( + query_states.shape[0], + query_states.shape[2], + query_states.shape[-1], + device=query_states.device, + dtype=query_states.dtype, + ) + cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + query_states, key_states = modeling_gemma.apply_rotary_pos_emb( + query_states, key_states, cos, sin, unsqueeze_dim=1 + ) + batch_size = query_states.shape[0] + scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling + # Attention computation + att_output, _ = modeling_gemma.eager_attention_forward( + paligemma.language_model.layers[layer_idx].self_attn, + query_states, + key_states, + value_states, + attention_mask, + scaling, + ) + # Get head_dim from the current layer, not from the model + head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim + att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) + # Process layer outputs + outputs_embeds = [] + start_pos = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + end_pos = start_pos + hidden_states.shape[1] + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) + # first residual + out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 + after_first_residual = out_emb.clone() + out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) + # Convert to bfloat16 if the next layer (mlp) uses bfloat16 + if layer.mlp.up_proj.weight.dtype == torch.bfloat16: + out_emb = out_emb.to(dtype=torch.bfloat16) + out_emb = layer.mlp(out_emb) + # second residual + out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 + outputs_embeds.append(out_emb) + start_pos = end_pos + return outputs_embeds + + +class GemmaConfig: # see openpi `gemma.py: Config` + """Configuration for Gemma model variants.""" + + def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim): + self.width = width + self.depth = depth + self.mlp_dim = mlp_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + + +def get_gemma_config(variant: str) -> GemmaConfig: # see openpi `gemma.py: get_config` + """Returns config for specified gemma variant.""" + if variant == "gemma_300m": + return GemmaConfig( + width=1024, + depth=18, + mlp_dim=4096, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + elif variant == "gemma_2b": + return GemmaConfig( + width=2048, + depth=18, + mlp_dim=16_384, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + else: + raise ValueError(f"Unknown variant: {variant}") + + +class PaliGemmaWithExpertModel( + nn.Module +): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi + """PaliGemma model with action expert for PI05.""" + + def __init__( + self, + vlm_config, + action_expert_config, + use_adarms=None, + precision: Literal["bfloat16", "float32"] = "bfloat16", + image_size: int = DEFAULT_IMAGE_SIZE, + freeze_vision_encoder: bool = False, + train_expert_only: bool = False, + ): + if use_adarms is None: + use_adarms = [False, False] + super().__init__() + self.freeze_vision_encoder = freeze_vision_encoder + self.train_expert_only = train_expert_only + + vlm_config_hf = CONFIG_MAPPING["paligemma"]() + vlm_config_hf._vocab_size = 257152 # noqa: SLF001 + vlm_config_hf.image_token_index = 257152 + vlm_config_hf.text_config.hidden_size = vlm_config.width + vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim + vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads + vlm_config_hf.text_config.head_dim = vlm_config.head_dim + vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth + vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads + vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" + vlm_config_hf.text_config.torch_dtype = "float32" + vlm_config_hf.text_config.vocab_size = 257152 + vlm_config_hf.text_config.use_adarms = use_adarms[0] + vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None + vlm_config_hf.vision_config.image_size = image_size + vlm_config_hf.vision_config.intermediate_size = 4304 + vlm_config_hf.vision_config.projection_dim = 2048 + vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" + vlm_config_hf.vision_config.torch_dtype = "float32" + + action_expert_config_hf = CONFIG_MAPPING["gemma"]( + head_dim=action_expert_config.head_dim, + hidden_size=action_expert_config.width, + intermediate_size=action_expert_config.mlp_dim, + num_attention_heads=action_expert_config.num_heads, + num_hidden_layers=action_expert_config.depth, + num_key_value_heads=action_expert_config.num_kv_heads, + vocab_size=257152, + hidden_activation="gelu_pytorch_tanh", + torch_dtype="float32", + use_adarms=use_adarms[1], + adarms_cond_dim=action_expert_config.width if use_adarms[1] else None, + ) + + self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf) + self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf) + self.gemma_expert.model.embed_tokens = None + + self.to_bfloat16_for_selected_params(precision) + self._set_requires_grad() + + def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): + if precision == "bfloat16": + self.to(dtype=torch.bfloat16) + elif precision == "float32": + self.to(dtype=torch.float32) + return + else: + raise ValueError(f"Invalid precision: {precision}") + + params_to_keep_float32 = [ + "vision_tower.vision_model.embeddings.patch_embedding.weight", + "vision_tower.vision_model.embeddings.patch_embedding.bias", + "vision_tower.vision_model.embeddings.position_embedding.weight", + "input_layernorm", + "post_attention_layernorm", + "model.norm", + ] + + for name, param in self.named_parameters(): + if any(selector in name for selector in params_to_keep_float32): + param.data = param.data.to(dtype=torch.float32) + + def _set_requires_grad(self): + if self.freeze_vision_encoder: + self.paligemma.vision_tower.eval() + for param in self.paligemma.vision_tower.parameters(): + param.requires_grad = False + if self.train_expert_only: + self.paligemma.eval() + for param in self.paligemma.parameters(): + param.requires_grad = False + + def train(self, mode: bool = True): + super().train(mode) + if self.freeze_vision_encoder: + self.paligemma.vision_tower.eval() + if self.train_expert_only: + self.paligemma.eval() + + def embed_image(self, image: torch.Tensor): + return self.paligemma.model.get_image_features(image) + + def embed_language_tokens(self, tokens: torch.Tensor): + return self.paligemma.language_model.embed_tokens(tokens) + + def forward( + self, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: list[torch.FloatTensor] | None = None, + use_cache: bool | None = None, + adarms_cond: list[torch.Tensor] | None = None, + ): + if adarms_cond is None: + adarms_cond = [None, None] + if inputs_embeds[1] is None: + prefix_output = self.paligemma.language_model.forward( + inputs_embeds=inputs_embeds[0], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms_cond[0] if adarms_cond is not None else None, + ) + prefix_past_key_values = prefix_output.past_key_values + prefix_output = prefix_output.last_hidden_state + suffix_output = None + elif inputs_embeds[0] is None: + suffix_output = self.gemma_expert.model.forward( + inputs_embeds=inputs_embeds[1], + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + past_key_values=None, #jadechoghari + adarms_cond=adarms_cond[1] if adarms_cond is not None else None, + ) + suffix_output = suffix_output.last_hidden_state + prefix_output = None + prefix_past_key_values = None + else: + models = [self.paligemma.language_model, self.gemma_expert.model] + num_layers = self.paligemma.config.text_config.num_hidden_layers + + # Check if gradient checkpointing is enabled for any of the models + use_gradient_checkpointing = ( + hasattr(self.gemma_expert.model, "gradient_checkpointing") + and self.gemma_expert.model.gradient_checkpointing + and self.training + ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training) + + # Process all layers with gradient checkpointing if enabled + for layer_idx in range(num_layers): + if use_gradient_checkpointing: + inputs_embeds = torch.utils.checkpoint.checkpoint( + compute_layer_complete, + layer_idx, + inputs_embeds, + attention_mask, + position_ids, + adarms_cond, + use_reentrant=False, + preserve_rng_state=False, + paligemma=self.paligemma, + gemma_expert=self.gemma_expert, + ) + else: + inputs_embeds = compute_layer_complete( + layer_idx, + inputs_embeds, + attention_mask, + position_ids, + adarms_cond, + paligemma=self.paligemma, + gemma_expert=self.gemma_expert, + ) + + # final norm + def compute_final_norms(inputs_embeds, adarms_cond): + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) + outputs_embeds.append(out_emb) + return outputs_embeds + + # Apply gradient checkpointing to final norm if enabled + if use_gradient_checkpointing: + outputs_embeds = torch.utils.checkpoint.checkpoint( + compute_final_norms, + inputs_embeds, + adarms_cond, + use_reentrant=False, + preserve_rng_state=False, + ) + else: + outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond) + + prefix_output = outputs_embeds[0] + suffix_output = outputs_embeds[1] + prefix_past_key_values = None + + return [prefix_output, suffix_output], prefix_past_key_values + + +class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` + """Core PI05 PyTorch model.""" + + def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None): + super().__init__() + self.config = config + self.rtc_processor = rtc_processor + + paligemma_config = get_gemma_config(config.paligemma_variant) + action_expert_config = get_gemma_config(config.action_expert_variant) + + if config.image_resolution[0] != config.image_resolution[1]: + raise ValueError( + f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}" + ) + + self.paligemma_with_expert = PaliGemmaWithExpertModel( + paligemma_config, + action_expert_config, + use_adarms=[False, True], + precision=config.dtype, + image_size=config.image_resolution[0], + freeze_vision_encoder=config.freeze_vision_encoder, + train_expert_only=config.train_expert_only, + ) + + self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width) + self.action_out_proj = nn.Linear(action_expert_config.width, config.max_action_dim) + + self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width) + self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) + + # Initialize gradient checkpointing flag + self.gradient_checkpointing_enabled = False + + # Compile model if requested + if config.compile_model: + torch.set_float32_matmul_precision("high") + self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode) + # Also compile the main forward pass used during training + self.forward = torch.compile(self.forward, mode=config.compile_mode) + + msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" + + # try: + # from transformers.models.siglip import check + + # if not check.check_whether_transformers_replace_is_installed_correctly(): + # raise ValueError(msg) + # except ImportError: + # raise ValueError(msg) from None + + def gradient_checkpointing_enable(self): + """Enable gradient checkpointing for memory optimization.""" + self.gradient_checkpointing_enabled = True + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True + logging.info("Enabled gradient checkpointing for PI05Pytorch model") + + def gradient_checkpointing_disable(self): + """Disable gradient checkpointing.""" + self.gradient_checkpointing_enabled = False + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False + logging.info("Disabled gradient checkpointing for PI05Pytorch model") + + def _rtc_enabled(self): + return self.config.rtc_config is not None and self.config.rtc_config.enabled + + def _apply_checkpoint(self, func, *args, **kwargs): + """Helper method to apply gradient checkpointing if enabled.""" + if self.gradient_checkpointing_enabled and self.training: + return torch.utils.checkpoint.checkpoint( + func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs + ) + return func(*args, **kwargs) + + def _prepare_attention_masks_4d(self, att_2d_masks): + """Helper method to prepare 4D attention masks for transformer.""" + att_2d_masks_4d = att_2d_masks[:, None, :, :] + return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE) + + def sample_noise(self, shape, device): + return torch.normal( + mean=0.0, + std=1.0, + size=shape, + dtype=torch.float32, + device=device, + ) + + def sample_time(self, bsize, device): + time_beta = sample_beta( + self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device + ) + time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset + return time.to(dtype=torch.float32, device=device) + + def embed_prefix( + self, images, img_masks, tokens, masks + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Embed images with SigLIP and language tokens with embedding layer.""" + embs = [] + pad_masks = [] + att_masks = [] + + # Process images + for img, img_mask in zip(images, img_masks, strict=True): + + def image_embed_func(img): + return self.paligemma_with_expert.embed_image(img) + + img_emb = self._apply_checkpoint(image_embed_func, img) + bsize, num_img_embs = img_emb.shape[:2] + + embs.append(img_emb) + pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) + att_masks += [0] * num_img_embs + + # Process language tokens + def lang_embed_func(tokens): + lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) + lang_emb_dim = lang_emb.shape[-1] + return lang_emb * math.sqrt(lang_emb_dim) + + lang_emb = self._apply_checkpoint(lang_embed_func, tokens) + embs.append(lang_emb) + pad_masks.append(masks) + + num_lang_embs = lang_emb.shape[1] + att_masks += [0] * num_lang_embs + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + + bsize = pad_masks.shape[0] + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks + + def embed_suffix(self, noisy_actions, timestep): + """Embed noisy_actions, timestep to prepare for Expert Gemma processing.""" + embs = [] + pad_masks = [] + att_masks = [] + + # Embed timestep using sine-cosine positional encoding + time_emb = create_sinusoidal_pos_embedding( + timestep, + self.action_in_proj.out_features, + min_period=self.config.min_period, + max_period=self.config.max_period, + device=timestep.device, + ) + time_emb = time_emb.type(dtype=timestep.dtype) + + # Fuse timestep + action information using an MLP + def action_proj_func(noisy_actions): + return self.action_in_proj(noisy_actions) + + action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) + + def time_mlp_func(time_emb): + x = self.time_mlp_in(time_emb) + x = F.silu(x) + x = self.time_mlp_out(x) + return F.silu(x) + + time_emb = self._apply_checkpoint(time_mlp_func, time_emb) + action_time_emb = action_emb + adarms_cond = time_emb + + embs.append(action_time_emb) + bsize, action_time_dim = action_time_emb.shape[:2] + action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device) + pad_masks.append(action_time_mask) + + # Set attention masks so that image, language and state inputs do not attend to action tokens + att_masks += [1] + ([0] * (self.config.chunk_size - 1)) + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks, adarms_cond + + def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor: + """Do a full training forward pass and compute the loss.""" + if noise is None: + noise = self.sample_noise(actions.shape, actions.device) + + if time is None: + time = self.sample_time(actions.shape[0], actions.device) + + time_expanded = time[:, None, None] + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t = noise - actions + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time) + + if ( + self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + == torch.bfloat16 + ): + suffix_embs = suffix_embs.to(dtype=torch.bfloat16) + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + + pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) + att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) + + att_2d_masks = make_att_2d_masks(pad_masks, att_masks) + position_ids = torch.cumsum(pad_masks, dim=1) - 1 + + att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) + + def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): + (_, suffix_out), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_masks_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + return suffix_out + + suffix_out = self._apply_checkpoint( + forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond + ) + + suffix_out = suffix_out[:, -self.config.chunk_size :] + suffix_out = suffix_out.to(dtype=torch.float32) + + def action_out_proj_func(suffix_out): + return self.action_out_proj(suffix_out) + + v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) + + return F.mse_loss(u_t, v_t, reduction="none") + + @torch.no_grad() # see openpi `sample_actions` (slightly adapted) + def sample_actions( + self, + images, + img_masks, + tokens, + masks, + noise=None, + num_steps=None, + **kwargs: Unpack[ActionSelectKwargs], + ) -> Tensor: + """Do a full inference forward and compute the action.""" + if num_steps is None: + num_steps = self.config.num_inference_steps + + bsize = tokens.shape[0] + device = tokens.device + + if noise is None: + # Sample noise with padded dimension as expected by action_in_proj + actions_shape = ( + bsize, + self.config.chunk_size, + self.config.max_action_dim, + ) # Use config max_action_dim for internal processing + noise = self.sample_noise(actions_shape, device) + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) + prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + + prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) + self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001 + + _, past_key_values = self.paligemma_with_expert.forward( + attention_mask=prefix_att_2d_masks_4d, + position_ids=prefix_position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=True, + ) + + dt = -1.0 / num_steps + + x_t = noise + for step in range(num_steps): + time = 1.0 + step * dt + time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) + + def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): + return self.denoise_step( + prefix_pad_masks=prefix_pad_masks, + past_key_values=past_key_values, + x_t=input_x_t, + timestep=current_timestep, + ) + + if self._rtc_enabled(): + inference_delay = kwargs.get("inference_delay") + prev_chunk_left_over = kwargs.get("prev_chunk_left_over") + execution_horizon = kwargs.get("execution_horizon") + + v_t = self.rtc_processor.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev_chunk_left_over, + inference_delay=inference_delay, + time=time, + original_denoise_step_partial=denoise_step_partial_call, + execution_horizon=execution_horizon, + ) + else: + v_t = denoise_step_partial_call(x_t) + + x_t = x_t + dt * v_t + + if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled(): + self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t) + + return x_t + + def denoise_step( + self, + prefix_pad_masks, + past_key_values, + x_t, + timestep, + ): + """Apply one denoising step of the noise `x_t` at a given timestep.""" + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep) + + suffix_len = suffix_pad_masks.shape[1] + batch_size = prefix_pad_masks.shape[0] + prefix_len = prefix_pad_masks.shape[1] + + prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) + suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) + full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) + + prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] + position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + + full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) + self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 + + outputs_embeds, _ = self.paligemma_with_expert.forward( + attention_mask=full_att_2d_masks_4d, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=[None, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + + suffix_out = outputs_embeds[1] + suffix_out = suffix_out[:, -self.config.chunk_size :] + suffix_out = suffix_out.to(dtype=torch.float32) + return self.action_out_proj(suffix_out) + + +class PI05Policy(PreTrainedPolicy): + """PI05 Policy for LeRobot.""" + + config_class = PI05Config + name = "pi05" + + def __init__( + self, + config: PI05Config, + **kwargs, + ): + """ + Args: + config: Policy configuration class instance. + """ + super().__init__(config) + config.validate_features() + self.config = config + + # Initialize the core PI05 model + self.init_rtc_processor() + self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor) + + # Enable gradient checkpointing if requested + if config.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + + self.model.to(config.device) + + self.reset() + + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + config: PreTrainedConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + strict: bool = True, + **kwargs, + ) -> T: + """Override the from_pretrained method to handle key remapping and display important disclaimer.""" + print( + "The PI05 model is a direct port of the OpenPI implementation. \n" + "This implementation follows the original OpenPI structure for compatibility. \n" + "Original implementation: https://github.com/Physical-Intelligence/openpi" + ) + if pretrained_name_or_path is None: + raise ValueError("pretrained_name_or_path is required") + + # Use provided config if available, otherwise create default config + if config is None: + config = PreTrainedConfig.from_pretrained( + pretrained_name_or_path=pretrained_name_or_path, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + **kwargs, + ) + + # Initialize model without loading weights + # Check if dataset_stats were provided in kwargs + model = cls(config, **kwargs) + + # Now manually load and remap the state dict + try: + # Try to load the pytorch_model.bin or model.safetensors file + print(f"Loading model from: {pretrained_name_or_path}") + try: + from transformers.utils import cached_file + + # Try safetensors first + resolved_file = cached_file( + pretrained_name_or_path, + "model.safetensors", + cache_dir=kwargs.get("cache_dir"), + force_download=kwargs.get("force_download", False), + resume_download=kwargs.get("resume_download"), + proxies=kwargs.get("proxies"), + use_auth_token=kwargs.get("use_auth_token"), + revision=kwargs.get("revision"), + local_files_only=kwargs.get("local_files_only", False), + ) + from safetensors.torch import load_file + + original_state_dict = load_file(resolved_file) + print("✓ Loaded state dict from model.safetensors") + except Exception as e: + print(f"Could not load state dict from remote files: {e}") + print("Returning model without loading pretrained weights") + return model + + # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` + fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) + + # Then add "model." prefix for all keys that don't already have it + remapped_state_dict = {} + remap_count = 0 + + for key, value in fixed_state_dict.items(): + if not key.startswith("model."): + new_key = f"model.{key}" + remapped_state_dict[new_key] = value + remap_count += 1 + if remap_count <= 10: # Only print first 10 to avoid spam + print(f"Remapped: {key} -> {new_key}") + else: + remapped_state_dict[key] = value + + if remap_count > 0: + print(f"Remapped {remap_count} state dict keys") + + # Load the remapped state dict into the model + missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) + + if missing_keys: + print(f"Missing keys when loading state dict: {len(missing_keys)} keys") + if len(missing_keys) <= 5: + for key in missing_keys: + print(f" - {key}") + else: + for key in missing_keys[:5]: + print(f" - {key}") + print(f" ... and {len(missing_keys) - 5} more") + + if unexpected_keys: + print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys") + if len(unexpected_keys) <= 5: + for key in unexpected_keys: + print(f" - {key}") + else: + for key in unexpected_keys[:5]: + print(f" - {key}") + print(f" ... and {len(unexpected_keys) - 5} more") + + if not missing_keys and not unexpected_keys: + print("All keys loaded successfully!") + + except Exception as e: + print(f"Warning: Could not remap state dict keys: {e}") + + return model + + def _fix_pytorch_state_dict_keys( + self, state_dict, model_config + ): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys` + """Fix state dict keys to match current model architecture.""" + import re + + fixed_state_dict = {} + + for key, value in state_dict.items(): + new_key = key + + # Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias + # For gemma expert layers + if re.match( + r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight", + key, + ): + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}") + continue + + if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key): + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping norm key (adaRMS mismatch): {key}") + continue + + # Handle MLP naming changes for pi05 + # pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_* + if key.startswith("action_time_mlp_in."): + new_key = key.replace("action_time_mlp_in.", "time_mlp_in.") + elif key.startswith("action_time_mlp_out."): + new_key = key.replace("action_time_mlp_out.", "time_mlp_out.") + # Also handle state_proj which shouldn't exist in pi05 + if key.startswith("state_proj."): + logging.warning(f"Skipping state_proj key in pi05 mode: {key}") + continue + + # Handle vision tower embedding layer potential differences + if "patch_embedding" in key: + # Some checkpoints might have this, but current model expects different structure + logging.warning(f"Vision embedding key might need handling: {key}") + + fixed_state_dict[new_key] = value + + return fixed_state_dict + + def get_optim_params(self) -> dict: + return self.parameters() + + def reset(self): + """Reset internal state - called when environment resets.""" + self._action_queue = deque(maxlen=self.config.n_action_steps) + self._queues = { + ACTION: deque(maxlen=self.config.n_action_steps), + } + + def init_rtc_processor(self): + """Initialize RTC processor if RTC is enabled in config.""" + self.rtc_processor = None + + # Create processor if config provided + # If RTC is not enabled - we can still track the denoising data + if self.config.rtc_config is not None: + self.rtc_processor = RTCProcessor(self.config.rtc_config) + + model_value = getattr(self, "model", None) + if model_value is not None: + model_value.rtc_processor = self.rtc_processor + + def _rtc_enabled(self) -> bool: + return self.config.rtc_config is not None and self.config.rtc_config.enabled + + def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: + """Preprocess images for the model. + + Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1]. + PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1]. + """ + images = [] + img_masks = [] + + # Get device from model parameters + device = next(self.parameters()).device + + present_img_keys = [key for key in self.config.image_features if key in batch] + missing_img_keys = [key for key in self.config.image_features if key not in batch] + + if len(present_img_keys) == 0: + raise ValueError( + f"All image features are missing from the batch. At least one expected. " + f"(batch: {batch.keys()}) (image_features: {self.config.image_features})" + ) + + # Preprocess image features present in the batch + for key in present_img_keys: + img = batch[key] + + # Ensure tensor is on the same device as the model + if img.device != device: + img = img.to(device) + + # Ensure float32 dtype for consistency + if img.dtype != torch.float32: + img = img.to(torch.float32) + + # from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats + is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1 + + if is_channels_first: + # Convert [B, C, H, W] to [B, H, W, C] for processing + img = img.permute(0, 2, 3, 1) + + # from openpi preprocess_observation_pytorch: Resize with padding if needed + if img.shape[1:3] != self.config.image_resolution: + img = resize_with_pad_torch(img, *self.config.image_resolution) + + # Normalize from [0,1] to [-1,1] as expected by siglip + img = img * 2.0 - 1.0 + + # from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first + if is_channels_first: + img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + images.append(img) + # Create mask (all ones for real images) + bsize = img.shape[0] + mask = torch.ones(bsize, dtype=torch.bool, device=device) + img_masks.append(mask) + + # Create image features not present in the batch as fully 0 padded images + for _num_empty_cameras in range(len(missing_img_keys)): + img = torch.ones_like(img) * -1 # Padded with -1 for SigLIP + mask = torch.zeros_like(mask) # Mask is zero for empty cameras + images.append(img) + img_masks.append(mask) + + return images, img_masks + + def prepare_action(self, batch): + """Pad action""" + actions = pad_vector(batch[ACTION], self.config.max_action_dim) + return actions + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations.""" + assert not self._rtc_enabled(), ( + "RTC is not supported for select_action, use it with predict_action_chunk" + ) + + self.eval() + + # Action queue logic for n_action_steps > 1 + if len(self._action_queue) == 0: + actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] + # Transpose to get shape (n_action_steps, batch_size, action_dim) + self._action_queue.extend(actions.transpose(0, 1)) + + return self._action_queue.popleft() + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + self.eval() + + # Prepare inputs + images, img_masks = self._preprocess_images(batch) + tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + + # Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05) + actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs) + + # Unpad actions to actual action dimension + original_action_dim = self.config.output_features[ACTION].shape[0] + actions = actions[:, :, :original_action_dim] + + return actions + + def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]: + """Run the batch through the model and compute the loss for training. + + Args: + batch: Training batch containing observations and actions. + reduction: How to reduce the loss. Options: + - "mean": Return scalar mean loss (default, backward compatible) + - "none": Return per-sample losses of shape (batch_size,) for RA-BC weighting + """ + # Prepare inputs + images, img_masks = self._preprocess_images(batch) + tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + + actions = self.prepare_action(batch) + + # Compute loss (no separate state needed for PI05) + losses = self.model.forward(images, img_masks, tokens, masks, actions) + + # Truncate losses to actual action dimensions + original_action_dim = self.config.output_features[ACTION].shape[0] + losses = losses[:, :, :original_action_dim] + + loss_dict = { + "loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(), + } + + if reduction == "none": + # Return per-sample losses (B,) by averaging over time and action dims + per_sample_loss = losses.mean(dim=(1, 2)) + loss_dict["loss"] = per_sample_loss.mean().item() + return per_sample_loss, loss_dict + else: + # Default: return scalar mean loss + loss = losses.mean() + loss_dict["loss"] = loss.item() + return loss, loss_dict + + def _get_default_peft_targets(self) -> dict[str, any]: + """Return default PEFT target modules for PI0.5 fine-tuning.""" + common_projections = ( + "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out" + ) + target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))" + return { + "target_modules": target_modules, + "modules_to_save": [], + } diff --git a/src/lerobot/policies/videovla/processor_pi05.py b/src/lerobot/policies/videovla/processor_pi05.py new file mode 100644 index 000000000..e29bc4c23 --- /dev/null +++ b/src/lerobot/policies/videovla/processor_pi05.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.policies.pi05.configuration_pi05 import PI05Config +from lerobot.policies.pi05.modeling_pi05 import pad_vector +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStep, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + TokenizerProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.utils.constants import ( + OBS_STATE, + POLICY_POSTPROCESSOR_DEFAULT_NAME, + POLICY_PREPROCESSOR_DEFAULT_NAME, +) + + +@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step") +@dataclass +class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep): + """ + Processor step to prepare the state and tokenize the language input. + """ + + max_state_dim: int = 32 + task_key: str = "task" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + transition = transition.copy() + + state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE) + if state is None: + raise ValueError("State is required for PI05") + tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key) + if tasks is None: + raise ValueError("No task found in complementary data") + + # TODO: check if this necessary + state = deepcopy(state) + + # Prepare state (pad to max_state_dim) + state = pad_vector(state, self.max_state_dim) + + # State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step + # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) + state_np = state.cpu().numpy() + discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 + + full_prompts = [] + for i, task in enumerate(tasks): + cleaned_text = task.strip().replace("_", " ").replace("\n", " ") + state_str = " ".join(map(str, discretized_states[i])) + full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " + full_prompts.append(full_prompt) + + transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts + # Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!) + # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + This step does not alter the feature definitions. + """ + return features + + +def make_pi05_pre_post_processors( + config: PI05Config, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the PI0 policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Appending a newline character to the task description for tokenizer compatibility. + 5. Tokenizing the text prompt using the PaliGemma tokenizer. + 6. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the PI0 policy. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + # Add remaining processors + input_steps: list[ProcessorStep] = [ + RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one + AddBatchDimensionProcessorStep(), + # NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep + # because the tokenizer step expects normalized state in [-1, 1] range for discretization + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim), + TokenizerProcessorStep( + tokenizer_name="google/paligemma-3b-pt-224", + max_length=config.tokenizer_max_length, + padding_side="right", + padding="max_length", + ), + DeviceProcessorStep(device=config.device), + ] + + output_steps: list[ProcessorStep] = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/videovla/videoprism/__init__.py b/src/lerobot/policies/videovla/videoprism/__init__.py new file mode 100644 index 000000000..2786877fd --- /dev/null +++ b/src/lerobot/policies/videovla/videoprism/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_videoprism import VideoPrismConfig, VideoPrismTextConfig, VideoPrismVisionConfig +from .modeling_videoprism import ( + VideoPrismClipModel, + VideoPrismForVideoClassification, + VideoPrismPreTrainedModel, + VideoPrismTextModel, + VideoPrismVideoModel, + VideoPrismVisionModel, +) +from .video_processing_videoprism import VideoPrismVideoProcessor + +__all__ = [ + "VideoPrismConfig", + "VideoPrismTextConfig", + "VideoPrismVisionConfig", + "VideoPrismClipModel", + "VideoPrismForVideoClassification", + "VideoPrismPreTrainedModel", + "VideoPrismTextModel", + "VideoPrismVideoModel", + "VideoPrismVisionModel", + "VideoPrismVideoProcessor", +] diff --git a/src/lerobot/policies/videovla/videoprism/configuration_videoprism.py b/src/lerobot/policies/videovla/videoprism/configuration_videoprism.py new file mode 100644 index 000000000..896f4a544 --- /dev/null +++ b/src/lerobot/policies/videovla/videoprism/configuration_videoprism.py @@ -0,0 +1,269 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/videoprism/modular_videoprism.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_videoprism.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from transformers import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class VideoPrismVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`VideoPrismVisionModel`]. It is used to instantiate a + VideoPrism vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the VideoPrism + [google/videoprism](https://huggingface.co/google/videoprism) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + image_size (`int`, *optional*, defaults to 288): + The size of the input image. + num_frames (`int`, *optional*, defaults to 16): + The number of frames in the input video. + tubelet_size (`List[int]`, *optional*, defaults to `[1, 18, 18]`): + The size of the tubelet patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_spatial_layers (`int`, *optional*, defaults to 12): + Number of spatial transformer blocks. + num_temporal_layers (`int`, *optional*, defaults to 4): + Number of temporal transformer blocks. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_python"`): + The non-linear activation function (function or string). + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the qkv projections in attention layers. + attn_logit_softcapping (`float`, *optional*, defaults to 50.0): + Softcapping constant for attention logits. + num_auxiliary_layers (`int`, *optional*, defaults to 2): + Number of auxiliary layers. This is used in the VideoPrismVideoModel that is a part of VideoPrismClipModel. + apply_l2_norm (`bool`, *optional*, defaults to `True`): + Whether to apply L2 normalization to the output. This is used in the VideoPrismVideoModel that is a part of VideoPrismClipModel. + + Example: + + ```python + >>> from transformers import VideoPrismVisionConfig, VideoPrismVisionModel + + >>> # Initializing a VideoPrismVisionConfig with default values + >>> configuration = VideoPrismVisionConfig() + + >>> # Initializing a VideoPrismVisionModel with the configuration + >>> model = VideoPrismVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "videoprism_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + image_size=288, + num_frames=16, + tubelet_size=[1, 18, 18], + num_channels=3, + hidden_size=768, + num_spatial_layers=12, + num_temporal_layers=4, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu_python", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-06, + qkv_bias=True, + attn_logit_softcapping=50.0, + num_auxiliary_layers=2, + apply_l2_norm=True, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + + self.image_size = image_size + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.num_spatial_layers = num_spatial_layers + self.num_temporal_layers = num_temporal_layers + self.attn_logit_softcapping = attn_logit_softcapping + self.num_auxiliary_layers = num_auxiliary_layers + self.apply_l2_norm = apply_l2_norm + + +class VideoPrismTextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`VideoPrismTextModel`]. It is used to instantiate a + VideoPrism text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the VideoPrism + [google/videoprism](https://huggingface.co/google/videoprism) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_text_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the text Transformer encoder. + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the text model. Defines the number of different tokens that can be represented by the + `input_ids` passed when calling [`VideoPrismTextModel`]. + apply_l2_norm (`bool`, *optional*, defaults to `True`): + Whether to apply L2 normalization to the output text embeddings. + hidden_act (`str` or `function`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the encoder and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the query, key, and value projections in the attention layers. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + attn_logit_softcapping (`float`, *optional*, defaults to 50.0): + Softcapping constant for attention logits. + + Example: + + ```python + >>> from transformers import VideoPrismTextConfig, VideoPrismTextModel + + >>> # Initializing a VideoPrismTextConfig with default values + >>> configuration = VideoPrismTextConfig() + + >>> # Initializing a VideoPrismTextModel (with random weights) from the configuration + >>> model = VideoPrismTextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "videoprism_text_model" + base_config_key = "text_config" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_attention_heads=12, + num_text_layers=12, + vocab_size=32000, + apply_l2_norm=True, + hidden_act="relu", + attention_probs_dropout_prob=0.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + layer_norm_eps=1e-06, + initializer_range=0.02, + attn_logit_softcapping=50.0, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_text_layers = num_text_layers + self.vocab_size = vocab_size + self.apply_l2_norm = apply_l2_norm + self.hidden_act = hidden_act + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.attn_logit_softcapping = attn_logit_softcapping + + +class VideoPrismConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`VideoPrismModel`]. It is used to instantiate a + VideoPrism model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the VideoPrism + [google/videoprism](https://huggingface.co/google/videoprism) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`VideoPrismTextConfig`, *optional*): + Configuration for the text model. + vision_config (`VideoPrismVisionConfig`, *optional*): + Configuration for the vision model. + kwargs (*optional*): + Dictionary of keyword arguments. + + Example: + + ```python + >>> from transformers import VideoPrismConfig, VideoPrismModel + + >>> # Initializing a VideoPrismConfig with default values + >>> configuration = VideoPrismConfig() + + >>> # Initializing a VideoPrismClipModel with the configuration + >>> model = VideoPrismClipModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "videoprism" + sub_configs = {"text_config": VideoPrismTextConfig, "vision_config": VideoPrismVisionConfig} + + def __init__(self, text_config=None, vision_config=None, **kwargs): + if text_config is None: + text_config = VideoPrismTextConfig() + logger.info("`text_config` is `None`. Initializing the `VideoPrismTextConfig` with default values.") + elif isinstance(text_config, dict): + text_config = VideoPrismTextConfig(**text_config) + + if vision_config is None: + vision_config = VideoPrismVisionConfig() + logger.info("`vision_config` is `None`. initializing the `VideoPrismVisionConfig` with default values.") + elif isinstance(vision_config, dict): + vision_config = VideoPrismVisionConfig(**vision_config) + + self.text_config = text_config + self.vision_config = vision_config + + super().__init__(**kwargs) + + +__all__ = ["VideoPrismVisionConfig", "VideoPrismTextConfig", "VideoPrismConfig"] diff --git a/src/lerobot/policies/videovla/videoprism/initialization.py b/src/lerobot/policies/videovla/videoprism/initialization.py new file mode 100644 index 000000000..9905b4b49 --- /dev/null +++ b/src/lerobot/policies/videovla/videoprism/initialization.py @@ -0,0 +1,245 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from collections import defaultdict +from contextlib import contextmanager + +import torch + + +# Record all the torch primitives in advance, so that we can use them without them being modified when we patch torch +# in context managers +TORCH_INIT_FUNCTIONS = { + "uniform_": torch.nn.init.uniform_, + "normal_": torch.nn.init.normal_, + "constant_": torch.nn.init.constant_, + "ones_": torch.nn.init.ones_, + "zeros_": torch.nn.init.zeros_, + "eye_": torch.nn.init.eye_, + "dirac_": torch.nn.init.dirac_, + "xavier_uniform_": torch.nn.init.xavier_uniform_, + "xavier_normal_": torch.nn.init.xavier_normal_, + "kaiming_uniform_": torch.nn.init.kaiming_uniform_, + "kaiming_normal_": torch.nn.init.kaiming_normal_, + "trunc_normal_": torch.nn.init.trunc_normal_, + "orthogonal_": torch.nn.init.orthogonal_, + "sparse_": torch.nn.init.sparse_, +} + + +def uniform_( + tensor: torch.Tensor, a: float = 0.0, b: float = 1.0, generator: torch.Generator | None = None +) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["uniform_"](tensor, a=a, b=b, generator=generator) + return tensor + + +def normal_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, generator: torch.Generator | None = None +) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["normal_"](tensor, mean=mean, std=std, generator=generator) + return tensor + + +def constant_(tensor: torch.Tensor, val: float) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["constant_"](tensor, val=val) + return tensor + + +def ones_(tensor: torch.Tensor) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["ones_"](tensor) + return tensor + + +def zeros_(tensor: torch.Tensor) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["zeros_"](tensor) + return tensor + + +def eye_(tensor: torch.Tensor) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["eye_"](tensor) + return tensor + + +def dirac_(tensor: torch.Tensor, groups: int = 1) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["dirac_"](tensor, groups=groups) + return tensor + + +def xavier_uniform_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["xavier_uniform_"](tensor, gain=gain, generator=generator) + return tensor + + +def xavier_normal_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["xavier_normal_"](tensor, gain=gain, generator=generator) + return tensor + + +def kaiming_uniform_( + tensor: torch.Tensor, + a: float = 0, + mode: str = "fan_in", + nonlinearity: str = "leaky_relu", + generator: torch.Generator | None = None, +) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["kaiming_uniform_"]( + tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator + ) + return tensor + + +def kaiming_normal_( + tensor: torch.Tensor, + a: float = 0, + mode: str = "fan_in", + nonlinearity: str = "leaky_relu", + generator: torch.Generator | None = None, +) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["kaiming_normal_"]( + tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator + ) + return tensor + + +def trunc_normal_( + tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, + generator: torch.Generator | None = None, +) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["trunc_normal_"](tensor, mean=mean, std=std, a=a, b=b, generator=generator) + return tensor + + +def orthogonal_( + tensor: torch.Tensor, + gain: float = 1, + generator: torch.Generator | None = None, +) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["orthogonal_"](tensor, gain=gain, generator=generator) + return tensor + + +def sparse_( + tensor: torch.Tensor, sparsity: float, std: float = 0.01, generator: torch.Generator | None = None +) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + return TORCH_INIT_FUNCTIONS["sparse_"](tensor, sparsity=sparsity, std=std, generator=generator) + return tensor + + +def copy_(tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + if not getattr(tensor, "_is_hf_initialized", False): + with torch.no_grad(): + return tensor.copy_(other) + return tensor + + +# Here, we need to check several modules imported, and hot patch all of them, as sometimes torch does +# something like `from torch.nn.init import xavier_uniform_` in their internals (e.g in torch.nn.modules.activations, +# where MultiHeadAttention lives), so the function name is binded at import time and just doing +# `setattr(torch.nn.init, name, globals()[name])` is thus not enough +# The following list should be enough for all torch versions we work with +TORCH_MODULES_TO_PATCH = ( + "torch.nn.init", + "torch.nn.modules.activation", + "torch.nn.modules.transformer", + "torch.nn.modules.linear", + "torch.nn.modules.loss", + "torch.nn.modules.batchnorm", + "torch.nn.modules.conv", + "torch.nn.modules.normalization", + "torch.nn.modules.rnn", + "torch.nn.modules.sparse", +) + + +@contextmanager +def guard_torch_init_functions(): + """ + Guard the `torch.nn.init` primitive functions to behave exactly like the functions in this file, i.e. be + protected against the `_is_hf_initialized` flag to avoid re-init if the param was already loaded. + + Usually, all models are using the init from `transformers` which are already guarded, but just to make extra sure + and for remote code, we also use this context manager. + """ + originals = defaultdict(dict) + try: + # Replace all torch funcs by the ones in this file + for module_name in TORCH_MODULES_TO_PATCH: + if module_name in sys.modules: + module = sys.modules[module_name] + for func_name in TORCH_INIT_FUNCTIONS.keys(): + if hasattr(module, func_name): + originals[module][func_name] = getattr(module, func_name) + setattr(module, func_name, globals()[func_name]) + yield + finally: + # Set back the original functions on all modules + for module, functions in originals.items(): + for func_name, func in functions.items(): + setattr(module, func_name, func) + + +@contextmanager +def no_init_weights(): + """ + Disable weight initialization both at the torch-level, and at the transformers-level (`init_weights`). + This is used to speed-up initializing an empty model with deepspeed, as we do not initialize the model on meta device + with deepspeed, but we still don't need to run expensive weight initializations as we are loading params afterwards. + """ + from .modeling_utils import PreTrainedModel + + def empty_func(*args, **kwargs): + pass + + originals = defaultdict(dict) + try: + # Replace all torch funcs by empty ones + for module_name in TORCH_MODULES_TO_PATCH: + if module_name in sys.modules: + module = sys.modules[module_name] + for func_name in TORCH_INIT_FUNCTIONS.keys(): + if hasattr(module, func_name): + originals[module][func_name] = getattr(module, func_name) + setattr(module, func_name, empty_func) + + # Also patch our own `init_weights` + original_init_weights = PreTrainedModel.init_weights + PreTrainedModel.init_weights = empty_func + + yield + finally: + # Set back the original torch functions on all modules + for module, functions in originals.items(): + for func_name, func in functions.items(): + setattr(module, func_name, func) + # Set back `init_weights` + PreTrainedModel.init_weights = original_init_weights diff --git a/src/lerobot/policies/videovla/videoprism/modeling_videoprism.py b/src/lerobot/policies/videovla/videoprism/modeling_videoprism.py new file mode 100644 index 000000000..f3f74fdaa --- /dev/null +++ b/src/lerobot/policies/videovla/videoprism/modeling_videoprism.py @@ -0,0 +1,994 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/videoprism/modular_videoprism.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_videoprism.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import math +from collections.abc import Callable +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import _calculate_fan_in_and_fan_out + +from . import initialization as init +from transformers.activations import ACT2FN +from transformers.masking_utils import create_causal_mask +from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.modeling_outputs import BaseModelOutput, ImageClassifierOutput +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.file_utils import ModelOutput + +from .configuration_videoprism import VideoPrismConfig, VideoPrismTextConfig, VideoPrismVisionConfig + +def torch_int(x): + """ + Casts an input to a torch int64 tensor if we are in a tracing context, otherwise to a Python int. + """ + if not torch.is_available(): + return int(x) + + return x.to(torch.int64) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x) + +@dataclass +class BaseModelOutputWithSpatialAndTemporalStates(ModelOutput): + """ + Base class for model outputs that include spatial and temporal states. + + Args: + last_hidden_state (Optional[torch.FloatTensor]): + The last hidden state of the model, typically of shape + (batch_size, num_patches * num_frames, hidden_size). + + temporal_hidden_state (Optional[torch.FloatTensor]): + The last hidden_state of the temporal encoder, typically of shape + (batch_size * num_patches, num_frames, hidden_size). + + spatial_hidden_state (Optional[torch.FloatTensor]): + The last hidden_state of the spatial encoder, typically of shape + (batch_size * num_frames, num_patches, hidden_size). + """ + + last_hidden_state: torch.FloatTensor | None = None + temporal_hidden_state: torch.FloatTensor | None = None + spatial_hidden_state: torch.FloatTensor | None = None + + +@dataclass +class VideoPrismClipOutput(ModelOutput): + """ + Base class for VideoPrismClip model outputs. + """ + + logits_per_video: torch.FloatTensor | None = None + logits_per_text: torch.FloatTensor | None = None + video_embeds: torch.FloatTensor | None = None + text_embeds: torch.FloatTensor | None = None + + +@dataclass +class VideoPrismVideoOutput(ModelOutput): + """ + Base class for VideoPrismVideo model outputs. + """ + + video_last_hidden_state: torch.FloatTensor | None = None + auxiliary_output: torch.FloatTensor | None = None + attention_pooling_output: torch.FloatTensor | None = None + + +class VideoPrismTubeletEmbeddings(nn.Module): + """ + Construct VideoPrism Tubelet embeddings. + + This module turns a batch of videos of shape (batch_size, num_frames, num_channels, height, width) into a tensor of + shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder. + + The seq_len (the number of patches) equals (number of frames // tubelet_size[0]) * (height // tubelet_size[1]) * + (width // tubelet_size[2]). + """ + + def __init__(self, config: VideoPrismVisionConfig): + super().__init__() + self.config = config + self.num_frames = config.num_frames + self.image_size = ( + config.image_size + if isinstance(self.config.image_size, tuple) + else (self.config.image_size, self.config.image_size) + ) + self.patch_size = config.tubelet_size + self.embed_dim = config.hidden_size + + self.projection = nn.Conv3d( + config.num_channels, config.hidden_size, kernel_size=config.tubelet_size, stride=config.tubelet_size + ) + self.pos_emb_shape = [self.image_size[0] // self.patch_size[1], self.image_size[1] // self.patch_size[2]] + self.num_patches = self.pos_emb_shape[0] * self.pos_emb_shape[1] + + def forward(self, pixel_values_videos: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, num_frames, num_channels, height, width = pixel_values_videos.shape + if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): + raise ValueError( + f"Image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]}). Set interpolate_pos_encoding=True to automatically resize the model position embeddings." + ) + # permute to (batch_size, num_channels, num_frames, height, width) + pixel_values_videos = pixel_values_videos.permute(0, 2, 1, 3, 4) + + hidden_states = self.projection(pixel_values_videos) + # flatten the spatial part and permute to (B, T, num_patches, dim) + hidden_states = hidden_states.flatten(3).permute(0, 2, 3, 1) + # combine batch and time dimension + batch_size, num_frames, num_patches, hidden_size = hidden_states.shape + hidden_states = hidden_states.reshape(batch_size * num_frames, num_patches, hidden_size) + + return hidden_states + + +class VideoPrismSpatialEmbeddings(nn.Module): + """ + VideoPrism Spatial Embeddings. + + Creates embeddings from a video using VideoPrismSpatialTubeletEmbeddings and adds positional embeddings. + """ + + def __init__(self, config: VideoPrismVisionConfig): + super().__init__() + self.config = config + self.patch_embeddings = VideoPrismTubeletEmbeddings(config) + self.position_embeddings = nn.Parameter(torch.zeros(1, self.patch_embeddings.num_patches, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.tubelet_size[1:] + self.tubelet_size = config.tubelet_size + + # Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] + num_positions = self.position_embeddings.shape[1] + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + dim = embeddings.shape[-1] + + num_row_patches = height // self.patch_size[0] + num_col_patches = width // self.patch_size[1] + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = self.position_embeddings.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(num_row_patches, num_col_patches), + mode="bilinear", + antialias=True, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward( + self, pixel_values_videos: torch.Tensor, interpolate_pos_encoding: bool | None = False + ) -> torch.Tensor: + b, t, c, h, w = pixel_values_videos.shape + assert h == w, "Input image height and width must be the same" + embeddings = self.patch_embeddings(pixel_values_videos, interpolate_pos_encoding) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, h, w) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +class VideoPrismTemporalEmbeddings(nn.Module): + """ + VideoPrism Temporal Embeddings. + + Receives embeddings from spatial encoder, reshapes the hidden state to + (batch_size * num_patches, num_frames, hidden_size) and adds positional embeddings. + """ + + def __init__(self, config: VideoPrismVisionConfig): + super().__init__() + self.config = config + + self.position_embeddings = nn.Parameter(torch.zeros(1, self.config.num_frames, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + target_emb_length = embeddings.shape[1] + source_emb_length = self.position_embeddings.shape[1] + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and target_emb_length == source_emb_length: + return self.position_embeddings + + source_emb = self.position_embeddings + dim = embeddings.shape[-1] + source_emb = source_emb.unsqueeze(1) + source_emb = nn.functional.interpolate( + source_emb, + size=(target_emb_length, dim), + mode="bilinear", + antialias=True, + ) + + return source_emb.squeeze(1) + + def forward( + self, + pixel_values_videos: torch.Tensor, + input_shape: torch.Size, + interpolate_pos_encoding: bool | None = False, + ) -> torch.Tensor: + if input_shape is not None: + b, t, c, h, w = input_shape + _, features, dim = pixel_values_videos.shape + hidden_states = pixel_values_videos.view(b, t, features, dim) + hidden_states = hidden_states.permute(0, 2, 1, 3) + embeddings = hidden_states.reshape(b * features, t, dim) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings) + else: + embeddings = embeddings + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + softcap: float | None = None, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + if softcap is not None: + attn_weights = attn_weights / softcap + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * softcap + if attention_mask is not None: + attn_weights = attn_weights + attention_mask.expand(*attn_weights.shape) + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class VideoPrismSelfAttention(nn.Module): + def __init__(self, config: VideoPrismVisionConfig | VideoPrismTextConfig): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.config = config + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scale = self.attention_head_size**-0.5 + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + batch_size = hidden_states.shape[0] + new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size + query = self.query(hidden_states).view(*new_shape).transpose(1, 2) + key = self.key(hidden_states).view(*new_shape).transpose(1, 2) + value = self.value(hidden_states).view(*new_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + context_layer, attention_probs = attention_interface( + self, + query, + key, + value, + attention_mask, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout_prob, + softcap=self.config.attn_logit_softcapping, + **kwargs, + ) + + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.reshape(new_context_layer_shape) + + return (context_layer, attention_probs) + + +class VideoPrismSelfOutput(nn.Module): + """ + The residual connection is defined in VideoPrismLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: VideoPrismConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class VideoPrismAttention(nn.Module): + def __init__(self, config: VideoPrismConfig): + super().__init__() + self.attention = VideoPrismSelfAttention(config) + self.output = VideoPrismSelfOutput(config) + + def forward( + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, **kwargs + ) -> torch.Tensor: + self_attn_output, _ = self.attention(hidden_states, attention_mask, **kwargs) + output = self.output(self_attn_output, hidden_states) + return output + + +class VideoPrismLayerNorm(nn.LayerNorm): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return F.layer_norm(hidden_states, self.normalized_shape, self.weight + 1, self.bias, self.eps) + + +class VideoPrismIntermediate(nn.Module): + def __init__(self, config: VideoPrismConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class VideoPrismOutput(nn.Module): + def __init__(self, config: VideoPrismConfig): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + input_tensor + return hidden_states + + +class VideoPrismLayer(GradientCheckpointingLayer): + """This corresponds to the EncoderBlock class in the scenic/videoprism implementation.""" + + def __init__(self, config: VideoPrismVisionConfig | VideoPrismTextConfig): + super().__init__() + self.config = config + self.attention = VideoPrismAttention(config) + self.intermediate = VideoPrismIntermediate(config) + self.output = VideoPrismOutput(config) + self.layernorm_before = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps) + self.layernorm_after = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + hidden_states_norm = self.layernorm_before(hidden_states) + attention_output = self.attention(hidden_states_norm, attention_mask, **kwargs) + + # first residual connection + hidden_states = attention_output + hidden_states + + # in VideoPrism, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + return layer_output + + +class VideoPrismSpatialEncoder(nn.Module): + def __init__(self, config: VideoPrismVisionConfig): + super().__init__() + self.config = config + self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_spatial_layers)]) + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> BaseModelOutput: + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states) + + return BaseModelOutput(last_hidden_state=hidden_states) + + +class VideoPrismTemporalEncoder(nn.Module): + def __init__(self, config: VideoPrismVisionConfig): + super().__init__() + self.config = config + self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_temporal_layers)]) + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> BaseModelOutput: + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states) + + return BaseModelOutput(last_hidden_state=hidden_states) + + +class VideoPrismAuxiliaryEncoder(nn.Module): + def __init__(self, config: VideoPrismVisionConfig): + super().__init__() + self.config = config + self.layer = nn.ModuleList([VideoPrismLayer(self.config) for _ in range(config.num_auxiliary_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs, + ) -> BaseModelOutput: + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states, attention_mask, **kwargs) + + return BaseModelOutput(last_hidden_state=hidden_states) + + +class VideoPrismTextEncoder(nn.Module): + def __init__(self, config: VideoPrismTextConfig): + super().__init__() + self.config = config + self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_text_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs, + ) -> BaseModelOutput: + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states, attention_mask, **kwargs) + + return BaseModelOutput(last_hidden_state=hidden_states) + + +def variance_scaling_(tensor, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = 1.0 / denom + + if distribution == "truncated_normal": + init.trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + init.normal_(tensor, std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + init.uniform_(tensor, -bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +class VideoPrismPreTrainedModel(PreTrainedModel): + config_class = VideoPrismConfig + config: VideoPrismConfig + base_model_prefix = "videoprism" + main_input_name = "pixel_values_videos" + input_modalities = ("video", "text") + supports_gradient_checkpointing = True + _no_split_modules = [ + "VideoPrismSpatialEmbeddings", + "VideoPrismTemporalEmbeddings", + "VideoPrismSpatialEncoder", + "VideoPrismTemporalEncoder", + "VideoPrismAuxiliaryEncoder", + "VideoPrismTextEncoder", + "VideoPrismMultiheadAttentionPoolingHead", + ] + _supports_sdpa = True + _supports_flash_attn = True + _supports_attention_backend = True + _supports_flex_attention = True + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Conv3d)): + lecun_normal_(module.weight) + init.zeros_(module.bias) + + elif isinstance(module, nn.LayerNorm): + init.zeros_(module.bias) + init.ones_(module.weight) + + +class VideoPrismVisionModel(VideoPrismPreTrainedModel): + config_class = VideoPrismVisionConfig + config: VideoPrismVisionConfig + + def __init__(self, config: VideoPrismVisionConfig): + super().__init__(config) + self.config = config + self.layernorm1 = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps) + self.layernorm2 = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps) + self.spatial_embeddings = VideoPrismSpatialEmbeddings(self.config) + self.temporal_embeddings = VideoPrismTemporalEmbeddings(self.config) + self.spatial_encoder = VideoPrismSpatialEncoder(self.config) + self.temporal_encoder = VideoPrismTemporalEncoder(self.config) + self.post_init() + + def get_input_embeddings(self): + return self.spatial_embeddings.patch_embeddings + + def forward( + self, + pixel_values_videos: torch.FloatTensor | None = None, + interpolate_pos_encoding: bool | None = False, + **kwargs, + ) -> BaseModelOutputWithSpatialAndTemporalStates: + r""" + Args: + pixel_values_videos (`torch.FloatTensor`): + Pixel values of the video frames of shape (batch_size, num_frames, num_channels, height, width). + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate positional encodings to match input size. + + Example: + + ```python + >>> from transformers import VideoPrismVideoProcessor, VideoPrismVisionModel + >>> import torch + + >>> processor = VideoPrismVideoProcessor.from_pretrained("google/videoprism") + >>> model = VideoPrismVisionModel.from_pretrained("google/videoprism") + + >>> video = "sample_video.mp4" + >>> inputs = processor(videos=video) + >>> with torch.no_grad(): + ... outputs = model(**inputs) + ... features = outputs.last_hidden_state + ``` + """ + if pixel_values_videos is None: + raise ValueError("You have to specify pixel_values_videos") + + input_shape = pixel_values_videos.shape + spatial_embeds = self.spatial_embeddings(pixel_values_videos, interpolate_pos_encoding) + spatial_encoder_outputs: BaseModelOutput = self.spatial_encoder(hidden_states=spatial_embeds, **kwargs) + # shape of spatial_sequence_output is (B * num_frames, num_patches, dim) + spatial_sequence_output = spatial_encoder_outputs.last_hidden_state + features = self.layernorm1(spatial_sequence_output) + + temporal_embeds = self.temporal_embeddings(features, input_shape, interpolate_pos_encoding) + temporal_encoder_outputs: BaseModelOutput = self.temporal_encoder(hidden_states=temporal_embeds, **kwargs) + # shape of temporal_sequence_output is (B * num_patches, num_frames, dim) + temporal_sequence_output = temporal_encoder_outputs.last_hidden_state + features = self.layernorm2(temporal_sequence_output) + _, num_frames, dim = features.shape + features = features.view(input_shape[0], -1, num_frames, dim).permute(0, 2, 1, 3).contiguous() + _, num_frames, num_patches, dim = features.shape + features = features.view(input_shape[0], num_frames * num_patches, -1) + + return BaseModelOutputWithSpatialAndTemporalStates( + last_hidden_state=features, + temporal_hidden_state=temporal_sequence_output, + spatial_hidden_state=spatial_sequence_output, + ) + + +class VideoPrismMultiheadAttentionPoolingHead(nn.Module): + def __init__(self, config: VideoPrismVisionConfig): + super().__init__() + self.config = config + self.num_attention_heads = self.config.num_attention_heads + self.attention_head_size = int(self.config.intermediate_size / self.config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = self.config.attention_probs_dropout_prob + # PerDimScale + self.dim = int(self.config.intermediate_size / self.config.num_attention_heads) + self.per_dim_scale = nn.Parameter(torch.zeros(self.dim)) + r_softplus_0 = 1.442695041 + scale = torch.tensor(r_softplus_0 / (self.dim**0.5)) + softplus = nn.functional.softplus(self.per_dim_scale) + scale = scale * softplus + self.register_buffer("scale", scale) + + self.pooling_attention_query = nn.Parameter(torch.zeros(1, 1, self.config.hidden_size)) + self.query = nn.Linear(self.config.hidden_size, self.config.intermediate_size, bias=self.config.qkv_bias) + self.key = nn.Linear(self.config.hidden_size, self.config.intermediate_size, bias=self.config.qkv_bias) + self.value = nn.Linear(self.config.hidden_size, self.config.intermediate_size, bias=self.config.qkv_bias) + self.projection = nn.Linear(self.config.intermediate_size, self.config.hidden_size, bias=self.config.qkv_bias) + self.layernorm = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps) + self.dim = int(self.config.intermediate_size / self.config.num_attention_heads) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.LongTensor | None = None, + **kwargs, + ) -> tuple[torch.FloatTensor, torch.FloatTensor]: + batch_size, seq_length, hidden_size = hidden_states.shape + query = self.pooling_attention_query.expand(batch_size, -1, -1) + query_layer = ( + self.query(query).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) + query_layer = query_layer * self.scale.expand(*query_layer.shape) + + key_layer = ( + self.key(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(hidden_states) + .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + context_layer, attention_probs = attention_interface( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + scaling=1.0, + dropout=0.0 if not self.training else self.dropout_prob, + softcap=None, + **kwargs, + ) + + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.reshape(new_context_layer_shape) + outputs = self.projection(context_layer) + outputs = self.layernorm(outputs) + return (outputs, attention_probs) + + +def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): + """This function is intended to align with the l2norm implementation in the FLA library.""" + inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + return x * inv_norm + + +class VideoPrismTextModel(VideoPrismPreTrainedModel): + config_class = VideoPrismTextConfig + config: VideoPrismTextConfig + + def __init__(self, config: VideoPrismTextConfig): + super().__init__(config) + self.config = config + self.text_encoder = VideoPrismTextEncoder(self.config) + self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.cls_emb = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.layernorm = VideoPrismLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.normalize = config.apply_l2_norm + self.post_init() + + def create_sinusoidal_positions(self, num_pos: int, dim: int) -> torch.Tensor: + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / (dim - 2))) + sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float() + return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1) + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs, + ) -> BaseModelOutput: + r""" + Args: + input_ids (`torch.Tensor`): + Input token IDs. + attention_mask (`torch.Tensor`, *optional*): + Attention mask to avoid performing attention on padding token indices. + """ + batch_size, seq_length = input_ids.shape + hidden_states = self.token_embeddings(input_ids) + hidden_states = hidden_states * (self.config.hidden_size**0.5) + + cls_padding = torch.ones(batch_size, 1) + input_ids = torch.cat((input_ids, cls_padding), dim=1) + attention_mask = torch.cat((attention_mask, cls_padding), dim=1) if attention_mask is not None else None + + if attention_mask is not None: + attention_mask = create_causal_mask( + config=self.config, + input_embeds=hidden_states, + attention_mask=attention_mask, + cache_position=torch.arange(hidden_states.shape[1] + 1, device=hidden_states.device), + past_key_values=None, + ) + + features = hidden_states + self.create_sinusoidal_positions(seq_length, self.config.hidden_size) + cls_emb = self.cls_emb * (self.config.hidden_size**0.5) + cls_emb = cls_emb.expand(features.shape[0], -1, -1) + features = torch.cat((features, cls_emb), dim=1) + text_encoder_output = self.text_encoder(features, attention_mask) + features = text_encoder_output.last_hidden_state + features = self.layernorm(features) + text_embeddings = features[:, -1] + + if self.normalize: + text_embeddings = l2norm(text_embeddings, dim=-1) + + return BaseModelOutput( + last_hidden_state=text_embeddings, + ) + + + +class VideoPrismVideoModel(VideoPrismPreTrainedModel): + config_class = VideoPrismVisionConfig + config: VideoPrismVisionConfig + + def __init__(self, config: VideoPrismVisionConfig): + super().__init__(config) + self.config = config + self.backbone = VideoPrismVisionModel(self.config) + self.auxiliary_encoder = VideoPrismAuxiliaryEncoder(self.config) + self.contrastive_vision_pooler = VideoPrismMultiheadAttentionPoolingHead(self.config) + self.normalize = self.config.apply_l2_norm + self.post_init() + + def get_input_embeddings(self): + return self.backbone.spatial_embeddings.patch_embeddings + + def forward( + self, + pixel_values_videos: torch.FloatTensor, + interpolate_pos_encoding: bool | None = False, + **kwargs, + ) -> VideoPrismVideoOutput: + r""" + Args: + pixel_values_videos (`torch.FloatTensor`): + Pixel values of the video frames. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate positional encodings to match input size. + """ + backbone_outputs = self.backbone( + pixel_values_videos=pixel_values_videos, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs + ) + video_features = backbone_outputs.last_hidden_state + auxiliary_output = self.auxiliary_encoder(video_features) + auxiliary_output_features = auxiliary_output.last_hidden_state + contrastive_vision_pooler_output = self.contrastive_vision_pooler(auxiliary_output_features, **kwargs) + video_embeddings = contrastive_vision_pooler_output[0] + if self.normalize: + video_embeddings = l2norm(video_embeddings, dim=-1) + + return VideoPrismVideoOutput( + video_last_hidden_state=video_embeddings, + auxiliary_output=auxiliary_output, + attention_pooling_output=contrastive_vision_pooler_output, + ) + + +class VideoPrismClipModel(VideoPrismPreTrainedModel): + config_class = VideoPrismConfig + + def __init__(self, config: VideoPrismConfig): + super().__init__(config) + self.config = config + self.vision_config = config.vision_config + self.text_config = config.text_config + self.video_model = VideoPrismVideoModel(self.vision_config) + self.text_model = VideoPrismTextModel(self.text_config) + self.post_init() + + def forward( + self, + pixel_values_videos: torch.FloatTensor, + input_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + interpolate_pos_encoding: bool | None = False, + temperature: float | None = None, + **kwargs, + ) -> VideoPrismClipOutput: + r""" + Args: + pixel_values_videos (`torch.FloatTensor`): + Pixel values of the video frames. + input_ids (`torch.Tensor`): + Input token IDs for text. + attention_mask (`torch.Tensor`, *optional*): + Attention mask for text inputs. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate positional encodings. + temperature (`float`, *optional*): + Temperature parameter for scaling similarity scores. + + Example: + + ```python + >>> from transformers import VideoPrismProcessor, VideoPrismClipModel + >>> import torch + + >>> processor = VideoPrismProcessor.from_pretrained("google/videoprism") + >>> model = VideoPrismClipModel.from_pretrained("google/videoprism") + + >>> video = "sample_video.mp4" + >>> texts = ["a dog", "a cat"] + >>> inputs = processor(videos=video, texts=texts, return_tensors="pt", padding=True) + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + ... logits_per_video = outputs.logits_per_video + ``` + """ + video_model_outputs = self.video_model( + pixel_values_videos=pixel_values_videos, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs + ) + text_model_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) + + video_embeddings = video_model_outputs.video_last_hidden_state + text_embeddings = text_model_outputs.last_hidden_state + emb_dim = video_embeddings[0].shape[-1] + assert emb_dim == text_embeddings[0].shape[-1] + + video_embeds = video_embeddings.reshape(-1, emb_dim) + text_embeds = text_embeddings.reshape(-1, emb_dim) + similarity_matrix = torch.matmul(video_embeds, text_embeds.T) + + if temperature is not None: + similarity_matrix /= temperature + + logits_per_video = torch.exp(similarity_matrix) + logits_per_text = logits_per_video.T + logits_per_video = logits_per_video / torch.sum(logits_per_video, dim=0, keepdims=True) + logits_per_text = logits_per_text / torch.sum(logits_per_text, dim=0, keepdims=True) + + return VideoPrismClipOutput( + logits_per_video=logits_per_video, + logits_per_text=logits_per_text, + video_embeds=video_embeds, + text_embeds=text_embeds, + ) + + + +class VideoPrismForVideoClassification(VideoPrismPreTrainedModel): + config_class = VideoPrismVisionConfig + config: VideoPrismVisionConfig + + def __init__(self, config: VideoPrismVisionConfig): + super().__init__(config) + self.config = config + self.encoder = VideoPrismVisionModel(self.config) + self.contrastive_vision_pooler = VideoPrismMultiheadAttentionPoolingHead(self.config) + self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels) + self.post_init() + + def get_input_embeddings(self): + return self.encoder.spatial_embeddings.patch_embeddings + + def forward( + self, + pixel_values_videos: torch.FloatTensor, + labels: torch.LongTensor | None = None, + interpolate_pos_encoding: bool | None = False, + **kwargs, + ) -> ImageClassifierOutput: + r""" + Args: + pixel_values_videos (`torch.FloatTensor`): + Pixel values of the video frames. + labels (`torch.LongTensor`, *optional*): + Video classification labels. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate positional encodings. + + Example: + + ```python + >>> from transformers import VideoPrismVideoProcessor, VideoPrismForVideoClassification + >>> import torch + + >>> processor = VideoPrismVideoProcessor("google/videoprism") + >>> model = VideoPrismForVideoClassification.from_pretrained("google/videoprism", num_labels=1000) + + >>> video = "sample_video.mp4" + >>> inputs = processor(videos=video, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + ... logits = outputs.logits + ``` + """ + encoder_outputs = self.encoder( + pixel_values_videos=pixel_values_videos, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs + ) + sequence_output = encoder_outputs.last_hidden_state + pooled_output = self.contrastive_vision_pooler(sequence_output, **kwargs).pooled_output + logits = self.classifier(pooled_output) + loss = None + if labels is not None: + loss = self.loss_function(labels, logits, self.config, **kwargs) + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=encoder_outputs.last_hidden_state, + ) + + +__all__ = [ + "VideoPrismVisionModel", + "VideoPrismPreTrainedModel", + "VideoPrismVideoModel", + "VideoPrismTextModel", + "VideoPrismClipModel", + "VideoPrismForVideoClassification", +] diff --git a/src/lerobot/policies/videovla/videoprism/video_processing_videoprism.py b/src/lerobot/policies/videovla/videoprism/video_processing_videoprism.py new file mode 100644 index 000000000..89883fa98 --- /dev/null +++ b/src/lerobot/policies/videovla/videoprism/video_processing_videoprism.py @@ -0,0 +1,44 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/videoprism/modular_videoprism.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_videoprism.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 + +from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling +from transformers.video_processing_utils import BaseVideoProcessor + + +class VideoPrismVideoProcessor(BaseVideoProcessor): + r""" + Constructs a VideoPrism video processor. + + This processor inherits from [`LlavaOnevisionVideoProcessor`] and sets default parameters for VideoPrism models. + Video frames are resized to 288x288 using bicubic resampling without normalization. + + Args: + size (`Dict[str, int]`, *optional*, defaults to `{"height": 288, "width": 288}`): + The size to resize the video frames to. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + The resampling filter to use when resizing images. + do_normalize (`bool`, *optional*, defaults to `False`): + Whether to normalize the video frames. + """ + + resample = PILImageResampling.BICUBIC + image_mean = OPENAI_CLIP_MEAN + image_std = OPENAI_CLIP_STD + + size = {"height": 288, "width": 288} + rescale_factor = 1 / 255 + default_to_square = False + crop_size = None + do_resize = True + do_center_crop = None + do_rescale = True + do_normalize = False + do_convert_rgb = True + do_sample_frames = False # Set to False for BC, recommended to set `True` in new models + + +__all__ = ["VideoPrismVideoProcessor"]