From 9779c58b07608d7b5283dee18c17770ba41df9eb Mon Sep 17 00:00:00 2001 From: Jade Date: Tue, 8 Jul 2025 23:57:46 -0400 Subject: [PATCH] add training support --- lerobot/common/datasets/factory.py | 96 +- lerobot/common/datasets/lerobot_dataset.py | 3 +- lerobot/common/datasets/utils_must.py | 24 +- .../smolvla2/configuration_smolvla2.py | 191 +++ .../policies/smolvla2/modeling_smolvla2.py | 1085 +++++++++++++++++ .../policies/smolvla2/smolvlm_with_expert2.py | 599 +++++++++ 6 files changed, 1995 insertions(+), 3 deletions(-) create mode 100644 lerobot/common/policies/smolvla2/configuration_smolvla2.py create mode 100644 lerobot/common/policies/smolvla2/modeling_smolvla2.py create mode 100644 lerobot/common/policies/smolvla2/smolvlm_with_expert2.py diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 88d3f767f..e9b05bce1 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -32,6 +32,7 @@ IMAGENET_STATS = { "std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1) } +from lerobot.common.datasets.utils_must import (EPISODES_DATASET_MAPPING, TRAINING_FEATURES, FEATURE_KEYS_MAPPING) def resolve_delta_timestamps( cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata @@ -66,7 +67,7 @@ def resolve_delta_timestamps( return delta_timestamps -def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset: +def make_dataset1(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset: """Handles the logic of setting up delta timestamps and image transforms before creating a dataset. Args: @@ -116,3 +117,96 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) return dataset + +def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDataset: + """Handles the logic of setting up delta timestamps and image transforms before creating a dataset. + + Args: + cfg (TrainPipelineConfig): A TrainPipelineConfig config which contains a DatasetConfig and a PreTrainedConfig. + + Raises: + NotImplementedError: The MultiLeRobotDataset is currently deactivated. + + Returns: + LeRobotDataset | MultiLeRobotDataset + """ + image_transforms = ( + ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None + ) + if "," in cfg.dataset.repo_id: + repo_id = cfg.dataset.repo_id.split(",") + repo_id = [r for r in repo_id if r] + else: + repo_id = cfg.dataset.repo_id + sampling_weights = cfg.dataset.sampling_weights.split(",") if cfg.dataset.sampling_weights else None + feature_keys_mapping = FEATURE_KEYS_MAPPING + if isinstance(repo_id, str): + revision = getattr(cfg.dataset, "revision", None) + ds_meta = LeRobotDatasetMetadata( + cfg.dataset.repo_id, + local_files_only=cfg.dataset.local_files_only, + feature_keys_mapping=feature_keys_mapping, + revision=revision, + ) + delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta) + dataset = LeRobotDataset( + cfg.dataset.repo_id, + root=getattr(cfg.dataset, "root", None), + episodes=cfg.dataset.episodes, + delta_timestamps=delta_timestamps, + image_transforms=image_transforms, + revision=revision, + video_backend=cfg.dataset.video_backend, + local_files_only=cfg.dataset.local_files_only, + feature_keys_mapping=feature_keys_mapping, + max_action_dim=cfg.dataset.max_action_dim, + max_state_dim=cfg.dataset.max_state_dim, + max_num_images=cfg.dataset.max_num_images, + max_image_dim=cfg.dataset.max_image_dim, + ) + else: + delta_timestamps = {} + episodes = {} + for i in range(len(repo_id)): + ds_meta = LeRobotDatasetMetadata( + repo_id[i], + local_files_only=cfg.dataset.local_files_only, + feature_keys_mapping=feature_keys_mapping, + ) # FIXME(mshukor): ? + delta_timestamps[repo_id[i]] = resolve_delta_timestamps(cfg.policy, ds_meta) + episodes[repo_id[i]] = EPISODES_DATASET_MAPPING.get(repo_id[i], cfg.dataset.episodes) + training_features = TRAINING_FEATURES.get(cfg.dataset.features_version, None) + dataset = MultiLeRobotDataset( + repo_id, + # TODO(aliberts): add proper support for multi dataset + episodes=episodes, + delta_timestamps=delta_timestamps, + image_transforms=image_transforms, + video_backend=cfg.dataset.video_backend, + local_files_only=cfg.dataset.local_files_only, + sampling_weights=sampling_weights, + feature_keys_mapping=feature_keys_mapping, + max_action_dim=cfg.dataset.max_action_dim, + max_state_dim=cfg.dataset.max_state_dim, + max_num_images=cfg.dataset.max_num_images, + max_image_dim=cfg.dataset.max_image_dim, + train_on_all_features=cfg.dataset.train_on_all_features, + training_features=training_features, + discard_first_n_frames=cfg.dataset.discard_first_n_frames, + min_fps=cfg.dataset.min_fps, + max_fps=cfg.dataset.max_fps, + discard_first_idle_frames=cfg.dataset.discard_first_idle_frames, + motion_threshold=cfg.dataset.motion_threshold, + motion_window_size=cfg.dataset.motion_window_size, + motion_buffer=cfg.dataset.motion_buffer, + ) + logging.info( + "Multiple datasets were provided. Applied the following index mapping to the provided datasets: " + f"{pformat(dataset.repo_id_to_index , indent=2)}" + ) + if cfg.dataset.use_imagenet_stats: + for key in dataset.meta.camera_keys: + for stats_type, stats in IMAGENET_STATS.items(): + dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) + + return dataset diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 7b777e0d7..df13bed72 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -83,6 +83,7 @@ from lerobot.common.datasets.utils_must import ( create_padded_features, pad_tensor, map_dict_keys, + find_start_of_motion, ROBOT_TYPE_KEYS_MAPPING, OBS_IMAGE, OBS_IMAGE_2, @@ -562,7 +563,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.subset_frame_ids += [frame_idx for frame_idx in range(from_ + int(self.fps*self.discard_first_n_frames), to_)] elif self.discard_first_idle_frames: print(f"Discarding first idle frames: motion_threshold={self.motion_threshold}, motion_window_size={self.motion_window_size}, motion_buffer={self.motion_buffer}") - self.robot_states = torch.stack(self.hf_dataset[OBS_ROBOT]).numpy() # shape: [T, D] + self.robot_states = torch.stack(self.hf_dataset[OBS_STATE]).numpy() # shape: [T, D] self.subset_frame_ids = [] for ep_idx in range(self.num_episodes): from_ = self.episode_data_index["from"][ep_idx] diff --git a/lerobot/common/datasets/utils_must.py b/lerobot/common/datasets/utils_must.py index 93927f6c9..040627dfa 100644 --- a/lerobot/common/datasets/utils_must.py +++ b/lerobot/common/datasets/utils_must.py @@ -191,6 +191,13 @@ def map_dict_keys(item: dict, feature_keys_mapping: dict, training_features: lis features[key] = item[key] return features +def find_start_of_motion(velocities, window_size, threshold, motion_buffer): + for t in range(len(velocities) - window_size): + window_mean = velocities[t:t+window_size].mean() + if window_mean > threshold: + return max(0, t - motion_buffer) # include slight context before motion + return 0 + import yaml import requests def load_yaml_mapping(name: str) -> dict: @@ -205,4 +212,19 @@ def load_yaml_mapping(name: str) -> dict: return yaml.safe_load(response.text) # Example usage -TASKS_KEYS_MAPPING = load_yaml_mapping("features") +TASKS_KEYS_MAPPING = load_yaml_mapping("tasks") +FEATURE_KEYS_MAPPING = load_yaml_mapping("features") +EPISODES_DATASET_MAPPING = { + "cadene/droid_1.0.1": list(range(50)), + "danaaubakirova/svla_so100_task5_v3": [0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51], + "danaaubakirova/svla_so100_task4_v3": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53], +} +ACTION = "action" +OBS_STATE = "observation.state" +TASK = "task" +ROBOT = "robot_type" +TRAINING_FEATURES = { + 0: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE], + 1: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2], + 2: [ACTION, OBS_STATE, TASK, ROBOT, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3], +} diff --git a/lerobot/common/policies/smolvla2/configuration_smolvla2.py b/lerobot/common/policies/smolvla2/configuration_smolvla2.py new file mode 100644 index 000000000..9b839c62a --- /dev/null +++ b/lerobot/common/policies/smolvla2/configuration_smolvla2.py @@ -0,0 +1,191 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.common.optim.optimizers import AdamWConfig +from lerobot.common.optim.schedulers import ( + CosineDecayWithWarmupSchedulerConfig, +) +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + + +@dataclass +class PEFTConfig: + r: int = 4 + lora_alpha: int = 16 + lora_dropout: float = 0.1 + target_modules: str = "q_proj,v_proj" + + +@PreTrainedConfig.register_subclass("smolvla2") +@dataclass +class SmolVLA2Config(PreTrainedConfig): + # Input / output structure. + n_obs_steps: int = 1 + chunk_size: int = 50 + n_action_steps: int = 50 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, + "ACTION": NormalizationMode.MEAN_STD, + } + ) + + # Shorter state and action vectors will be padded + max_state_dim: int = 32 + max_action_dim: int = 32 + + # Image preprocessing + resize_imgs_with_padding: tuple[int, int] = (512, 512) + + # Add empty images. Used by smolvla_aloha_sim which adds the empty + # left and right wrist cameras in addition to the top camera. + empty_cameras: int = 0 + + # Converts the joint and gripper values from the standard Aloha space to + # the space used by the pi internal runtime which was used to train the base model. + adapt_to_pi_aloha: bool = False + + # Converts joint dimensions to deltas with respect to the current state before passing to the model. + # Gripper dimensions will remain in absolute values. + use_delta_joint_actions_aloha: bool = False + + # Tokenizer + tokenizer_max_length: int = 48 + proj_width: int = 480 + # Decoding + num_steps: int = 10 + + # Attention utils + use_cache: bool = True + + # Finetuning settings + freeze_vision_encoder: bool = True + train_expert_only: bool = False + train_state_proj: bool = True + + # Training presets + optimizer_lr: float = 2.5e-5 # 1e-4 + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 1e-10 + optimizer_grad_clip_norm: float = 10 + optimizer_lr_vlm: float = 0 + + scheduler_warmup_steps: int = 1_000 + scheduler_decay_steps: int = 30_000 + scheduler_decay_lr: float = 2.5e-6 + + vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone. + load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights + checkpoint_path: str = None + peft_method: str = "" + peft_config: PEFTConfig = field(default_factory=PEFTConfig) + peft_target_model: str = "" + add_image_special_tokens: bool = False # Whether to use special image tokens around image features. + + attention_mode: str = "cross_attn" + + prefix_length: int = -1 + + pad_language_to: str = "longest" # "max_length" + + num_expert_layers: int = -1 # Less or equal to 0 is the default where the action expert has the same number of layers of VLM. Otherwise the expert have less layers. + num_vlm_layers: int = 16 + past_obs_keys: str = "image" + add_local_special_image_tokens: bool = False + + reverse_images_order: bool = False + + state_to_prefix: bool = False + + pad_language_to: str = "longest" # "max_length" + causal_action_attention_mask: bool = False + + self_attn_every_n_layers: int = -1 # Number of layers used in the VLM (first num_vlm_layers layers) + # self_attn_every_n_layers: int = 2 # Interleave SA layers each self_attn_every_n_layers + expert_width_multiplier: float = 0.75 # The action expert hidden size (wrt to the VLM) + + min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding + max_period: float = 4.0 + + robot_type: str = "" + + self_attn_only_actions: bool = False + + causal_attention_on_history: bool = False + + predict_relative_actions: bool = False + relative_actions_mode: str = "first" + + shuffle_camera_positions: bool = False + vlm_img_size: int = -1 + + regression_loss: bool = False + + def __post_init__(self): + super().__post_init__() + + """Input validation (not exhaustive).""" + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"The chunk size is the upper bound for the number of action steps per model invocation. Got " + f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." + ) + if self.use_delta_joint_actions_aloha: + raise NotImplementedError( + "`use_delta_joint_actions_aloha` is used by smolvla for aloha real models. It is not ported yet in LeRobot." + ) + + def validate_features(self) -> None: + for i in range(self.empty_cameras): + key = f"observation.images.empty_camera_{i}" + empty_camera = PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 480, 640), + ) + self.input_features[key] = empty_camera + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, + ) + + def get_scheduler_preset(self): + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> list: + return [0] + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/common/policies/smolvla2/modeling_smolvla2.py b/lerobot/common/policies/smolvla2/modeling_smolvla2.py new file mode 100644 index 000000000..3e85ba0ff --- /dev/null +++ b/lerobot/common/policies/smolvla2/modeling_smolvla2.py @@ -0,0 +1,1085 @@ +#!/usr/bin/env python + +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +SmolVLA: + +[Paper](https://huggingface.co/papers/2506.01844) + +Designed by Hugging Face. + +Install smolvla extra dependencies: +```bash +pip install -e ".[smolvla]" +``` + +Example of finetuning the smolvla pretrained model (`smolvla_base`): +```bash +python lerobot/scripts/train.py \ +--policy.path=lerobot/smolvla_base \ +--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ +--batch_size=64 \ +--steps=200000 +``` + +Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM, +and an action expert. +```bash +python lerobot/scripts/train.py \ +--policy.type=smolvla \ +--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ +--batch_size=64 \ +--steps=200000 +``` + +Example of using the smolvla pretrained model outside LeRobot training framework: +```python +policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base") +``` + +""" + +import math +import os +import random +import re +from collections import deque + +import safetensors +import torch +import torch.nn.functional as F # noqa: N812 +from torch import Tensor, nn +from transformers import AutoProcessor + +from lerobot.common.constants import ACTION, OBS_STATE +from lerobot.common.policies.normalize import ( + Normalize, + Unnormalize, +) +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel +from lerobot.common.policies.smolvla2.configuration_smolvla2 import SmolVLA2Config +from lerobot.common.policies.utils import ( + populate_queues, +) +from lerobot.common.utils.utils import get_safe_dtype +from lerobot.datasets import IMAGES_ORDER + +# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker +_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_") + + +def canonicalise(k: str) -> str: + """ + Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a + normalisation-buffer key. + """ + return _VARIANT_RE.sub(".buffer_", k) + + +def standardise_state_dict( + checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True +) -> tuple[dict[str, torch.Tensor], list[str]]: + """ + • Re-keys `checkpoint ` so that every entry matches the *reference* key set. + • If several variant keys collapse to the same canonical name we keep the + first one and log the collision. + • Returns the new dict + a list of entries that could not be matched. + """ + out, collisions, unmatched = {}, {}, [] + + for k, v in checkpoint.items(): + canon = canonicalise(k) + if canon in ref_keys: + if canon in out: # duplicate after collapsing + collisions.setdefault(canon, []).append(k) + else: + out[canon] = v + else: + unmatched.append(k) + + if verbose: + for canon, variants in collisions.items(): + print(f"[standardise_state_dict] '{canon}' ← {variants}") + if unmatched: + print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys") + + out.update({k: checkpoint[k] for k in unmatched}) + return out, unmatched + + +def rename_checkpoint_keys(checkpoint: dict, rename_str: str): + """ + Renames keys in a checkpoint dictionary based on the given rename string. + + Args: + checkpoint (dict): The checkpoint dictionary. + rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2". + + Returns: + dict: The modified checkpoint with renamed keys. + """ + + rename_dict = dict(pair.split("//") for pair in rename_str.split(",")) + + new_checkpoint = {} + for k, v in checkpoint.items(): + for old_key, new_key in rename_dict.items(): + if old_key in k: + k = k.replace(old_key, new_key) + new_checkpoint[k] = v + return new_checkpoint + + +def load_smolvla( + model: torch.nn.Module, + filename: str | os.PathLike, + *, + device: str = "cpu", + checkpoint_keys_mapping: str = "", +) -> torch.nn.Module: + state_dict = safetensors.torch.load_file(filename, device=device) + + # Optional user-supplied renames (e.g. "model._orig_mod.//model.") + if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping: + state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping) + + state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys())) + + # HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset + norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs") + state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)} + + missing, unexpected = model.load_state_dict(state_dict, strict=False) + + if not all(key.startswith(norm_keys) for key in missing) or unexpected: + raise RuntimeError( + "SmolVLA %d missing / %d unexpected keys", + len(missing), + len(unexpected), + ) + + return model + + +def create_sinusoidal_pos_embedding( + time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" +) -> Tensor: + """Computes sine-cosine positional embedding vectors for scalar positions.""" + if dimension % 2 != 0: + raise ValueError(f"dimension ({dimension}) must be divisible by 2") + + if time.ndim != 1: + raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") + + dtype = get_safe_dtype(torch.float64, device.type) + fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) + period = min_period * (max_period / min_period) ** fraction + + # Compute the outer product + scaling_factor = 1.0 / period * 2 * math.pi + sin_input = scaling_factor[None, :] * time[:, None] + pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + return pos_emb + + +def sample_beta(alpha, beta, bsize, device): + gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha) + gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta) + return gamma1 / (gamma1 + gamma2) + + +def make_att_2d_masks(pad_masks, att_masks): + """Copied from big_vision. + + Tokens can attend to valid inputs tokens which have a cumulative mask_ar + smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to + setup several types of attention, for example: + + [[1 1 1 1 1 1]]: pure causal attention. + + [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between + themselves and the last 3 tokens have a causal attention. The first + entry could also be a 1 without changing behaviour. + + [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a + block can attend all previous blocks and all tokens on the same block. + + Args: + input_mask: bool[B, N] true if its part of the input, false if padding. + mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on + it and 0 where it shares the same attention mask as the previous token. + """ + if att_masks.ndim != 2: + raise ValueError(att_masks.ndim) + if pad_masks.ndim != 2: + raise ValueError(pad_masks.ndim) + + cumsum = torch.cumsum(att_masks, dim=1) + att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + att_2d_masks = att_2d_masks & pad_2d_masks + return att_2d_masks + + +def resize_with_pad(img, width, height, pad_value=-1): + # assume no-op when width height fits already + if img.ndim != 4: + raise ValueError(f"(b,c,h,w) expected, but {img.shape}") + + cur_height, cur_width = img.shape[2:] + + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + resized_img = F.interpolate( + img, size=(resized_height, resized_width), mode="bilinear", align_corners=False + ) + + pad_height = max(0, int(height - resized_height)) + pad_width = max(0, int(width - resized_width)) + + # pad on left and top of image + padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) + return padded_img + + +def pad_vector(vector, new_dim): + """Can be (batch_size x sequence_length x features_dimension) + or (batch_size x features_dimension) + """ + if vector.shape[-1] == new_dim: + return vector + shape = list(vector.shape) + current_dim = shape[-1] + shape[-1] = new_dim + new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device) + new_vector[..., :current_dim] = vector + return new_vector + + +def normalize(x, min_val, max_val): + return (x - min_val) / (max_val - min_val) + + +def unnormalize(x, min_val, max_val): + return x * (max_val - min_val) + min_val + + +def safe_arcsin(value): + # This ensures that the input stays within + # [−1,1] to avoid invalid values for arcsin + return torch.arcsin(torch.clamp(value, -1.0, 1.0)) + + +def aloha_gripper_to_angular(value): + # Aloha transforms the gripper positions into a linear space. The following code + # reverses this transformation to be consistent with smolvla which is pretrained in + # angular space. + # + # These values are coming from the Aloha code: + # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED + value = unnormalize(value, min_val=0.01844, max_val=0.05800) + + # This is the inverse of the angular to linear transformation inside the Interbotix code. + def linear_to_radian(linear_position, arm_length, horn_radius): + value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) + return safe_arcsin(value) + + # The constants are taken from the Interbotix code. + value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) + + # Normalize to [0, 1]. + # The values 0.4 and 1.5 were measured on an actual Trossen robot. + return normalize(value, min_val=0.4, max_val=1.5) + + +def aloha_gripper_from_angular(value): + # Convert from the gripper position used by smolvla to the gripper position that is used by Aloha. + # Note that the units are still angular but the range is different. + + # The values 0.4 and 1.5 were measured on an actual Trossen robot. + value = unnormalize(value, min_val=0.4, max_val=1.5) + + # These values are coming from the Aloha code: + # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE + return normalize(value, min_val=-0.6213, max_val=1.4910) + + +def aloha_gripper_from_angular_inv(value): + # Directly inverts the gripper_from_angular function. + value = unnormalize(value, min_val=-0.6213, max_val=1.4910) + return normalize(value, min_val=0.4, max_val=1.5) + + +class SmolVLA2Policy(PreTrainedPolicy): + """Wrapper class around VLAFlowMatching model to train and run inference within LeRobot.""" + + config_class = SmolVLA2Config + name = "smolvla2" + + def __init__( + self, + config: SmolVLA2Config, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + + super().__init__(config) + config.validate_features() + self.config = config + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer + self.model = VLAFlowMatching(config) + self.reset() + + def reset(self): + """This should be called whenever the environment is reset.""" + self._queues = { + ACTION: deque(maxlen=self.config.n_action_steps), + } + if self.config.n_obs_steps > 1: + for k in self.config.input_features: + if any([past_obs_key in k for past_obs_key in self.config.past_obs_keys.split(",")]): + self._queues[k] = deque(maxlen=self.config.n_obs_steps) + + # HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues + @classmethod + def _load_as_safetensor( + cls, + model: "SmolVLA2Policy", + model_file: str, + map_location: str, + strict: bool, + ): + safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) + return load_smolvla( + model, + model_file, + device=map_location, + checkpoint_keys_mapping="model._orig_mod.//model.", + ) + + def get_optim_params(self) -> dict: + return self.parameters() + + def merge_peft_model_weights(self) -> None: + if "lora" in self.config.peft_method: + self.model.vlm_with_expert.merge_lora_weights() + + @torch.no_grad + def select_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + """Select a single action given environment observations. + + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + self.eval() + + if self.config.adapt_to_pi_aloha: + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) + + batch = self.normalize_inputs(batch) + + images, img_masks = self.prepare_images(batch) + state = self.prepare_state(batch) + lang_tokens, lang_masks = self.prepare_language(batch) + + actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise) + # Unpad actions + original_action_dim = self.config.action_feature.shape[0] + actions = actions[:, :, :original_action_dim] + + actions = self.unnormalize_outputs({"action": actions, "robot_type": batch["robot_type"]})["action"] + + if self.config.adapt_to_pi_aloha: + actions = self._pi_aloha_encode_actions(actions) + + return actions + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + """Select a single action given environment observations. + + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + self.eval() + + if self.config.adapt_to_pi_aloha: + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) + + batch = self.normalize_inputs(batch) + + self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) + # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by + # querying the policy. + if len(self._queues[ACTION]) == 0: + for k in batch: + if k in self._queues: + batch[k] = torch.stack(list(self._queues[k]), dim=1) + images, img_masks = self.prepare_images(batch) + state = self.prepare_state(batch) + lang_tokens, lang_masks = self.prepare_language(batch) + + actions = self.model.sample_actions( + images, img_masks, lang_tokens, lang_masks, state, noise=noise + ) + if self.config.predict_relative_actions and actions.ndim == 3: + # If the model predicts relative actions, we need to unpad the actions + # and then convert them to absolute actions. + if self.config.relative_actions_mode == "first": + actions = torch.cat((actions[:, :1], actions[:, 1:] + actions[:, :1]), dim=1) + elif self.config.relative_actions_mode == "state": + actions = actions + state.unsqueeze(1) + else: + actions = torch.cat((actions[:, :1], actions[:, 1:] + actions[:, :-1]), dim=1) + # Unpad actions + + original_action_dim = self.config.action_feature.shape[0] + actions = actions[:, :, :original_action_dim] + + actions = self.unnormalize_outputs({"action": actions})["action"] + + if self.config.adapt_to_pi_aloha: + actions = self._pi_aloha_encode_actions(actions) + + # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue + # effectively has shape (n_action_steps, batch_size, *), hence the transpose. + self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps]) + return self._queues[ACTION].popleft() + + def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: + """Do a full training forward pass to compute the loss""" + if self.config.adapt_to_pi_aloha: + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) + batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + images, img_masks = self.prepare_images(batch) + state = self.prepare_state(batch) + lang_tokens, lang_masks = self.prepare_language(batch) + actions = self.prepare_action(batch, state=state) + actions_is_pad = batch.get("actions_id_pad") + loss_dict = {} + losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) + loss_dict["losses_after_forward"] = losses.clone() + + if actions_is_pad is not None: + in_episode_bound = ~actions_is_pad + losses = losses * in_episode_bound.unsqueeze(-1) + loss_dict["losses_after_in_ep_bound"] = losses.clone() + + # Remove padding + losses = losses[:, :, : self.config.max_action_dim] + loss_dict["losses_after_rm_padding"] = losses.clone() + + # For backward pass + loss = losses.mean() + # For backward pass + loss_dict["loss"] = loss.item() + return loss, loss_dict + + def prepare_images(self, batch): + """Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and + convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP. + """ + images = [] + img_masks = [] + present_img_keys = [key for key in self.config.image_features if key in batch] + missing_img_keys = [key for key in self.config.image_features if key not in batch] + + present_img_keys = sorted( + present_img_keys, + key=lambda k: IMAGES_ORDER.get(k, float("inf")), + reverse=self.config.reverse_images_order, + ) + if self.config.shuffle_camera_positions and ACTION in batch: # only during training + present_img_keys = random.sample(present_img_keys, len(present_img_keys)) + if len(present_img_keys) == 0: + raise ValueError( + f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" + ) + for i in range(self.num_past_images): + # Preprocess image features present in the batch + for key in present_img_keys: + img = batch[key][:, i, :, :, :] if batch[key].ndim == 5 else batch[key] + if self.config.resize_imgs_with_padding is not None: + img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0) + + # Normalize from range [0,1] to [-1,1] as expacted by siglip + img = img * 2.0 - 1.0 + + bsize = img.shape[0] + device = img.device + if f"{key}_padding_mask" in batch: + mask = batch[f"{key}_padding_mask"].bool() + else: + mask = torch.ones(bsize, dtype=torch.bool, device=device) + images.append(img) + img_masks.append(mask) + + # Create image features not present in the batch + # as fully 0 padded images. + for num_empty_cameras in range(len(missing_img_keys)): + if num_empty_cameras >= self.config.empty_cameras: + break + img = torch.ones_like(img) * -1 + mask = torch.zeros_like(mask) + images.append(img) + img_masks.append(mask) + return images, img_masks + + def prepare_language(self, batch) -> tuple[Tensor, Tensor]: + """Tokenize the text input""" + device = batch[OBS_STATE].device + tasks = batch["task"] + if isinstance(tasks, str): + tasks = [tasks] + + if len(tasks) == 1: + tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])] + + tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] + + tokenized_prompt = self.language_tokenizer.__call__( + tasks, + padding=self.config.pad_language_to, + padding_side="right", + max_length=self.config.tokenizer_max_length, + return_tensors="pt", + truncation=True, # FIXME(mshukor) + ) + lang_tokens = tokenized_prompt["input_ids"].to(device=device) + lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) + + return lang_tokens, lang_masks + + def _pi_aloha_decode_state(self, state): + # Flip the joints. + for motor_idx in [1, 2, 8, 9]: + state[:, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx]) + return state + + def _pi_aloha_encode_actions(self, actions): + # Flip the joints. + for motor_idx in [1, 2, 8, 9]: + actions[:, :, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx]) + return actions + + def _pi_aloha_encode_actions_inv(self, actions): + # Flip the joints again. + for motor_idx in [1, 2, 8, 9]: + actions[:, :, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) + return actions + + def prepare_state(self, batch): + """Pad state""" + state = batch[OBS_STATE][:, -1, :] if batch[OBS_STATE].ndim > 2 else batch[OBS_STATE] + state = pad_vector(state, self.config.max_state_dim) + return state + + def prepare_action(self, batch, state=None): + """Pad action""" + actions = pad_vector(batch[ACTION], self.config.max_action_dim) + if self.config.predict_relative_actions and actions.ndim == 3: + if self.config.relative_actions_mode == "first": + actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :1]), dim=1) + elif self.config.relative_actions_mode == "state": + assert batch[ACTION].shape[-1] == batch[OBS_STATE].shape[-1], ( + "Relative action mode 'state' requires the action and state to have the same dimension." + ) + if state.ndim == 2: + state = state.unsqueeze(1) + actions = actions - state + else: + actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :-1]), dim=1) + return actions + + +def pad_tensor(tensor, max_len, pad_value=0): + """ + Efficiently pads a tensor along sequence dimension to match max_len. + + Args: + tensor (torch.Tensor): Shape (B, L, ...) or (B, L). + max_len (int): Fixed sequence length. + pad_value (int/float): Value for padding. + + Returns: + torch.Tensor: Shape (B, max_len, ...) or (B, max_len). + """ + b, d = tensor.shape[:2] + + # Create a padded tensor of max_len and copy the existing values + padded_tensor = torch.full( + (b, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device + ) + padded_tensor[:, :d] = tensor # Efficient in-place copy + + return padded_tensor + + +class VLAFlowMatching(nn.Module): + """ + SmolVLA + + [Paper]() + + Designed by Hugging Face. + ┌──────────────────────────────┐ + │ actions │ + │ ▲ │ + │ ┌─────────┐ ┌─|────┐ │ + │ | │────► │ │ │ + │ | │ kv │ │ │ + │ | │────► │Action│ │ + │ | VLM │cache │Expert│ | + │ │ │────► | │ │ + │ │ │ │ │ │ + │ └▲──▲───▲─┘ └───▲──┘ | + │ │ | | │ | + │ | | | noise │ + │ │ │ state │ + │ │ language tokens │ + │ image(s) │ + └──────────────────────────────┘ + """ + + def __init__(self, config): + super().__init__() + self.config = config + + self.vlm_with_expert = SmolVLMWithExpertModel( + model_id=self.config.vlm_model_name, + freeze_vision_encoder=self.config.freeze_vision_encoder, + train_expert_only=self.config.train_expert_only, + attention_implementation=self.config.attention_implementation, + load_vlm_weights=self.config.load_vlm_weights, + attention_mode=self.config.attention_mode, + num_expert_layers=self.config.num_expert_layers, + num_vlm_layers=self.config.num_vlm_layers, + self_attn_every_n_layers=self.config.self_attn_every_n_layers, + expert_width_multiplier=self.config.expert_width_multiplier, + ) + self.vlm_with_expert.configure_peft(config=self.config) + # Projections are float32 + self.state_to_prefix = self.config.state_to_prefix + if self.state_to_prefix: + self.state_proj = nn.Linear( + self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size + ) + else: + self.state_proj = nn.Linear(self.config.max_state_dim, self.vlm_with_expert.expert_hidden_size) + self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size) + self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim) + + self.action_time_mlp_in = nn.Linear( + self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size + ) + self.action_time_mlp_out = nn.Linear( + self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size + ) + + self.set_requires_grad() + # SmolVLM2 has: [fake_tok + crop_tok + crop + fake_tok + crop_tok ... + fake_tok + global_tok + global + fake_tok] + [second image] + ... + self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id + self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id + self.global_image_start_token = torch.tensor( + [self.fake_image_token, self.global_image_token], dtype=torch.long + ) + + self.add_image_special_tokens = self.config.add_image_special_tokens + self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long) + self.prefix_length = self.config.prefix_length + self.include_past_images = self.config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split( + "," + ) + self.num_past_images = self.config.n_obs_steps if self.include_past_images else 1 + self.causal_attention_on_history = self.config.causal_attention_on_history + + def set_requires_grad(self): + for params in self.state_proj.parameters(): + params.requires_grad = self.config.train_state_proj + + def sample_noise(self, shape, device): + noise = torch.normal( + mean=0.0, + std=1.0, + size=shape, + dtype=torch.float32, + device=device, + ) + return noise + + def sample_time(self, bsize, device): + time_beta = sample_beta(1.5, 1.0, bsize, device) + time = time_beta * 0.999 + 0.001 + return time.to(dtype=torch.float32, device=device) + + def embed_prefix( + self, + images, + img_masks, + lang_tokens, + lang_masks, + state: torch.Tensor = None, + pointtrackers=None, + pt_masks=None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Embed multiple modalities for vlm processing. + + Simple, extensible approach using list + torch.cat. + Easy to add new information/modalities like point trackers, audio, etc. + + Args: + images: List of image tensors + img_masks: List of image masks + lang_tokens: Language token tensor + lang_masks: Language mask tensor + state: Optional state tensor + pointtrackers: Optional point tracker tensors (future extension) + pt_masks: Optional point tracker masks (future extension) + **kwargs: Additional modalities for future extensions + """ + embs = [] + pad_masks = [] + att_masks = [] + + # Process each modality type + self._add_image_embeddings(images, img_masks, embs, pad_masks, att_masks) + self._add_language_embeddings(lang_tokens, lang_masks, embs, pad_masks, att_masks) + + if state is not None and self.state_to_prefix: + self._add_state_embeddings(state, embs, pad_masks, att_masks) + + # Future extensions - easy to add new modalities + if pointtrackers is not None: + self._add_pointtracker_embeddings(pointtrackers, pt_masks, embs, pad_masks, att_masks) + + # Add more modalities here as needed: + # if audio is not None: + # self._add_audio_embeddings(audio, audio_masks, embs, pad_masks, att_masks) + + # Concatenate all embeddings + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + + # Handle prefix length padding + seq_len = pad_masks.shape[1] + if seq_len < self.prefix_length: + embs = pad_tensor(embs, self.prefix_length, pad_value=0) + pad_masks = pad_tensor(pad_masks, self.prefix_length, pad_value=0) + att_masks = pad_tensor(att_masks, self.prefix_length, pad_value=0) + + # Expand attention masks to batch size + bsize = pad_masks.shape[0] + att_masks = att_masks[None, :].expand(bsize, -1) + + return embs, pad_masks, att_masks + + def _add_image_embeddings(self, images, img_masks, embs, pad_masks, att_masks): + """Add image embeddings with special tokens to the lists.""" + for img, img_mask in zip(images, img_masks, strict=False): + # Add image start tokens if enabled + if self.add_image_special_tokens: + start_emb = ( + self.vlm_with_expert.embed_language_tokens( + self.global_image_start_token.to(device=img.device) + ) + .unsqueeze(0) + .expand(img.shape[0], -1, -1) + ) + + start_mask = torch.ones_like(start_emb[:, :, 0], dtype=torch.bool) + embs.append(start_emb) + pad_masks.append(start_mask) + att_masks += [0] * start_emb.shape[1] + + # Process image embedding + img_emb = self.vlm_with_expert.embed_image(img) + + # Normalize image embeddings + img_emb_dim = img_emb.shape[-1] + img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device) + + # Expand mask to match image embedding sequence length + bsize, num_img_embs = img_emb.shape[:2] + expanded_mask = img_mask[:, None].expand(bsize, num_img_embs) + + embs.append(img_emb) + pad_masks.append(expanded_mask) + att_masks += [0] * num_img_embs + + # Add image end tokens if enabled + if self.add_image_special_tokens: + end_emb = ( + self.vlm_with_expert.embed_language_tokens(self.image_end_token.to(device=img.device)) + .unsqueeze(0) + .expand(img.shape[0], -1, -1) + ) + + end_mask = torch.ones_like(end_emb[:, :, 0], dtype=torch.bool) + embs.append(end_emb) + pad_masks.append(end_mask) + att_masks += [0] * end_emb.shape[1] + + def _add_language_embeddings(self, lang_tokens, lang_masks, embs, pad_masks, att_masks): + """Add language embeddings to the lists.""" + lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens) + + # Normalize language embeddings + lang_emb_dim = lang_emb.shape[-1] + lang_emb = lang_emb * math.sqrt(lang_emb_dim) + + embs.append(lang_emb) + pad_masks.append(lang_masks) + att_masks += [0] * lang_emb.shape[1] + + def _add_state_embeddings(self, state, embs, pad_masks, att_masks): + """Add state embeddings to the lists.""" + state_emb = self.state_proj(state) + state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb + + bsize, states_seq_len = state_emb.shape[:2] + state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=state_emb.device) + + embs.append(state_emb) + pad_masks.append(state_mask) + att_masks += [1] * states_seq_len # State tokens get causal attention + + def _add_pointtracker_embeddings(self, pointtrackers, pt_masks, embs, pad_masks, att_masks): + """Add point tracker embeddings to the lists (future extension).""" + # TODO: Implement point tracker processing + # Example implementation: + # for pt, pt_mask in zip(pointtrackers, pt_masks): + # pt_emb = self.pointtracker_encoder(pt) # Need to add this + # embs.append(pt_emb) + # pad_masks.append(pt_mask) + # att_masks += [0] * pt_emb.shape[1] + pass + + def embed_suffix(self, state, noisy_actions, timestep): + """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" + embs = [] + pad_masks = [] + att_masks = [] + # Embed state + if not self.state_to_prefix: + state_emb = self.state_proj(state) + state_emb = ( + state_emb[:, None, :] if state_emb.ndim == 2 else state_emb + ) # .to(dtype=self.vlm_with_expert.type) + embs.append(state_emb) + bsize = state_emb.shape[0] + dtype = state_emb.dtype + device = state_emb.device + + states_seq_len = state_emb.shape[1] + state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=device) + pad_masks.append(state_mask) + + # Set attention masks so that image and language inputs do not attend to state or actions + att_masks += [1] + [0] * (states_seq_len - 1) + # Fuse timestep + action information using an MLP + action_emb = self.action_in_proj(noisy_actions) + device = action_emb.device + bsize = action_emb.shape[0] + dtype = action_emb.dtype + # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] + time_emb = create_sinusoidal_pos_embedding( + timestep, + self.vlm_with_expert.expert_hidden_size, + self.config.min_period, + self.config.max_period, + device=device, + ) + time_emb = time_emb.type(dtype=dtype) + + time_emb = time_emb[:, None, :].expand_as(action_emb) + action_time_emb = torch.cat([action_emb, time_emb], dim=2) + + action_time_emb = self.action_time_mlp_in(action_time_emb) + action_time_emb = F.silu(action_time_emb) # swish == silu + action_time_emb = self.action_time_mlp_out(action_time_emb) + + # Add to input tokens + embs.append(action_time_emb) + + bsize, action_time_dim = action_time_emb.shape[:2] + action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device) + pad_masks.append(action_time_mask) + + # Set attention masks so that image, language and state inputs do not attend to action tokens + att_masks += [1] * self.config.chunk_size + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + return embs, pad_masks, att_masks + + def forward( + self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None + ) -> Tensor: + """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" + if noise is None: + noise = self.sample_noise(actions.shape, actions.device) + + if time is None: + time = self.sample_time(actions.shape[0], actions.device) + + time_expanded = time[:, None, None] + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t = noise - actions + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + images, img_masks, lang_tokens, lang_masks, state=state + ) + suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, time) + + pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) + att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) + + att_2d_masks = make_att_2d_masks(pad_masks, att_masks) + position_ids = torch.cumsum(pad_masks, dim=1) - 1 + (_, suffix_out), _ = self.vlm_with_expert.forward( + attention_mask=att_2d_masks, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + fill_kv_cache=False, + ) + suffix_out = suffix_out[:, -self.config.chunk_size :] + # Original openpi code, upcast attention output + suffix_out = suffix_out.to(dtype=torch.float32) + v_t = self.action_out_proj(suffix_out) + losses = F.mse_loss(u_t, v_t, reduction="none") + return losses + + def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor: + """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" + bsize = state.shape[0] + device = state.device + + if noise is None: + actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim) + noise = self.sample_noise(actions_shape, device) + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + images, img_masks, lang_tokens, lang_masks, state=state + ) + prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + # Compute image and language key value cache + _, past_key_values = self.vlm_with_expert.forward( + attention_mask=prefix_att_2d_masks, + position_ids=prefix_position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=self.config.use_cache, + fill_kv_cache=True, + ) + if self.config.regression_loss: + x_t = torch.zeros_like(noise, dtype=torch.float32, device=device) + expanded_time = torch.zeros(bsize, dtype=torch.float32, device=device) + x_t = self.denoise_step( + state, + prefix_pad_masks, + past_key_values, + x_t, + expanded_time, + ) + else: + dt = -1.0 / self.config.num_steps + dt = torch.tensor(dt, dtype=torch.float32, device=device) + + x_t = noise + time = torch.tensor(1.0, dtype=torch.float32, device=device) + while time >= -dt / 2: + expanded_time = time.expand(bsize) + v_t = self.denoise_step( + state, + prefix_pad_masks, + past_key_values, + x_t, + expanded_time, + ) + # Euler step + x_t += dt * v_t + time += dt + return x_t + + def denoise_step( + self, + state, + prefix_pad_masks, + past_key_values, + x_t, + timestep, + ): + """Apply one denoising step of the noise `x_t` at a given timestep.""" + suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep) + + suffix_len = suffix_pad_masks.shape[1] + batch_size = prefix_pad_masks.shape[0] + prefix_len = prefix_pad_masks.shape[1] + prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) + + suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) + + full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) + prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] + position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + + outputs_embeds, _ = self.vlm_with_expert.forward( + attention_mask=full_att_2d_masks, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=[None, suffix_embs], + use_cache=self.config.use_cache, + fill_kv_cache=False, + ) + suffix_out = outputs_embeds[1] + suffix_out = suffix_out[:, -self.config.chunk_size :] + suffix_out = suffix_out.to(dtype=torch.float32) + v_t = self.action_out_proj(suffix_out) + return v_t diff --git a/lerobot/common/policies/smolvla2/smolvlm_with_expert2.py b/lerobot/common/policies/smolvla2/smolvlm_with_expert2.py new file mode 100644 index 000000000..f1030cc4b --- /dev/null +++ b/lerobot/common/policies/smolvla2/smolvlm_with_expert2.py @@ -0,0 +1,599 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import List, Optional + +import torch +from torch import nn +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForImageTextToText, + AutoProcessor, + SmolVLMForConditionalGeneration, +) + + +def apply_rope(x, positions, max_wavelength=10_000): + """ + Applies RoPE positions [B, L] to x [B, L, H, D]. + """ + d_half = x.shape[-1] // 2 + device = x.device + dtype = x.dtype + x = x.to(torch.float32) + + freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device) + timescale = max_wavelength**freq_exponents + radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32) + + radians = radians[..., None, :] + + sin = torch.sin(radians) # .to(dtype=dtype) + cos = torch.cos(radians) # .to(dtype=dtype) + + x1, x2 = x.split(d_half, dim=-1) + res = torch.empty_like(x) + res[..., :d_half] = x1 * cos - x2 * sin + res[..., d_half:] = x2 * cos + x1 * sin + + return res.to(dtype) + + +def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256): + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + return hidden_dim + + +class SmolVLMWithExpertModel(nn.Module): + def __init__( + self, + model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct", + load_vlm_weights: bool = True, + train_expert_only: bool = True, + freeze_vision_encoder: bool = False, + attention_mode: str = "self_attn", + num_expert_layers: int = -1, + num_vlm_layers: int = -1, + self_attn_every_n_layers: int = -1, + expert_width_multiplier: float = 0.5, + ): + super().__init__() + if load_vlm_weights: + print(f"Loading {model_id} weights ...") + self.vlm = AutoModelForImageTextToText.from_pretrained( + model_id, + device_map="auto", + torch_dtype="bfloat16", + low_cpu_mem_usage=True, + ) + config = self.vlm.config + else: + config = AutoConfig.from_pretrained(model_id) + self.vlm = SmolVLMForConditionalGeneration(config=config) + self.processor = AutoProcessor.from_pretrained(model_id) + if num_vlm_layers > 0: + print(f"Reducing the number of VLM layers to {num_vlm_layers} ...") + self.get_vlm_model().text_model.layers = self.get_vlm_model().text_model.layers[:num_vlm_layers] + self.num_vlm_layers = len(self.get_vlm_model().text_model.layers) + self.config = config + # Smaller lm expert + lm_expert_config = copy.deepcopy(config.text_config) + hidden_size = lm_expert_config.hidden_size + lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2 + lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier)) + lm_expert_config.num_hidden_layers = self.num_vlm_layers + if num_expert_layers > 0: + assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, ( + f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}" + ) + lm_expert_config.num_hidden_layers = num_expert_layers + self.lm_expert = AutoModel.from_config(lm_expert_config) + + self.num_expert_layers = len(self.lm_expert.layers) + self.self_attn_every_n_layers = self_attn_every_n_layers + if "cross" in attention_mode: + # Reshape qkv projections to have the same input dimension as the vlm + for layer_idx in range(len(self.lm_expert.layers)): + if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0: + continue + self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear( + config.text_config.num_key_value_heads * config.text_config.head_dim, + lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, + bias=lm_expert_config.attention_bias, + ) + self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear( + config.text_config.num_key_value_heads * config.text_config.head_dim, + lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, + bias=lm_expert_config.attention_bias, + ) + # Remove unused embed_tokens + self.lm_expert.embed_tokens = None + + self.num_attention_heads = self.config.text_config.num_attention_heads + self.num_key_value_heads = self.config.text_config.num_key_value_heads + + self.freeze_vision_encoder = freeze_vision_encoder + self.train_expert_only = train_expert_only + self.attention_mode = attention_mode + self.expert_hidden_size = lm_expert_config.hidden_size + self.set_requires_grad() + + def configure_peft(self, config): + # return model + self.peft_method = config.peft_method + self.peft_target_model = config.peft_target_model + if "lora" in self.peft_method: + peft_config = config.peft_config + target_modules = peft_config.target_modules + if not isinstance(target_modules, list): + target_modules = target_modules.split(",") + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, # Based on the task type (e.g., language modeling, etc.) + r=peft_config.r, # The rank of the low-rank adaptation + lora_alpha=peft_config.lora_alpha, # Scaling factor + lora_dropout=peft_config.lora_dropout, # Dropout applied to LoRA layers + target_modules=target_modules, # The components where LoRA is applied + exclude_modules=[ + "lm_expert", + "model.lm_expert.model.layers", + ], # FIXME(mshukor): this does not work for now + ) + self.lora_config = lora_config + # Apply LoRA and ensure only LoRA parameters are trainable + if "text" in self.peft_target_model: + self.get_vlm_model().text_model = get_peft_model(self.get_vlm_model().text_model, lora_config) + else: + self.vlm = get_peft_model(self.vlm, lora_config) + # assert config.train_expert_only, "Backbone should be frozen and only lora parameters are " # FIXME(mshukor): handle this here? + for name, param in self.vlm.named_parameters(): + if ( + "lora" in name and "text_model.model.layers.17" not in name + ): # lm_head is not a parameter in most LLMs becasue it's tied to the embedding layer + param.requires_grad = True + else: + param.requires_grad = False + + def merge_lora_weights(self): + """ + Merge LoRA weights into the base model. + """ + if "text" in self.peft_target_model: + self.get_vlm_model().text_model = self.get_vlm_model().text_model.merge_and_unload() + else: + self.vlm = self.vlm.merge_and_unload() + + def get_vlm_model( + self, + ): + if hasattr(self.vlm.model, "model"): # When using peft + return self.vlm.model.model + else: + return self.vlm.model + + def set_requires_grad(self): + if self.freeze_vision_encoder: + self.get_vlm_model().vision_model.eval() + for params in self.get_vlm_model().vision_model.parameters(): + params.requires_grad = False + if self.train_expert_only: + self.vlm.eval() + for params in self.vlm.parameters(): + params.requires_grad = False + else: + # To avoid unused params issue with distributed training + last_layers = [self.num_vlm_layers - 1] + if ( + self.num_vlm_layers != self.num_expert_layers + and self.num_vlm_layers % self.num_expert_layers == 0 + ): + last_layers.append(self.num_vlm_layers - 2) + frozen_layers = [ + "lm_head", + "text_model.model.norm.weight", + ] + for layer in last_layers: + frozen_layers.append(f"text_model.model.layers.{layer}.") + + for name, params in self.vlm.named_parameters(): + if any(k in name for k in frozen_layers): + params.requires_grad = False + # To avoid unused params issue with distributed training + for name, params in self.lm_expert.named_parameters(): + if "lm_head" in name: + params.requires_grad = False + + def train(self, mode: bool = True): + super().train(mode) + + if self.freeze_vision_encoder: + self.get_vlm_model().vision_model.eval() + + if self.train_expert_only: + self.vlm.eval() + + def embed_image(self, image: torch.Tensor): + patch_attention_mask = None + # Get sequence from the vision encoder + image_hidden_states = ( + self.get_vlm_model() + .vision_model( + pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype), + patch_attention_mask=patch_attention_mask, + ) + .last_hidden_state + ) + # Modality projection & resampling + image_hidden_states = self.get_vlm_model().connector(image_hidden_states) + return image_hidden_states + + def embed_language_tokens(self, tokens: torch.Tensor): + return self.get_vlm_model().text_model.get_input_embeddings()(tokens) + + def forward_attn_layer( + self, + model_layers, + inputs_embeds, + layer_idx, + position_ids, + attention_mask, + batch_size, + head_dim, + use_cache: bool = True, + fill_kv_cache: bool = True, + past_key_values=None, + ) -> list[torch.Tensor]: + query_states = [] + key_states = [] + value_states = [] + for i, hidden_states in enumerate(inputs_embeds): + layer = model_layers[i][layer_idx] + if hidden_states is None or layer is None: + continue + hidden_states = layer.input_layernorm(hidden_states) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + + hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape) + + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + + # B,L,H,D with L sequence length, H number of heads, D head dim + # concatenate on the number of embeddings/tokens + query_states = torch.cat(query_states, dim=1) + key_states = torch.cat(key_states, dim=1) + value_states = torch.cat(value_states, dim=1) + seq_len = query_states.shape[1] + if seq_len < position_ids.shape[1]: + _position_ids = position_ids[:, :seq_len] + _attention_mask = attention_mask[:, :seq_len, :seq_len] + else: + _position_ids = position_ids + _attention_mask = attention_mask + + attention_mask_ = _attention_mask + position_ids_ = _position_ids + + query_states = apply_rope(query_states, position_ids_) + key_states = apply_rope(key_states, position_ids_) + + if use_cache and past_key_values is None: + past_key_values = {} + + if use_cache: + if fill_kv_cache: + past_key_values[layer_idx] = { + "key_states": key_states, + "value_states": value_states, + } + else: + # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before. + # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach + # the max len, then we (for instance) double the cache size. This implementation already exists + # in `transformers`. (molbap) + key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1) + value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1) + + attention_interface = self.get_attention_interface() + + att_output = attention_interface( + attention_mask_, batch_size, head_dim, query_states, key_states, value_states + ) + return [att_output], past_key_values + + def forward_cross_attn_layer( + self, + model_layers, + inputs_embeds, + layer_idx, + position_ids, + attention_mask, + batch_size, + head_dim, + use_cache: bool = True, + fill_kv_cache: bool = True, + past_key_values=None, + ) -> list[torch.Tensor]: + attention_interface = self.get_attention_interface() + + att_outputs = [] + assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), ( + f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}" + ) + + if len(inputs_embeds) == 2 and not past_key_values: + # Prefix attention + seq_len = inputs_embeds[0].shape[1] + position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:] + prefix_attention_mask = attention_mask[:, :seq_len, :seq_len] + + layer = model_layers[0][layer_idx] + + hidden_states = layer.input_layernorm(inputs_embeds[0]) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + + hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) + value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape) + + # B,L,H,D with L sequence length, H number of heads, D head dim + query_states = apply_rope(query_state, position_id) + key_states = apply_rope(key_state, position_id) + + att_output = attention_interface( + prefix_attention_mask, batch_size, head_dim, query_states, key_states, value_states + ) + att_outputs.append(att_output) + else: + expert_position_id = position_ids + + if use_cache and past_key_values is None: + past_key_values = {} + + if use_cache: + if fill_kv_cache: + past_key_values[layer_idx] = { + "key_states": key_states, + "value_states": value_states, + } + else: + # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before. + # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach + # the max len, then we (for instance) double the cache size. This implementation already exists + # in `transformers`. (molbap) + key_states = past_key_values[layer_idx]["key_states"] + value_states = past_key_values[layer_idx]["value_states"] + + # Expert + expert_layer = model_layers[1][layer_idx] + if expert_layer is not None: + expert_hidden_states = expert_layer.input_layernorm(inputs_embeds[1]) + + expert_input_shape = expert_hidden_states.shape[:-1] + expert_hidden_shape = (*expert_input_shape, -1, expert_layer.self_attn.head_dim) + + expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype) + expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape) + + _key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view( + *key_states.shape[:2], -1 + ) + expert_key_states = expert_layer.self_attn.k_proj(_key_states).view( + *_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim + ) # k_proj should have same dim as kv + + _value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view( + *value_states.shape[:2], -1 + ) + expert_value_states = expert_layer.self_attn.v_proj(_value_states).view( + *_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim + ) + + expert_position_id = ( + expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values + ) # start from 0 + expert_attention_mask = attention_mask[ + :, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] : + ] # take into account kv + + expert_query_states = apply_rope(expert_query_state, expert_position_id) + + att_output = attention_interface( + expert_attention_mask, + batch_size, + head_dim, + expert_query_states, + expert_key_states, + expert_value_states, + ) + att_outputs.append(att_output) + else: + att_outputs.append(None) + + # att_output = att_output.to(dtype=models[i].dtype) + return att_outputs, past_key_values + + def get_model_layers(self, models: list) -> list: + vlm_layers = [] + expert_layers = [] + multiple_of = self.num_vlm_layers // self.num_expert_layers + for i in range(self.num_vlm_layers): + if multiple_of > 0 and i > 0 and i % multiple_of != 0: + expert_layer = None + else: + expert_layer_index = i // multiple_of if multiple_of > 0 else i + expert_layer = models[1].layers[expert_layer_index] + vlm_layers.append(models[0].layers[i]) + expert_layers.append(expert_layer) + return [vlm_layers, expert_layers] + + def forward( + self, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: List[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + fill_kv_cache: Optional[bool] = None, + ): + models = [self.get_vlm_model().text_model, self.lm_expert] + model_layers = self.get_model_layers(models) + for hidden_states in inputs_embeds: + # TODO this is very inefficient + # dtype is always the same, batch size too (if > 1 len) + # device could be trickier in multi gpu edge cases but that's it + if hidden_states is None: + continue + batch_size = hidden_states.shape[0] + + # RMSNorm + num_layers = self.num_vlm_layers + head_dim = self.vlm.config.text_config.head_dim + for layer_idx in range(num_layers): + if ( + fill_kv_cache + or "cross" not in self.attention_mode + or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0) + ): + att_outputs, past_key_values = self.forward_attn_layer( + model_layers, + inputs_embeds, + layer_idx, + position_ids, + attention_mask, + batch_size, + head_dim, + use_cache=use_cache, + fill_kv_cache=fill_kv_cache, + past_key_values=past_key_values, + ) + else: + att_outputs, past_key_values = self.forward_cross_attn_layer( + model_layers, + inputs_embeds, + layer_idx, + position_ids, + attention_mask, + batch_size, + head_dim, + use_cache=use_cache, + fill_kv_cache=fill_kv_cache, + past_key_values=past_key_values, + ) + outputs_embeds = [] + start = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = model_layers[i][layer_idx] + att_output = ( + att_outputs[i] if i < len(att_outputs) else att_outputs[0] + ) # in case of self_attn + if hidden_states is not None: + if layer is None: + outputs_embeds.append(hidden_states) + continue + end = start + hidden_states.shape[1] + + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + att_out = att_output[:, start:end] + out_emb = layer.self_attn.o_proj(att_out) + + out_emb += hidden_states + after_first_residual = out_emb.clone() + + out_emb = layer.post_attention_layernorm(out_emb) + out_emb = layer.mlp(out_emb) + + out_emb += after_first_residual + + outputs_embeds.append(out_emb) + + start = end if len(att_outputs) == 1 else 0 + else: + outputs_embeds.append(None) + + inputs_embeds = outputs_embeds + + # final norm + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + if hidden_states is not None: + out_emb = models[i].norm(hidden_states) + outputs_embeds.append(out_emb) + else: + outputs_embeds.append(None) + return outputs_embeds, past_key_values + + def get_attention_interface(self): + attention_interface = self.eager_attention_forward + return attention_interface + + def eager_attention_forward( + self, attention_mask, batch_size, head_dim, query_states, key_states, value_states + ): + num_att_heads = self.num_attention_heads + num_key_value_heads = self.num_key_value_heads + num_key_value_groups = num_att_heads // num_key_value_heads + + sequence_length = key_states.shape[1] + + key_states = key_states[:, :, :, None, :].expand( + batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim + ) + key_states = key_states.reshape( + batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim + ) + + value_states = value_states[:, :, :, None, :].expand( + batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim + ) + value_states = value_states.reshape( + batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim + ) + + # Attention here is upcasted to float32 to match the original eager implementation. + query_states = query_states.to(dtype=torch.float32) + key_states = key_states.to(dtype=torch.float32) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + att_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + att_weights *= head_dim**-0.5 + + att_weights = att_weights.to(dtype=torch.float32) + big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py + masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg) + probs = nn.functional.softmax(masked_att_weights, dim=-1) + probs = probs.to(dtype=value_states.dtype) + + att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3)) + + att_output = att_output.permute(0, 2, 1, 3) + # we use -1 because sequence length can change + att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim) + + return att_output