remove files

This commit is contained in:
Jade Choghari
2025-09-11 14:18:09 +02:00
parent a19d7fb6bf
commit 4c2add41d7
15 changed files with 3 additions and 4334 deletions

View File

@@ -62,7 +62,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = False
gradient_accumulation_steps: int = 1
push_to_hub: bool = True
repo_id: str | None = None

View File

@@ -597,7 +597,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if self.episodes is None:
path = str(self.root / "data")
# added by jade
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
else:
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]

View File

@@ -455,8 +455,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
shape = (shape[2], shape[0], shape[1])
elif key == "observation.environment_state":
type = FeatureType.ENV
# changed by jade
elif key.startswith("observation") or key.startswith("state"):
elif key.startswith("observation"):
type = FeatureType.STATE
elif key.startswith("action"):
type = FeatureType.ACTION

View File

@@ -31,7 +31,6 @@ from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
@@ -75,10 +74,6 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "smolpi0":
from lerobot.policies.smolpi0.modeling_smolpi0 import SMOLPI0Policy
return SMOLPI0Policy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -102,8 +97,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return SmolVLAConfig(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
elif policy_type == "smolpi0":
return SMOLPI0Config(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")

View File

@@ -255,85 +255,6 @@ class Unnormalize(nn.Module):
return batch
class NormalizePerRobotType(nn.Module):
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
and values are dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.features = features
self.norm_map = norm_map
for robot_type in stats.keys():
stats_buffers = create_stats_buffers(features, norm_map, stats[robot_type])
for key, buffer in stats_buffers.items():
setattr(self, f"{robot_type}_buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
assert "robot_type" in batch, "robot_type is not in the batch"
robot_types = batch["robot_type"]
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
# FIXME(mshukor): make it more efficient
buffers = [
getattr(self, f"{robot_type}_buffer_" + key.replace(".", "_")) for robot_type in robot_types
]
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.stack([buffers[i]["mean"] for i in range(len(robot_types))], dim=0)
std = torch.stack([buffers[i]["std"] for i in range(len(robot_types))], dim=0)
if batch[key].ndim == 3:
mean = mean.unsqueeze(1)
std = std.unsqueeze(1)
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif norm_mode is NormalizationMode.MIN_MAX:
min = torch.stack([buffers[i]["min"] for i in range(len(robot_types))], dim=0)
max = torch.stack([buffers[i]["max"] for i in range(len(robot_types))], dim=0)
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
if batch[key].ndim == 3:
min = min.unsqueeze(1)
max = max.unsqueeze(1)
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min + 1e-8)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:
raise ValueError(norm_mode)
return batch
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
# and remove the `Normalize` and `Unnormalize` classes.
def _initialize_stats_buffers(
@@ -497,87 +418,3 @@ class UnnormalizeBuffer(nn.Module):
raise ValueError(norm_mode)
return batch
class UnnormalizePerRobotType(nn.Module):
"""
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
original range used by the environment.
"""
def __init__(
self,
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
and values are dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.features = features
self.norm_map = norm_map
self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
for robot_type in stats.keys():
stats_buffers = create_stats_buffers(features, norm_map, stats[robot_type])
for key, buffer in stats_buffers.items():
setattr(self, f"{robot_type}_buffer_" + key.replace(".", "_"), buffer)
# TODO(rcadene): should we remove torch.no_grad?
@torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
assert "robot_type" in batch, "robot_type is not in the batch"
robot_types = batch["robot_type"]
for key, ft in self.features.items():
if key not in batch:
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
# buffer = getattr(self, "buffer_" + key.replace(".", "_"))
buffers = [
getattr(self, f"{robot_type}_buffer_" + key.replace(".", "_")) for robot_type in robot_types
]
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.stack([buffers[i]["mean"] for i in range(len(robot_types))], dim=0)
std = torch.stack([buffers[i]["std"] for i in range(len(robot_types))], dim=0)
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
if batch[key].ndim == 3:
mean = mean.unsqueeze(1)
std = std.unsqueeze(1)
batch[key] = batch[key] * std + mean
elif norm_mode is NormalizationMode.MIN_MAX:
min = torch.stack([buffers[i]["min"] for i in range(len(robot_types))], dim=0)
max = torch.stack([buffers[i]["max"] for i in range(len(robot_types))], dim=0)
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
if batch[key].ndim == 3:
min = min.unsqueeze(1)
max = max.unsqueeze(1)
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
raise ValueError(norm_mode)
return batch

View File

@@ -1,210 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
@dataclass
class PEFTConfig:
r: int = 4
lora_alpha: int = 16
lora_dropout: float = 0.1
target_modules: str = "q_proj,v_proj"
@PreTrainedConfig.register_subclass("smolpi0")
@dataclass
class SMOLPI0Config(PreTrainedConfig):
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
n_obs_gap: int = 1
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Shorter state and action vectors will be padded
max_state_dim: int = 32
max_action_dim: int = 32
# Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (512, 512) # (224, 224)
# Add empty images. Used by pi0_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera.
empty_cameras: int = 0
# Converts the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False
# Tokenizer
tokenizer_max_length: int = 48
# Projector
proj_width: int = 480
# Decoding
num_steps: int = 10
# Attention utils
use_cache: bool = True
attention_implementation: str = "eager" # or fa2, flex
# Finetuning settings
freeze_vision_encoder: bool = True
train_expert_only: bool = False
train_state_proj: bool = True
# Training presets
optimizer_lr: float = 2.5e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_grad_clip_norm: float = 10
optimizer_lr_vlm: float = 0
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
# TODO: Add EMA
vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct"
checkpoint_path: str = None
load_vlm_weights: bool = False
peft_method: str = ""
peft_config: PEFTConfig = PEFTConfig()
peft_target_model: str = ""
add_image_special_tokens: bool = False
add_prompt_template: bool = False
prefix_prompt_template: str = "<|im_start|>User: What action should the robot take to"
suffix_prompt_template: str = "?\nAssistant:"
attention_mode: str = "self_attn"
prefix_length: int = -1 # n_obs_steps * num_cameras * num_image_token_per_image + tokenizer_max_length
past_obs_keys: str = "image"
add_local_special_image_tokens: bool = False
reverse_images_order: bool = False
state_to_prefix: bool = False
pad_language_to: str = "longest" # "max_length"
num_expert_layers: int = -1
num_vlm_layers: int = -1
causal_action_attention_mask: bool = False
self_attn_every_n_layers: int = -1
expert_width_multiplier: float = 0.5
robot_type: str = ""
self_attn_only_actions: bool = False
causal_attention_on_history: bool = False
predict_relative_actions: bool = False
relative_actions_mode: str = "first"
shuffle_camera_positions: bool = False
vlm_img_size: int = -1
regression_loss: bool = False
def __post_init__(self):
super().__post_init__()
if self.vlm_img_size > 0:
self.resize_imgs_with_padding = (self.vlm_img_size, self.vlm_img_size)
"""Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
# if self.n_obs_steps != 1:
# raise ValueError(
# f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
# )
if self.use_delta_joint_actions_aloha:
raise NotImplementedError(
"`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
)
def validate_features(self) -> None:
# TODO: implement value error
# if not self.image_features and not self.env_state_feature:
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
)
self.input_features[key] = empty_camera
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list: # FIXME(mshukor): support spacing between observations
return [-k for k in range(0, self.n_obs_steps * self.n_obs_gap, self.n_obs_gap)][::-1]
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -1,145 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F # noqa: N812
from packaging.version import Version
if Version(torch.__version__) > Version("2.5.0"):
# Ffex attention is only available from torch 2.5 onwards
from torch.nn.attention.flex_attention import (
_mask_mod_signature,
_round_up_to_multiple,
create_block_mask,
create_mask,
flex_attention,
)
@torch.compile(dynamic=False)
def flex_attention_forward(
attention_mask: torch.Tensor,
batch_size: int,
head_dim: int,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
scaling=None,
num_att_heads: int = 8,
num_key_value_heads: int = 1,
):
"""
This is defined out of classes to make compile happy.
"""
original_dtype = query_states.dtype
num_key_value_groups = num_att_heads // num_key_value_heads
key_states = key_states[:, :, :, None, :]
key_states = key_states.expand(
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
)
key_states = key_states.reshape(
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
)
value_states = value_states[:, :, :, None, :]
value_states = value_states.expand(
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
)
value_states = value_states.reshape(
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# query_states = query_states.to(torch.float32)
# key_states = key_states.to(torch.float32)
# value_states = value_states.to(torch.float32)
causal_mask = attention_mask
if causal_mask is not None:
causal_mask = causal_mask[:, None, :, : key_states.shape[2]]
if causal_mask.shape[1] == 1 and query_states.shape[1] > 1:
causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1)
def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature:
def mask_mod(b, h, q_idx, kv_idx):
# Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs.
return precomputed_mask[b][h][q_idx][kv_idx]
return mask_mod
b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
block_size = 128 # limitation of flex attention
q_len_rounded = _round_up_to_multiple(q_len, block_size)
kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
# *CRITICAL* we do need to expand here, else we get a CUDA index error
pad_q = q_len_rounded - q_len
pad_k = kv_len_rounded - kv_len
if pad_q > 0 or pad_k > 0:
padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
else:
padded_causal_mask = causal_mask
mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
mask_4d = create_mask(
mod_fn=mask_mod_fn_orig,
B=b_mask,
H=h_mask,
Q_LEN=q_len_rounded,
KV_LEN=kv_len_rounded,
device=causal_mask.device,
)
mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
# FIXME(mshukor): compile mask torch.compile(create_block_mask)
create_block_mask_compiled = torch.compile(create_block_mask)
block_mask = create_block_mask_compiled(
mask_mod=mask_mod_fn_padded,
B=b_mask,
H=None, #
Q_LEN=q_len_rounded,
KV_LEN=kv_len_rounded,
BLOCK_SIZE=block_size,
device=causal_mask.device,
_compile=False,
)
padded_query_states = F.pad(query_states, (0, 0, 0, pad_q), value=0.0) if pad_q > 0 else query_states
padded_key_states = F.pad(key_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else key_states
padded_value_states = F.pad(value_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else value_states
# mask is applied inside the kernel, ideally more efficiently than score_mod.
attn_output, attention_weights = flex_attention(
padded_query_states,
padded_key_states,
padded_value_states,
block_mask=block_mask,
enable_gqa=True, # because we shaped query/key states for GQA
scale=head_dim**-0.5 if scaling is None else scaling,
return_lse=True,
)
attn_output = attn_output.to(dtype=original_dtype)
attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim]
attn_output = attn_output.reshape(
batch_size,
-1,
attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim]
)
return attn_output[:, :-pad_k, :] if pad_k > 0 else attn_output

File diff suppressed because it is too large Load Diff

View File

@@ -1,920 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from functools import partial
from typing import List, Optional, Union
import torch
import torch.nn.functional as F # noqa: N812
import torch.version
from peft import LoraConfig, TaskType, get_peft_model
from pytest import Cache
from torch import nn
from transformers import (
AutoConfig,
AutoModel,
AutoModelForImageTextToText,
AutoModelForVision2Seq,
AutoProcessor,
SmolVLMForConditionalGeneration,
)
from lerobot.policies.smolpi0.flex_attention import flex_attention_forward
def _round_up_to_multiple(x, multiple):
return (x + multiple - 1) // multiple * multiple
def apply_rope(x, positions, max_wavelength=10_000):
"""
Applies RoPE positions [B, L] to x [B, L, H, D].
"""
d_half = x.shape[-1] // 2
device = x.device
dtype = x.dtype
x = x.to(torch.float32)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
timescale = max_wavelength**freq_exponents
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
radians = radians[..., None, :]
sin = torch.sin(radians) # .to(dtype=dtype)
cos = torch.cos(radians) # .to(dtype=dtype)
x1, x2 = x.split(d_half, dim=-1)
res = torch.empty_like(x)
res[..., :d_half] = x1 * cos - x2 * sin
res[..., d_half:] = x2 * cos + x1 * sin
return res.to(dtype)
# class SmolVLMWithExpertConfig(PretrainedConfig):
# model_type = "SmolVLMWithExpertModel"
# sub_configs = {"smolvlm_config": AutoConfig, "lm_expert_config": AutoConfig}
# def __init__(
# self,
# smolvlm_config: dict | None = None,
# lm_expert_config: dict | None = None,
# freeze_vision_encoder: bool = True,
# train_expert_only: bool = True,
# attention_implementation: str = "eager",
# load_vlm_weights: bool = False,
# **kwargs,
# ):
# self.load_vlm_weights = load_vlm_weights
# self.freeze_vision_encoder = freeze_vision_encoder
# self.train_expert_only = train_expert_only
# self.attention_implementation = attention_implementation
# if smolvlm_config is None:
# # Default config from Pi0
# self.smolvlm_config = CONFIG_MAPPING["smolvlm"](
# transformers_version="4.48.1",
# _vocab_size=257152,
# bos_token_id=2,
# eos_token_id=1,
# hidden_size=2048,
# image_token_index=257152,
# model_type="smolvlm",
# pad_token_id=0,
# projection_dim=2048,
# text_config={
# "hidden_activation": "gelu_pytorch_tanh",
# "hidden_size": 2048,
# "intermediate_size": 16384,
# "model_type": "gemma",
# "num_attention_heads": 8,
# "num_hidden_layers": 18,
# "num_image_tokens": 256,
# "num_key_value_heads": 1,
# "torch_dtype": "float32",
# "vocab_size": 257152,
# },
# vision_config={
# "hidden_size": 1152,
# "intermediate_size": 4304,
# "model_type": "siglip_vision_model",
# "num_attention_heads": 16,
# "num_hidden_layers": 27,
# "num_image_tokens": 256,
# "patch_size": 14,
# "projection_dim": 2048,
# "projector_hidden_act": "gelu_fast",
# "torch_dtype": "float32",
# "vision_use_head": False,
# },
# )
# elif isinstance(self.paligemma_config, dict):
# # Override Pi0 default config for PaliGemma
# if "model_type" not in gemma_expert_config:
# paligemma_config["model_type"] = "paligemma"
# cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
# self.paligemma_config = cfg_cls(**paligemma_config)
# if gemma_expert_config is None:
# # Default config from Pi0
# self.gemma_expert_config = CONFIG_MAPPING["gemma"](
# attention_bias=False,
# attention_dropout=0.0,
# bos_token_id=2,
# eos_token_id=1,
# head_dim=256,
# hidden_act="gelu_pytorch_tanh",
# hidden_activation="gelu_pytorch_tanh",
# hidden_size=1024,
# initializer_range=0.02,
# intermediate_size=4096,
# max_position_embeddings=8192,
# model_type="gemma",
# num_attention_heads=8,
# num_hidden_layers=18,
# num_key_value_heads=1,
# pad_token_id=0,
# rms_norm_eps=1e-06,
# rope_theta=10000.0,
# torch_dtype="float32",
# transformers_version="4.48.1",
# use_cache=True,
# vocab_size=257152,
# )
# elif isinstance(self.gemma_expert_config, dict):
# # Override Pi0 default config for Gemma Expert
# if "model_type" not in gemma_expert_config:
# gemma_expert_config["model_type"] = "gemma"
# cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
# self.gemma_expert_config = cfg_cls(**gemma_expert_config)
# super().__init__(**kwargs)
# def __post_init__(self):
# super().__post_init__()
# if self.train_expert_only and not self.freeze_vision_encoder:
# raise ValueError(
# "You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
# )
# if self.attention_implementation not in ["eager", "fa2", "flex"]:
# raise ValueError(
# f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
# )
def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
return hidden_dim
class SmolVLMWithExpertModel(nn.Module):
# config_class = PaliGemmaWithExpertConfig
def __init__(
self,
model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
load_vlm_weights: bool = True,
train_expert_only: bool = True,
freeze_vision_encoder: bool = False,
attention_implementation: str = "eager",
attention_mode: str = "self_attn",
num_expert_layers: int = -1,
num_vlm_layers: int = -1,
self_attn_every_n_layers: int = -1,
expert_width_multiplier: float = 0.5,
self_attn_only_actions: bool = False,
):
super().__init__()
if load_vlm_weights:
print(f"Loading {model_id} weights ...")
if "SmolVLM-" in model_id:
self.vlm = AutoModelForVision2Seq.from_pretrained(
model_id,
device_map="cuda",
torch_dtype="bfloat16",
low_cpu_mem_usage=True,
)
else:
# model_id = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct"
self.vlm = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map="cuda",
torch_dtype="bfloat16",
low_cpu_mem_usage=True,
# attn_implementation="eager",
# attn_implementation="flash_attention_2"
)
config = self.vlm.config
else:
config = AutoConfig.from_pretrained(model_id)
self.vlm = SmolVLMForConditionalGeneration(config=config)
self.processor = AutoProcessor.from_pretrained(model_id)
if num_vlm_layers > 0:
print(f"Reducing the number of VLM layers to {num_vlm_layers} ...")
self.get_vlm_model().text_model.layers = self.get_vlm_model().text_model.layers[:num_vlm_layers]
self.num_vlm_layers = len(self.get_vlm_model().text_model.layers)
self.config = config
# Smaller lm expert
lm_expert_config = copy.deepcopy(config.text_config)
hidden_size = lm_expert_config.hidden_size
lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
lm_expert_config.num_hidden_layers = self.num_vlm_layers
if num_expert_layers > 0:
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, (
f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
)
lm_expert_config.num_hidden_layers = num_expert_layers
# lm_expert_config.head_dim = lm_expert_config.head_dim * 2
self.lm_expert = AutoModel.from_config(lm_expert_config)
self.num_expert_layers = len(self.lm_expert.layers)
self.self_attn_every_n_layers = self_attn_every_n_layers
self.self_attn_only_actions = self_attn_only_actions
if "cross" in attention_mode:
# Reshape qkv projections to have the same input dimension as the vlm
for layer_idx in range(len(self.lm_expert.layers)):
if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0:
continue
self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear(
config.text_config.num_key_value_heads * config.text_config.head_dim,
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
bias=lm_expert_config.attention_bias,
)
self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear(
config.text_config.num_key_value_heads * config.text_config.head_dim,
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
bias=lm_expert_config.attention_bias,
)
# Remove unused embed_tokens
self.lm_expert.embed_tokens = None
self.num_attention_heads = self.config.text_config.num_attention_heads
self.num_key_value_heads = self.config.text_config.num_key_value_heads
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
self.attention_implementation = attention_implementation
self.attention_mode = attention_mode
self.expert_hidden_size = lm_expert_config.hidden_size
# self.to_bfloat16_like_physical_intelligence()
self.set_requires_grad()
def configure_peft(self, config):
# return model
self.peft_method = config.peft_method
self.peft_target_model = config.peft_target_model
if "lora" in self.peft_method:
peft_config = config.peft_config
target_modules = peft_config.target_modules
if not isinstance(target_modules, list):
target_modules = target_modules.split(",")
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, # Based on the task type (e.g., language modeling, etc.)
r=peft_config.r, # The rank of the low-rank adaptation
lora_alpha=peft_config.lora_alpha, # Scaling factor
lora_dropout=peft_config.lora_dropout, # Dropout applied to LoRA layers
target_modules=target_modules, # The components where LoRA is applied
exclude_modules=[
"lm_expert",
"model.lm_expert.model.layers",
], # FIXME(mshukor): this does not work for now
)
self.lora_config = lora_config
# Apply LoRA and ensure only LoRA parameters are trainable
if "text" in self.peft_target_model:
self.get_vlm_model().text_model = get_peft_model(self.get_vlm_model().text_model, lora_config)
else:
self.vlm = get_peft_model(self.vlm, lora_config)
# assert config.train_expert_only, "Backbone should be frozen and only lora parameters are " # FIXME(mshukor): handle this here?
for name, param in self.vlm.named_parameters():
if (
"lora" in name and "text_model.model.layers.17" not in name
): # lm_head is not a parameter in most LLMs becasue it's tied to the embedding layer
param.requires_grad = True
else:
param.requires_grad = False
def merge_lora_weights(self):
"""
Merge LoRA weights into the base model.
"""
if "text" in self.peft_target_model:
self.get_vlm_model().text_model = self.get_vlm_model().text_model.merge_and_unload()
else:
self.vlm = self.vlm.merge_and_unload()
def get_vlm_model(
self,
):
if hasattr(self.vlm.model, "model"): # When using peft
return self.vlm.model.model
else:
return self.vlm.model
def set_requires_grad(self):
if self.freeze_vision_encoder:
self.get_vlm_model().vision_model.eval()
for params in self.get_vlm_model().vision_model.parameters():
params.requires_grad = False
if self.train_expert_only:
self.vlm.eval()
for params in self.vlm.parameters():
params.requires_grad = False
else:
# To avoid unused params issue with distributed training
last_layers = [self.num_vlm_layers - 1]
if (
self.num_vlm_layers != self.num_expert_layers
and self.num_vlm_layers % self.num_expert_layers == 0
):
last_layers.append(self.num_vlm_layers - 2)
frozen_layers = [
"lm_head",
"text_model.model.norm.weight",
]
for layer in last_layers:
frozen_layers.append(f"text_model.model.layers.{layer}.")
for name, params in self.vlm.named_parameters():
if any([k in name for k in frozen_layers]):
params.requires_grad = False
# To avoid unused params issue with distributed training
for name, params in self.lm_expert.named_parameters():
if any(
[
k in name
for k in [
"lm_head",
]
]
):
params.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.get_vlm_model().vision_model.eval()
if self.train_expert_only:
self.vlm.eval()
# def to_bfloat16_like_physical_intelligence(self):
# self.vlm = self.vlm.to(dtype=torch.bfloat16)
# params_to_change_dtype = [
# "language_model.model.layers",
# "gemma_expert.model.layers",
# "vision_tower",
# "multi_modal",
# ]
# for name, param in self.named_parameters():
# if any(selector in name for selector in params_to_change_dtype):
# param.data = param.data.to(dtype=torch.bfloat16)
def embed_image(self, image: torch.Tensor):
patch_attention_mask = None
# # FIXME(mshukor): probably not needed as we don't have padded images here
# pixel_values = image.unsqueeze(1)
# batch_size, num_images, num_channels, height, width = pixel_values.shape
# pixel_values = pixel_values
# pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
# # Remove padding images - padding images are full 0.
# nb_values_per_image = pixel_values.shape[1:].numel()
# real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
# if not any(real_images_inds):
# # no images, leave one empty image.
# real_images_inds[0] = True
# pixel_values = pixel_values[real_images_inds].contiguous()
# # Handle the vision attention mask
# pixel_attention_mask = torch.ones(
# size=[pixel_values.shape[i] for i in (0, 2, 3)],
# dtype=torch.bool,
# device=pixel_values.device,
# )
# patch_size = self.vlm.config.vision_config.patch_size
# patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
# patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
# patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
# FIXME(mshukor): add special image tokens specific to smolvlm
# Get sequence from the vision encoder
image_hidden_states = (
self.get_vlm_model()
.vision_model(
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
patch_attention_mask=patch_attention_mask,
)
.last_hidden_state
)
# Modality projection & resampling
image_hidden_states = self.get_vlm_model().connector(image_hidden_states)
return image_hidden_states
def embed_language_tokens(self, tokens: torch.Tensor):
return self.get_vlm_model().text_model.get_input_embeddings()(tokens)
def forward_attn_layer(
self,
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache: bool = True,
fill_kv_cache: bool = True,
past_key_values=None,
) -> list[torch.Tensor]:
query_states = []
key_states = []
value_states = []
for i, hidden_states in enumerate(inputs_embeds):
layer = model_layers[i][layer_idx]
if hidden_states is None or layer is None:
continue
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
# hidden_states = hidden_states * normalizer
hidden_states = layer.input_layernorm(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# FIXME(mshukor): self attention always when having only the prefix
# B,L,H,D with L sequence length, H number of heads, D head dim
# concatenate on the number of embeddings/tokens
query_states = torch.cat(query_states, dim=1)
key_states = torch.cat(key_states, dim=1)
value_states = torch.cat(value_states, dim=1)
# FIXME(mshukor): seq should be B, H, L, D ?
seq_len = query_states.shape[1]
if seq_len < position_ids.shape[1]:
_position_ids = position_ids[:, :seq_len]
_attention_mask = attention_mask[:, :seq_len, :seq_len]
else:
_position_ids = position_ids
_attention_mask = attention_mask
if self.self_attn_only_actions:
attention_mask_ = _attention_mask.clone()
position_ids_ = _position_ids.clone()
if inputs_embeds[1] is not None:
suffix_len = inputs_embeds[1].shape[1]
attention_mask_[:, -suffix_len:, :-suffix_len] = False
position_ids_[:, -suffix_len:] = (
_position_ids[:, -suffix_len:] - _position_ids[:, -suffix_len][:, None]
)
else:
attention_mask_ = _attention_mask
position_ids_ = _position_ids
query_states = apply_rope(
query_states, position_ids_
) # FIXME(mshukor): this assumes we have always the vlm features?
key_states = apply_rope(key_states, position_ids_)
if use_cache and past_key_values is None:
past_key_values = {}
if use_cache:
if fill_kv_cache:
past_key_values[layer_idx] = {
"key_states": key_states,
"value_states": value_states,
}
else:
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1)
attention_interface = self.get_attention_interface()
att_output = attention_interface(
attention_mask_, batch_size, head_dim, query_states, key_states, value_states
)
# att_output = att_output.to(dtype=models[i].dtype)
return [att_output], past_key_values
def forward_cross_attn_layer(
self,
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache: bool = True,
fill_kv_cache: bool = True,
past_key_values=None,
) -> list[torch.Tensor]:
attention_interface = self.get_attention_interface()
att_outputs = []
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), (
f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
)
if len(inputs_embeds) == 2 and not past_key_values:
# Prefix attention
seq_len = inputs_embeds[0].shape[1]
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
prefix_attention_mask = attention_mask[:, :seq_len, :seq_len]
layer = model_layers[0][layer_idx]
hidden_states = layer.input_layernorm(inputs_embeds[0])
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
# B,L,H,D with L sequence length, H number of heads, D head dim
query_states = apply_rope(query_state, position_id)
key_states = apply_rope(key_state, position_id)
att_output = attention_interface(
prefix_attention_mask, batch_size, head_dim, query_states, key_states, value_states
)
att_outputs.append(att_output)
else:
expert_position_id = position_ids
if use_cache and past_key_values is None:
past_key_values = {}
if use_cache:
if fill_kv_cache:
past_key_values[layer_idx] = {
"key_states": key_states,
"value_states": value_states,
}
else:
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = past_key_values[layer_idx]["key_states"]
value_states = past_key_values[layer_idx]["value_states"]
# Expert
expert_layer = model_layers[1][layer_idx]
if expert_layer is not None:
expert_hidden_states = expert_layer.input_layernorm(inputs_embeds[1])
expert_input_shape = expert_hidden_states.shape[:-1]
expert_hidden_shape = (*expert_input_shape, -1, expert_layer.self_attn.head_dim)
expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype)
expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape)
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(
*key_states.shape[:2], -1
)
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(
*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim
) # k_proj should have same dim as kv
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(
*value_states.shape[:2], -1
)
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(
*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim
)
expert_position_id = (
expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values
) # start from 0
expert_attention_mask = attention_mask[
:, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] :
] # take into account kv
expert_query_states = apply_rope(expert_query_state, expert_position_id)
# expert_key_states = apply_rope(expert_key_state, expert_position_id)
att_output = attention_interface(
expert_attention_mask,
batch_size,
head_dim,
expert_query_states,
expert_key_states,
expert_value_states,
)
att_outputs.append(att_output)
else:
att_outputs.append(None)
# att_output = att_output.to(dtype=models[i].dtype)
return att_outputs, past_key_values
def get_model_layers(self, models: list) -> list: # FIXME(mshukor): is this efficient?
vlm_layers = []
expert_layers = []
multiple_of = self.num_vlm_layers // self.num_expert_layers
for i in range(self.num_vlm_layers):
if multiple_of > 0 and i > 0 and i % multiple_of != 0:
expert_layer = None
else:
expert_layer_index = i // multiple_of if multiple_of > 0 else i
expert_layer = models[1].layers[expert_layer_index]
vlm_layers.append(models[0].layers[i])
expert_layers.append(expert_layer)
return [vlm_layers, expert_layers]
# TODO: break down this huge forward into modules or functions
def forward(
self,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | Cache | None = None,
inputs_embeds: list[torch.FloatTensor] = None,
use_cache: bool | None = None,
fill_kv_cache: bool | None = None,
):
models = [self.get_vlm_model().text_model, self.lm_expert]
model_layers = self.get_model_layers(models)
for hidden_states in inputs_embeds:
# TODO this is very inefficient
# dtype is always the same, batch size too (if > 1 len)
# device could be trickier in multi gpu edge cases but that's it
if hidden_states is None:
continue
batch_size = hidden_states.shape[0]
# # Pad prefix embds so that prefix_embs + prefix_embs len are multiple of 128, pad left or right depending on the gen or train
if self.attention_implementation == "flex":
if (
inputs_embeds[0] is not None
and inputs_embeds[1] is not None
and attention_mask.shape[-1] == attention_mask.shape[-2]
and past_key_values is None
): # Now only during training
seq_len = inputs_embeds[0].shape[1] + inputs_embeds[1].shape[1]
padded_seq_len = _round_up_to_multiple(
seq_len, 128
) # FIXME(mshukor): more efficient to have a fixed seq len?
b_mask, q_len, kv_len = attention_mask.shape # The shape of your mask
pad = padded_seq_len - q_len
attention_mask = F.pad(attention_mask, (0, pad, 0, pad), value=True)
inputs_embeds[0] = F.pad(inputs_embeds[0], (0, 0, 0, pad), value=0.0)
position_ids = F.pad(position_ids, (0, pad), value=0)
# RMSNorm
num_layers = self.num_vlm_layers
head_dim = self.vlm.config.text_config.head_dim
for layer_idx in range(num_layers):
if (
fill_kv_cache
or "cross" not in self.attention_mode
or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0)
):
att_outputs, past_key_values = self.forward_attn_layer(
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache=use_cache,
fill_kv_cache=fill_kv_cache,
past_key_values=past_key_values,
)
else:
att_outputs, past_key_values = self.forward_cross_attn_layer(
model_layers,
inputs_embeds,
layer_idx,
position_ids,
attention_mask,
batch_size,
head_dim,
use_cache=use_cache,
fill_kv_cache=fill_kv_cache,
past_key_values=past_key_values,
)
# query_states = []
# key_states = []
# value_states = []
# for i, hidden_states in enumerate(inputs_embeds):
# if hidden_states is None:
# continue
# layer = models[i].layers[layer_idx]
# # normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
# # hidden_states = hidden_states * normalizer
# hidden_states = layer.input_layernorm(hidden_states)
# input_shape = hidden_states.shape[:-1]
# hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
# hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
# query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
# key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
# value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
# query_states.append(query_state)
# key_states.append(key_state)
# value_states.append(value_state)
# # FIXME(mshukor): self attention always when having only the prefix
# # B,L,H,D with L sequence length, H number of heads, D head dim
# # concatenate on the number of embeddings/tokens
# query_states = torch.cat(query_states, dim=1)
# key_states = torch.cat(key_states, dim=1)
# value_states = torch.cat(value_states, dim=1)
# # FIXME(mshukor): seq should be B, H, L, D ?
# query_states = apply_rope(query_states, position_ids)
# key_states = apply_rope(key_states, position_ids)
# if use_cache and past_key_values is None:
# past_key_values = {}
# if use_cache:
# if fill_kv_cache:
# past_key_values[layer_idx] = {
# "key_states": key_states,
# "value_states": value_states,
# }
# else:
# # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# # the max len, then we (for instance) double the cache size. This implementation already exists
# # in `transformers`. (molbap)
# key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
# value_states = torch.cat(
# [past_key_values[layer_idx]["value_states"], value_states], dim=1
# )
# attention_interface = self.get_attention_interface()
# att_output = attention_interface(
# attention_mask, batch_size, head_dim, query_states, key_states, value_states
# )
# att_output = att_output.to(dtype=models[i].dtype)
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
outputs_embeds = []
start = 0
for i, hidden_states in enumerate(inputs_embeds):
# layer = models[i].layers[layer_idx]
layer = model_layers[i][layer_idx]
att_output = (
att_outputs[i] if i < len(att_outputs) else att_outputs[0]
) # in case of self_attn
if hidden_states is not None:
if layer is None:
outputs_embeds.append(hidden_states)
continue
end = start + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
att_out = att_output[:, start:end]
out_emb = layer.self_attn.o_proj(att_out)
# TODO: first dropout (by default 0.0)
# first residual
out_emb += hidden_states
after_first_residual = out_emb.clone()
out_emb = layer.post_attention_layernorm(out_emb)
out_emb = layer.mlp(out_emb)
# TODO: second dropout (by default 0.0)
# second residual
out_emb += after_first_residual
outputs_embeds.append(out_emb)
start = end if len(att_outputs) == 1 else 0
else:
outputs_embeds.append(None)
inputs_embeds = outputs_embeds
# final norm
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
if hidden_states is not None:
out_emb = models[i].norm(hidden_states)
outputs_embeds.append(out_emb)
else:
outputs_embeds.append(None)
return outputs_embeds, past_key_values
def get_attention_interface(self):
if self.attention_implementation == "fa2":
attention_interface = self.flash_attention_forward
elif self.attention_implementation == "flex":
attention_interface = partial(
flex_attention_forward,
num_att_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
)
else:
attention_interface = self.eager_attention_forward
return attention_interface
def flash_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
):
raise NotImplementedError("FA2 is not implemented (yet)")
def eager_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
):
num_att_heads = self.num_attention_heads
num_key_value_heads = self.num_key_value_heads
num_key_value_groups = num_att_heads // num_key_value_heads
# query_states: batch_size, sequence_length, num_att_head, head_dim
# key_states: batch_size, sequence_length, num_key_value_head, head_dim
# value_states: batch_size, sequence_length, num_key_value_head, head_dim
sequence_length = key_states.shape[1]
key_states = key_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
key_states = key_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
value_states = value_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
value_states = value_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
# Attention here is upcasted to float32 to match the original eager implementation.
query_states = query_states.to(dtype=torch.float32)
key_states = key_states.to(dtype=torch.float32)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
att_weights *= head_dim**-0.5
att_weights = att_weights.to(dtype=torch.float32)
big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
probs = nn.functional.softmax(masked_att_weights, dim=-1)
probs = probs.to(dtype=value_states.dtype)
# probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
# value_states: batch_size, sequence_length, num_att_heads, head_dim
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
att_output = att_output.permute(0, 2, 1, 3)
# we use -1 because sequence length can change
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
return att_output

View File

@@ -1,955 +1,3 @@
# #!/usr/bin/env python
# # Copyright 2025 HuggingFace Inc. team. All rights reserved.
# #
# # Licensed under the Apache License, Version 2.0 (the "License");
# # you may not use this file except in compliance with the License.
# # You may obtain a copy of the License at
# #
# # http://www.apache.org/licenses/LICENSE-2.0
# #
# # Unless required by applicable law or agreed to in writing, software
# # distributed under the License is distributed on an "AS IS" BASIS,
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# # See the License for the specific language governing permissions and
# # limitations under the License.
# """
# SmolVLA:
# [Paper](https://huggingface.co/papers/2506.01844)
# Designed by Hugging Face.
# Install smolvla extra dependencies:
# ```bash
# pip install -e ".[smolvla]"
# ```
# Example of finetuning the smolvla pretrained model (`smolvla_base`):
# ```bash
# lerobot-train \
# --policy.path=lerobot/smolvla_base \
# --dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
# --batch_size=64 \
# --steps=200000
# ```
# Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM,
# and an action expert.
# ```bash
# lerobot-train \
# --policy.type=smolvla \
# --dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
# --batch_size=64 \
# --steps=200000
# ```
# Example of using the smolvla pretrained model outside LeRobot training framework:
# ```python
# policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
# ```
# """
# import math
# import os
# import re
# from collections import deque
# import safetensors
# import torch
# import torch.nn.functional as F # noqa: N812
# from torch import Tensor, nn
# from transformers import AutoProcessor
# from lerobot.constants import ACTION
# from lerobot.policies.normalize import (
# Normalize,
# Unnormalize,
# )
# from lerobot.policies.pretrained import PreTrainedPolicy
# from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
# from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
# from lerobot.policies.utils import (
# populate_queues,
# )
# from lerobot.utils.utils import get_safe_dtype
# OBS_STATE = 'state'
# ACTION = 'actions'
# # Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
# _VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
# def canonicalise(k: str) -> str:
# """
# Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
# normalisation-buffer key.
# """
# return _VARIANT_RE.sub(".buffer_", k)
# def standardise_state_dict(
# checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
# ) -> tuple[dict[str, torch.Tensor], list[str]]:
# """
# • Re-keys `checkpoint ` so that every entry matches the *reference* key set.
# • If several variant keys collapse to the same canonical name we keep the
# first one and log the collision.
# • Returns the new dict + a list of entries that could not be matched.
# """
# out, collisions, unmatched = {}, {}, []
# for k, v in checkpoint.items():
# canon = canonicalise(k)
# if canon in ref_keys:
# if canon in out: # duplicate after collapsing
# collisions.setdefault(canon, []).append(k)
# else:
# out[canon] = v
# else:
# unmatched.append(k)
# if verbose:
# for canon, variants in collisions.items():
# print(f"[standardise_state_dict] '{canon}' ← {variants}")
# if unmatched:
# print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys")
# out.update({k: checkpoint[k] for k in unmatched})
# return out, unmatched
# def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
# """
# Renames keys in a checkpoint dictionary based on the given rename string.
# Args:
# checkpoint (dict): The checkpoint dictionary.
# rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2".
# Returns:
# dict: The modified checkpoint with renamed keys.
# """
# rename_dict = dict(pair.split("//") for pair in rename_str.split(","))
# new_checkpoint = {}
# for k, v in checkpoint.items():
# for old_key, new_key in rename_dict.items():
# if old_key in k:
# k = k.replace(old_key, new_key)
# new_checkpoint[k] = v
# return new_checkpoint
# def load_smolvla(
# model: torch.nn.Module,
# filename: str | os.PathLike,
# *,
# device: str = "cpu",
# checkpoint_keys_mapping: str = "",
# ) -> torch.nn.Module:
# state_dict = safetensors.torch.load_file(filename, device=device)
# # Optional user-supplied renames (e.g. "model._orig_mod.//model.")
# if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
# state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)
# state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
# # HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
# norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
# state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}
# missing, unexpected = model.load_state_dict(state_dict, strict=False)
# if not all(key.startswith(norm_keys) for key in missing) or unexpected:
# raise RuntimeError(
# "SmolVLA %d missing / %d unexpected keys",
# len(missing),
# len(unexpected),
# )
# return model
# def create_sinusoidal_pos_embedding(
# time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
# ) -> Tensor:
# """Computes sine-cosine positional embedding vectors for scalar positions."""
# if dimension % 2 != 0:
# raise ValueError(f"dimension ({dimension}) must be divisible by 2")
# if time.ndim != 1:
# raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
# dtype = get_safe_dtype(torch.float64, device.type)
# fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
# period = min_period * (max_period / min_period) ** fraction
# # Compute the outer product
# scaling_factor = 1.0 / period * 2 * math.pi
# sin_input = scaling_factor[None, :] * time[:, None]
# pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
# return pos_emb
# def make_att_2d_masks(pad_masks, att_masks):
# """Copied from big_vision.
# Tokens can attend to valid inputs tokens which have a cumulative mask_ar
# smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
# setup several types of attention, for example:
# [[1 1 1 1 1 1]]: pure causal attention.
# [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
# themselves and the last 3 tokens have a causal attention. The first
# entry could also be a 1 without changing behaviour.
# [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
# block can attend all previous blocks and all tokens on the same block.
# Args:
# input_mask: bool[B, N] true if its part of the input, false if padding.
# mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
# it and 0 where it shares the same attention mask as the previous token.
# """
# if att_masks.ndim != 2:
# raise ValueError(att_masks.ndim)
# if pad_masks.ndim != 2:
# raise ValueError(pad_masks.ndim)
# cumsum = torch.cumsum(att_masks, dim=1)
# att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
# pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
# att_2d_masks = att_2d_masks & pad_2d_masks
# return att_2d_masks
# def resize_with_pad(img, width, height, pad_value=-1):
# # assume no-op when width height fits already
# if img.ndim != 4:
# raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
# cur_height, cur_width = img.shape[2:]
# ratio = max(cur_width / width, cur_height / height)
# resized_height = int(cur_height / ratio)
# resized_width = int(cur_width / ratio)
# resized_img = F.interpolate(
# img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
# )
# pad_height = max(0, int(height - resized_height))
# pad_width = max(0, int(width - resized_width))
# # pad on left and top of image
# padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
# return padded_img
# def pad_vector(vector, new_dim):
# """Can be (batch_size x sequence_length x features_dimension)
# or (batch_size x features_dimension)
# """
# if vector.shape[-1] == new_dim:
# return vector
# shape = list(vector.shape)
# current_dim = shape[-1]
# shape[-1] = new_dim
# new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
# new_vector[..., :current_dim] = vector
# return new_vector
# def normalize(x, min_val, max_val):
# return (x - min_val) / (max_val - min_val)
# def unnormalize(x, min_val, max_val):
# return x * (max_val - min_val) + min_val
# def safe_arcsin(value):
# # This ensures that the input stays within
# # [1,1] to avoid invalid values for arcsin
# return torch.arcsin(torch.clamp(value, -1.0, 1.0))
# def aloha_gripper_to_angular(value):
# # Aloha transforms the gripper positions into a linear space. The following code
# # reverses this transformation to be consistent with smolvla which is pretrained in
# # angular space.
# #
# # These values are coming from the Aloha code:
# # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
# value = unnormalize(value, min_val=0.01844, max_val=0.05800)
# # This is the inverse of the angular to linear transformation inside the Interbotix code.
# def linear_to_radian(linear_position, arm_length, horn_radius):
# value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
# return safe_arcsin(value)
# # The constants are taken from the Interbotix code.
# value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# # Normalize to [0, 1].
# # The values 0.4 and 1.5 were measured on an actual Trossen robot.
# return normalize(value, min_val=0.4, max_val=1.5)
# def aloha_gripper_from_angular(value):
# # Convert from the gripper position used by smolvla to the gripper position that is used by Aloha.
# # Note that the units are still angular but the range is different.
# # The values 0.4 and 1.5 were measured on an actual Trossen robot.
# value = unnormalize(value, min_val=0.4, max_val=1.5)
# # These values are coming from the Aloha code:
# # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
# return normalize(value, min_val=-0.6213, max_val=1.4910)
# def aloha_gripper_from_angular_inv(value):
# # Directly inverts the gripper_from_angular function.
# value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
# return normalize(value, min_val=0.4, max_val=1.5)
# class SmolVLAPolicy(PreTrainedPolicy):
# """Wrapper class around VLAFlowMatching model to train and run inference within LeRobot."""
# config_class = SmolVLAConfig
# name = "smolvla"
# def __init__(
# self,
# config: SmolVLAConfig,
# dataset_stats: dict[str, dict[str, Tensor]] | None = None,
# ):
# """
# Args:
# config: Policy configuration class instance or None, in which case the default instantiation of
# the configuration class is used.
# dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
# that they will be passed with a call to `load_state_dict` before the policy is used.
# """
# super().__init__(config)
# config.validate_features()
# self.config = config
# self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
# self.normalize_targets = Normalize(
# config.output_features, config.normalization_mapping, dataset_stats
# )
# self.unnormalize_outputs = Unnormalize(
# config.output_features, config.normalization_mapping, dataset_stats
# )
# self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
# self.model = VLAFlowMatching(config)
# self.reset()
# def reset(self):
# """This should be called whenever the environment is reset."""
# self._queues = {
# ACTION: deque(maxlen=self.config.n_action_steps),
# }
# # HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues
# @classmethod
# def _load_as_safetensor(
# cls,
# model: "SmolVLAPolicy",
# model_file: str,
# map_location: str,
# strict: bool,
# ):
# safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
# return load_smolvla(
# model,
# model_file,
# device=map_location,
# checkpoint_keys_mapping="model._orig_mod.//model.",
# )
# def get_optim_params(self) -> dict:
# return self.parameters()
# def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
# # TODO: Check if this for loop is needed.
# # Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch
# # In the case of offline inference, we have the action in the batch
# # that why without the k != ACTION check, it will raise an error because we are trying to stack
# # on an empty container.
# for k in batch:
# if k in self._queues and k != ACTION:
# batch[k] = torch.stack(list(self._queues[k]), dim=1)
# images, img_masks = self.prepare_images(batch)
# state = self.prepare_state(batch)
# lang_tokens, lang_masks = self.prepare_language(batch)
# actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
# # Unpad actions
# original_action_dim = self.config.action_feature.shape[0]
# actions = actions[:, :, :original_action_dim]
# actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
# if self.config.adapt_to_pi_aloha:
# actions = self._pi_aloha_encode_actions(actions)
# return actions
# def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# if self.config.adapt_to_pi_aloha:
# batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
# batch = self.normalize_inputs(batch)
# return batch
# @torch.no_grad()
# def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
# self.eval()
# batch = self._prepare_batch(batch)
# self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
# actions = self._get_action_chunk(batch, noise)
# return actions
# @torch.no_grad()
# def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
# """Select a single action given environment observations.
# This method wraps `select_actions` in order to return one action at a time for execution in the
# environment. It works by managing the actions in a queue and only calling `select_actions` when the
# queue is empty.
# """
# self.eval()
# batch = self._prepare_batch(batch)
# self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
# # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# # querying the policy.
# if len(self._queues[ACTION]) == 0:
# actions = self._get_action_chunk(batch, noise)
# # `self.predict_action_chunk` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# # effectively has shape (n_action_steps, batch_size, *), hence the transpose.
# self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
# return self._queues[ACTION].popleft()
# def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
# """Do a full training forward pass to compute the loss"""
# if self.config.adapt_to_pi_aloha:
# batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
# batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
# batch = self.normalize_inputs(batch)
# batch = self.normalize_targets(batch)
# images, img_masks = self.prepare_images(batch)
# state = self.prepare_state(batch)
# lang_tokens, lang_masks = self.prepare_language(batch)
# actions = self.prepare_action(batch)
# actions_is_pad = batch.get("actions_id_pad")
# loss_dict = {}
# losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
# loss_dict["losses_after_forward"] = losses.clone()
# if actions_is_pad is not None:
# in_episode_bound = ~actions_is_pad
# losses = losses * in_episode_bound.unsqueeze(-1)
# loss_dict["losses_after_in_ep_bound"] = losses.clone()
# # Remove padding
# losses = losses[:, :, : self.config.max_action_dim]
# loss_dict["losses_after_rm_padding"] = losses.clone()
# # For backward pass
# loss = losses.mean()
# # For backward pass
# loss_dict["loss"] = loss.item()
# return loss, loss_dict
# def prepare_images(self, batch):
# """Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
# convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
# """
# images = []
# img_masks = []
# present_img_keys = [key for key in self.config.image_features if key in batch]
# missing_img_keys = [key for key in self.config.image_features if key not in batch]
# if len(present_img_keys) == 0:
# raise ValueError(
# f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
# )
# # Preprocess image features present in the batch
# for key in present_img_keys:
# img = batch[key][:, -1, :, :, :] if batch[key].ndim == 5 else batch[key]
# if self.config.resize_imgs_with_padding is not None:
# img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
# # Normalize from range [0,1] to [-1,1] as expacted by siglip
# img = img * 2.0 - 1.0
# bsize = img.shape[0]
# device = img.device
# if f"{key}_padding_mask" in batch:
# mask = batch[f"{key}_padding_mask"].bool()
# else:
# mask = torch.ones(bsize, dtype=torch.bool, device=device)
# images.append(img)
# img_masks.append(mask)
# # Create image features not present in the batch
# # as fully 0 padded images.
# for num_empty_cameras in range(len(missing_img_keys)):
# if num_empty_cameras >= self.config.empty_cameras:
# break
# img = torch.ones_like(img) * -1
# mask = torch.zeros_like(mask)
# images.append(img)
# img_masks.append(mask)
# return images, img_masks
# def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
# """Tokenize the text input"""
# device = batch[OBS_STATE].device
# tasks = batch["task"]
# if isinstance(tasks, str):
# tasks = [tasks]
# if len(tasks) == 1:
# tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
# tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
# tokenized_prompt = self.language_tokenizer.__call__(
# tasks,
# padding=self.config.pad_language_to,
# padding_side="right",
# max_length=self.config.tokenizer_max_length,
# return_tensors="pt",
# )
# lang_tokens = tokenized_prompt["input_ids"].to(device=device)
# lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
# return lang_tokens, lang_masks
# def _pi_aloha_decode_state(self, state):
# # Flip the joints.
# for motor_idx in [1, 2, 8, 9]:
# state[:, motor_idx] *= -1
# # Reverse the gripper transformation that is being applied by the Aloha runtime.
# for motor_idx in [6, 13]:
# state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
# return state
# def _pi_aloha_encode_actions(self, actions):
# # Flip the joints.
# for motor_idx in [1, 2, 8, 9]:
# actions[:, :, motor_idx] *= -1
# # Reverse the gripper transformation that is being applied by the Aloha runtime.
# for motor_idx in [6, 13]:
# actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
# return actions
# def _pi_aloha_encode_actions_inv(self, actions):
# # Flip the joints again.
# for motor_idx in [1, 2, 8, 9]:
# actions[:, :, motor_idx] *= -1
# # Reverse the gripper transformation that is being applied by the Aloha runtime.
# for motor_idx in [6, 13]:
# actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
# return actions
# def prepare_state(self, batch):
# """Pad state"""
# state = batch[OBS_STATE][:, -1, :] if batch[OBS_STATE].ndim > 2 else batch[OBS_STATE]
# state = pad_vector(state, self.config.max_state_dim)
# return state
# def prepare_action(self, batch):
# """Pad action"""
# actions = pad_vector(batch[ACTION], self.config.max_action_dim)
# return actions
# def pad_tensor(tensor, max_len, pad_value=0):
# """
# Efficiently pads a tensor along sequence dimension to match max_len.
# Args:
# tensor (torch.Tensor): Shape (B, L, ...) or (B, L).
# max_len (int): Fixed sequence length.
# pad_value (int/float): Value for padding.
# Returns:
# torch.Tensor: Shape (B, max_len, ...) or (B, max_len).
# """
# b, d = tensor.shape[:2]
# # Create a padded tensor of max_len and copy the existing values
# padded_tensor = torch.full(
# (b, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device
# )
# padded_tensor[:, :d] = tensor # Efficient in-place copy
# return padded_tensor
# class VLAFlowMatching(nn.Module):
# """
# SmolVLA
# [Paper]()
# Designed by Hugging Face.
# ┌──────────────────────────────┐
# │ actions │
# │ ▲ │
# │ ┌─────────┐ ┌─|────┐ │
# │ | │────► │ │ │
# │ | │ kv │ │ │
# │ | │────► │Action│ │
# │ | VLM │cache │Expert│ |
# │ │ │────► | │ │
# │ │ │ │ │ │
# │ └▲──▲───▲─┘ └───▲──┘ |
# │ │ | | │ |
# │ | | | noise │
# │ │ │ state │
# │ │ language tokens │
# │ image(s) │
# └──────────────────────────────┘
# """
# def __init__(self, config: SmolVLAConfig):
# super().__init__()
# self.config = config
# self.vlm_with_expert = SmolVLMWithExpertModel(
# model_id=self.config.vlm_model_name,
# freeze_vision_encoder=self.config.freeze_vision_encoder,
# train_expert_only=self.config.train_expert_only,
# load_vlm_weights=self.config.load_vlm_weights,
# attention_mode=self.config.attention_mode,
# num_expert_layers=self.config.num_expert_layers,
# num_vlm_layers=self.config.num_vlm_layers,
# self_attn_every_n_layers=self.config.self_attn_every_n_layers,
# expert_width_multiplier=self.config.expert_width_multiplier,
# )
# self.state_proj = nn.Linear(
# self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
# )
# self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size)
# self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim)
# self.action_time_mlp_in = nn.Linear(
# self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size
# )
# self.action_time_mlp_out = nn.Linear(
# self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size
# )
# self.set_requires_grad()
# self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id
# self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id
# self.global_image_start_token = torch.tensor(
# [self.fake_image_token, self.global_image_token], dtype=torch.long
# )
# self.add_image_special_tokens = self.config.add_image_special_tokens
# self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
# self.prefix_length = self.config.prefix_length
# def set_requires_grad(self):
# for params in self.state_proj.parameters():
# params.requires_grad = self.config.train_state_proj
# def sample_noise(self, shape, device):
# noise = torch.normal(
# mean=0.0,
# std=1.0,
# size=shape,
# dtype=torch.float32,
# device=device,
# )
# return noise
# def sample_time(self, bsize, device):
# beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
# time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
# time = time_beta * 0.999 + 0.001
# return time
# def embed_prefix(
# self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None
# ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# """Embed images with SigLIP and language tokens with embedding layer to prepare
# for SmolVLM transformer processing.
# """
# embs = []
# pad_masks = []
# att_masks = []
# for _img_idx, (
# img,
# img_mask,
# ) in enumerate(zip(images, img_masks, strict=False)):
# if self.add_image_special_tokens:
# image_start_token = (
# self.vlm_with_expert.embed_language_tokens(
# self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device)
# )
# .unsqueeze(0)
# .expand(img.shape[0], -1, -1)
# )
# image_start_mask = torch.ones_like(
# image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device
# )
# att_masks += [0] * (image_start_mask.shape[-1])
# embs.append(image_start_token)
# pad_masks.append(image_start_mask)
# img_emb = self.vlm_with_expert.embed_image(img)
# img_emb = img_emb
# # Normalize image embeddings
# img_emb_dim = img_emb.shape[-1]
# img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
# bsize, num_img_embs = img_emb.shape[:2]
# img_mask = img_mask[:, None].expand(bsize, num_img_embs)
# embs.append(img_emb)
# pad_masks.append(img_mask)
# att_masks += [0] * (num_img_embs)
# if self.add_image_special_tokens:
# image_end_token = (
# self.vlm_with_expert.embed_language_tokens(
# self.image_end_token.to(device=self.vlm_with_expert.vlm.device)
# )
# .unsqueeze(0)
# .expand(img.shape[0], -1, -1)
# )
# image_end_mask = torch.ones_like(
# image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device
# )
# embs.append(image_end_token)
# pad_masks.append(image_end_mask)
# att_masks += [0] * (image_end_mask.shape[1])
# lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens)
# # Normalize language embeddings
# lang_emb_dim = lang_emb.shape[-1]
# lang_emb = lang_emb * math.sqrt(lang_emb_dim)
# embs.append(lang_emb)
# pad_masks.append(lang_masks)
# num_lang_embs = lang_emb.shape[1]
# att_masks += [0] * num_lang_embs
# state_emb = self.state_proj(state)
# state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
# embs.append(state_emb)
# bsize = state_emb.shape[0]
# device = state_emb.device
# states_seq_len = state_emb.shape[1]
# state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=device)
# pad_masks.append(state_mask)
# # Set attention masks so that image and language inputs do not attend to state or actions
# att_masks += [1] * (states_seq_len)
# embs = torch.cat(embs, dim=1)
# pad_masks = torch.cat(pad_masks, dim=1)
# att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
# att_masks = att_masks[None, :]
# seq_len = pad_masks.shape[1]
# if seq_len < self.prefix_length:
# embs = pad_tensor(embs, self.prefix_length, pad_value=0)
# pad_masks = pad_tensor(pad_masks, self.prefix_length, pad_value=0)
# att_masks = pad_tensor(att_masks, self.prefix_length, pad_value=0)
# att_masks = att_masks.expand(bsize, -1)
# return embs, pad_masks, att_masks
# def embed_suffix(self, noisy_actions, timestep):
# """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
# embs = []
# pad_masks = []
# att_masks = []
# # Fuse timestep + action information using an MLP
# action_emb = self.action_in_proj(noisy_actions)
# device = action_emb.device
# bsize = action_emb.shape[0]
# dtype = action_emb.dtype
# # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
# time_emb = create_sinusoidal_pos_embedding(
# timestep,
# self.vlm_with_expert.expert_hidden_size,
# self.config.min_period,
# self.config.max_period,
# device=device,
# )
# time_emb = time_emb.type(dtype=dtype)
# time_emb = time_emb[:, None, :].expand_as(action_emb)
# action_time_emb = torch.cat([action_emb, time_emb], dim=2)
# action_time_emb = self.action_time_mlp_in(action_time_emb)
# action_time_emb = F.silu(action_time_emb) # swish == silu
# action_time_emb = self.action_time_mlp_out(action_time_emb)
# # Add to input tokens
# embs.append(action_time_emb)
# bsize, action_time_dim = action_time_emb.shape[:2]
# action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
# pad_masks.append(action_time_mask)
# # Set attention masks so that image, language and state inputs do not attend to action tokens
# att_masks += [1] * self.config.chunk_size
# embs = torch.cat(embs, dim=1)
# pad_masks = torch.cat(pad_masks, dim=1)
# att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
# att_masks = att_masks[None, :].expand(bsize, len(att_masks))
# # added by jade
# seq_len = pad_masks.shape[1]
# if seq_len < self.config.chunk_size:
# embs = pad_tensor(embs, self.config.chunk_size, pad_value=0)
# pad_masks = pad_tensor(pad_masks, self.config.chunk_size, pad_value=0)
# att_masks = pad_tensor(att_masks, self.config.chunk_size, pad_value=0)
# return embs, pad_masks, att_masks
# def forward(
# self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
# ) -> Tensor:
# """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
# #added by jade
# if actions.ndim == 2:
# actions = actions[:, None, :].expand(-1, self.config.chunk_size, -1)
# if noise is None:
# noise = self.sample_noise(actions.shape, actions.device)
# if time is None:
# time = self.sample_time(actions.shape[0], actions.device)
# time_expanded = time[:, None, None]
# x_t = time_expanded * noise + (1 - time_expanded) * actions
# u_t = noise - actions
# prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
# images, img_masks, lang_tokens, lang_masks, state=state
# )
# suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, time)
# pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
# att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
# att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
# position_ids = torch.cumsum(pad_masks, dim=1) - 1
# (_, suffix_out), _ = self.vlm_with_expert.forward(
# attention_mask=att_2d_masks,
# position_ids=position_ids,
# past_key_values=None,
# inputs_embeds=[prefix_embs, suffix_embs],
# use_cache=False,
# fill_kv_cache=False,
# )
# # suffix_out = suffix_out[:, -self.config.chunk_size :]
# suffix_out = suffix_out[:, -self.config.chunk_size:, :]
# # Original openpi code, upcast attention output
# suffix_out = suffix_out.to(dtype=torch.float32)
# v_t = self.action_out_proj(suffix_out)
# losses = F.mse_loss(u_t, v_t, reduction="none")
# return losses
# def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
# """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
# bsize = state.shape[0]
# device = state.device
# if noise is None:
# actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim)
# noise = self.sample_noise(actions_shape, device)
# prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
# images, img_masks, lang_tokens, lang_masks, state=state
# )
# prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
# prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
# # Compute image and language key value cache
# _, past_key_values = self.vlm_with_expert.forward(
# attention_mask=prefix_att_2d_masks,
# position_ids=prefix_position_ids,
# past_key_values=None,
# inputs_embeds=[prefix_embs, None],
# use_cache=self.config.use_cache,
# fill_kv_cache=True,
# )
# dt = -1.0 / self.config.num_steps
# dt = torch.tensor(dt, dtype=torch.float32, device=device)
# x_t = noise
# time = torch.tensor(1.0, dtype=torch.float32, device=device)
# while time >= -dt / 2:
# expanded_time = time.expand(bsize)
# v_t = self.denoise_step(
# prefix_pad_masks,
# past_key_values,
# x_t,
# expanded_time,
# )
# # Euler step
# x_t += dt * v_t
# time += dt
# return x_t
# def denoise_step(
# self,
# prefix_pad_masks,
# past_key_values,
# x_t,
# timestep,
# ):
# """Apply one denoising step of the noise `x_t` at a given timestep."""
# suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, timestep)
# suffix_len = suffix_pad_masks.shape[1]
# batch_size = prefix_pad_masks.shape[0]
# prefix_len = prefix_pad_masks.shape[1]
# prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
# suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
# full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
# prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
# position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
# outputs_embeds, _ = self.vlm_with_expert.forward(
# attention_mask=full_att_2d_masks,
# position_ids=position_ids,
# past_key_values=past_key_values,
# inputs_embeds=[None, suffix_embs],
# use_cache=self.config.use_cache,
# fill_kv_cache=False,
# )
# suffix_out = outputs_embeds[1]
# suffix_out = suffix_out[:, -self.config.chunk_size :]
# suffix_out = suffix_out.to(dtype=torch.float32)
# v_t = self.action_out_proj(suffix_out)
# return v_t
#!/usr/bin/env python
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
@@ -1028,7 +76,6 @@ from lerobot.policies.utils import (
)
from lerobot.utils.utils import get_safe_dtype
# OBS_STATE = 'state'
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
@@ -1348,7 +395,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
original_action_dim = 7
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
@@ -1892,4 +938,4 @@ class VLAFlowMatching(nn.Module):
suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32)
v_t = self.action_out_proj(suffix_out)
return v_t
return v_t

View File

@@ -1 +0,0 @@
c

View File

@@ -132,10 +132,6 @@ def rollout(
# Reset the policy and environments.
policy.reset()
# added by jade
# for k in list(policy.config.input_features.keys()):
# if k.startswith("observation.image"):
# policy.config.input_features["observation.images." + k.split("observation.", 1)[1]] = policy.config.input_features.pop(k)
observation, info = env.reset(seed=seeds)
if render_callback is not None:
render_callback(env)
@@ -171,26 +167,6 @@ def rollout(
# Infer "task" from attributes of environments.
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
observation = add_envs_task(env, observation)
# breakpoint()
# observation = {
# k.replace("observation.images.", "observation.") if k.startswith("observation.images.") else k: v
# for k, v in observation.items()
# # }
# if "observation.image" in observation:
# observation["image"] = observation.pop("observation.image").to(
# device, non_blocking=device.type == "cuda"
# )
# if "observation.image2" in observation:
# observation["wrist_image"] = observation.pop("observation.image2").to(
# device, non_blocking=device.type == "cuda"
# )
# if "observation.state" in observation:
# observation["state"] = observation.pop("observation.state").to(
# device, non_blocking=device.type == "cuda"
# )
with torch.inference_mode():
action = policy.select_action(observation)
# Convert to CPU / numpy.
@@ -550,15 +526,12 @@ def eval_main(cfg: EvalPipelineConfig):
logging.info("Making environment.")
envs = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
breakpoint()
logging.info("Making policy.")
policy = make_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
)
breakpoint()
# policy, _ = load_smolvla(cfg.policy, "physical-intelligence/libero", policy)
# rename "image" -> "observation.image"
policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():

View File

@@ -1,345 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from contextlib import nullcontext
from pprint import pformat
from typing import Any
import torch
from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.envs.factory import make_env
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.scripts.eval import eval_policy, eval_policy_multitask
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
from lerobot.utils.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state,
save_checkpoint,
update_last_checkpoint,
)
from lerobot.utils.utils import (
format_big_number,
get_safe_torch_device,
has_method,
init_logging,
)
from lerobot.utils.wandb_utils import WandBLogger
def update_policy(
train_metrics: MetricsTracker,
policy: PreTrainedPolicy,
batch: Any,
optimizer: Optimizer,
grad_clip_norm: float,
grad_scaler: GradScaler,
lr_scheduler=None,
use_amp: bool = False,
lock=None,
) -> tuple[MetricsTracker, dict]:
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
grad_scaler.scale(loss).backward()
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
with lock if lock is not None else nullcontext():
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
optimizer.zero_grad()
# Step through pytorch scheduler at every batch instead of epoch
if lr_scheduler is not None:
lr_scheduler.step()
if has_method(policy, "update"):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item()
train_metrics.lr = optimizer.param_groups[0]["lr"]
train_metrics.update_s = time.perf_counter() - start_time
return train_metrics, output_dict
def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
"""Recreate normalization layers with dataset stats if missing (Adil's workaround)."""
from lerobot.policies.normalize import Normalize, Unnormalize
if not hasattr(dataset_meta, "stats") or not dataset_meta.stats:
print("⚠️ Dataset has no stats, skipping normalization injection.")
return
stats = {}
for key, stat_dict in dataset_meta.stats.items():
stats[key] = {
stat_type: torch.as_tensor(stat_array) if isinstance(stat_array, np.ndarray) else stat_array
for stat_type, stat_array in stat_dict.items()
}
normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats)
normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats)
unnormalize_outputs = Unnormalize(
policy.config.output_features, policy.config.normalization_mapping, stats
)
policy.normalize_inputs = normalize_inputs
policy.normalize_targets = normalize_targets
policy.unnormalize_outputs = unnormalize_outputs
print("✅ Normalization layers injected with dataset stats.")
@parser.wrap()
def train(cfg: TrainPipelineConfig):
cfg.validate()
logging.info(pformat(cfg.to_dict()))
if cfg.wandb.enable and cfg.wandb.project:
wandb_logger = WandBLogger(cfg)
else:
wandb_logger = None
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
if cfg.seed is not None:
set_seed(cfg.seed)
# Check device is available
device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("Creating dataset")
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs.
eval_env = None
if cfg.eval_freq > 0 and cfg.env is not None:
logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Creating policy")
policy = make_policy(
cfg=cfg.policy,
ds_meta=dataset.meta,
)
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if hasattr(cfg.policy, "drop_n_last_frames"):
shuffle = False
sampler = EpisodeAwareSampler(
dataset.episode_data_index,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
else:
shuffle = True
sampler = None
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle,
sampler=sampler,
pin_memory=device.type == "cuda",
drop_last=False,
)
dl_iter = cycle(dataloader)
policy.train()
train_metrics = {
"loss": AverageMeter("loss", ":.3f"),
"grad_norm": AverageMeter("grdn", ":.3f"),
"lr": AverageMeter("lr", ":0.1e"),
"update_s": AverageMeter("updt_s", ":.3f"),
"dataloading_s": AverageMeter("data_s", ":.3f"),
}
train_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
)
logging.info("Start offline training on a fixed dataset")
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
batch = next(dl_iter)
train_tracker.dataloading_s = time.perf_counter() - start_time
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
train_tracker, output_dict = update_policy(
train_tracker,
policy,
batch,
optimizer,
cfg.optimizer.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.policy.use_amp,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.
step += 1
train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
if is_log_step:
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step:
logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
with (
torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
):
if cfg.env.multitask_eval:
eval_info = eval_policy_multitask(
eval_env,
policy,
cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
start_seed=cfg.seed,
max_parallel_tasks=cfg.env.max_parallel_tasks,
)
aggregated = eval_info["overall"]["aggregated"]
# Print per-suite stats, log?
for task_group, task_group_info in eval_info.items():
if task_group == "overall":
continue # Skip the overall stats since we already printed it
print(f"\nAggregated Metrics for {task_group}:")
print(task_group_info["aggregated"])
else:
eval_info = eval_policy(
eval_env,
policy,
cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
start_seed=cfg.seed,
)
aggregated = eval_info["aggregated"]
eval_metrics = {
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
"pc_success": AverageMeter("success", ":.1f"),
"eval_s": AverageMeter("eval_s", ":.3f"),
}
eval_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
)
eval_tracker.eval_s = aggregated.pop("eval_s")
eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
eval_tracker.pc_success = aggregated.pop("pc_success")
logging.info(eval_tracker)
if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
if eval_env:
if cfg.env.multitask_eval:
for _task_group, envs_dict in eval_env.items():
for _idx, env in envs_dict.items():
env.close()
else:
eval_env.close()
logging.info("End of training")
if cfg.policy.push_to_hub:
policy.push_model_to_hub(cfg)
def main():
init_logging()
train()
if __name__ == "__main__":
main()

View File

@@ -1,366 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from pprint import pformat
from typing import Any
import torch
from accelerate import Accelerator
from accelerate.utils import set_seed as accelerate_set_seed
from termcolor import colored
from torch.optim import Optimizer
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.envs.factory import make_env
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.scripts.eval import eval_policy
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state,
save_checkpoint,
update_last_checkpoint,
)
from lerobot.utils.utils import (
format_big_number,
has_method,
init_logging,
)
def update_policy(
train_metrics: MetricsTracker,
policy: PreTrainedPolicy,
batch: Any,
optimizer: Optimizer,
grad_clip_norm: float,
accelerator: Accelerator,
lr_scheduler=None,
) -> tuple[MetricsTracker, dict]:
start_time = time.perf_counter()
policy.train()
# Use accelerator's autocast context if mixed precision is enabled
with accelerator.autocast():
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
# Use accelerator for backward pass
accelerator.backward(loss)
# Gradient clipping - accelerator handles unscaling automatically
if accelerator.sync_gradients and grad_clip_norm > 0:
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
else:
grad_norm = torch.tensor(0.0)
optimizer.step()
lr_scheduler.step() if lr_scheduler is not None else None
optimizer.zero_grad()
# Update policy-specific buffers if needed
if has_method(policy, "update"):
policy.update()
# Gather metrics across all processes
loss_value = accelerator.gather(loss.detach()).mean().item()
grad_norm_value = accelerator.gather(grad_norm).mean().item()
train_metrics.loss = loss_value
train_metrics.grad_norm = grad_norm_value
train_metrics.lr = optimizer.param_groups[0]["lr"]
train_metrics.update_s = time.perf_counter() - start_time
return train_metrics, output_dict
@parser.wrap()
def train(cfg: TrainPipelineConfig):
cfg.validate()
logging.info(pformat(cfg.to_dict()))
# Initialize accelerator
from accelerate.utils import DistributedDataParallelKwargs
# added by jade 2 lines
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
accelerator = Accelerator(..., kwargs_handlers=[ddp_kwargs])
from lerobot.utils.wandb_utils import cfg_to_group, get_wandb_run_id_from_filesystem
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
mixed_precision="bf16" if cfg.policy.use_amp else "no",
gradient_accumulation_steps=cfg.policy.gradient_accumulation_steps,
log_with="wandb" if cfg.wandb.enable else None,
kwargs_handlers=[ddp_kwargs],
project_dir=cfg.output_dir,
)
accelerator.init_trackers(
project_name=cfg.wandb.project,
init_kwargs={
"wandb": {
"entity": cfg.wandb.entity,
"name": cfg.job_name,
"notes": cfg.wandb.notes,
"tags": cfg_to_group(cfg, return_list=True),
"dir": cfg.output_dir,
"config": cfg.to_dict(),
"save_code": False,
"job_type": "train_eval",
"mode": cfg.wandb.mode if cfg.wandb.mode in ["online", "offline", "disabled"] else "online",
"resume": "must" if cfg.resume else None,
"id": cfg.wandb.run_id
if cfg.wandb.run_id
else (get_wandb_run_id_from_filesystem(cfg.output_dir) if cfg.resume else None),
}
},
)
# Set seed for reproducibility
if cfg.seed is not None:
accelerate_set_seed(cfg.seed)
# Setup device - accelerator handles device placement
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
# Create dataset
if accelerator.is_main_process:
logging.info("Creating dataset")
dataset = make_dataset(cfg)
print("c")
# Create evaluation environment (only on main process)
eval_env = None
if cfg.eval_freq > 0 and cfg.env is not None and accelerator.is_main_process:
logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
# Create policy
if accelerator.is_main_process:
logging.info("Creating policy")
# Use accelerator's device instead of cfg.policy.device
with accelerator.main_process_first():
policy = make_policy(
cfg=cfg.policy,
ds_meta=dataset.meta,
)
# Create optimizer and scheduler
if accelerator.is_main_process:
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
step = 0 # number of policy updates
if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
# Prepare dataloader
if hasattr(cfg.policy, "drop_n_last_frames"):
shuffle = False
sampler = EpisodeAwareSampler(
dataset.episode_data_index,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
else:
shuffle = True
sampler = None
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle,
sampler=sampler,
pin_memory=True,
drop_last=True, # Important for distributed training
)
# Prepare for distributed training
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
)
# Log training info (only on main process)
if accelerator.is_main_process:
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
logging.info(f"Number of processes: {accelerator.num_processes}")
logging.info(f"Device: {accelerator.device}")
logging.info(f"Mixed precision: {accelerator.mixed_precision}")
# Create metrics trackers
train_metrics = {
"loss": AverageMeter("loss", ":.3f"),
"grad_norm": AverageMeter("grdn", ":.3f"),
"lr": AverageMeter("lr", ":0.1e"),
"update_s": AverageMeter("updt_s", ":.3f"),
"dataloading_s": AverageMeter("data_s", ":.3f"),
}
train_tracker = MetricsTracker(
cfg.batch_size * accelerator.num_processes, # Account for all processes
dataset.num_frames,
dataset.num_episodes,
train_metrics,
initial_step=step,
)
# Training loop
policy.train()
if accelerator.is_main_process:
logging.info("Start offline training on a fixed dataset")
# Create iterator from dataloader
dl_iter = iter(dataloader)
for current_step in range(step, cfg.steps):
start_time = time.perf_counter()
# Get next batch, cycling through dataloader if needed
try:
batch = next(dl_iter)
print("data laoder batch keys: ", batch.keys())
breakpoint()
except StopIteration:
dl_iter = iter(dataloader)
batch = next(dl_iter)
train_tracker.dataloading_s = time.perf_counter() - start_time
# Update policy
train_tracker, output_dict = update_policy(
train_tracker,
policy,
batch,
optimizer,
cfg.optimizer.grad_clip_norm,
accelerator,
lr_scheduler=lr_scheduler,
)
# Increment step counter
step += 1
train_tracker.step()
# Determine if we should log, save, or evaluate
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
# Logging (only on main process)
if is_log_step and accelerator.is_main_process:
logging.info(train_tracker)
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
for k, v in wandb_log_dict.items():
accelerator.log({f"{'train'}/{k}": v}, step=step)
train_tracker.reset_averages()
# Checkpointing (only on main process)
if cfg.save_checkpoint and is_saving_step:
# ✅ all processes wait here
accelerator.wait_for_everyone()
if accelerator.is_main_process:
logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
unwrapped_policy = accelerator.unwrap_model(policy)
save_checkpoint(checkpoint_dir, step, cfg, unwrapped_policy, optimizer, lr_scheduler)
update_last_checkpoint(checkpoint_dir)
# ✅ all processes sync again after saving
accelerator.wait_for_everyone()
# if wandb_logger:
# wandb_logger.log_policy(checkpoint_dir)
# Evaluation (only on main process)
if cfg.env and is_eval_step and accelerator.is_main_process:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
# Unwrap model for evaluation
unwrapped_policy = accelerator.unwrap_model(policy)
unwrapped_policy.eval()
with torch.no_grad():
eval_info = eval_policy(
eval_env,
unwrapped_policy,
cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
start_seed=cfg.seed,
)
eval_metrics = {
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
"pc_success": AverageMeter("success", ":.1f"),
"eval_s": AverageMeter("eval_s", ":.3f"),
}
eval_tracker = MetricsTracker(
cfg.batch_size * accelerator.num_processes,
dataset.num_frames,
dataset.num_episodes,
eval_metrics,
initial_step=step,
)
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success")
logging.info(eval_tracker)
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
for k, v in wandb_log_dict.items():
accelerator.log({f"{'eval'}/{k}": v}, step=step)
# Set back to training mode
policy.train()
# Wait for all processes to finish
accelerator.wait_for_everyone()
# Cleanup
if eval_env and accelerator.is_main_process:
eval_env.close()
if accelerator.is_main_process:
logging.info("End of training")
accelerator.end_training() # added by jade
if __name__ == "__main__":
init_logging()
train()