From 4c2add41d764ae3e0e05a6702f75a84bae9a2d86 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Thu, 11 Sep 2025 14:18:09 +0200 Subject: [PATCH] remove files --- src/lerobot/configs/policies.py | 1 - src/lerobot/datasets/lerobot_dataset.py | 1 - src/lerobot/datasets/utils.py | 3 +- src/lerobot/policies/factory.py | 7 - src/lerobot/policies/normalize.py | 163 --- .../policies/smolpi0/configuration_smolpi0.py | 210 --- .../policies/smolpi0/flex_attention.py | 145 -- .../policies/smolpi0/modeling_smolpi0.py | 1190 ----------------- .../policies/smolpi0/smolvlm_with_expert.py | 920 ------------- .../policies/smolvla/modeling_smolvla.py | 956 +------------ .../policies/smolvla/modeling_smolvla_v2.py | 0 src/lerobot/policies/smolvla/saver.txt | 1 - src/lerobot/scripts/eval.py | 29 +- src/lerobot/scripts/train_2.py | 345 ----- src/lerobot/scripts/train_accelerate.py | 366 ----- 15 files changed, 3 insertions(+), 4334 deletions(-) delete mode 100644 src/lerobot/policies/smolpi0/configuration_smolpi0.py delete mode 100644 src/lerobot/policies/smolpi0/flex_attention.py delete mode 100644 src/lerobot/policies/smolpi0/modeling_smolpi0.py delete mode 100644 src/lerobot/policies/smolpi0/smolvlm_with_expert.py delete mode 100644 src/lerobot/policies/smolvla/modeling_smolvla_v2.py delete mode 100644 src/lerobot/policies/smolvla/saver.txt delete mode 100644 src/lerobot/scripts/train_2.py delete mode 100644 src/lerobot/scripts/train_accelerate.py diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 75863d3fc..f5fa727cf 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -62,7 +62,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP, # automatic gradient scaling is used. use_amp: bool = False - gradient_accumulation_steps: int = 1 push_to_hub: bool = True repo_id: str | None = None diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 6509993bb..875608c59 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -597,7 +597,6 @@ class LeRobotDataset(torch.utils.data.Dataset): """hf_dataset contains all the observations, states, actions, rewards, etc.""" if self.episodes is None: path = str(self.root / "data") - # added by jade hf_dataset = load_dataset("parquet", data_dir=path, split="train") else: files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes] diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index daa1de163..078c5351d 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -455,8 +455,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea shape = (shape[2], shape[0], shape[1]) elif key == "observation.environment_state": type = FeatureType.ENV - # changed by jade - elif key.startswith("observation") or key.startswith("state"): + elif key.startswith("observation"): type = FeatureType.STATE elif key.startswith("action"): type = FeatureType.ACTION diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index cc1b0480d..c3ae9cd54 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -31,7 +31,6 @@ from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig -from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig @@ -75,10 +74,6 @@ def get_policy_class(name: str) -> PreTrainedPolicy: from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy return SmolVLAPolicy - elif name == "smolpi0": - from lerobot.policies.smolpi0.modeling_smolpi0 import SMOLPI0Policy - - return SMOLPI0Policy else: raise NotImplementedError(f"Policy with name {name} is not implemented.") @@ -102,8 +97,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return SmolVLAConfig(**kwargs) elif policy_type == "reward_classifier": return RewardClassifierConfig(**kwargs) - elif policy_type == "smolpi0": - return SMOLPI0Config(**kwargs) else: raise ValueError(f"Policy type '{policy_type}' is not available.") diff --git a/src/lerobot/policies/normalize.py b/src/lerobot/policies/normalize.py index 646c330cb..119055873 100644 --- a/src/lerobot/policies/normalize.py +++ b/src/lerobot/policies/normalize.py @@ -255,85 +255,6 @@ class Unnormalize(nn.Module): return batch -class NormalizePerRobotType(nn.Module): - """Normalizes data (e.g. "observation.image") for more stable and faster convergence during training.""" - - def __init__( - self, - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, - ): - """ - Args: - shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values - are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing - mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape - is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format. - modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values - are their normalization modes among: - - "mean_std": subtract the mean and divide by standard deviation. - - "min_max": map to [-1, 1] range. - stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") - and values are dictionaries of statistic types and their values (e.g. - `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for - training the model for the first time, these statistics will overwrite the default buffers. If - not provided, as expected for finetuning or evaluation, the default buffers should to be - overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the - dataset is not needed to get the stats, since they are already in the policy state_dict. - """ - super().__init__() - self.features = features - self.norm_map = norm_map - for robot_type in stats.keys(): - stats_buffers = create_stats_buffers(features, norm_map, stats[robot_type]) - for key, buffer in stats_buffers.items(): - setattr(self, f"{robot_type}_buffer_" + key.replace(".", "_"), buffer) - - # TODO(rcadene): should we remove torch.no_grad? - @torch.no_grad - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - batch = dict(batch) # shallow copy avoids mutating the input batch - assert "robot_type" in batch, "robot_type is not in the batch" - robot_types = batch["robot_type"] - - for key, ft in self.features.items(): - if key not in batch: - continue - - norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - # FIXME(mshukor): make it more efficient - buffers = [ - getattr(self, f"{robot_type}_buffer_" + key.replace(".", "_")) for robot_type in robot_types - ] - if norm_mode is NormalizationMode.MEAN_STD: - mean = torch.stack([buffers[i]["mean"] for i in range(len(robot_types))], dim=0) - std = torch.stack([buffers[i]["std"] for i in range(len(robot_types))], dim=0) - if batch[key].ndim == 3: - mean = mean.unsqueeze(1) - std = std.unsqueeze(1) - assert not torch.isinf(mean).any(), _no_stats_error_str("mean") - assert not torch.isinf(std).any(), _no_stats_error_str("std") - batch[key] = (batch[key] - mean) / (std + 1e-8) - elif norm_mode is NormalizationMode.MIN_MAX: - min = torch.stack([buffers[i]["min"] for i in range(len(robot_types))], dim=0) - max = torch.stack([buffers[i]["max"] for i in range(len(robot_types))], dim=0) - assert not torch.isinf(min).any(), _no_stats_error_str("min") - assert not torch.isinf(max).any(), _no_stats_error_str("max") - if batch[key].ndim == 3: - min = min.unsqueeze(1) - max = max.unsqueeze(1) - # normalize to [0,1] - batch[key] = (batch[key] - min) / (max - min + 1e-8) - # normalize to [-1, 1] - batch[key] = batch[key] * 2 - 1 - else: - raise ValueError(norm_mode) - return batch - - # TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization # and remove the `Normalize` and `Unnormalize` classes. def _initialize_stats_buffers( @@ -497,87 +418,3 @@ class UnnormalizeBuffer(nn.Module): raise ValueError(norm_mode) return batch - - -class UnnormalizePerRobotType(nn.Module): - """ - Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their - original range used by the environment. - """ - - def __init__( - self, - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, - ): - """ - Args: - shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values - are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing - mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape - is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format. - modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values - are their normalization modes among: - - "mean_std": subtract the mean and divide by standard deviation. - - "min_max": map to [-1, 1] range. - stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") - and values are dictionaries of statistic types and their values (e.g. - `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for - training the model for the first time, these statistics will overwrite the default buffers. If - not provided, as expected for finetuning or evaluation, the default buffers should to be - overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the - dataset is not needed to get the stats, since they are already in the policy state_dict. - """ - super().__init__() - self.features = features - self.norm_map = norm_map - self.stats = stats - # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)` - for robot_type in stats.keys(): - stats_buffers = create_stats_buffers(features, norm_map, stats[robot_type]) - for key, buffer in stats_buffers.items(): - setattr(self, f"{robot_type}_buffer_" + key.replace(".", "_"), buffer) - - # TODO(rcadene): should we remove torch.no_grad? - @torch.no_grad - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - batch = dict(batch) # shallow copy avoids mutating the input batch - assert "robot_type" in batch, "robot_type is not in the batch" - robot_types = batch["robot_type"] - - for key, ft in self.features.items(): - if key not in batch: - continue - - norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - - # buffer = getattr(self, "buffer_" + key.replace(".", "_")) - buffers = [ - getattr(self, f"{robot_type}_buffer_" + key.replace(".", "_")) for robot_type in robot_types - ] - - if norm_mode is NormalizationMode.MEAN_STD: - mean = torch.stack([buffers[i]["mean"] for i in range(len(robot_types))], dim=0) - std = torch.stack([buffers[i]["std"] for i in range(len(robot_types))], dim=0) - assert not torch.isinf(mean).any(), _no_stats_error_str("mean") - assert not torch.isinf(std).any(), _no_stats_error_str("std") - if batch[key].ndim == 3: - mean = mean.unsqueeze(1) - std = std.unsqueeze(1) - batch[key] = batch[key] * std + mean - elif norm_mode is NormalizationMode.MIN_MAX: - min = torch.stack([buffers[i]["min"] for i in range(len(robot_types))], dim=0) - max = torch.stack([buffers[i]["max"] for i in range(len(robot_types))], dim=0) - assert not torch.isinf(min).any(), _no_stats_error_str("min") - assert not torch.isinf(max).any(), _no_stats_error_str("max") - if batch[key].ndim == 3: - min = min.unsqueeze(1) - max = max.unsqueeze(1) - batch[key] = (batch[key] + 1) / 2 - batch[key] = batch[key] * (max - min) + min - else: - raise ValueError(norm_mode) - return batch diff --git a/src/lerobot/policies/smolpi0/configuration_smolpi0.py b/src/lerobot/policies/smolpi0/configuration_smolpi0.py deleted file mode 100644 index e39d17f15..000000000 --- a/src/lerobot/policies/smolpi0/configuration_smolpi0.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass, field - -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.optim.optimizers import AdamWConfig -from lerobot.optim.schedulers import ( - CosineDecayWithWarmupSchedulerConfig, -) - - -@dataclass -class PEFTConfig: - r: int = 4 - lora_alpha: int = 16 - lora_dropout: float = 0.1 - target_modules: str = "q_proj,v_proj" - - -@PreTrainedConfig.register_subclass("smolpi0") -@dataclass -class SMOLPI0Config(PreTrainedConfig): - # Input / output structure. - n_obs_steps: int = 1 - chunk_size: int = 50 - n_action_steps: int = 50 - n_obs_gap: int = 1 - - normalization_mapping: dict[str, NormalizationMode] = field( - default_factory=lambda: { - "VISUAL": NormalizationMode.IDENTITY, - "STATE": NormalizationMode.MEAN_STD, - "ACTION": NormalizationMode.MEAN_STD, - } - ) - - # Shorter state and action vectors will be padded - max_state_dim: int = 32 - max_action_dim: int = 32 - - # Image preprocessing - resize_imgs_with_padding: tuple[int, int] = (512, 512) # (224, 224) - - # Add empty images. Used by pi0_aloha_sim which adds the empty - # left and right wrist cameras in addition to the top camera. - empty_cameras: int = 0 - - # Converts the joint and gripper values from the standard Aloha space to - # the space used by the pi internal runtime which was used to train the base model. - adapt_to_pi_aloha: bool = False - - # Converts joint dimensions to deltas with respect to the current state before passing to the model. - # Gripper dimensions will remain in absolute values. - use_delta_joint_actions_aloha: bool = False - - # Tokenizer - tokenizer_max_length: int = 48 - - # Projector - proj_width: int = 480 - - # Decoding - num_steps: int = 10 - - # Attention utils - use_cache: bool = True - attention_implementation: str = "eager" # or fa2, flex - - # Finetuning settings - freeze_vision_encoder: bool = True - train_expert_only: bool = False - train_state_proj: bool = True - - # Training presets - optimizer_lr: float = 2.5e-5 - optimizer_betas: tuple[float, float] = (0.9, 0.95) - optimizer_eps: float = 1e-8 - optimizer_weight_decay: float = 1e-10 - optimizer_grad_clip_norm: float = 10 - optimizer_lr_vlm: float = 0 - - scheduler_warmup_steps: int = 1_000 - scheduler_decay_steps: int = 30_000 - scheduler_decay_lr: float = 2.5e-6 - - # TODO: Add EMA - vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" - checkpoint_path: str = None - load_vlm_weights: bool = False - - peft_method: str = "" - peft_config: PEFTConfig = PEFTConfig() - peft_target_model: str = "" - - add_image_special_tokens: bool = False - add_prompt_template: bool = False - prefix_prompt_template: str = "<|im_start|>User: What action should the robot take to" - suffix_prompt_template: str = "?\nAssistant:" - - attention_mode: str = "self_attn" - - prefix_length: int = -1 # n_obs_steps * num_cameras * num_image_token_per_image + tokenizer_max_length - - past_obs_keys: str = "image" - - add_local_special_image_tokens: bool = False - - reverse_images_order: bool = False - - state_to_prefix: bool = False - - pad_language_to: str = "longest" # "max_length" - - num_expert_layers: int = -1 - num_vlm_layers: int = -1 - - causal_action_attention_mask: bool = False - - self_attn_every_n_layers: int = -1 - - expert_width_multiplier: float = 0.5 - - robot_type: str = "" - - self_attn_only_actions: bool = False - - causal_attention_on_history: bool = False - - predict_relative_actions: bool = False - relative_actions_mode: str = "first" - - shuffle_camera_positions: bool = False - vlm_img_size: int = -1 - - regression_loss: bool = False - - def __post_init__(self): - super().__post_init__() - if self.vlm_img_size > 0: - self.resize_imgs_with_padding = (self.vlm_img_size, self.vlm_img_size) - """Input validation (not exhaustive).""" - if self.n_action_steps > self.chunk_size: - raise ValueError( - f"The chunk size is the upper bound for the number of action steps per model invocation. Got " - f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." - ) - # if self.n_obs_steps != 1: - # raise ValueError( - # f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" - # ) - - if self.use_delta_joint_actions_aloha: - raise NotImplementedError( - "`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot." - ) - - def validate_features(self) -> None: - # TODO: implement value error - # if not self.image_features and not self.env_state_feature: - # raise ValueError("You must provide at least one image or the environment state among the inputs.") - - for i in range(self.empty_cameras): - key = f"observation.images.empty_camera_{i}" - empty_camera = PolicyFeature( - type=FeatureType.VISUAL, - shape=(3, 480, 640), - ) - self.input_features[key] = empty_camera - - def get_optimizer_preset(self) -> AdamWConfig: - return AdamWConfig( - lr=self.optimizer_lr, - betas=self.optimizer_betas, - eps=self.optimizer_eps, - weight_decay=self.optimizer_weight_decay, - grad_clip_norm=self.optimizer_grad_clip_norm, - ) - - def get_scheduler_preset(self): - return CosineDecayWithWarmupSchedulerConfig( - peak_lr=self.optimizer_lr, - decay_lr=self.scheduler_decay_lr, - num_warmup_steps=self.scheduler_warmup_steps, - num_decay_steps=self.scheduler_decay_steps, - ) - - @property - def observation_delta_indices(self) -> list: # FIXME(mshukor): support spacing between observations - return [-k for k in range(0, self.n_obs_steps * self.n_obs_gap, self.n_obs_gap)][::-1] - - @property - def action_delta_indices(self) -> list: - return list(range(self.chunk_size)) - - @property - def reward_delta_indices(self) -> None: - return None diff --git a/src/lerobot/policies/smolpi0/flex_attention.py b/src/lerobot/policies/smolpi0/flex_attention.py deleted file mode 100644 index 732920af2..000000000 --- a/src/lerobot/policies/smolpi0/flex_attention.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn.functional as F # noqa: N812 -from packaging.version import Version - -if Version(torch.__version__) > Version("2.5.0"): - # Ffex attention is only available from torch 2.5 onwards - from torch.nn.attention.flex_attention import ( - _mask_mod_signature, - _round_up_to_multiple, - create_block_mask, - create_mask, - flex_attention, - ) - - -@torch.compile(dynamic=False) -def flex_attention_forward( - attention_mask: torch.Tensor, - batch_size: int, - head_dim: int, - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - scaling=None, - num_att_heads: int = 8, - num_key_value_heads: int = 1, -): - """ - This is defined out of classes to make compile happy. - """ - - original_dtype = query_states.dtype - num_key_value_groups = num_att_heads // num_key_value_heads - key_states = key_states[:, :, :, None, :] - key_states = key_states.expand( - batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim - ) - key_states = key_states.reshape( - batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim - ) - - value_states = value_states[:, :, :, None, :] - value_states = value_states.expand( - batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim - ) - value_states = value_states.reshape( - batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim - ) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - # query_states = query_states.to(torch.float32) - # key_states = key_states.to(torch.float32) - # value_states = value_states.to(torch.float32) - - causal_mask = attention_mask - if causal_mask is not None: - causal_mask = causal_mask[:, None, :, : key_states.shape[2]] - - if causal_mask.shape[1] == 1 and query_states.shape[1] > 1: - causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1) - - def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature: - def mask_mod(b, h, q_idx, kv_idx): - # Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs. - return precomputed_mask[b][h][q_idx][kv_idx] - - return mask_mod - - b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask - - block_size = 128 # limitation of flex attention - q_len_rounded = _round_up_to_multiple(q_len, block_size) - kv_len_rounded = _round_up_to_multiple(kv_len, block_size) - - # *CRITICAL* we do need to expand here, else we get a CUDA index error - - pad_q = q_len_rounded - q_len - pad_k = kv_len_rounded - kv_len - if pad_q > 0 or pad_k > 0: - padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0) - else: - padded_causal_mask = causal_mask - mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask) - - mask_4d = create_mask( - mod_fn=mask_mod_fn_orig, - B=b_mask, - H=h_mask, - Q_LEN=q_len_rounded, - KV_LEN=kv_len_rounded, - device=causal_mask.device, - ) - - mask_mod_fn_padded = precomputed_mask_factory(mask_4d) - # FIXME(mshukor): compile mask torch.compile(create_block_mask) - create_block_mask_compiled = torch.compile(create_block_mask) - block_mask = create_block_mask_compiled( - mask_mod=mask_mod_fn_padded, - B=b_mask, - H=None, # - Q_LEN=q_len_rounded, - KV_LEN=kv_len_rounded, - BLOCK_SIZE=block_size, - device=causal_mask.device, - _compile=False, - ) - padded_query_states = F.pad(query_states, (0, 0, 0, pad_q), value=0.0) if pad_q > 0 else query_states - padded_key_states = F.pad(key_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else key_states - padded_value_states = F.pad(value_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else value_states - # mask is applied inside the kernel, ideally more efficiently than score_mod. - attn_output, attention_weights = flex_attention( - padded_query_states, - padded_key_states, - padded_value_states, - block_mask=block_mask, - enable_gqa=True, # because we shaped query/key states for GQA - scale=head_dim**-0.5 if scaling is None else scaling, - return_lse=True, - ) - - attn_output = attn_output.to(dtype=original_dtype) - attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim] - attn_output = attn_output.reshape( - batch_size, - -1, - attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim] - ) - return attn_output[:, :-pad_k, :] if pad_k > 0 else attn_output diff --git a/src/lerobot/policies/smolpi0/modeling_smolpi0.py b/src/lerobot/policies/smolpi0/modeling_smolpi0.py deleted file mode 100644 index 765a5901a..000000000 --- a/src/lerobot/policies/smolpi0/modeling_smolpi0.py +++ /dev/null @@ -1,1190 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -π0: A Vision-Language-Action Flow Model for General Robot Control - -[Paper](https://www.physicalintelligence.company/download/pi0.pdf) -[Jax code](https://github.com/Physical-Intelligence/openpi) - -Designed by Physical Intelligence. Ported from Jax by Hugging Face. - -Install pi0 extra dependencies: -```bash -pip install -e ".[pi0]" -``` - -Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`): -```bash -python lerobot/scripts/train.py \ ---policy.path=lerobot/pi0 \ ---dataset.repo_id=danaaubakirova/koch_test -``` - -Example of finetuning the pi0 neural network with PaliGemma and expert Gemma -pretrained with VLM default parameters before pi0 finetuning: -```bash -python lerobot/scripts/train.py \ ---policy.type=pi0 \ ---dataset.repo_id=danaaubakirova/koch_test -``` - -Example of using the pi0 pretrained model outside LeRobot training framework: -```python -policy = Pi0Policy.from_pretrained("lerobot/pi0") -``` - -""" - -import math -import os -import re -from collections import deque - -import safetensors -import torch -import torch.nn.functional as F # noqa: N812 -from torch import Tensor, nn -from transformers import AutoProcessor - -from lerobot.constants import ACTION, OBS_STATE -from lerobot.policies.normalize import ( - Normalize, - NormalizePerRobotType, - Unnormalize, - UnnormalizePerRobotType, -) -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config -from lerobot.policies.smolpi0.smolvlm_with_expert import SmolVLMWithExpertModel -from lerobot.utils.utils import get_safe_dtype - -OBS_IMAGE = "observation.image" -OBS_IMAGES = "observation.images" -ACTION = "action" -OBS_IMAGE_2 = "observation.image2" -OBS_IMAGE_3 = "observation.image3" -OBS_IMAGE_4 = "observation.image4" -TASK = "task" -ROBOT = "robot_type" -IMAGES_ORDER = { - OBS_IMAGE: 0, - OBS_IMAGE_2: 1, - OBS_IMAGE_3: 2, - OBS_IMAGE_4: 3, -} -import random - -from lerobot.policies.utils import ( - populate_queues, -) - - -def create_sinusoidal_pos_embedding( - time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" -) -> Tensor: - """Computes sine-cosine positional embedding vectors for scalar positions.""" - if dimension % 2 != 0: - raise ValueError(f"dimension ({dimension}) must be divisible by 2") - - if time.ndim != 1: - raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") - - dtype = get_safe_dtype(torch.float64, device.type) - fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) - period = min_period * (max_period / min_period) ** fraction - - # Compute the outer product - scaling_factor = 1.0 / period * 2 * math.pi - sin_input = scaling_factor[None, :] * time[:, None] - pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) - return pos_emb - - -def sample_beta(alpha, beta, bsize, device): - gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha) - gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta) - return gamma1 / (gamma1 + gamma2) - - -def make_att_2d_masks(pad_masks, att_masks): - """Copied from big_vision. - - Tokens can attend to valid inputs tokens which have a cumulative mask_ar - smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to - setup several types of attention, for example: - - [[1 1 1 1 1 1]]: pure causal attention. - - [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between - themselves and the last 3 tokens have a causal attention. The first - entry could also be a 1 without changing behaviour. - - [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a - block can attend all previous blocks and all tokens on the same block. - - Args: - input_mask: bool[B, N] true if its part of the input, false if padding. - mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on - it and 0 where it shares the same attention mask as the previous token. - """ - if att_masks.ndim != 2: - raise ValueError(att_masks.ndim) - if pad_masks.ndim != 2: - raise ValueError(pad_masks.ndim) - - cumsum = torch.cumsum(att_masks, dim=1) - att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] - pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] - att_2d_masks = att_2d_masks & pad_2d_masks - return att_2d_masks - - -def resize_with_pad(img, width, height, pad_value=-1): - # assume no-op when width height fits already - if img.ndim != 4: - raise ValueError(f"(b,c,h,w) expected, but {img.shape}") - - cur_height, cur_width = img.shape[2:] - - ratio = max(cur_width / width, cur_height / height) - resized_height = int(cur_height / ratio) - resized_width = int(cur_width / ratio) - resized_img = F.interpolate( - img, size=(resized_height, resized_width), mode="bilinear", align_corners=False - ) - - pad_height = max(0, int(height - resized_height)) - pad_width = max(0, int(width - resized_width)) - - # pad on left and top of image - padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) - return padded_img - - -_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_") - - -def canonicalise(k: str) -> str: - """ - Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a - normalisation-buffer key. - """ - return _VARIANT_RE.sub(".buffer_", k) - - -def standardise_state_dict( - checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True -) -> tuple[dict[str, torch.Tensor], list[str]]: - """ - • Re-keys `checkpoint ` so that every entry matches the *reference* key set. - • If several variant keys collapse to the same canonical name we keep the - first one and log the collision. - • Returns the new dict + a list of entries that could not be matched. - """ - out, collisions, unmatched = {}, {}, [] - - for k, v in checkpoint.items(): - canon = canonicalise(k) - if canon in ref_keys: - if canon in out: # duplicate after collapsing - collisions.setdefault(canon, []).append(k) - else: - out[canon] = v - else: - unmatched.append(k) - - if verbose: - for canon, variants in collisions.items(): - print(f"[standardise_state_dict] '{canon}' ← {variants}") - if unmatched: - print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys") - - out.update({k: checkpoint[k] for k in unmatched}) - return out, unmatched - - -def load_smolvla( - model: torch.nn.Module, - filename: str | os.PathLike, - *, - device: str = "cpu", - checkpoint_keys_mapping: str = "", -) -> torch.nn.Module: - state_dict = safetensors.torch.load_file(filename, device=device) - - # Optional user-supplied renames (e.g. "model._orig_mod.//model.") - if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping: - state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping) - - state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys())) - - # HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset - norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs") - state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)} - - missing, unexpected = model.load_state_dict(state_dict, strict=False) - if not all(key.startswith(norm_keys) for key in missing) or unexpected: - raise RuntimeError( - "SmolVLA %d missing / %d unexpected keys", - len(missing), - len(unexpected), - ) - - return model - - -def pad_vector(vector, new_dim): - """Can be (batch_size x sequence_length x features_dimension) - or (batch_size x features_dimension) - """ - if vector.shape[-1] == new_dim: - return vector - shape = list(vector.shape) - current_dim = shape[-1] - shape[-1] = new_dim - new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device) - new_vector[..., :current_dim] = vector - return new_vector - - -def normalize(x, min_val, max_val): - return (x - min_val) / (max_val - min_val) - - -def unnormalize(x, min_val, max_val): - return x * (max_val - min_val) + min_val - - -def safe_arcsin(value): - # This ensures that the input stays within - # [−1,1] to avoid invalid values for arcsin - return torch.arcsin(torch.clamp(value, -1.0, 1.0)) - - -def aloha_gripper_to_angular(value): - # Aloha transforms the gripper positions into a linear space. The following code - # reverses this transformation to be consistent with pi0 which is pretrained in - # angular space. - # - # These values are coming from the Aloha code: - # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED - value = unnormalize(value, min_val=0.01844, max_val=0.05800) - - # This is the inverse of the angular to linear transformation inside the Interbotix code. - def linear_to_radian(linear_position, arm_length, horn_radius): - value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) - return safe_arcsin(value) - - # The constants are taken from the Interbotix code. - value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) - - # Normalize to [0, 1]. - # The values 0.4 and 1.5 were measured on an actual Trossen robot. - return normalize(value, min_val=0.4, max_val=1.5) - - -def rename_checkpoint_keys(checkpoint: dict, rename_str: str): - """ - Renames keys in a checkpoint dictionary based on the given rename string. - - Args: - checkpoint (dict): The checkpoint dictionary. - rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2". - - Returns: - dict: The modified checkpoint with renamed keys. - """ - - rename_dict = dict(pair.split("//") for pair in rename_str.split(",")) - - new_checkpoint = {} - for k, v in checkpoint.items(): - for old_key, new_key in rename_dict.items(): - if old_key in k: - k = k.replace(old_key, new_key) - new_checkpoint[k] = v - return new_checkpoint - - -def aloha_gripper_from_angular(value): - # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha. - # Note that the units are still angular but the range is different. - - # The values 0.4 and 1.5 were measured on an actual Trossen robot. - value = unnormalize(value, min_val=0.4, max_val=1.5) - - # These values are coming from the Aloha code: - # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE - return normalize(value, min_val=-0.6213, max_val=1.4910) - - -def aloha_gripper_from_angular_inv(value): - # Directly inverts the gripper_from_angular function. - value = unnormalize(value, min_val=-0.6213, max_val=1.4910) - return normalize(value, min_val=0.4, max_val=1.5) - - -class SMOLPI0Policy(PreTrainedPolicy): - """Wrapper class around VLAFlowMatching model to train and run inference within LeRobot.""" - - config_class = SMOLPI0Config - name = "smolpi0" - - def __init__( - self, - config: SMOLPI0Config, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, - ): - """ - Args: - config: Policy configuration class instance or None, in which case the default instantiation of - the configuration class is used. - dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected - that they will be passed with a call to `load_state_dict` before the policy is used. - """ - - super().__init__(config) - config.validate_features() - self.config = config - self.normalize_per_robot_type = getattr( - config, "normalize_per_robot_type", False - ) # FIXME(mshukor): assert in case of single dataset - if self.normalize_per_robot_type: - if not dataset_stats: - dataset_stats[config.robot_type] = {} - self.normalize_inputs = NormalizePerRobotType( - config.input_features, config.normalization_mapping, dataset_stats - ) - self.normalize_targets = NormalizePerRobotType( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = UnnormalizePerRobotType( - config.output_features, config.normalization_mapping, dataset_stats - ) - else: - self.normalize_inputs = Normalize( - config.input_features, config.normalization_mapping, dataset_stats - ) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - - self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer - self.model = VLAFlowMatching(config) - self.include_past_states = config.n_obs_steps > 1 and OBS_STATE in self.config.past_obs_keys.split( - "," - ) - self.include_past_images = config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split(",") - self.num_past_images = self.config.n_obs_steps if self.include_past_images else 1 - self.reset() - - def reset(self): - """This should be called whenever the environment is reset.""" - # self._action_queue = deque([], maxlen=self.config.n_action_steps) - self._queues = { - ACTION: deque(maxlen=self.config.n_action_steps), - } - if self.config.n_obs_steps > 1: - for k in self.config.input_features: - if any([past_obs_key in k for past_obs_key in self.config.past_obs_keys.split(",")]): - self._queues[k] = deque(maxlen=self.config.n_obs_steps) - - def get_optim_params(self) -> dict: - if self.config.optimizer_lr_vlm > 0 and self.config.optimizer_lr_vlm != self.config.optimizer_lr: - params = [ - {"params": [p for n, p in self.named_parameters() if ".vlm." not in n and p.requires_grad]}, - { - "params": [p for n, p in self.named_parameters() if ".vlm." in n and p.requires_grad], - "lr": self.config.optimizer_lr_vlm, - }, - ] - return params - - else: - return self.parameters() - - def merge_peft_model_weights(self) -> None: - if "lora" in self.config.peft_method: - self.model.vlm_with_expert.merge_lora_weights() - - @torch.no_grad - def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: - """Select a single action given environment observations. - - This method wraps `select_actions` in order to return one action at a time for execution in the - environment. It works by managing the actions in a queue and only calling `select_actions` when the - queue is empty. - """ - self.eval() - - if self.config.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - - batch = self.normalize_inputs(batch) - - images, img_masks = self.prepare_images(batch) - state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) - - actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise) - # Unpad actions - original_action_dim = self.config.action_feature.shape[0] - actions = actions[:, :, :original_action_dim] - - actions = self.unnormalize_outputs({"action": actions, "robot_type": batch["robot_type"]})["action"] - - if self.config.adapt_to_pi_aloha: - actions = self._pi_aloha_encode_actions(actions) - - return actions - - # HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues - @classmethod - def _load_as_safetensor( - cls, - model: "SmolVLAPolicy", - model_file: str, - map_location: str, - strict: bool, - **kwargs, - ): - safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) - return load_smolvla( - model, - model_file, - device=map_location, - checkpoint_keys_mapping="model._orig_mod.//model.", - ) - - @torch.no_grad - def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: - """Select a single action given environment observations. - - This method wraps `select_actions` in order to return one action at a time for execution in the - environment. It works by managing the actions in a queue and only calling `select_actions` when the - queue is empty. - """ - self.eval() - - if self.config.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - - batch = self.normalize_inputs(batch) - - self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) - # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by - # querying the policy. - if len(self._queues[ACTION]) == 0: - for k in batch: - if k in self._queues: - batch[k] = torch.stack(list(self._queues[k]), dim=1) - images, img_masks = self.prepare_images(batch) - state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) - actions = self.model.sample_actions( - images, img_masks, lang_tokens, lang_masks, state, noise=noise - ) - if self.config.predict_relative_actions and actions.ndim == 3: - # If the model predicts relative actions, we need to unpad the actions - # and then convert them to absolute actions. - if self.config.relative_actions_mode == "first": - actions = torch.cat((actions[:, :1], actions[:, 1:] + actions[:, :1]), dim=1) - elif self.config.relative_actions_mode == "state": - actions = actions + state.unsqueeze(1) - else: - actions = torch.cat((actions[:, :1], actions[:, 1:] + actions[:, :-1]), dim=1) - # Unpad actions - original_action_dim = self.config.action_feature.shape[0] - actions = actions[:, :, :original_action_dim] - - actions = self.unnormalize_outputs({"action": actions})["action"] - - if self.config.adapt_to_pi_aloha: - actions = self._pi_aloha_encode_actions(actions) - - # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue - # effectively has shape (n_action_steps, batch_size, *), hence the transpose. - self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps]) - return self._queues[ACTION].popleft() - - def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: - """Do a full training forward pass to compute the loss""" - if self.config.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) - batch = self.normalize_inputs(batch) - batch = self.normalize_targets(batch) - images, img_masks = self.prepare_images( - batch - ) # FIXME(mshukor): adapte it to take into account already padded images in the batch - state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) - actions = self.prepare_action(batch, state=state) - actions_is_pad = batch.get("actions_id_pad") - loss_dict = {} - losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) - loss_dict["losses_after_forward"] = losses.mean().clone() - - if actions_is_pad is not None: - in_episode_bound = ~actions_is_pad - losses = losses * in_episode_bound.unsqueeze(-1) - loss_dict["losses_after_in_ep_bound"] = losses.mean().clone() - - # Remove padding - losses = losses[:, :, : self.config.max_action_dim] - loss_dict["losses_after_rm_padding"] = losses.mean().clone() - - # For backward pass - loss = losses.mean() - # For backward pass - loss_dict["loss"] = loss - # # For logging - # loss_dict["l2_loss"] = loss.item() # remove for torch compile - return loss_dict - - def prepare_images(self, batch): - """Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and - convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP. - """ - images = [] - img_masks = [] - present_img_keys = [key for key in self.config.image_features if key in batch] - missing_img_keys = [key for key in self.config.image_features if key not in batch] - - present_img_keys = sorted( - present_img_keys, - key=lambda k: IMAGES_ORDER.get(k, float("inf")), - reverse=self.config.reverse_images_order, - ) - if self.config.shuffle_camera_positions and ACTION in batch: # only during training - present_img_keys = random.sample(present_img_keys, len(present_img_keys)) - if len(present_img_keys) == 0: - raise ValueError( - f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" - ) - for i in range(self.num_past_images): - # Preprocess image features present in the batch - for key in present_img_keys: - img = batch[key][:, i, :, :, :] if batch[key].ndim == 5 else batch[key] - if self.config.resize_imgs_with_padding is not None: - img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0) - - # Normalize from range [0,1] to [-1,1] as expacted by siglip - img = img * 2.0 - 1.0 - - bsize = img.shape[0] - device = img.device - if f"{key}_padding_mask" in batch: - mask = batch[f"{key}_padding_mask"].bool() - else: - mask = torch.ones(bsize, dtype=torch.bool, device=device) - images.append(img) - img_masks.append(mask) - - # Create image features not present in the batch - # as fully 0 padded images. - for num_empty_cameras in range(len(missing_img_keys)): - if num_empty_cameras >= self.config.empty_cameras: - break - img = torch.ones_like(img) * -1 - mask = torch.zeros_like(mask) - images.append(img) - img_masks.append(mask) - return images, img_masks - - def prepare_language(self, batch) -> tuple[Tensor, Tensor]: - """Tokenize the text input""" - device = batch[OBS_STATE].device - tasks = batch["task"] - if len(tasks) == 1: - tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])] - - if self.config.add_prompt_template: - tasks = [ - f"{self.config.prefix_prompt_template}{task}{self.config.suffix_prompt_template}" - for task in tasks - ] - else: - tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] - tokenized_prompt = self.language_tokenizer.__call__( - tasks, - padding=self.config.pad_language_to, - padding_side="right", - max_length=self.config.tokenizer_max_length, - return_tensors="pt", - truncation=True, # FIXME(mshukor) - ) - - lang_tokens = tokenized_prompt["input_ids"].to(device=device) - lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) - - return lang_tokens, lang_masks - - def _pi_aloha_decode_state(self, state): - # Flip the joints. - for motor_idx in [1, 2, 8, 9]: - state[:, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx]) - return state - - def _pi_aloha_encode_actions(self, actions): - # Flip the joints. - for motor_idx in [1, 2, 8, 9]: - actions[:, :, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx]) - return actions - - def _pi_aloha_encode_actions_inv(self, actions): - # Flip the joints again. - for motor_idx in [1, 2, 8, 9]: - actions[:, :, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) - return actions - - def prepare_state(self, batch): - """Pad state""" - state = ( - batch[OBS_STATE][:, -1, :] - if (batch[OBS_STATE].ndim > 2 and not self.include_past_states) - else batch[OBS_STATE] - ) # FIXME(mshukor): no state history for now - state = pad_vector(state, self.config.max_state_dim) - return state - - def prepare_action(self, batch, state=None): - """Pad action""" - actions = pad_vector(batch[ACTION], self.config.max_action_dim) - if self.config.predict_relative_actions and actions.ndim == 3: - if self.config.relative_actions_mode == "first": - actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :1]), dim=1) - elif self.config.relative_actions_mode == "state": - assert batch[ACTION].shape[-1] == batch[OBS_STATE].shape[-1], ( - "Relative action mode 'state' requires the action and state to have the same dimension." - ) - if state.ndim == 2: - state = state.unsqueeze(1) - actions = actions - state - else: - actions = torch.cat((actions[:, :1], actions[:, 1:] - actions[:, :-1]), dim=1) - return actions - - -def pad_tensor(tensor, max_len, pad_value=0): - """ - Efficiently pads a tensor along sequence dimension to match max_len. - - Args: - tensor (torch.Tensor): Shape (B, L, ...) or (B, L). - max_len (int): Fixed sequence length. - pad_value (int/float): Value for padding. - - Returns: - torch.Tensor: Shape (B, max_len, ...) or (B, max_len). - """ - B, L = tensor.shape[:2] - - # Create a padded tensor of max_len and copy the existing values - padded_tensor = torch.full( - (B, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device - ) - padded_tensor[:, :L] = tensor # Efficient in-place copy - - return padded_tensor - - -class VLAFlowMatching(nn.Module): - """ - π0: A Vision-Language-Action Flow Model for General Robot Control - - [Paper](https://www.physicalintelligence.company/download/pi0.pdf) - [Jax code](https://github.com/Physical-Intelligence/openpi) - - Designed by Physical Intelligence. Ported from Jax by Hugging Face. - ┌──────────────────────────────┐ - │ actions │ - │ ▲ │ - │ ┌┴─────┐ │ - │ kv cache │Gemma │ │ - │ ┌──────────►│Expert│ │ - │ │ │ │ │ - │ ┌┴────────┐ │x 10 │ │ - │ │ │ └▲──▲──┘ │ - │ │ VLM │ │ │ │ - │ │ │ │ robot state │ - │ │ │ noise │ - │ └▲──▲─────┘ │ - │ │ │ │ - │ │ image(s) │ - │ language tokens │ - └──────────────────────────────┘ - """ - - def __init__(self, config): - super().__init__() - self.config = config - - self.vlm_with_expert = SmolVLMWithExpertModel( - model_id=self.config.vlm_model_name, - freeze_vision_encoder=self.config.freeze_vision_encoder, - train_expert_only=self.config.train_expert_only, - attention_implementation=self.config.attention_implementation, - load_vlm_weights=self.config.load_vlm_weights, - attention_mode=self.config.attention_mode, - num_expert_layers=self.config.num_expert_layers, - num_vlm_layers=self.config.num_vlm_layers, - self_attn_every_n_layers=self.config.self_attn_every_n_layers, - expert_width_multiplier=self.config.expert_width_multiplier, - self_attn_only_actions=self.config.self_attn_only_actions, - ) - # self.paligemma_with_expert = self.configure_peft(paligemma_with_expert) - self.vlm_with_expert.configure_peft(config=self.config) - # Projections are float32 - self.state_to_prefix = self.config.state_to_prefix - if self.state_to_prefix: - self.state_proj = nn.Linear( - self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size - ) - else: - self.state_proj = nn.Linear(self.config.max_state_dim, self.vlm_with_expert.expert_hidden_size) - self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size) - self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim) - - self.action_time_mlp_in = nn.Linear( - self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size - ) - self.action_time_mlp_out = nn.Linear( - self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size - ) - - self.set_requires_grad() - # SmolVLM2 has: [fake_tok + crop_tok + crop + fake_tok + crop_tok ... + fake_tok + global_tok + global + fake_tok] + [second image] + ... - if any([k in self.config.vlm_model_name for k in ["SmolVLM-", "SmolVLA-"]]): - if "SmolVLM-Instruct" in self.config.vlm_model_name: - self.fake_image_token = 49152 - self.global_image_token = [44, 13906, 29, 6266, 46] - self.global_image_start_token = torch.tensor( - [self.fake_image_token] + self.global_image_token, dtype=torch.long - ) - else: - self.fake_image_token = 49189 - self.global_image_token = 49152 - self.global_image_start_token = torch.tensor( - [self.fake_image_token, self.global_image_token], dtype=torch.long - ) - else: - self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id - self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id - self.global_image_start_token = torch.tensor( - [self.fake_image_token, self.global_image_token], dtype=torch.long - ) - - self.add_image_special_tokens = self.config.add_image_special_tokens - self.add_local_special_image_tokens = self.config.add_local_special_image_tokens - self.local_image_tokens = [ - torch.tensor([self.fake_image_token, tok], dtype=torch.long) - for tok in [49153, 49154, 49155, 49159, 49160, 49161, 49165, 49166, 49167] - ] # assume 3 x 3 grid - - self.local_image_start_token = self.global_image_start_token - self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long) - self.prefix_length = self.config.prefix_length - self.include_past_images = self.config.n_obs_steps > 1 and "image" in self.config.past_obs_keys.split( - "," - ) - self.num_past_images = self.config.n_obs_steps if self.include_past_images else 1 - self.causal_attention_on_history = self.config.causal_attention_on_history - - # def configure_peft(self, model): - # # return model - # self.peft_method = self.config.peft_method - # if "lora" in self.peft_method: - # peft_config = self.config.peft_config - # target_modules = peft_config.target_modules - # if not isinstance(target_modules, list): - # target_modules = target_modules.split(",") - # lora_config = LoraConfig( - # task_type=TaskType.CAUSAL_LM, # Based on the task type (e.g., language modeling, etc.) - # r=peft_config.r, # The rank of the low-rank adaptation - # lora_alpha=peft_config.lora_alpha, # Scaling factor - # lora_dropout=peft_config.lora_dropout, # Dropout applied to LoRA layers - # target_modules=target_modules, # The components where LoRA is applied - # exclude_modules=["gemma_expert", "model.gemma_expert.model.layers"], # FIXME(mshukor): this does not work for now - # ) - # # LoraConfig(task_type=TaskType.CAUSAL_LM, r=16, lora_alpha=1, lora_dropout=0, target_modules=["q_proj"], exclude_modules=["gemma_expert"]) - # self.lora_config = lora_config - # # Apply LoRA and ensure only LoRA parameters are trainable - - # model = get_peft_model(model, lora_config) - # assert self.config.train_expert_only, "Backbone should be frozen and only lora parameters are " # FIXME(mshukor): handle this here? - # for name, param in model.named_parameters(): - # if ( - # "lora" in name - # ): # lm_head is not a parameter in most LLMs becasue it's tied to the embedding layer - # param.requires_grad = True - # return model - - def set_requires_grad(self): - for params in self.state_proj.parameters(): - params.requires_grad = self.config.train_state_proj - - def sample_noise(self, shape, device): - noise = torch.normal( - mean=0.0, - std=1.0, - size=shape, - dtype=torch.float32, - device=device, - ) - return noise - - def sample_time(self, bsize, device): - time_beta = sample_beta(1.5, 1.0, bsize, device) - time = time_beta * 0.999 + 0.001 - return time.to(dtype=torch.float32, device=device) - - def embed_prefix( - self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Embed images with SigLIP and language tokens with embedding layer to prepare - for SmolVLM transformer processing. - """ - # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty - embs = [] - pad_masks = [] - att_masks = [] - num_images = len(images) // self.num_past_images - # TODO: remove for loop - for img_idx, ( - img, - img_mask, - ) in enumerate(zip(images, img_masks, strict=False)): - # FIXME(mshukor): add special tokens for the history each history_steps or not - if self.add_image_special_tokens: - if self.add_local_special_image_tokens and img_idx % num_images != num_images - 1: - local_token_idx = img_idx % num_images - image_start_token = ( - self.vlm_with_expert.embed_language_tokens( - self.local_image_tokens[local_token_idx].to( - device=self.vlm_with_expert.vlm.device - ) - ) - .unsqueeze(0) - .expand(img.shape[0], -1, -1) - ) - else: - image_start_token = ( - self.vlm_with_expert.embed_language_tokens( - self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device) - ) - .unsqueeze(0) - .expand(img.shape[0], -1, -1) - ) - image_start_mask = torch.ones_like( - image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device - ) - if self.causal_attention_on_history and img_idx % num_images == 0: - att_masks += [1] + [0] * (image_start_mask.shape[-1] - 1) - else: - att_masks += [0] * (image_start_mask.shape[-1]) - embs.append(image_start_token) - pad_masks.append(image_start_mask) - - img_emb = self.vlm_with_expert.embed_image(img) - img_emb = img_emb # .to(dtype=self.vlm_with_expert.type) - - # Normalize image embeddings - img_emb_dim = img_emb.shape[-1] - img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device) - - bsize, num_img_embs = img_emb.shape[:2] - img_mask = img_mask[:, None].expand(bsize, num_img_embs) - - # FIXME(mshukor): add special image tokens. Assume no tiling fake global images fake - # template <|im_start|>User: What actions? image tokens \nAssistant: or processor.apply_chat_template? - # processor.fake_image_token - # processor.global_image_token - - embs.append(img_emb) - pad_masks.append(img_mask) - - att_masks += [0] * (num_img_embs) - if self.add_image_special_tokens: - if not self.add_local_special_image_tokens or ( - self.add_local_special_image_tokens and img_idx % num_images == num_images - 1 - ): - image_end_token = ( - self.vlm_with_expert.embed_language_tokens( - self.image_end_token.to(device=self.vlm_with_expert.vlm.device) - ) - .unsqueeze(0) - .expand(img.shape[0], -1, -1) - ) - image_end_mask = torch.ones_like( - image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device - ) - embs.append(image_end_token) - pad_masks.append(image_end_mask) - att_masks += [0] * (image_end_mask.shape[1]) - lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens) - # Normalize language embeddings - lang_emb_dim = lang_emb.shape[-1] - lang_emb = lang_emb * math.sqrt(lang_emb_dim) # FIXME(mshukor): is this needed for smolvlm? - - embs.append(lang_emb) - pad_masks.append(lang_masks) - - # full attention between image and language inputs - num_lang_embs = lang_emb.shape[1] - att_masks += [0] * num_lang_embs - - if state is not None and self.state_to_prefix: - state_emb = self.state_proj(state) - state_emb = ( - state_emb[:, None, :] if state_emb.ndim == 2 else state_emb - ) # .to(dtype=self.vlm_with_expert.type) - embs.append(state_emb) - bsize = state_emb.shape[0] - dtype = state_emb.dtype - device = state_emb.device - - states_seq_len = state_emb.shape[1] - state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=device) - pad_masks.append(state_mask) - - # Set attention masks so that image and language inputs do not attend to state or actions - # att_masks += [1] + [0]*(states_seq_len - 1) - att_masks += [1] * (states_seq_len) - embs = torch.cat(embs, dim=1) - pad_masks = torch.cat(pad_masks, dim=1) - att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) - att_masks = att_masks[None, :] - - seq_len = pad_masks.shape[1] - if seq_len < self.prefix_length: - embs = pad_tensor(embs, self.prefix_length, pad_value=0) - pad_masks = pad_tensor(pad_masks, self.prefix_length, pad_value=0) - att_masks = pad_tensor(att_masks, self.prefix_length, pad_value=0) - - att_masks = att_masks.expand(bsize, -1) - - return embs, pad_masks, att_masks - - def embed_suffix(self, state, noisy_actions, timestep): - """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" - embs = [] - pad_masks = [] - att_masks = [] - - # Embed state - if not self.state_to_prefix: - state_emb = self.state_proj(state) - state_emb = ( - state_emb[:, None, :] if state_emb.ndim == 2 else state_emb - ) # .to(dtype=self.vlm_with_expert.type) - embs.append(state_emb) - bsize = state_emb.shape[0] - dtype = state_emb.dtype - device = state_emb.device - - states_seq_len = state_emb.shape[1] - state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=device) - pad_masks.append(state_mask) - - # Set attention masks so that image and language inputs do not attend to state or actions - att_masks += [1] + [0] * (states_seq_len - 1) - - # Fuse timestep + action information using an MLP - action_emb = self.action_in_proj(noisy_actions) - device = action_emb.device - bsize = action_emb.shape[0] - dtype = action_emb.dtype - # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] - time_emb = create_sinusoidal_pos_embedding( - timestep, self.vlm_with_expert.expert_hidden_size, min_period=4e-3, max_period=4.0, device=device - ) - time_emb = time_emb.type(dtype=dtype) - - time_emb = time_emb[:, None, :].expand_as(action_emb) - action_time_emb = torch.cat([action_emb, time_emb], dim=2) - - action_time_emb = self.action_time_mlp_in(action_time_emb) - action_time_emb = F.silu(action_time_emb) # swish == silu - action_time_emb = self.action_time_mlp_out(action_time_emb) - - # Add to input tokens - embs.append(action_time_emb) - - bsize, action_time_dim = action_time_emb.shape[:2] - action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device) - pad_masks.append(action_time_mask) - - # Set attention masks so that image, language and state inputs do not attend to action tokens - if self.config.causal_action_attention_mask: - att_masks += [1] * self.config.chunk_size - else: - att_masks += [1] + ([0] * (self.config.chunk_size - 1)) - embs = torch.cat(embs, dim=1) - pad_masks = torch.cat(pad_masks, dim=1) - att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) - att_masks = att_masks[None, :].expand(bsize, len(att_masks)) - return embs, pad_masks, att_masks - - def forward( - self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None - ) -> Tensor: - """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" - if noise is None: - noise = self.sample_noise(actions.shape, actions.device) - - if time is None: - time = self.sample_time(actions.shape[0], actions.device) - - time_expanded = time[:, None, None] - if self.config.regression_loss: - # Hack to compare regression to flow matching - time = torch.zeros_like(time, dtype=time.dtype, device=time.device) - x_t = torch.zeros_like(actions, dtype=actions.dtype, device=actions.device) - u_t = actions - else: - x_t = time_expanded * noise + (1 - time_expanded) * actions - u_t = noise - actions - prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( - images, img_masks, lang_tokens, lang_masks, state=state - ) - suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time) - - pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) - att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) - - att_2d_masks = make_att_2d_masks(pad_masks, att_masks) - position_ids = torch.cumsum(pad_masks, dim=1) - 1 - (_, suffix_out), _ = self.vlm_with_expert.forward( - attention_mask=att_2d_masks, - position_ids=position_ids, - past_key_values=None, - inputs_embeds=[prefix_embs, suffix_embs], - use_cache=False, - fill_kv_cache=False, - ) - suffix_out = suffix_out[:, -self.config.chunk_size :] - # Original openpi code, upcast attention output - suffix_out = suffix_out.to(dtype=torch.float32) - v_t = self.action_out_proj(suffix_out) - if self.config.regression_loss: - losses = F.l1_loss(u_t, v_t, reduction="none") - else: - losses = F.mse_loss(u_t, v_t, reduction="none") - return losses - - def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor: - """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" - bsize = state.shape[0] - device = state.device - - if noise is None: - actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim) - noise = self.sample_noise(actions_shape, device) - - prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( - images, img_masks, lang_tokens, lang_masks, state=state - ) - prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) - prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 - # Compute image and language key value cache - _, past_key_values = self.vlm_with_expert.forward( - attention_mask=prefix_att_2d_masks, - position_ids=prefix_position_ids, - past_key_values=None, - inputs_embeds=[prefix_embs, None], - use_cache=self.config.use_cache, - fill_kv_cache=True, - ) - if self.config.regression_loss: - x_t = torch.zeros_like(noise, dtype=torch.float32, device=device) - expanded_time = torch.zeros(bsize, dtype=torch.float32, device=device) - x_t = self.denoise_step( - state, - prefix_pad_masks, - past_key_values, - x_t, - expanded_time, - ) - else: - dt = -1.0 / self.config.num_steps - dt = torch.tensor(dt, dtype=torch.float32, device=device) - - x_t = noise - time = torch.tensor(1.0, dtype=torch.float32, device=device) - while time >= -dt / 2: - expanded_time = time.expand(bsize) - v_t = self.denoise_step( - state, - prefix_pad_masks, - past_key_values, - x_t, - expanded_time, - ) - - # Euler step - x_t += dt * v_t - time += dt - return x_t - - def denoise_step( - self, - state, - prefix_pad_masks, - past_key_values, - x_t, - timestep, - ): - """Apply one denoising step of the noise `x_t` at a given timestep.""" - suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep) - - suffix_len = suffix_pad_masks.shape[1] - batch_size = prefix_pad_masks.shape[0] - prefix_len = prefix_pad_masks.shape[1] - prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) - - suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) - - full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) - prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] - position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 - - outputs_embeds, _ = self.vlm_with_expert.forward( - attention_mask=full_att_2d_masks, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=[None, suffix_embs], - use_cache=self.config.use_cache, - fill_kv_cache=False, - ) - suffix_out = outputs_embeds[1] - suffix_out = suffix_out[:, -self.config.chunk_size :] - suffix_out = suffix_out.to(dtype=torch.float32) - v_t = self.action_out_proj(suffix_out) - return v_t diff --git a/src/lerobot/policies/smolpi0/smolvlm_with_expert.py b/src/lerobot/policies/smolpi0/smolvlm_with_expert.py deleted file mode 100644 index 0ccdcccc8..000000000 --- a/src/lerobot/policies/smolpi0/smolvlm_with_expert.py +++ /dev/null @@ -1,920 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -from functools import partial -from typing import List, Optional, Union - -import torch -import torch.nn.functional as F # noqa: N812 -import torch.version -from peft import LoraConfig, TaskType, get_peft_model -from pytest import Cache -from torch import nn -from transformers import ( - AutoConfig, - AutoModel, - AutoModelForImageTextToText, - AutoModelForVision2Seq, - AutoProcessor, - SmolVLMForConditionalGeneration, -) - -from lerobot.policies.smolpi0.flex_attention import flex_attention_forward - - -def _round_up_to_multiple(x, multiple): - return (x + multiple - 1) // multiple * multiple - - -def apply_rope(x, positions, max_wavelength=10_000): - """ - Applies RoPE positions [B, L] to x [B, L, H, D]. - """ - d_half = x.shape[-1] // 2 - device = x.device - dtype = x.dtype - x = x.to(torch.float32) - - freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device) - timescale = max_wavelength**freq_exponents - radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32) - - radians = radians[..., None, :] - - sin = torch.sin(radians) # .to(dtype=dtype) - cos = torch.cos(radians) # .to(dtype=dtype) - - x1, x2 = x.split(d_half, dim=-1) - res = torch.empty_like(x) - res[..., :d_half] = x1 * cos - x2 * sin - res[..., d_half:] = x2 * cos + x1 * sin - - return res.to(dtype) - - -# class SmolVLMWithExpertConfig(PretrainedConfig): -# model_type = "SmolVLMWithExpertModel" -# sub_configs = {"smolvlm_config": AutoConfig, "lm_expert_config": AutoConfig} - -# def __init__( -# self, -# smolvlm_config: dict | None = None, -# lm_expert_config: dict | None = None, -# freeze_vision_encoder: bool = True, -# train_expert_only: bool = True, -# attention_implementation: str = "eager", -# load_vlm_weights: bool = False, -# **kwargs, -# ): -# self.load_vlm_weights = load_vlm_weights -# self.freeze_vision_encoder = freeze_vision_encoder -# self.train_expert_only = train_expert_only -# self.attention_implementation = attention_implementation - -# if smolvlm_config is None: -# # Default config from Pi0 -# self.smolvlm_config = CONFIG_MAPPING["smolvlm"]( -# transformers_version="4.48.1", -# _vocab_size=257152, -# bos_token_id=2, -# eos_token_id=1, -# hidden_size=2048, -# image_token_index=257152, -# model_type="smolvlm", -# pad_token_id=0, -# projection_dim=2048, -# text_config={ -# "hidden_activation": "gelu_pytorch_tanh", -# "hidden_size": 2048, -# "intermediate_size": 16384, -# "model_type": "gemma", -# "num_attention_heads": 8, -# "num_hidden_layers": 18, -# "num_image_tokens": 256, -# "num_key_value_heads": 1, -# "torch_dtype": "float32", -# "vocab_size": 257152, -# }, -# vision_config={ -# "hidden_size": 1152, -# "intermediate_size": 4304, -# "model_type": "siglip_vision_model", -# "num_attention_heads": 16, -# "num_hidden_layers": 27, -# "num_image_tokens": 256, -# "patch_size": 14, -# "projection_dim": 2048, -# "projector_hidden_act": "gelu_fast", -# "torch_dtype": "float32", -# "vision_use_head": False, -# }, -# ) -# elif isinstance(self.paligemma_config, dict): -# # Override Pi0 default config for PaliGemma -# if "model_type" not in gemma_expert_config: -# paligemma_config["model_type"] = "paligemma" - -# cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]] -# self.paligemma_config = cfg_cls(**paligemma_config) - -# if gemma_expert_config is None: -# # Default config from Pi0 -# self.gemma_expert_config = CONFIG_MAPPING["gemma"]( -# attention_bias=False, -# attention_dropout=0.0, -# bos_token_id=2, -# eos_token_id=1, -# head_dim=256, -# hidden_act="gelu_pytorch_tanh", -# hidden_activation="gelu_pytorch_tanh", -# hidden_size=1024, -# initializer_range=0.02, -# intermediate_size=4096, -# max_position_embeddings=8192, -# model_type="gemma", -# num_attention_heads=8, -# num_hidden_layers=18, -# num_key_value_heads=1, -# pad_token_id=0, -# rms_norm_eps=1e-06, -# rope_theta=10000.0, -# torch_dtype="float32", -# transformers_version="4.48.1", -# use_cache=True, -# vocab_size=257152, -# ) -# elif isinstance(self.gemma_expert_config, dict): -# # Override Pi0 default config for Gemma Expert -# if "model_type" not in gemma_expert_config: -# gemma_expert_config["model_type"] = "gemma" - -# cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]] -# self.gemma_expert_config = cfg_cls(**gemma_expert_config) - -# super().__init__(**kwargs) - -# def __post_init__(self): -# super().__post_init__() -# if self.train_expert_only and not self.freeze_vision_encoder: -# raise ValueError( -# "You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible." -# ) - -# if self.attention_implementation not in ["eager", "fa2", "flex"]: -# raise ValueError( -# f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'." -# ) - - -def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256): - hidden_dim = int(2 * hidden_dim / 3) - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - return hidden_dim - - -class SmolVLMWithExpertModel(nn.Module): - # config_class = PaliGemmaWithExpertConfig - - def __init__( - self, - model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct", - load_vlm_weights: bool = True, - train_expert_only: bool = True, - freeze_vision_encoder: bool = False, - attention_implementation: str = "eager", - attention_mode: str = "self_attn", - num_expert_layers: int = -1, - num_vlm_layers: int = -1, - self_attn_every_n_layers: int = -1, - expert_width_multiplier: float = 0.5, - self_attn_only_actions: bool = False, - ): - super().__init__() - if load_vlm_weights: - print(f"Loading {model_id} weights ...") - if "SmolVLM-" in model_id: - self.vlm = AutoModelForVision2Seq.from_pretrained( - model_id, - device_map="cuda", - torch_dtype="bfloat16", - low_cpu_mem_usage=True, - ) - else: - # model_id = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" - self.vlm = AutoModelForImageTextToText.from_pretrained( - model_id, - device_map="cuda", - torch_dtype="bfloat16", - low_cpu_mem_usage=True, - # attn_implementation="eager", - # attn_implementation="flash_attention_2" - ) - config = self.vlm.config - else: - config = AutoConfig.from_pretrained(model_id) - self.vlm = SmolVLMForConditionalGeneration(config=config) - self.processor = AutoProcessor.from_pretrained(model_id) - if num_vlm_layers > 0: - print(f"Reducing the number of VLM layers to {num_vlm_layers} ...") - self.get_vlm_model().text_model.layers = self.get_vlm_model().text_model.layers[:num_vlm_layers] - self.num_vlm_layers = len(self.get_vlm_model().text_model.layers) - self.config = config - # Smaller lm expert - lm_expert_config = copy.deepcopy(config.text_config) - hidden_size = lm_expert_config.hidden_size - lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2 - lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier)) - lm_expert_config.num_hidden_layers = self.num_vlm_layers - if num_expert_layers > 0: - assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, ( - f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}" - ) - lm_expert_config.num_hidden_layers = num_expert_layers - # lm_expert_config.head_dim = lm_expert_config.head_dim * 2 - self.lm_expert = AutoModel.from_config(lm_expert_config) - - self.num_expert_layers = len(self.lm_expert.layers) - self.self_attn_every_n_layers = self_attn_every_n_layers - self.self_attn_only_actions = self_attn_only_actions - if "cross" in attention_mode: - # Reshape qkv projections to have the same input dimension as the vlm - for layer_idx in range(len(self.lm_expert.layers)): - if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0: - continue - self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear( - config.text_config.num_key_value_heads * config.text_config.head_dim, - lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, - bias=lm_expert_config.attention_bias, - ) - self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear( - config.text_config.num_key_value_heads * config.text_config.head_dim, - lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, - bias=lm_expert_config.attention_bias, - ) - # Remove unused embed_tokens - self.lm_expert.embed_tokens = None - - self.num_attention_heads = self.config.text_config.num_attention_heads - self.num_key_value_heads = self.config.text_config.num_key_value_heads - - self.freeze_vision_encoder = freeze_vision_encoder - self.train_expert_only = train_expert_only - self.attention_implementation = attention_implementation - self.attention_mode = attention_mode - self.expert_hidden_size = lm_expert_config.hidden_size - # self.to_bfloat16_like_physical_intelligence() - self.set_requires_grad() - - def configure_peft(self, config): - # return model - self.peft_method = config.peft_method - self.peft_target_model = config.peft_target_model - if "lora" in self.peft_method: - peft_config = config.peft_config - target_modules = peft_config.target_modules - if not isinstance(target_modules, list): - target_modules = target_modules.split(",") - lora_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, # Based on the task type (e.g., language modeling, etc.) - r=peft_config.r, # The rank of the low-rank adaptation - lora_alpha=peft_config.lora_alpha, # Scaling factor - lora_dropout=peft_config.lora_dropout, # Dropout applied to LoRA layers - target_modules=target_modules, # The components where LoRA is applied - exclude_modules=[ - "lm_expert", - "model.lm_expert.model.layers", - ], # FIXME(mshukor): this does not work for now - ) - self.lora_config = lora_config - # Apply LoRA and ensure only LoRA parameters are trainable - if "text" in self.peft_target_model: - self.get_vlm_model().text_model = get_peft_model(self.get_vlm_model().text_model, lora_config) - else: - self.vlm = get_peft_model(self.vlm, lora_config) - # assert config.train_expert_only, "Backbone should be frozen and only lora parameters are " # FIXME(mshukor): handle this here? - for name, param in self.vlm.named_parameters(): - if ( - "lora" in name and "text_model.model.layers.17" not in name - ): # lm_head is not a parameter in most LLMs becasue it's tied to the embedding layer - param.requires_grad = True - else: - param.requires_grad = False - - def merge_lora_weights(self): - """ - Merge LoRA weights into the base model. - """ - if "text" in self.peft_target_model: - self.get_vlm_model().text_model = self.get_vlm_model().text_model.merge_and_unload() - else: - self.vlm = self.vlm.merge_and_unload() - - def get_vlm_model( - self, - ): - if hasattr(self.vlm.model, "model"): # When using peft - return self.vlm.model.model - else: - return self.vlm.model - - def set_requires_grad(self): - if self.freeze_vision_encoder: - self.get_vlm_model().vision_model.eval() - for params in self.get_vlm_model().vision_model.parameters(): - params.requires_grad = False - if self.train_expert_only: - self.vlm.eval() - for params in self.vlm.parameters(): - params.requires_grad = False - else: - # To avoid unused params issue with distributed training - last_layers = [self.num_vlm_layers - 1] - if ( - self.num_vlm_layers != self.num_expert_layers - and self.num_vlm_layers % self.num_expert_layers == 0 - ): - last_layers.append(self.num_vlm_layers - 2) - frozen_layers = [ - "lm_head", - "text_model.model.norm.weight", - ] - for layer in last_layers: - frozen_layers.append(f"text_model.model.layers.{layer}.") - - for name, params in self.vlm.named_parameters(): - if any([k in name for k in frozen_layers]): - params.requires_grad = False - # To avoid unused params issue with distributed training - for name, params in self.lm_expert.named_parameters(): - if any( - [ - k in name - for k in [ - "lm_head", - ] - ] - ): - params.requires_grad = False - - def train(self, mode: bool = True): - super().train(mode) - - if self.freeze_vision_encoder: - self.get_vlm_model().vision_model.eval() - - if self.train_expert_only: - self.vlm.eval() - - # def to_bfloat16_like_physical_intelligence(self): - # self.vlm = self.vlm.to(dtype=torch.bfloat16) - - # params_to_change_dtype = [ - # "language_model.model.layers", - # "gemma_expert.model.layers", - # "vision_tower", - # "multi_modal", - # ] - # for name, param in self.named_parameters(): - # if any(selector in name for selector in params_to_change_dtype): - # param.data = param.data.to(dtype=torch.bfloat16) - - def embed_image(self, image: torch.Tensor): - patch_attention_mask = None - # # FIXME(mshukor): probably not needed as we don't have padded images here - # pixel_values = image.unsqueeze(1) - # batch_size, num_images, num_channels, height, width = pixel_values.shape - # pixel_values = pixel_values - # pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) - - # # Remove padding images - padding images are full 0. - # nb_values_per_image = pixel_values.shape[1:].numel() - # real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image - - # if not any(real_images_inds): - # # no images, leave one empty image. - # real_images_inds[0] = True - - # pixel_values = pixel_values[real_images_inds].contiguous() - - # # Handle the vision attention mask - - # pixel_attention_mask = torch.ones( - # size=[pixel_values.shape[i] for i in (0, 2, 3)], - # dtype=torch.bool, - # device=pixel_values.device, - # ) - - # patch_size = self.vlm.config.vision_config.patch_size - # patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) - # patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) - # patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - - # FIXME(mshukor): add special image tokens specific to smolvlm - # Get sequence from the vision encoder - image_hidden_states = ( - self.get_vlm_model() - .vision_model( - pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype), - patch_attention_mask=patch_attention_mask, - ) - .last_hidden_state - ) - # Modality projection & resampling - image_hidden_states = self.get_vlm_model().connector(image_hidden_states) - return image_hidden_states - - def embed_language_tokens(self, tokens: torch.Tensor): - return self.get_vlm_model().text_model.get_input_embeddings()(tokens) - - def forward_attn_layer( - self, - model_layers, - inputs_embeds, - layer_idx, - position_ids, - attention_mask, - batch_size, - head_dim, - use_cache: bool = True, - fill_kv_cache: bool = True, - past_key_values=None, - ) -> list[torch.Tensor]: - query_states = [] - key_states = [] - value_states = [] - for i, hidden_states in enumerate(inputs_embeds): - layer = model_layers[i][layer_idx] - if hidden_states is None or layer is None: - continue - - # normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype) - # hidden_states = hidden_states * normalizer - hidden_states = layer.input_layernorm(hidden_states) - - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) - - hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype) - query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) - key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) - value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape) - - query_states.append(query_state) - key_states.append(key_state) - value_states.append(value_state) - - # FIXME(mshukor): self attention always when having only the prefix - # B,L,H,D with L sequence length, H number of heads, D head dim - # concatenate on the number of embeddings/tokens - query_states = torch.cat(query_states, dim=1) - key_states = torch.cat(key_states, dim=1) - value_states = torch.cat(value_states, dim=1) - # FIXME(mshukor): seq should be B, H, L, D ? - seq_len = query_states.shape[1] - if seq_len < position_ids.shape[1]: - _position_ids = position_ids[:, :seq_len] - _attention_mask = attention_mask[:, :seq_len, :seq_len] - else: - _position_ids = position_ids - _attention_mask = attention_mask - - if self.self_attn_only_actions: - attention_mask_ = _attention_mask.clone() - position_ids_ = _position_ids.clone() - if inputs_embeds[1] is not None: - suffix_len = inputs_embeds[1].shape[1] - attention_mask_[:, -suffix_len:, :-suffix_len] = False - position_ids_[:, -suffix_len:] = ( - _position_ids[:, -suffix_len:] - _position_ids[:, -suffix_len][:, None] - ) - else: - attention_mask_ = _attention_mask - position_ids_ = _position_ids - - query_states = apply_rope( - query_states, position_ids_ - ) # FIXME(mshukor): this assumes we have always the vlm features? - key_states = apply_rope(key_states, position_ids_) - - if use_cache and past_key_values is None: - past_key_values = {} - - if use_cache: - if fill_kv_cache: - past_key_values[layer_idx] = { - "key_states": key_states, - "value_states": value_states, - } - else: - # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before. - # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach - # the max len, then we (for instance) double the cache size. This implementation already exists - # in `transformers`. (molbap) - key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1) - value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1) - - attention_interface = self.get_attention_interface() - - att_output = attention_interface( - attention_mask_, batch_size, head_dim, query_states, key_states, value_states - ) - # att_output = att_output.to(dtype=models[i].dtype) - - return [att_output], past_key_values - - def forward_cross_attn_layer( - self, - model_layers, - inputs_embeds, - layer_idx, - position_ids, - attention_mask, - batch_size, - head_dim, - use_cache: bool = True, - fill_kv_cache: bool = True, - past_key_values=None, - ) -> list[torch.Tensor]: - attention_interface = self.get_attention_interface() - - att_outputs = [] - assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), ( - f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}" - ) - - if len(inputs_embeds) == 2 and not past_key_values: - # Prefix attention - seq_len = inputs_embeds[0].shape[1] - position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:] - prefix_attention_mask = attention_mask[:, :seq_len, :seq_len] - - layer = model_layers[0][layer_idx] - - hidden_states = layer.input_layernorm(inputs_embeds[0]) - - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) - - hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype) - query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) - key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) - value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape) - - # B,L,H,D with L sequence length, H number of heads, D head dim - query_states = apply_rope(query_state, position_id) - key_states = apply_rope(key_state, position_id) - - att_output = attention_interface( - prefix_attention_mask, batch_size, head_dim, query_states, key_states, value_states - ) - att_outputs.append(att_output) - else: - expert_position_id = position_ids - - if use_cache and past_key_values is None: - past_key_values = {} - - if use_cache: - if fill_kv_cache: - past_key_values[layer_idx] = { - "key_states": key_states, - "value_states": value_states, - } - else: - # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before. - # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach - # the max len, then we (for instance) double the cache size. This implementation already exists - # in `transformers`. (molbap) - key_states = past_key_values[layer_idx]["key_states"] - value_states = past_key_values[layer_idx]["value_states"] - - # Expert - expert_layer = model_layers[1][layer_idx] - if expert_layer is not None: - expert_hidden_states = expert_layer.input_layernorm(inputs_embeds[1]) - - expert_input_shape = expert_hidden_states.shape[:-1] - expert_hidden_shape = (*expert_input_shape, -1, expert_layer.self_attn.head_dim) - - expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype) - expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape) - - _key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view( - *key_states.shape[:2], -1 - ) - expert_key_states = expert_layer.self_attn.k_proj(_key_states).view( - *_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim - ) # k_proj should have same dim as kv - - _value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view( - *value_states.shape[:2], -1 - ) - expert_value_states = expert_layer.self_attn.v_proj(_value_states).view( - *_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim - ) - - expert_position_id = ( - expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values - ) # start from 0 - expert_attention_mask = attention_mask[ - :, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] : - ] # take into account kv - - expert_query_states = apply_rope(expert_query_state, expert_position_id) - # expert_key_states = apply_rope(expert_key_state, expert_position_id) - - att_output = attention_interface( - expert_attention_mask, - batch_size, - head_dim, - expert_query_states, - expert_key_states, - expert_value_states, - ) - att_outputs.append(att_output) - else: - att_outputs.append(None) - - # att_output = att_output.to(dtype=models[i].dtype) - return att_outputs, past_key_values - - def get_model_layers(self, models: list) -> list: # FIXME(mshukor): is this efficient? - vlm_layers = [] - expert_layers = [] - multiple_of = self.num_vlm_layers // self.num_expert_layers - for i in range(self.num_vlm_layers): - if multiple_of > 0 and i > 0 and i % multiple_of != 0: - expert_layer = None - else: - expert_layer_index = i // multiple_of if multiple_of > 0 else i - expert_layer = models[1].layers[expert_layer_index] - vlm_layers.append(models[0].layers[i]) - expert_layers.append(expert_layer) - return [vlm_layers, expert_layers] - - # TODO: break down this huge forward into modules or functions - def forward( - self, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | Cache | None = None, - inputs_embeds: list[torch.FloatTensor] = None, - use_cache: bool | None = None, - fill_kv_cache: bool | None = None, - ): - models = [self.get_vlm_model().text_model, self.lm_expert] - model_layers = self.get_model_layers(models) - for hidden_states in inputs_embeds: - # TODO this is very inefficient - # dtype is always the same, batch size too (if > 1 len) - # device could be trickier in multi gpu edge cases but that's it - if hidden_states is None: - continue - batch_size = hidden_states.shape[0] - - # # Pad prefix embds so that prefix_embs + prefix_embs len are multiple of 128, pad left or right depending on the gen or train - if self.attention_implementation == "flex": - if ( - inputs_embeds[0] is not None - and inputs_embeds[1] is not None - and attention_mask.shape[-1] == attention_mask.shape[-2] - and past_key_values is None - ): # Now only during training - seq_len = inputs_embeds[0].shape[1] + inputs_embeds[1].shape[1] - padded_seq_len = _round_up_to_multiple( - seq_len, 128 - ) # FIXME(mshukor): more efficient to have a fixed seq len? - b_mask, q_len, kv_len = attention_mask.shape # The shape of your mask - pad = padded_seq_len - q_len - attention_mask = F.pad(attention_mask, (0, pad, 0, pad), value=True) - inputs_embeds[0] = F.pad(inputs_embeds[0], (0, 0, 0, pad), value=0.0) - position_ids = F.pad(position_ids, (0, pad), value=0) - - # RMSNorm - num_layers = self.num_vlm_layers - head_dim = self.vlm.config.text_config.head_dim - for layer_idx in range(num_layers): - if ( - fill_kv_cache - or "cross" not in self.attention_mode - or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0) - ): - att_outputs, past_key_values = self.forward_attn_layer( - model_layers, - inputs_embeds, - layer_idx, - position_ids, - attention_mask, - batch_size, - head_dim, - use_cache=use_cache, - fill_kv_cache=fill_kv_cache, - past_key_values=past_key_values, - ) - else: - att_outputs, past_key_values = self.forward_cross_attn_layer( - model_layers, - inputs_embeds, - layer_idx, - position_ids, - attention_mask, - batch_size, - head_dim, - use_cache=use_cache, - fill_kv_cache=fill_kv_cache, - past_key_values=past_key_values, - ) - # query_states = [] - # key_states = [] - # value_states = [] - # for i, hidden_states in enumerate(inputs_embeds): - # if hidden_states is None: - # continue - # layer = models[i].layers[layer_idx] - # # normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype) - # # hidden_states = hidden_states * normalizer - # hidden_states = layer.input_layernorm(hidden_states) - - # input_shape = hidden_states.shape[:-1] - # hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) - - # hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype) - # query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) - # key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) - # value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape) - - # query_states.append(query_state) - # key_states.append(key_state) - # value_states.append(value_state) - - # # FIXME(mshukor): self attention always when having only the prefix - # # B,L,H,D with L sequence length, H number of heads, D head dim - # # concatenate on the number of embeddings/tokens - # query_states = torch.cat(query_states, dim=1) - # key_states = torch.cat(key_states, dim=1) - # value_states = torch.cat(value_states, dim=1) - # # FIXME(mshukor): seq should be B, H, L, D ? - # query_states = apply_rope(query_states, position_ids) - # key_states = apply_rope(key_states, position_ids) - - # if use_cache and past_key_values is None: - # past_key_values = {} - - # if use_cache: - # if fill_kv_cache: - # past_key_values[layer_idx] = { - # "key_states": key_states, - # "value_states": value_states, - # } - # else: - # # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before. - # # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach - # # the max len, then we (for instance) double the cache size. This implementation already exists - # # in `transformers`. (molbap) - # key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1) - # value_states = torch.cat( - # [past_key_values[layer_idx]["value_states"], value_states], dim=1 - # ) - - # attention_interface = self.get_attention_interface() - # att_output = attention_interface( - # attention_mask, batch_size, head_dim, query_states, key_states, value_states - # ) - - # att_output = att_output.to(dtype=models[i].dtype) - - # first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len]) - outputs_embeds = [] - start = 0 - for i, hidden_states in enumerate(inputs_embeds): - # layer = models[i].layers[layer_idx] - layer = model_layers[i][layer_idx] - att_output = ( - att_outputs[i] if i < len(att_outputs) else att_outputs[0] - ) # in case of self_attn - if hidden_states is not None: - if layer is None: - outputs_embeds.append(hidden_states) - continue - end = start + hidden_states.shape[1] - - if att_output.dtype != layer.self_attn.o_proj.weight.dtype: - att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) - att_out = att_output[:, start:end] - out_emb = layer.self_attn.o_proj(att_out) - - # TODO: first dropout (by default 0.0) - # first residual - out_emb += hidden_states - after_first_residual = out_emb.clone() - - out_emb = layer.post_attention_layernorm(out_emb) - out_emb = layer.mlp(out_emb) - - # TODO: second dropout (by default 0.0) - - # second residual - out_emb += after_first_residual - - outputs_embeds.append(out_emb) - - start = end if len(att_outputs) == 1 else 0 - else: - outputs_embeds.append(None) - - inputs_embeds = outputs_embeds - - # final norm - outputs_embeds = [] - for i, hidden_states in enumerate(inputs_embeds): - if hidden_states is not None: - out_emb = models[i].norm(hidden_states) - outputs_embeds.append(out_emb) - else: - outputs_embeds.append(None) - return outputs_embeds, past_key_values - - def get_attention_interface(self): - if self.attention_implementation == "fa2": - attention_interface = self.flash_attention_forward - elif self.attention_implementation == "flex": - attention_interface = partial( - flex_attention_forward, - num_att_heads=self.num_attention_heads, - num_key_value_heads=self.num_key_value_heads, - ) - else: - attention_interface = self.eager_attention_forward - return attention_interface - - def flash_attention_forward( - self, attention_mask, batch_size, head_dim, query_states, key_states, value_states - ): - raise NotImplementedError("FA2 is not implemented (yet)") - - def eager_attention_forward( - self, attention_mask, batch_size, head_dim, query_states, key_states, value_states - ): - num_att_heads = self.num_attention_heads - num_key_value_heads = self.num_key_value_heads - num_key_value_groups = num_att_heads // num_key_value_heads - - # query_states: batch_size, sequence_length, num_att_head, head_dim - # key_states: batch_size, sequence_length, num_key_value_head, head_dim - # value_states: batch_size, sequence_length, num_key_value_head, head_dim - sequence_length = key_states.shape[1] - - key_states = key_states[:, :, :, None, :].expand( - batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim - ) - key_states = key_states.reshape( - batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim - ) - - value_states = value_states[:, :, :, None, :].expand( - batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim - ) - value_states = value_states.reshape( - batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim - ) - - # Attention here is upcasted to float32 to match the original eager implementation. - - query_states = query_states.to(dtype=torch.float32) - key_states = key_states.to(dtype=torch.float32) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - - att_weights = torch.matmul(query_states, key_states.transpose(2, 3)) - att_weights *= head_dim**-0.5 - - att_weights = att_weights.to(dtype=torch.float32) - big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py - masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg) - probs = nn.functional.softmax(masked_att_weights, dim=-1) - probs = probs.to(dtype=value_states.dtype) - - # probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length - # value_states: batch_size, sequence_length, num_att_heads, head_dim - - att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3)) - - att_output = att_output.permute(0, 2, 1, 3) - # we use -1 because sequence length can change - att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim) - - return att_output diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 9bb22d7f7..6bf956aa3 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -1,955 +1,3 @@ -# #!/usr/bin/env python - -# # Copyright 2025 HuggingFace Inc. team. All rights reserved. -# # -# # Licensed under the Apache License, Version 2.0 (the "License"); -# # you may not use this file except in compliance with the License. -# # You may obtain a copy of the License at -# # -# # http://www.apache.org/licenses/LICENSE-2.0 -# # -# # Unless required by applicable law or agreed to in writing, software -# # distributed under the License is distributed on an "AS IS" BASIS, -# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# # See the License for the specific language governing permissions and -# # limitations under the License. - -# """ -# SmolVLA: - -# [Paper](https://huggingface.co/papers/2506.01844) - -# Designed by Hugging Face. - -# Install smolvla extra dependencies: -# ```bash -# pip install -e ".[smolvla]" -# ``` - -# Example of finetuning the smolvla pretrained model (`smolvla_base`): -# ```bash -# lerobot-train \ -# --policy.path=lerobot/smolvla_base \ -# --dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ -# --batch_size=64 \ -# --steps=200000 -# ``` - -# Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM, -# and an action expert. -# ```bash -# lerobot-train \ -# --policy.type=smolvla \ -# --dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ -# --batch_size=64 \ -# --steps=200000 -# ``` - -# Example of using the smolvla pretrained model outside LeRobot training framework: -# ```python -# policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base") -# ``` - -# """ - -# import math -# import os -# import re -# from collections import deque - -# import safetensors -# import torch -# import torch.nn.functional as F # noqa: N812 -# from torch import Tensor, nn -# from transformers import AutoProcessor - -# from lerobot.constants import ACTION -# from lerobot.policies.normalize import ( -# Normalize, -# Unnormalize, -# ) -# from lerobot.policies.pretrained import PreTrainedPolicy -# from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig -# from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel -# from lerobot.policies.utils import ( -# populate_queues, -# ) -# from lerobot.utils.utils import get_safe_dtype -# OBS_STATE = 'state' -# ACTION = 'actions' -# # Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker -# _VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_") - - -# def canonicalise(k: str) -> str: -# """ -# Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a -# normalisation-buffer key. -# """ -# return _VARIANT_RE.sub(".buffer_", k) - - -# def standardise_state_dict( -# checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True -# ) -> tuple[dict[str, torch.Tensor], list[str]]: -# """ -# • Re-keys `checkpoint ` so that every entry matches the *reference* key set. -# • If several variant keys collapse to the same canonical name we keep the -# first one and log the collision. -# • Returns the new dict + a list of entries that could not be matched. -# """ -# out, collisions, unmatched = {}, {}, [] - -# for k, v in checkpoint.items(): -# canon = canonicalise(k) -# if canon in ref_keys: -# if canon in out: # duplicate after collapsing -# collisions.setdefault(canon, []).append(k) -# else: -# out[canon] = v -# else: -# unmatched.append(k) - -# if verbose: -# for canon, variants in collisions.items(): -# print(f"[standardise_state_dict] '{canon}' ← {variants}") -# if unmatched: -# print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys") - -# out.update({k: checkpoint[k] for k in unmatched}) -# return out, unmatched - - -# def rename_checkpoint_keys(checkpoint: dict, rename_str: str): -# """ -# Renames keys in a checkpoint dictionary based on the given rename string. - -# Args: -# checkpoint (dict): The checkpoint dictionary. -# rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2". - -# Returns: -# dict: The modified checkpoint with renamed keys. -# """ - -# rename_dict = dict(pair.split("//") for pair in rename_str.split(",")) - -# new_checkpoint = {} -# for k, v in checkpoint.items(): -# for old_key, new_key in rename_dict.items(): -# if old_key in k: -# k = k.replace(old_key, new_key) -# new_checkpoint[k] = v -# return new_checkpoint - - -# def load_smolvla( -# model: torch.nn.Module, -# filename: str | os.PathLike, -# *, -# device: str = "cpu", -# checkpoint_keys_mapping: str = "", -# ) -> torch.nn.Module: -# state_dict = safetensors.torch.load_file(filename, device=device) - -# # Optional user-supplied renames (e.g. "model._orig_mod.//model.") -# if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping: -# state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping) - -# state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys())) - -# # HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset -# norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs") -# state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)} - -# missing, unexpected = model.load_state_dict(state_dict, strict=False) - -# if not all(key.startswith(norm_keys) for key in missing) or unexpected: -# raise RuntimeError( -# "SmolVLA %d missing / %d unexpected keys", -# len(missing), -# len(unexpected), -# ) - -# return model - - -# def create_sinusoidal_pos_embedding( -# time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" -# ) -> Tensor: -# """Computes sine-cosine positional embedding vectors for scalar positions.""" -# if dimension % 2 != 0: -# raise ValueError(f"dimension ({dimension}) must be divisible by 2") - -# if time.ndim != 1: -# raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") - -# dtype = get_safe_dtype(torch.float64, device.type) -# fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) -# period = min_period * (max_period / min_period) ** fraction - -# # Compute the outer product -# scaling_factor = 1.0 / period * 2 * math.pi -# sin_input = scaling_factor[None, :] * time[:, None] -# pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) -# return pos_emb - - -# def make_att_2d_masks(pad_masks, att_masks): -# """Copied from big_vision. - -# Tokens can attend to valid inputs tokens which have a cumulative mask_ar -# smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to -# setup several types of attention, for example: - -# [[1 1 1 1 1 1]]: pure causal attention. - -# [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between -# themselves and the last 3 tokens have a causal attention. The first -# entry could also be a 1 without changing behaviour. - -# [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a -# block can attend all previous blocks and all tokens on the same block. - -# Args: -# input_mask: bool[B, N] true if its part of the input, false if padding. -# mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on -# it and 0 where it shares the same attention mask as the previous token. -# """ -# if att_masks.ndim != 2: -# raise ValueError(att_masks.ndim) -# if pad_masks.ndim != 2: -# raise ValueError(pad_masks.ndim) - -# cumsum = torch.cumsum(att_masks, dim=1) -# att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] -# pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] -# att_2d_masks = att_2d_masks & pad_2d_masks -# return att_2d_masks - - -# def resize_with_pad(img, width, height, pad_value=-1): -# # assume no-op when width height fits already -# if img.ndim != 4: -# raise ValueError(f"(b,c,h,w) expected, but {img.shape}") - -# cur_height, cur_width = img.shape[2:] - -# ratio = max(cur_width / width, cur_height / height) -# resized_height = int(cur_height / ratio) -# resized_width = int(cur_width / ratio) -# resized_img = F.interpolate( -# img, size=(resized_height, resized_width), mode="bilinear", align_corners=False -# ) - -# pad_height = max(0, int(height - resized_height)) -# pad_width = max(0, int(width - resized_width)) - -# # pad on left and top of image -# padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) -# return padded_img - - -# def pad_vector(vector, new_dim): -# """Can be (batch_size x sequence_length x features_dimension) -# or (batch_size x features_dimension) -# """ -# if vector.shape[-1] == new_dim: -# return vector -# shape = list(vector.shape) -# current_dim = shape[-1] -# shape[-1] = new_dim -# new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device) -# new_vector[..., :current_dim] = vector -# return new_vector - - -# def normalize(x, min_val, max_val): -# return (x - min_val) / (max_val - min_val) - - -# def unnormalize(x, min_val, max_val): -# return x * (max_val - min_val) + min_val - - -# def safe_arcsin(value): -# # This ensures that the input stays within -# # [−1,1] to avoid invalid values for arcsin -# return torch.arcsin(torch.clamp(value, -1.0, 1.0)) - - -# def aloha_gripper_to_angular(value): -# # Aloha transforms the gripper positions into a linear space. The following code -# # reverses this transformation to be consistent with smolvla which is pretrained in -# # angular space. -# # -# # These values are coming from the Aloha code: -# # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED -# value = unnormalize(value, min_val=0.01844, max_val=0.05800) - -# # This is the inverse of the angular to linear transformation inside the Interbotix code. -# def linear_to_radian(linear_position, arm_length, horn_radius): -# value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) -# return safe_arcsin(value) - -# # The constants are taken from the Interbotix code. -# value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) - -# # Normalize to [0, 1]. -# # The values 0.4 and 1.5 were measured on an actual Trossen robot. -# return normalize(value, min_val=0.4, max_val=1.5) - - -# def aloha_gripper_from_angular(value): -# # Convert from the gripper position used by smolvla to the gripper position that is used by Aloha. -# # Note that the units are still angular but the range is different. - -# # The values 0.4 and 1.5 were measured on an actual Trossen robot. -# value = unnormalize(value, min_val=0.4, max_val=1.5) - -# # These values are coming from the Aloha code: -# # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE -# return normalize(value, min_val=-0.6213, max_val=1.4910) - - -# def aloha_gripper_from_angular_inv(value): -# # Directly inverts the gripper_from_angular function. -# value = unnormalize(value, min_val=-0.6213, max_val=1.4910) -# return normalize(value, min_val=0.4, max_val=1.5) - - -# class SmolVLAPolicy(PreTrainedPolicy): -# """Wrapper class around VLAFlowMatching model to train and run inference within LeRobot.""" - -# config_class = SmolVLAConfig -# name = "smolvla" - -# def __init__( -# self, -# config: SmolVLAConfig, -# dataset_stats: dict[str, dict[str, Tensor]] | None = None, -# ): -# """ -# Args: -# config: Policy configuration class instance or None, in which case the default instantiation of -# the configuration class is used. -# dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected -# that they will be passed with a call to `load_state_dict` before the policy is used. -# """ - -# super().__init__(config) -# config.validate_features() -# self.config = config -# self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) -# self.normalize_targets = Normalize( -# config.output_features, config.normalization_mapping, dataset_stats -# ) -# self.unnormalize_outputs = Unnormalize( -# config.output_features, config.normalization_mapping, dataset_stats -# ) - -# self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer -# self.model = VLAFlowMatching(config) -# self.reset() - -# def reset(self): -# """This should be called whenever the environment is reset.""" -# self._queues = { -# ACTION: deque(maxlen=self.config.n_action_steps), -# } - -# # HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues -# @classmethod -# def _load_as_safetensor( -# cls, -# model: "SmolVLAPolicy", -# model_file: str, -# map_location: str, -# strict: bool, -# ): -# safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) -# return load_smolvla( -# model, -# model_file, -# device=map_location, -# checkpoint_keys_mapping="model._orig_mod.//model.", -# ) - -# def get_optim_params(self) -> dict: -# return self.parameters() - -# def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: -# # TODO: Check if this for loop is needed. -# # Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch -# # In the case of offline inference, we have the action in the batch -# # that why without the k != ACTION check, it will raise an error because we are trying to stack -# # on an empty container. -# for k in batch: -# if k in self._queues and k != ACTION: -# batch[k] = torch.stack(list(self._queues[k]), dim=1) - -# images, img_masks = self.prepare_images(batch) -# state = self.prepare_state(batch) -# lang_tokens, lang_masks = self.prepare_language(batch) - -# actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise) - -# # Unpad actions -# original_action_dim = self.config.action_feature.shape[0] -# actions = actions[:, :, :original_action_dim] - -# actions = self.unnormalize_outputs({ACTION: actions})[ACTION] - -# if self.config.adapt_to_pi_aloha: -# actions = self._pi_aloha_encode_actions(actions) - -# return actions - -# def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: -# if self.config.adapt_to_pi_aloha: -# batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - -# batch = self.normalize_inputs(batch) - -# return batch - -# @torch.no_grad() -# def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: -# self.eval() - -# batch = self._prepare_batch(batch) -# self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) - -# actions = self._get_action_chunk(batch, noise) -# return actions - -# @torch.no_grad() -# def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: -# """Select a single action given environment observations. - -# This method wraps `select_actions` in order to return one action at a time for execution in the -# environment. It works by managing the actions in a queue and only calling `select_actions` when the -# queue is empty. -# """ -# self.eval() -# batch = self._prepare_batch(batch) -# self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) - -# # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by -# # querying the policy. -# if len(self._queues[ACTION]) == 0: -# actions = self._get_action_chunk(batch, noise) - -# # `self.predict_action_chunk` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue -# # effectively has shape (n_action_steps, batch_size, *), hence the transpose. -# self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps]) - -# return self._queues[ACTION].popleft() - -# def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: -# """Do a full training forward pass to compute the loss""" -# if self.config.adapt_to_pi_aloha: -# batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) -# batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) -# batch = self.normalize_inputs(batch) -# batch = self.normalize_targets(batch) -# images, img_masks = self.prepare_images(batch) -# state = self.prepare_state(batch) -# lang_tokens, lang_masks = self.prepare_language(batch) -# actions = self.prepare_action(batch) -# actions_is_pad = batch.get("actions_id_pad") -# loss_dict = {} -# losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) -# loss_dict["losses_after_forward"] = losses.clone() - -# if actions_is_pad is not None: -# in_episode_bound = ~actions_is_pad -# losses = losses * in_episode_bound.unsqueeze(-1) -# loss_dict["losses_after_in_ep_bound"] = losses.clone() - -# # Remove padding -# losses = losses[:, :, : self.config.max_action_dim] -# loss_dict["losses_after_rm_padding"] = losses.clone() - -# # For backward pass -# loss = losses.mean() -# # For backward pass -# loss_dict["loss"] = loss.item() -# return loss, loss_dict - -# def prepare_images(self, batch): -# """Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and -# convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP. -# """ -# images = [] -# img_masks = [] -# present_img_keys = [key for key in self.config.image_features if key in batch] -# missing_img_keys = [key for key in self.config.image_features if key not in batch] - -# if len(present_img_keys) == 0: -# raise ValueError( -# f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" -# ) -# # Preprocess image features present in the batch -# for key in present_img_keys: -# img = batch[key][:, -1, :, :, :] if batch[key].ndim == 5 else batch[key] -# if self.config.resize_imgs_with_padding is not None: -# img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0) - -# # Normalize from range [0,1] to [-1,1] as expacted by siglip -# img = img * 2.0 - 1.0 - -# bsize = img.shape[0] -# device = img.device -# if f"{key}_padding_mask" in batch: -# mask = batch[f"{key}_padding_mask"].bool() -# else: -# mask = torch.ones(bsize, dtype=torch.bool, device=device) -# images.append(img) -# img_masks.append(mask) - -# # Create image features not present in the batch -# # as fully 0 padded images. -# for num_empty_cameras in range(len(missing_img_keys)): -# if num_empty_cameras >= self.config.empty_cameras: -# break -# img = torch.ones_like(img) * -1 -# mask = torch.zeros_like(mask) -# images.append(img) -# img_masks.append(mask) -# return images, img_masks - -# def prepare_language(self, batch) -> tuple[Tensor, Tensor]: -# """Tokenize the text input""" -# device = batch[OBS_STATE].device -# tasks = batch["task"] -# if isinstance(tasks, str): -# tasks = [tasks] - -# if len(tasks) == 1: -# tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])] - -# tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] - -# tokenized_prompt = self.language_tokenizer.__call__( -# tasks, -# padding=self.config.pad_language_to, -# padding_side="right", -# max_length=self.config.tokenizer_max_length, -# return_tensors="pt", -# ) -# lang_tokens = tokenized_prompt["input_ids"].to(device=device) -# lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) - -# return lang_tokens, lang_masks - -# def _pi_aloha_decode_state(self, state): -# # Flip the joints. -# for motor_idx in [1, 2, 8, 9]: -# state[:, motor_idx] *= -1 -# # Reverse the gripper transformation that is being applied by the Aloha runtime. -# for motor_idx in [6, 13]: -# state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx]) -# return state - -# def _pi_aloha_encode_actions(self, actions): -# # Flip the joints. -# for motor_idx in [1, 2, 8, 9]: -# actions[:, :, motor_idx] *= -1 -# # Reverse the gripper transformation that is being applied by the Aloha runtime. -# for motor_idx in [6, 13]: -# actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx]) -# return actions - -# def _pi_aloha_encode_actions_inv(self, actions): -# # Flip the joints again. -# for motor_idx in [1, 2, 8, 9]: -# actions[:, :, motor_idx] *= -1 -# # Reverse the gripper transformation that is being applied by the Aloha runtime. -# for motor_idx in [6, 13]: -# actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) -# return actions - -# def prepare_state(self, batch): -# """Pad state""" -# state = batch[OBS_STATE][:, -1, :] if batch[OBS_STATE].ndim > 2 else batch[OBS_STATE] -# state = pad_vector(state, self.config.max_state_dim) -# return state - -# def prepare_action(self, batch): -# """Pad action""" -# actions = pad_vector(batch[ACTION], self.config.max_action_dim) -# return actions - - -# def pad_tensor(tensor, max_len, pad_value=0): -# """ -# Efficiently pads a tensor along sequence dimension to match max_len. - -# Args: -# tensor (torch.Tensor): Shape (B, L, ...) or (B, L). -# max_len (int): Fixed sequence length. -# pad_value (int/float): Value for padding. - -# Returns: -# torch.Tensor: Shape (B, max_len, ...) or (B, max_len). -# """ -# b, d = tensor.shape[:2] - -# # Create a padded tensor of max_len and copy the existing values -# padded_tensor = torch.full( -# (b, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device -# ) -# padded_tensor[:, :d] = tensor # Efficient in-place copy - -# return padded_tensor - - -# class VLAFlowMatching(nn.Module): -# """ -# SmolVLA - -# [Paper]() - -# Designed by Hugging Face. -# ┌──────────────────────────────┐ -# │ actions │ -# │ ▲ │ -# │ ┌─────────┐ ┌─|────┐ │ -# │ | │────► │ │ │ -# │ | │ kv │ │ │ -# │ | │────► │Action│ │ -# │ | VLM │cache │Expert│ | -# │ │ │────► | │ │ -# │ │ │ │ │ │ -# │ └▲──▲───▲─┘ └───▲──┘ | -# │ │ | | │ | -# │ | | | noise │ -# │ │ │ state │ -# │ │ language tokens │ -# │ image(s) │ -# └──────────────────────────────┘ -# """ - -# def __init__(self, config: SmolVLAConfig): -# super().__init__() -# self.config = config - -# self.vlm_with_expert = SmolVLMWithExpertModel( -# model_id=self.config.vlm_model_name, -# freeze_vision_encoder=self.config.freeze_vision_encoder, -# train_expert_only=self.config.train_expert_only, -# load_vlm_weights=self.config.load_vlm_weights, -# attention_mode=self.config.attention_mode, -# num_expert_layers=self.config.num_expert_layers, -# num_vlm_layers=self.config.num_vlm_layers, -# self_attn_every_n_layers=self.config.self_attn_every_n_layers, -# expert_width_multiplier=self.config.expert_width_multiplier, -# ) -# self.state_proj = nn.Linear( -# self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size -# ) -# self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size) -# self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim) - -# self.action_time_mlp_in = nn.Linear( -# self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size -# ) -# self.action_time_mlp_out = nn.Linear( -# self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size -# ) - -# self.set_requires_grad() -# self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id -# self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id -# self.global_image_start_token = torch.tensor( -# [self.fake_image_token, self.global_image_token], dtype=torch.long -# ) - -# self.add_image_special_tokens = self.config.add_image_special_tokens -# self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long) -# self.prefix_length = self.config.prefix_length - -# def set_requires_grad(self): -# for params in self.state_proj.parameters(): -# params.requires_grad = self.config.train_state_proj - -# def sample_noise(self, shape, device): -# noise = torch.normal( -# mean=0.0, -# std=1.0, -# size=shape, -# dtype=torch.float32, -# device=device, -# ) -# return noise - -# def sample_time(self, bsize, device): -# beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0) -# time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32) -# time = time_beta * 0.999 + 0.001 -# return time - -# def embed_prefix( -# self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None -# ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: -# """Embed images with SigLIP and language tokens with embedding layer to prepare -# for SmolVLM transformer processing. -# """ -# embs = [] -# pad_masks = [] -# att_masks = [] -# for _img_idx, ( -# img, -# img_mask, -# ) in enumerate(zip(images, img_masks, strict=False)): -# if self.add_image_special_tokens: -# image_start_token = ( -# self.vlm_with_expert.embed_language_tokens( -# self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device) -# ) -# .unsqueeze(0) -# .expand(img.shape[0], -1, -1) -# ) -# image_start_mask = torch.ones_like( -# image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device -# ) -# att_masks += [0] * (image_start_mask.shape[-1]) -# embs.append(image_start_token) -# pad_masks.append(image_start_mask) - -# img_emb = self.vlm_with_expert.embed_image(img) -# img_emb = img_emb - -# # Normalize image embeddings -# img_emb_dim = img_emb.shape[-1] -# img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device) - -# bsize, num_img_embs = img_emb.shape[:2] -# img_mask = img_mask[:, None].expand(bsize, num_img_embs) - -# embs.append(img_emb) -# pad_masks.append(img_mask) - -# att_masks += [0] * (num_img_embs) -# if self.add_image_special_tokens: -# image_end_token = ( -# self.vlm_with_expert.embed_language_tokens( -# self.image_end_token.to(device=self.vlm_with_expert.vlm.device) -# ) -# .unsqueeze(0) -# .expand(img.shape[0], -1, -1) -# ) -# image_end_mask = torch.ones_like( -# image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device -# ) -# embs.append(image_end_token) -# pad_masks.append(image_end_mask) -# att_masks += [0] * (image_end_mask.shape[1]) -# lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens) -# # Normalize language embeddings -# lang_emb_dim = lang_emb.shape[-1] -# lang_emb = lang_emb * math.sqrt(lang_emb_dim) - -# embs.append(lang_emb) -# pad_masks.append(lang_masks) - -# num_lang_embs = lang_emb.shape[1] -# att_masks += [0] * num_lang_embs - -# state_emb = self.state_proj(state) -# state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb -# embs.append(state_emb) -# bsize = state_emb.shape[0] -# device = state_emb.device - -# states_seq_len = state_emb.shape[1] -# state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=device) -# pad_masks.append(state_mask) - -# # Set attention masks so that image and language inputs do not attend to state or actions -# att_masks += [1] * (states_seq_len) -# embs = torch.cat(embs, dim=1) -# pad_masks = torch.cat(pad_masks, dim=1) -# att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) -# att_masks = att_masks[None, :] - -# seq_len = pad_masks.shape[1] -# if seq_len < self.prefix_length: -# embs = pad_tensor(embs, self.prefix_length, pad_value=0) -# pad_masks = pad_tensor(pad_masks, self.prefix_length, pad_value=0) -# att_masks = pad_tensor(att_masks, self.prefix_length, pad_value=0) - -# att_masks = att_masks.expand(bsize, -1) - -# return embs, pad_masks, att_masks - -# def embed_suffix(self, noisy_actions, timestep): -# """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" -# embs = [] -# pad_masks = [] -# att_masks = [] - -# # Fuse timestep + action information using an MLP -# action_emb = self.action_in_proj(noisy_actions) -# device = action_emb.device -# bsize = action_emb.shape[0] -# dtype = action_emb.dtype -# # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] -# time_emb = create_sinusoidal_pos_embedding( -# timestep, -# self.vlm_with_expert.expert_hidden_size, -# self.config.min_period, -# self.config.max_period, -# device=device, -# ) -# time_emb = time_emb.type(dtype=dtype) - -# time_emb = time_emb[:, None, :].expand_as(action_emb) -# action_time_emb = torch.cat([action_emb, time_emb], dim=2) - -# action_time_emb = self.action_time_mlp_in(action_time_emb) -# action_time_emb = F.silu(action_time_emb) # swish == silu -# action_time_emb = self.action_time_mlp_out(action_time_emb) - -# # Add to input tokens -# embs.append(action_time_emb) - -# bsize, action_time_dim = action_time_emb.shape[:2] -# action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device) -# pad_masks.append(action_time_mask) - -# # Set attention masks so that image, language and state inputs do not attend to action tokens -# att_masks += [1] * self.config.chunk_size -# embs = torch.cat(embs, dim=1) -# pad_masks = torch.cat(pad_masks, dim=1) -# att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) -# att_masks = att_masks[None, :].expand(bsize, len(att_masks)) -# # added by jade -# seq_len = pad_masks.shape[1] -# if seq_len < self.config.chunk_size: -# embs = pad_tensor(embs, self.config.chunk_size, pad_value=0) -# pad_masks = pad_tensor(pad_masks, self.config.chunk_size, pad_value=0) -# att_masks = pad_tensor(att_masks, self.config.chunk_size, pad_value=0) -# return embs, pad_masks, att_masks - -# def forward( -# self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None -# ) -> Tensor: -# """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" -# #added by jade -# if actions.ndim == 2: -# actions = actions[:, None, :].expand(-1, self.config.chunk_size, -1) -# if noise is None: -# noise = self.sample_noise(actions.shape, actions.device) - -# if time is None: -# time = self.sample_time(actions.shape[0], actions.device) - -# time_expanded = time[:, None, None] -# x_t = time_expanded * noise + (1 - time_expanded) * actions -# u_t = noise - actions -# prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( -# images, img_masks, lang_tokens, lang_masks, state=state -# ) -# suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, time) - -# pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) -# att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) - -# att_2d_masks = make_att_2d_masks(pad_masks, att_masks) -# position_ids = torch.cumsum(pad_masks, dim=1) - 1 -# (_, suffix_out), _ = self.vlm_with_expert.forward( -# attention_mask=att_2d_masks, -# position_ids=position_ids, -# past_key_values=None, -# inputs_embeds=[prefix_embs, suffix_embs], -# use_cache=False, -# fill_kv_cache=False, -# ) -# # suffix_out = suffix_out[:, -self.config.chunk_size :] -# suffix_out = suffix_out[:, -self.config.chunk_size:, :] -# # Original openpi code, upcast attention output -# suffix_out = suffix_out.to(dtype=torch.float32) -# v_t = self.action_out_proj(suffix_out) -# losses = F.mse_loss(u_t, v_t, reduction="none") -# return losses - -# def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor: -# """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" -# bsize = state.shape[0] -# device = state.device - -# if noise is None: -# actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim) -# noise = self.sample_noise(actions_shape, device) - -# prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( -# images, img_masks, lang_tokens, lang_masks, state=state -# ) -# prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) -# prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 -# # Compute image and language key value cache -# _, past_key_values = self.vlm_with_expert.forward( -# attention_mask=prefix_att_2d_masks, -# position_ids=prefix_position_ids, -# past_key_values=None, -# inputs_embeds=[prefix_embs, None], -# use_cache=self.config.use_cache, -# fill_kv_cache=True, -# ) -# dt = -1.0 / self.config.num_steps -# dt = torch.tensor(dt, dtype=torch.float32, device=device) - -# x_t = noise -# time = torch.tensor(1.0, dtype=torch.float32, device=device) -# while time >= -dt / 2: -# expanded_time = time.expand(bsize) -# v_t = self.denoise_step( -# prefix_pad_masks, -# past_key_values, -# x_t, -# expanded_time, -# ) -# # Euler step -# x_t += dt * v_t -# time += dt -# return x_t - -# def denoise_step( -# self, -# prefix_pad_masks, -# past_key_values, -# x_t, -# timestep, -# ): -# """Apply one denoising step of the noise `x_t` at a given timestep.""" -# suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, timestep) - -# suffix_len = suffix_pad_masks.shape[1] -# batch_size = prefix_pad_masks.shape[0] -# prefix_len = prefix_pad_masks.shape[1] -# prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) - -# suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) - -# full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) -# prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] -# position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 - -# outputs_embeds, _ = self.vlm_with_expert.forward( -# attention_mask=full_att_2d_masks, -# position_ids=position_ids, -# past_key_values=past_key_values, -# inputs_embeds=[None, suffix_embs], -# use_cache=self.config.use_cache, -# fill_kv_cache=False, -# ) -# suffix_out = outputs_embeds[1] -# suffix_out = suffix_out[:, -self.config.chunk_size :] -# suffix_out = suffix_out.to(dtype=torch.float32) -# v_t = self.action_out_proj(suffix_out) -# return v_t #!/usr/bin/env python # Copyright 2025 HuggingFace Inc. team. All rights reserved. @@ -1028,7 +76,6 @@ from lerobot.policies.utils import ( ) from lerobot.utils.utils import get_safe_dtype -# OBS_STATE = 'state' # Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker _VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_") @@ -1348,7 +395,6 @@ class SmolVLAPolicy(PreTrainedPolicy): # Unpad actions original_action_dim = self.config.action_feature.shape[0] - original_action_dim = 7 actions = actions[:, :, :original_action_dim] actions = self.unnormalize_outputs({ACTION: actions})[ACTION] @@ -1892,4 +938,4 @@ class VLAFlowMatching(nn.Module): suffix_out = suffix_out[:, -self.config.chunk_size :] suffix_out = suffix_out.to(dtype=torch.float32) v_t = self.action_out_proj(suffix_out) - return v_t + return v_t \ No newline at end of file diff --git a/src/lerobot/policies/smolvla/modeling_smolvla_v2.py b/src/lerobot/policies/smolvla/modeling_smolvla_v2.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/lerobot/policies/smolvla/saver.txt b/src/lerobot/policies/smolvla/saver.txt deleted file mode 100644 index f2ad6c76f..000000000 --- a/src/lerobot/policies/smolvla/saver.txt +++ /dev/null @@ -1 +0,0 @@ -c diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index b96fbb8a3..bd131f708 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -132,10 +132,6 @@ def rollout( # Reset the policy and environments. policy.reset() - # added by jade - # for k in list(policy.config.input_features.keys()): - # if k.startswith("observation.image"): - # policy.config.input_features["observation.images." + k.split("observation.", 1)[1]] = policy.config.input_features.pop(k) observation, info = env.reset(seed=seeds) if render_callback is not None: render_callback(env) @@ -171,26 +167,6 @@ def rollout( # Infer "task" from attributes of environments. # TODO: works with SyncVectorEnv but not AsyncVectorEnv observation = add_envs_task(env, observation) - # breakpoint() - # observation = { - # k.replace("observation.images.", "observation.") if k.startswith("observation.images.") else k: v - # for k, v in observation.items() - # # } - # if "observation.image" in observation: - # observation["image"] = observation.pop("observation.image").to( - # device, non_blocking=device.type == "cuda" - # ) - - # if "observation.image2" in observation: - # observation["wrist_image"] = observation.pop("observation.image2").to( - # device, non_blocking=device.type == "cuda" - # ) - - # if "observation.state" in observation: - # observation["state"] = observation.pop("observation.state").to( - # device, non_blocking=device.type == "cuda" - # ) - with torch.inference_mode(): action = policy.select_action(observation) # Convert to CPU / numpy. @@ -550,15 +526,12 @@ def eval_main(cfg: EvalPipelineConfig): logging.info("Making environment.") envs = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) - breakpoint() + logging.info("Making policy.") policy = make_policy( cfg=cfg.policy, env_cfg=cfg.env, ) - breakpoint() - # policy, _ = load_smolvla(cfg.policy, "physical-intelligence/libero", policy) - # rename "image" -> "observation.image" policy.eval() with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): diff --git a/src/lerobot/scripts/train_2.py b/src/lerobot/scripts/train_2.py deleted file mode 100644 index 5b82ef044..000000000 --- a/src/lerobot/scripts/train_2.py +++ /dev/null @@ -1,345 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import time -from contextlib import nullcontext -from pprint import pformat -from typing import Any - -import torch -from termcolor import colored -from torch.amp import GradScaler -from torch.optim import Optimizer - -from lerobot.configs import parser -from lerobot.configs.train import TrainPipelineConfig -from lerobot.datasets.factory import make_dataset -from lerobot.datasets.sampler import EpisodeAwareSampler -from lerobot.datasets.utils import cycle -from lerobot.envs.factory import make_env -from lerobot.optim.factory import make_optimizer_and_scheduler -from lerobot.policies.factory import make_policy -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.utils import get_device_from_parameters -from lerobot.scripts.eval import eval_policy, eval_policy_multitask -from lerobot.utils.logging_utils import AverageMeter, MetricsTracker -from lerobot.utils.random_utils import set_seed -from lerobot.utils.train_utils import ( - get_step_checkpoint_dir, - get_step_identifier, - load_training_state, - save_checkpoint, - update_last_checkpoint, -) -from lerobot.utils.utils import ( - format_big_number, - get_safe_torch_device, - has_method, - init_logging, -) -from lerobot.utils.wandb_utils import WandBLogger - - -def update_policy( - train_metrics: MetricsTracker, - policy: PreTrainedPolicy, - batch: Any, - optimizer: Optimizer, - grad_clip_norm: float, - grad_scaler: GradScaler, - lr_scheduler=None, - use_amp: bool = False, - lock=None, -) -> tuple[MetricsTracker, dict]: - start_time = time.perf_counter() - device = get_device_from_parameters(policy) - policy.train() - with torch.autocast(device_type=device.type) if use_amp else nullcontext(): - loss, output_dict = policy.forward(batch) - # TODO(rcadene): policy.unnormalize_outputs(out_dict) - grad_scaler.scale(loss).backward() - - # Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**. - grad_scaler.unscale_(optimizer) - - grad_norm = torch.nn.utils.clip_grad_norm_( - policy.parameters(), - grad_clip_norm, - error_if_nonfinite=False, - ) - - # Optimizer's gradients are already unscaled, so scaler.step does not unscale them, - # although it still skips optimizer.step() if the gradients contain infs or NaNs. - with lock if lock is not None else nullcontext(): - grad_scaler.step(optimizer) - # Updates the scale for next iteration. - grad_scaler.update() - - optimizer.zero_grad() - - # Step through pytorch scheduler at every batch instead of epoch - if lr_scheduler is not None: - lr_scheduler.step() - - if has_method(policy, "update"): - # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). - policy.update() - - train_metrics.loss = loss.item() - train_metrics.grad_norm = grad_norm.item() - train_metrics.lr = optimizer.param_groups[0]["lr"] - train_metrics.update_s = time.perf_counter() - start_time - return train_metrics, output_dict - - -def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata): - """Recreate normalization layers with dataset stats if missing (Adil's workaround).""" - from lerobot.policies.normalize import Normalize, Unnormalize - - if not hasattr(dataset_meta, "stats") or not dataset_meta.stats: - print("⚠️ Dataset has no stats, skipping normalization injection.") - return - - stats = {} - for key, stat_dict in dataset_meta.stats.items(): - stats[key] = { - stat_type: torch.as_tensor(stat_array) if isinstance(stat_array, np.ndarray) else stat_array - for stat_type, stat_array in stat_dict.items() - } - - normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats) - normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats) - unnormalize_outputs = Unnormalize( - policy.config.output_features, policy.config.normalization_mapping, stats - ) - - policy.normalize_inputs = normalize_inputs - policy.normalize_targets = normalize_targets - policy.unnormalize_outputs = unnormalize_outputs - - print("✅ Normalization layers injected with dataset stats.") - - -@parser.wrap() -def train(cfg: TrainPipelineConfig): - cfg.validate() - logging.info(pformat(cfg.to_dict())) - - if cfg.wandb.enable and cfg.wandb.project: - wandb_logger = WandBLogger(cfg) - else: - wandb_logger = None - logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) - - if cfg.seed is not None: - set_seed(cfg.seed) - - # Check device is available - device = get_safe_torch_device(cfg.policy.device, log=True) - torch.backends.cudnn.benchmark = True - torch.backends.cuda.matmul.allow_tf32 = True - - logging.info("Creating dataset") - dataset = make_dataset(cfg) - - # Create environment used for evaluating checkpoints during training on simulation data. - # On real-world data, no need to create an environment as evaluations are done outside train.py, - # using the eval.py instead, with gym_dora environment and dora-rs. - eval_env = None - if cfg.eval_freq > 0 and cfg.env is not None: - logging.info("Creating env") - eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) - - logging.info("Creating policy") - policy = make_policy( - cfg=cfg.policy, - ds_meta=dataset.meta, - ) - logging.info("Creating optimizer and scheduler") - optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) - grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp) - - step = 0 # number of policy updates (forward + backward + optim) - - if cfg.resume: - step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler) - - num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) - num_total_params = sum(p.numel() for p in policy.parameters()) - - logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") - if cfg.env is not None: - logging.info(f"{cfg.env.task=}") - logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") - logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") - logging.info(f"{dataset.num_episodes=}") - logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") - logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") - - # create dataloader for offline training - if hasattr(cfg.policy, "drop_n_last_frames"): - shuffle = False - sampler = EpisodeAwareSampler( - dataset.episode_data_index, - drop_n_last_frames=cfg.policy.drop_n_last_frames, - shuffle=True, - ) - else: - shuffle = True - sampler = None - - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=cfg.num_workers, - batch_size=cfg.batch_size, - shuffle=shuffle, - sampler=sampler, - pin_memory=device.type == "cuda", - drop_last=False, - ) - dl_iter = cycle(dataloader) - - policy.train() - train_metrics = { - "loss": AverageMeter("loss", ":.3f"), - "grad_norm": AverageMeter("grdn", ":.3f"), - "lr": AverageMeter("lr", ":0.1e"), - "update_s": AverageMeter("updt_s", ":.3f"), - "dataloading_s": AverageMeter("data_s", ":.3f"), - } - - train_tracker = MetricsTracker( - cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step - ) - - logging.info("Start offline training on a fixed dataset") - for _ in range(step, cfg.steps): - start_time = time.perf_counter() - batch = next(dl_iter) - train_tracker.dataloading_s = time.perf_counter() - start_time - - for key in batch: - if isinstance(batch[key], torch.Tensor): - batch[key] = batch[key].to(device, non_blocking=device.type == "cuda") - - train_tracker, output_dict = update_policy( - train_tracker, - policy, - batch, - optimizer, - cfg.optimizer.grad_clip_norm, - grad_scaler=grad_scaler, - lr_scheduler=lr_scheduler, - use_amp=cfg.policy.use_amp, - ) - - # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we - # increment `step` here. - step += 1 - train_tracker.step() - is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 - is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps - is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 - - if is_log_step: - logging.info(train_tracker) - if wandb_logger: - wandb_log_dict = train_tracker.to_dict() - if output_dict: - wandb_log_dict.update(output_dict) - wandb_logger.log_dict(wandb_log_dict, step) - train_tracker.reset_averages() - - if cfg.save_checkpoint and is_saving_step: - logging.info(f"Checkpoint policy after step {step}") - checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) - save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler) - update_last_checkpoint(checkpoint_dir) - if wandb_logger: - wandb_logger.log_policy(checkpoint_dir) - - if cfg.env and is_eval_step: - step_id = get_step_identifier(step, cfg.steps) - logging.info(f"Eval policy at step {step}") - with ( - torch.no_grad(), - torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), - ): - if cfg.env.multitask_eval: - eval_info = eval_policy_multitask( - eval_env, - policy, - cfg.eval.n_episodes, - videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", - max_episodes_rendered=4, - start_seed=cfg.seed, - max_parallel_tasks=cfg.env.max_parallel_tasks, - ) - aggregated = eval_info["overall"]["aggregated"] - # Print per-suite stats, log? - for task_group, task_group_info in eval_info.items(): - if task_group == "overall": - continue # Skip the overall stats since we already printed it - print(f"\nAggregated Metrics for {task_group}:") - print(task_group_info["aggregated"]) - else: - eval_info = eval_policy( - eval_env, - policy, - cfg.eval.n_episodes, - videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", - max_episodes_rendered=4, - start_seed=cfg.seed, - ) - aggregated = eval_info["aggregated"] - - eval_metrics = { - "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), - "pc_success": AverageMeter("success", ":.1f"), - "eval_s": AverageMeter("eval_s", ":.3f"), - } - eval_tracker = MetricsTracker( - cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step - ) - eval_tracker.eval_s = aggregated.pop("eval_s") - eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward") - eval_tracker.pc_success = aggregated.pop("pc_success") - logging.info(eval_tracker) - if wandb_logger: - wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} - wandb_logger.log_dict(wandb_log_dict, step, mode="eval") - wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval") - - if eval_env: - if cfg.env.multitask_eval: - for _task_group, envs_dict in eval_env.items(): - for _idx, env in envs_dict.items(): - env.close() - else: - eval_env.close() - logging.info("End of training") - - if cfg.policy.push_to_hub: - policy.push_model_to_hub(cfg) - - -def main(): - init_logging() - train() - - -if __name__ == "__main__": - main() diff --git a/src/lerobot/scripts/train_accelerate.py b/src/lerobot/scripts/train_accelerate.py deleted file mode 100644 index 1e8a59a64..000000000 --- a/src/lerobot/scripts/train_accelerate.py +++ /dev/null @@ -1,366 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import time -from pprint import pformat -from typing import Any - -import torch -from accelerate import Accelerator -from accelerate.utils import set_seed as accelerate_set_seed -from termcolor import colored -from torch.optim import Optimizer - -from lerobot.configs import parser -from lerobot.configs.train import TrainPipelineConfig -from lerobot.datasets.factory import make_dataset -from lerobot.datasets.sampler import EpisodeAwareSampler -from lerobot.envs.factory import make_env -from lerobot.optim.factory import make_optimizer_and_scheduler -from lerobot.policies.factory import make_policy -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.scripts.eval import eval_policy -from lerobot.utils.logging_utils import AverageMeter, MetricsTracker -from lerobot.utils.train_utils import ( - get_step_checkpoint_dir, - get_step_identifier, - load_training_state, - save_checkpoint, - update_last_checkpoint, -) -from lerobot.utils.utils import ( - format_big_number, - has_method, - init_logging, -) - - -def update_policy( - train_metrics: MetricsTracker, - policy: PreTrainedPolicy, - batch: Any, - optimizer: Optimizer, - grad_clip_norm: float, - accelerator: Accelerator, - lr_scheduler=None, -) -> tuple[MetricsTracker, dict]: - start_time = time.perf_counter() - policy.train() - - # Use accelerator's autocast context if mixed precision is enabled - with accelerator.autocast(): - loss, output_dict = policy.forward(batch) - # TODO(rcadene): policy.unnormalize_outputs(out_dict) - - # Use accelerator for backward pass - accelerator.backward(loss) - - # Gradient clipping - accelerator handles unscaling automatically - if accelerator.sync_gradients and grad_clip_norm > 0: - grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm) - else: - grad_norm = torch.tensor(0.0) - - optimizer.step() - lr_scheduler.step() if lr_scheduler is not None else None - optimizer.zero_grad() - - # Update policy-specific buffers if needed - if has_method(policy, "update"): - policy.update() - - # Gather metrics across all processes - loss_value = accelerator.gather(loss.detach()).mean().item() - grad_norm_value = accelerator.gather(grad_norm).mean().item() - - train_metrics.loss = loss_value - train_metrics.grad_norm = grad_norm_value - train_metrics.lr = optimizer.param_groups[0]["lr"] - train_metrics.update_s = time.perf_counter() - start_time - return train_metrics, output_dict - - -@parser.wrap() -def train(cfg: TrainPipelineConfig): - cfg.validate() - logging.info(pformat(cfg.to_dict())) - - # Initialize accelerator - from accelerate.utils import DistributedDataParallelKwargs - - # added by jade 2 lines - ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) - accelerator = Accelerator(..., kwargs_handlers=[ddp_kwargs]) - - from lerobot.utils.wandb_utils import cfg_to_group, get_wandb_run_id_from_filesystem - - ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) - accelerator = Accelerator( - mixed_precision="bf16" if cfg.policy.use_amp else "no", - gradient_accumulation_steps=cfg.policy.gradient_accumulation_steps, - log_with="wandb" if cfg.wandb.enable else None, - kwargs_handlers=[ddp_kwargs], - project_dir=cfg.output_dir, - ) - - accelerator.init_trackers( - project_name=cfg.wandb.project, - init_kwargs={ - "wandb": { - "entity": cfg.wandb.entity, - "name": cfg.job_name, - "notes": cfg.wandb.notes, - "tags": cfg_to_group(cfg, return_list=True), - "dir": cfg.output_dir, - "config": cfg.to_dict(), - "save_code": False, - "job_type": "train_eval", - "mode": cfg.wandb.mode if cfg.wandb.mode in ["online", "offline", "disabled"] else "online", - "resume": "must" if cfg.resume else None, - "id": cfg.wandb.run_id - if cfg.wandb.run_id - else (get_wandb_run_id_from_filesystem(cfg.output_dir) if cfg.resume else None), - } - }, - ) - - # Set seed for reproducibility - if cfg.seed is not None: - accelerate_set_seed(cfg.seed) - - # Setup device - accelerator handles device placement - torch.backends.cudnn.benchmark = True - torch.backends.cuda.matmul.allow_tf32 = True - - # Create dataset - if accelerator.is_main_process: - logging.info("Creating dataset") - dataset = make_dataset(cfg) - print("c") - # Create evaluation environment (only on main process) - eval_env = None - if cfg.eval_freq > 0 and cfg.env is not None and accelerator.is_main_process: - logging.info("Creating env") - eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) - - # Create policy - if accelerator.is_main_process: - logging.info("Creating policy") - - # Use accelerator's device instead of cfg.policy.device - with accelerator.main_process_first(): - policy = make_policy( - cfg=cfg.policy, - ds_meta=dataset.meta, - ) - - # Create optimizer and scheduler - if accelerator.is_main_process: - logging.info("Creating optimizer and scheduler") - optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) - - step = 0 # number of policy updates - - if cfg.resume: - step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler) - - # Prepare dataloader - if hasattr(cfg.policy, "drop_n_last_frames"): - shuffle = False - sampler = EpisodeAwareSampler( - dataset.episode_data_index, - drop_n_last_frames=cfg.policy.drop_n_last_frames, - shuffle=True, - ) - else: - shuffle = True - sampler = None - - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=cfg.num_workers, - batch_size=cfg.batch_size, - shuffle=shuffle, - sampler=sampler, - pin_memory=True, - drop_last=True, # Important for distributed training - ) - - # Prepare for distributed training - policy, optimizer, dataloader, lr_scheduler = accelerator.prepare( - policy, optimizer, dataloader, lr_scheduler - ) - - # Log training info (only on main process) - if accelerator.is_main_process: - num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) - num_total_params = sum(p.numel() for p in policy.parameters()) - - logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") - if cfg.env is not None: - logging.info(f"{cfg.env.task=}") - logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") - logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") - logging.info(f"{dataset.num_episodes=}") - logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") - logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") - logging.info(f"Number of processes: {accelerator.num_processes}") - logging.info(f"Device: {accelerator.device}") - logging.info(f"Mixed precision: {accelerator.mixed_precision}") - - # Create metrics trackers - train_metrics = { - "loss": AverageMeter("loss", ":.3f"), - "grad_norm": AverageMeter("grdn", ":.3f"), - "lr": AverageMeter("lr", ":0.1e"), - "update_s": AverageMeter("updt_s", ":.3f"), - "dataloading_s": AverageMeter("data_s", ":.3f"), - } - - train_tracker = MetricsTracker( - cfg.batch_size * accelerator.num_processes, # Account for all processes - dataset.num_frames, - dataset.num_episodes, - train_metrics, - initial_step=step, - ) - - # Training loop - policy.train() - if accelerator.is_main_process: - logging.info("Start offline training on a fixed dataset") - - # Create iterator from dataloader - dl_iter = iter(dataloader) - - for current_step in range(step, cfg.steps): - start_time = time.perf_counter() - # Get next batch, cycling through dataloader if needed - try: - batch = next(dl_iter) - print("data laoder batch keys: ", batch.keys()) - breakpoint() - except StopIteration: - dl_iter = iter(dataloader) - batch = next(dl_iter) - train_tracker.dataloading_s = time.perf_counter() - start_time - # Update policy - train_tracker, output_dict = update_policy( - train_tracker, - policy, - batch, - optimizer, - cfg.optimizer.grad_clip_norm, - accelerator, - lr_scheduler=lr_scheduler, - ) - - # Increment step counter - step += 1 - train_tracker.step() - - # Determine if we should log, save, or evaluate - is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 - is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps - is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 - - # Logging (only on main process) - if is_log_step and accelerator.is_main_process: - logging.info(train_tracker) - wandb_log_dict = train_tracker.to_dict() - if output_dict: - wandb_log_dict.update(output_dict) - for k, v in wandb_log_dict.items(): - accelerator.log({f"{'train'}/{k}": v}, step=step) - train_tracker.reset_averages() - - # Checkpointing (only on main process) - if cfg.save_checkpoint and is_saving_step: - # ✅ all processes wait here - accelerator.wait_for_everyone() - - if accelerator.is_main_process: - logging.info(f"Checkpoint policy after step {step}") - checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) - - unwrapped_policy = accelerator.unwrap_model(policy) - save_checkpoint(checkpoint_dir, step, cfg, unwrapped_policy, optimizer, lr_scheduler) - update_last_checkpoint(checkpoint_dir) - - # ✅ all processes sync again after saving - accelerator.wait_for_everyone() - - # if wandb_logger: - # wandb_logger.log_policy(checkpoint_dir) - - # Evaluation (only on main process) - if cfg.env and is_eval_step and accelerator.is_main_process: - step_id = get_step_identifier(step, cfg.steps) - logging.info(f"Eval policy at step {step}") - - # Unwrap model for evaluation - unwrapped_policy = accelerator.unwrap_model(policy) - unwrapped_policy.eval() - - with torch.no_grad(): - eval_info = eval_policy( - eval_env, - unwrapped_policy, - cfg.eval.n_episodes, - videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", - max_episodes_rendered=4, - start_seed=cfg.seed, - ) - - eval_metrics = { - "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), - "pc_success": AverageMeter("success", ":.1f"), - "eval_s": AverageMeter("eval_s", ":.3f"), - } - eval_tracker = MetricsTracker( - cfg.batch_size * accelerator.num_processes, - dataset.num_frames, - dataset.num_episodes, - eval_metrics, - initial_step=step, - ) - eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s") - eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward") - eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success") - logging.info(eval_tracker) - - wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} - for k, v in wandb_log_dict.items(): - accelerator.log({f"{'eval'}/{k}": v}, step=step) - - # Set back to training mode - policy.train() - - # Wait for all processes to finish - accelerator.wait_for_everyone() - - # Cleanup - if eval_env and accelerator.is_main_process: - eval_env.close() - - if accelerator.is_main_process: - logging.info("End of training") - accelerator.end_training() # added by jade - - -if __name__ == "__main__": - init_logging() - train()