mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 04:11:24 +00:00
remove files
This commit is contained in:
@@ -62,7 +62,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: bool = False
|
||||
gradient_accumulation_steps: int = 1
|
||||
|
||||
push_to_hub: bool = True
|
||||
repo_id: str | None = None
|
||||
|
||||
@@ -597,7 +597,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if self.episodes is None:
|
||||
path = str(self.root / "data")
|
||||
# added by jade
|
||||
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
||||
else:
|
||||
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
|
||||
@@ -455,8 +455,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif key == "observation.environment_state":
|
||||
type = FeatureType.ENV
|
||||
# changed by jade
|
||||
elif key.startswith("observation") or key.startswith("state"):
|
||||
elif key.startswith("observation"):
|
||||
type = FeatureType.STATE
|
||||
elif key.startswith("action"):
|
||||
type = FeatureType.ACTION
|
||||
|
||||
@@ -31,7 +31,6 @@ from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.smolpi0.configuration_smolpi0 import SMOLPI0Config
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
@@ -75,10 +74,6 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "smolpi0":
|
||||
from lerobot.policies.smolpi0.modeling_smolpi0 import SMOLPI0Policy
|
||||
|
||||
return SMOLPI0Policy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -102,8 +97,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "smolpi0":
|
||||
return SMOLPI0Config(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
@@ -255,85 +255,6 @@ class Unnormalize(nn.Module):
|
||||
return batch
|
||||
|
||||
|
||||
class NormalizePerRobotType(nn.Module):
|
||||
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
for robot_type in stats.keys():
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats[robot_type])
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, f"{robot_type}_buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
assert "robot_type" in batch, "robot_type is not in the batch"
|
||||
robot_types = batch["robot_type"]
|
||||
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
# FIXME(mshukor): make it more efficient
|
||||
buffers = [
|
||||
getattr(self, f"{robot_type}_buffer_" + key.replace(".", "_")) for robot_type in robot_types
|
||||
]
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.stack([buffers[i]["mean"] for i in range(len(robot_types))], dim=0)
|
||||
std = torch.stack([buffers[i]["std"] for i in range(len(robot_types))], dim=0)
|
||||
if batch[key].ndim == 3:
|
||||
mean = mean.unsqueeze(1)
|
||||
std = std.unsqueeze(1)
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = torch.stack([buffers[i]["min"] for i in range(len(robot_types))], dim=0)
|
||||
max = torch.stack([buffers[i]["max"] for i in range(len(robot_types))], dim=0)
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
if batch[key].ndim == 3:
|
||||
min = min.unsqueeze(1)
|
||||
max = max.unsqueeze(1)
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
|
||||
# and remove the `Normalize` and `Unnormalize` classes.
|
||||
def _initialize_stats_buffers(
|
||||
@@ -497,87 +418,3 @@ class UnnormalizeBuffer(nn.Module):
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class UnnormalizePerRobotType(nn.Module):
|
||||
"""
|
||||
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
|
||||
original range used by the environment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
for robot_type in stats.keys():
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats[robot_type])
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, f"{robot_type}_buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
assert "robot_type" in batch, "robot_type is not in the batch"
|
||||
robot_types = batch["robot_type"]
|
||||
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
# buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
buffers = [
|
||||
getattr(self, f"{robot_type}_buffer_" + key.replace(".", "_")) for robot_type in robot_types
|
||||
]
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.stack([buffers[i]["mean"] for i in range(len(robot_types))], dim=0)
|
||||
std = torch.stack([buffers[i]["std"] for i in range(len(robot_types))], dim=0)
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
if batch[key].ndim == 3:
|
||||
mean = mean.unsqueeze(1)
|
||||
std = std.unsqueeze(1)
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = torch.stack([buffers[i]["min"] for i in range(len(robot_types))], dim=0)
|
||||
max = torch.stack([buffers[i]["max"] for i in range(len(robot_types))], dim=0)
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
if batch[key].ndim == 3:
|
||||
min = min.unsqueeze(1)
|
||||
max = max.unsqueeze(1)
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
@@ -1,210 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PEFTConfig:
|
||||
r: int = 4
|
||||
lora_alpha: int = 16
|
||||
lora_dropout: float = 0.1
|
||||
target_modules: str = "q_proj,v_proj"
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("smolpi0")
|
||||
@dataclass
|
||||
class SMOLPI0Config(PreTrainedConfig):
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
n_obs_gap: int = 1
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Shorter state and action vectors will be padded
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Image preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] = (512, 512) # (224, 224)
|
||||
|
||||
# Add empty images. Used by pi0_aloha_sim which adds the empty
|
||||
# left and right wrist cameras in addition to the top camera.
|
||||
empty_cameras: int = 0
|
||||
|
||||
# Converts the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi_aloha: bool = False
|
||||
|
||||
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
||||
# Gripper dimensions will remain in absolute values.
|
||||
use_delta_joint_actions_aloha: bool = False
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 48
|
||||
|
||||
# Projector
|
||||
proj_width: int = 480
|
||||
|
||||
# Decoding
|
||||
num_steps: int = 10
|
||||
|
||||
# Attention utils
|
||||
use_cache: bool = True
|
||||
attention_implementation: str = "eager" # or fa2, flex
|
||||
|
||||
# Finetuning settings
|
||||
freeze_vision_encoder: bool = True
|
||||
train_expert_only: bool = False
|
||||
train_state_proj: bool = True
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 2.5e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
optimizer_grad_clip_norm: float = 10
|
||||
optimizer_lr_vlm: float = 0
|
||||
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
# TODO: Add EMA
|
||||
vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct"
|
||||
checkpoint_path: str = None
|
||||
load_vlm_weights: bool = False
|
||||
|
||||
peft_method: str = ""
|
||||
peft_config: PEFTConfig = PEFTConfig()
|
||||
peft_target_model: str = ""
|
||||
|
||||
add_image_special_tokens: bool = False
|
||||
add_prompt_template: bool = False
|
||||
prefix_prompt_template: str = "<|im_start|>User: What action should the robot take to"
|
||||
suffix_prompt_template: str = "?\nAssistant:"
|
||||
|
||||
attention_mode: str = "self_attn"
|
||||
|
||||
prefix_length: int = -1 # n_obs_steps * num_cameras * num_image_token_per_image + tokenizer_max_length
|
||||
|
||||
past_obs_keys: str = "image"
|
||||
|
||||
add_local_special_image_tokens: bool = False
|
||||
|
||||
reverse_images_order: bool = False
|
||||
|
||||
state_to_prefix: bool = False
|
||||
|
||||
pad_language_to: str = "longest" # "max_length"
|
||||
|
||||
num_expert_layers: int = -1
|
||||
num_vlm_layers: int = -1
|
||||
|
||||
causal_action_attention_mask: bool = False
|
||||
|
||||
self_attn_every_n_layers: int = -1
|
||||
|
||||
expert_width_multiplier: float = 0.5
|
||||
|
||||
robot_type: str = ""
|
||||
|
||||
self_attn_only_actions: bool = False
|
||||
|
||||
causal_attention_on_history: bool = False
|
||||
|
||||
predict_relative_actions: bool = False
|
||||
relative_actions_mode: str = "first"
|
||||
|
||||
shuffle_camera_positions: bool = False
|
||||
vlm_img_size: int = -1
|
||||
|
||||
regression_loss: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.vlm_img_size > 0:
|
||||
self.resize_imgs_with_padding = (self.vlm_img_size, self.vlm_img_size)
|
||||
"""Input validation (not exhaustive)."""
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
# if self.n_obs_steps != 1:
|
||||
# raise ValueError(
|
||||
# f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||||
# )
|
||||
|
||||
if self.use_delta_joint_actions_aloha:
|
||||
raise NotImplementedError(
|
||||
"`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# TODO: implement value error
|
||||
# if not self.image_features and not self.env_state_feature:
|
||||
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list: # FIXME(mshukor): support spacing between observations
|
||||
return [-k for k in range(0, self.n_obs_steps * self.n_obs_gap, self.n_obs_gap)][::-1]
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -1,145 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from packaging.version import Version
|
||||
|
||||
if Version(torch.__version__) > Version("2.5.0"):
|
||||
# Ffex attention is only available from torch 2.5 onwards
|
||||
from torch.nn.attention.flex_attention import (
|
||||
_mask_mod_signature,
|
||||
_round_up_to_multiple,
|
||||
create_block_mask,
|
||||
create_mask,
|
||||
flex_attention,
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=False)
|
||||
def flex_attention_forward(
|
||||
attention_mask: torch.Tensor,
|
||||
batch_size: int,
|
||||
head_dim: int,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
scaling=None,
|
||||
num_att_heads: int = 8,
|
||||
num_key_value_heads: int = 1,
|
||||
):
|
||||
"""
|
||||
This is defined out of classes to make compile happy.
|
||||
"""
|
||||
|
||||
original_dtype = query_states.dtype
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
key_states = key_states[:, :, :, None, :]
|
||||
key_states = key_states.expand(
|
||||
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :]
|
||||
value_states = value_states.expand(
|
||||
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
# query_states = query_states.to(torch.float32)
|
||||
# key_states = key_states.to(torch.float32)
|
||||
# value_states = value_states.to(torch.float32)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if causal_mask is not None:
|
||||
causal_mask = causal_mask[:, None, :, : key_states.shape[2]]
|
||||
|
||||
if causal_mask.shape[1] == 1 and query_states.shape[1] > 1:
|
||||
causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1)
|
||||
|
||||
def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature:
|
||||
def mask_mod(b, h, q_idx, kv_idx):
|
||||
# Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs.
|
||||
return precomputed_mask[b][h][q_idx][kv_idx]
|
||||
|
||||
return mask_mod
|
||||
|
||||
b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
|
||||
|
||||
block_size = 128 # limitation of flex attention
|
||||
q_len_rounded = _round_up_to_multiple(q_len, block_size)
|
||||
kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
|
||||
|
||||
# *CRITICAL* we do need to expand here, else we get a CUDA index error
|
||||
|
||||
pad_q = q_len_rounded - q_len
|
||||
pad_k = kv_len_rounded - kv_len
|
||||
if pad_q > 0 or pad_k > 0:
|
||||
padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
|
||||
else:
|
||||
padded_causal_mask = causal_mask
|
||||
mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
|
||||
|
||||
mask_4d = create_mask(
|
||||
mod_fn=mask_mod_fn_orig,
|
||||
B=b_mask,
|
||||
H=h_mask,
|
||||
Q_LEN=q_len_rounded,
|
||||
KV_LEN=kv_len_rounded,
|
||||
device=causal_mask.device,
|
||||
)
|
||||
|
||||
mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
|
||||
# FIXME(mshukor): compile mask torch.compile(create_block_mask)
|
||||
create_block_mask_compiled = torch.compile(create_block_mask)
|
||||
block_mask = create_block_mask_compiled(
|
||||
mask_mod=mask_mod_fn_padded,
|
||||
B=b_mask,
|
||||
H=None, #
|
||||
Q_LEN=q_len_rounded,
|
||||
KV_LEN=kv_len_rounded,
|
||||
BLOCK_SIZE=block_size,
|
||||
device=causal_mask.device,
|
||||
_compile=False,
|
||||
)
|
||||
padded_query_states = F.pad(query_states, (0, 0, 0, pad_q), value=0.0) if pad_q > 0 else query_states
|
||||
padded_key_states = F.pad(key_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else key_states
|
||||
padded_value_states = F.pad(value_states, (0, 0, 0, pad_k), value=0.0) if pad_k > 0 else value_states
|
||||
# mask is applied inside the kernel, ideally more efficiently than score_mod.
|
||||
attn_output, attention_weights = flex_attention(
|
||||
padded_query_states,
|
||||
padded_key_states,
|
||||
padded_value_states,
|
||||
block_mask=block_mask,
|
||||
enable_gqa=True, # because we shaped query/key states for GQA
|
||||
scale=head_dim**-0.5 if scaling is None else scaling,
|
||||
return_lse=True,
|
||||
)
|
||||
|
||||
attn_output = attn_output.to(dtype=original_dtype)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim]
|
||||
attn_output = attn_output.reshape(
|
||||
batch_size,
|
||||
-1,
|
||||
attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim]
|
||||
)
|
||||
return attn_output[:, :-pad_k, :] if pad_k > 0 else attn_output
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,920 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torch.version
|
||||
from peft import LoraConfig, TaskType, get_peft_model
|
||||
from pytest import Cache
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForImageTextToText,
|
||||
AutoModelForVision2Seq,
|
||||
AutoProcessor,
|
||||
SmolVLMForConditionalGeneration,
|
||||
)
|
||||
|
||||
from lerobot.policies.smolpi0.flex_attention import flex_attention_forward
|
||||
|
||||
|
||||
def _round_up_to_multiple(x, multiple):
|
||||
return (x + multiple - 1) // multiple * multiple
|
||||
|
||||
|
||||
def apply_rope(x, positions, max_wavelength=10_000):
|
||||
"""
|
||||
Applies RoPE positions [B, L] to x [B, L, H, D].
|
||||
"""
|
||||
d_half = x.shape[-1] // 2
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
|
||||
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
|
||||
|
||||
radians = radians[..., None, :]
|
||||
|
||||
sin = torch.sin(radians) # .to(dtype=dtype)
|
||||
cos = torch.cos(radians) # .to(dtype=dtype)
|
||||
|
||||
x1, x2 = x.split(d_half, dim=-1)
|
||||
res = torch.empty_like(x)
|
||||
res[..., :d_half] = x1 * cos - x2 * sin
|
||||
res[..., d_half:] = x2 * cos + x1 * sin
|
||||
|
||||
return res.to(dtype)
|
||||
|
||||
|
||||
# class SmolVLMWithExpertConfig(PretrainedConfig):
|
||||
# model_type = "SmolVLMWithExpertModel"
|
||||
# sub_configs = {"smolvlm_config": AutoConfig, "lm_expert_config": AutoConfig}
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# smolvlm_config: dict | None = None,
|
||||
# lm_expert_config: dict | None = None,
|
||||
# freeze_vision_encoder: bool = True,
|
||||
# train_expert_only: bool = True,
|
||||
# attention_implementation: str = "eager",
|
||||
# load_vlm_weights: bool = False,
|
||||
# **kwargs,
|
||||
# ):
|
||||
# self.load_vlm_weights = load_vlm_weights
|
||||
# self.freeze_vision_encoder = freeze_vision_encoder
|
||||
# self.train_expert_only = train_expert_only
|
||||
# self.attention_implementation = attention_implementation
|
||||
|
||||
# if smolvlm_config is None:
|
||||
# # Default config from Pi0
|
||||
# self.smolvlm_config = CONFIG_MAPPING["smolvlm"](
|
||||
# transformers_version="4.48.1",
|
||||
# _vocab_size=257152,
|
||||
# bos_token_id=2,
|
||||
# eos_token_id=1,
|
||||
# hidden_size=2048,
|
||||
# image_token_index=257152,
|
||||
# model_type="smolvlm",
|
||||
# pad_token_id=0,
|
||||
# projection_dim=2048,
|
||||
# text_config={
|
||||
# "hidden_activation": "gelu_pytorch_tanh",
|
||||
# "hidden_size": 2048,
|
||||
# "intermediate_size": 16384,
|
||||
# "model_type": "gemma",
|
||||
# "num_attention_heads": 8,
|
||||
# "num_hidden_layers": 18,
|
||||
# "num_image_tokens": 256,
|
||||
# "num_key_value_heads": 1,
|
||||
# "torch_dtype": "float32",
|
||||
# "vocab_size": 257152,
|
||||
# },
|
||||
# vision_config={
|
||||
# "hidden_size": 1152,
|
||||
# "intermediate_size": 4304,
|
||||
# "model_type": "siglip_vision_model",
|
||||
# "num_attention_heads": 16,
|
||||
# "num_hidden_layers": 27,
|
||||
# "num_image_tokens": 256,
|
||||
# "patch_size": 14,
|
||||
# "projection_dim": 2048,
|
||||
# "projector_hidden_act": "gelu_fast",
|
||||
# "torch_dtype": "float32",
|
||||
# "vision_use_head": False,
|
||||
# },
|
||||
# )
|
||||
# elif isinstance(self.paligemma_config, dict):
|
||||
# # Override Pi0 default config for PaliGemma
|
||||
# if "model_type" not in gemma_expert_config:
|
||||
# paligemma_config["model_type"] = "paligemma"
|
||||
|
||||
# cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
|
||||
# self.paligemma_config = cfg_cls(**paligemma_config)
|
||||
|
||||
# if gemma_expert_config is None:
|
||||
# # Default config from Pi0
|
||||
# self.gemma_expert_config = CONFIG_MAPPING["gemma"](
|
||||
# attention_bias=False,
|
||||
# attention_dropout=0.0,
|
||||
# bos_token_id=2,
|
||||
# eos_token_id=1,
|
||||
# head_dim=256,
|
||||
# hidden_act="gelu_pytorch_tanh",
|
||||
# hidden_activation="gelu_pytorch_tanh",
|
||||
# hidden_size=1024,
|
||||
# initializer_range=0.02,
|
||||
# intermediate_size=4096,
|
||||
# max_position_embeddings=8192,
|
||||
# model_type="gemma",
|
||||
# num_attention_heads=8,
|
||||
# num_hidden_layers=18,
|
||||
# num_key_value_heads=1,
|
||||
# pad_token_id=0,
|
||||
# rms_norm_eps=1e-06,
|
||||
# rope_theta=10000.0,
|
||||
# torch_dtype="float32",
|
||||
# transformers_version="4.48.1",
|
||||
# use_cache=True,
|
||||
# vocab_size=257152,
|
||||
# )
|
||||
# elif isinstance(self.gemma_expert_config, dict):
|
||||
# # Override Pi0 default config for Gemma Expert
|
||||
# if "model_type" not in gemma_expert_config:
|
||||
# gemma_expert_config["model_type"] = "gemma"
|
||||
|
||||
# cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
|
||||
# self.gemma_expert_config = cfg_cls(**gemma_expert_config)
|
||||
|
||||
# super().__init__(**kwargs)
|
||||
|
||||
# def __post_init__(self):
|
||||
# super().__post_init__()
|
||||
# if self.train_expert_only and not self.freeze_vision_encoder:
|
||||
# raise ValueError(
|
||||
# "You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
|
||||
# )
|
||||
|
||||
# if self.attention_implementation not in ["eager", "fa2", "flex"]:
|
||||
# raise ValueError(
|
||||
# f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
|
||||
# )
|
||||
|
||||
|
||||
def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
return hidden_dim
|
||||
|
||||
|
||||
class SmolVLMWithExpertModel(nn.Module):
|
||||
# config_class = PaliGemmaWithExpertConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
|
||||
load_vlm_weights: bool = True,
|
||||
train_expert_only: bool = True,
|
||||
freeze_vision_encoder: bool = False,
|
||||
attention_implementation: str = "eager",
|
||||
attention_mode: str = "self_attn",
|
||||
num_expert_layers: int = -1,
|
||||
num_vlm_layers: int = -1,
|
||||
self_attn_every_n_layers: int = -1,
|
||||
expert_width_multiplier: float = 0.5,
|
||||
self_attn_only_actions: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if load_vlm_weights:
|
||||
print(f"Loading {model_id} weights ...")
|
||||
if "SmolVLM-" in model_id:
|
||||
self.vlm = AutoModelForVision2Seq.from_pretrained(
|
||||
model_id,
|
||||
device_map="cuda",
|
||||
torch_dtype="bfloat16",
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
else:
|
||||
# model_id = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct"
|
||||
self.vlm = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id,
|
||||
device_map="cuda",
|
||||
torch_dtype="bfloat16",
|
||||
low_cpu_mem_usage=True,
|
||||
# attn_implementation="eager",
|
||||
# attn_implementation="flash_attention_2"
|
||||
)
|
||||
config = self.vlm.config
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
self.vlm = SmolVLMForConditionalGeneration(config=config)
|
||||
self.processor = AutoProcessor.from_pretrained(model_id)
|
||||
if num_vlm_layers > 0:
|
||||
print(f"Reducing the number of VLM layers to {num_vlm_layers} ...")
|
||||
self.get_vlm_model().text_model.layers = self.get_vlm_model().text_model.layers[:num_vlm_layers]
|
||||
self.num_vlm_layers = len(self.get_vlm_model().text_model.layers)
|
||||
self.config = config
|
||||
# Smaller lm expert
|
||||
lm_expert_config = copy.deepcopy(config.text_config)
|
||||
hidden_size = lm_expert_config.hidden_size
|
||||
lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2
|
||||
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
|
||||
lm_expert_config.num_hidden_layers = self.num_vlm_layers
|
||||
if num_expert_layers > 0:
|
||||
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, (
|
||||
f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
|
||||
)
|
||||
lm_expert_config.num_hidden_layers = num_expert_layers
|
||||
# lm_expert_config.head_dim = lm_expert_config.head_dim * 2
|
||||
self.lm_expert = AutoModel.from_config(lm_expert_config)
|
||||
|
||||
self.num_expert_layers = len(self.lm_expert.layers)
|
||||
self.self_attn_every_n_layers = self_attn_every_n_layers
|
||||
self.self_attn_only_actions = self_attn_only_actions
|
||||
if "cross" in attention_mode:
|
||||
# Reshape qkv projections to have the same input dimension as the vlm
|
||||
for layer_idx in range(len(self.lm_expert.layers)):
|
||||
if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0:
|
||||
continue
|
||||
self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear(
|
||||
config.text_config.num_key_value_heads * config.text_config.head_dim,
|
||||
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
|
||||
bias=lm_expert_config.attention_bias,
|
||||
)
|
||||
self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear(
|
||||
config.text_config.num_key_value_heads * config.text_config.head_dim,
|
||||
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
|
||||
bias=lm_expert_config.attention_bias,
|
||||
)
|
||||
# Remove unused embed_tokens
|
||||
self.lm_expert.embed_tokens = None
|
||||
|
||||
self.num_attention_heads = self.config.text_config.num_attention_heads
|
||||
self.num_key_value_heads = self.config.text_config.num_key_value_heads
|
||||
|
||||
self.freeze_vision_encoder = freeze_vision_encoder
|
||||
self.train_expert_only = train_expert_only
|
||||
self.attention_implementation = attention_implementation
|
||||
self.attention_mode = attention_mode
|
||||
self.expert_hidden_size = lm_expert_config.hidden_size
|
||||
# self.to_bfloat16_like_physical_intelligence()
|
||||
self.set_requires_grad()
|
||||
|
||||
def configure_peft(self, config):
|
||||
# return model
|
||||
self.peft_method = config.peft_method
|
||||
self.peft_target_model = config.peft_target_model
|
||||
if "lora" in self.peft_method:
|
||||
peft_config = config.peft_config
|
||||
target_modules = peft_config.target_modules
|
||||
if not isinstance(target_modules, list):
|
||||
target_modules = target_modules.split(",")
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM, # Based on the task type (e.g., language modeling, etc.)
|
||||
r=peft_config.r, # The rank of the low-rank adaptation
|
||||
lora_alpha=peft_config.lora_alpha, # Scaling factor
|
||||
lora_dropout=peft_config.lora_dropout, # Dropout applied to LoRA layers
|
||||
target_modules=target_modules, # The components where LoRA is applied
|
||||
exclude_modules=[
|
||||
"lm_expert",
|
||||
"model.lm_expert.model.layers",
|
||||
], # FIXME(mshukor): this does not work for now
|
||||
)
|
||||
self.lora_config = lora_config
|
||||
# Apply LoRA and ensure only LoRA parameters are trainable
|
||||
if "text" in self.peft_target_model:
|
||||
self.get_vlm_model().text_model = get_peft_model(self.get_vlm_model().text_model, lora_config)
|
||||
else:
|
||||
self.vlm = get_peft_model(self.vlm, lora_config)
|
||||
# assert config.train_expert_only, "Backbone should be frozen and only lora parameters are " # FIXME(mshukor): handle this here?
|
||||
for name, param in self.vlm.named_parameters():
|
||||
if (
|
||||
"lora" in name and "text_model.model.layers.17" not in name
|
||||
): # lm_head is not a parameter in most LLMs becasue it's tied to the embedding layer
|
||||
param.requires_grad = True
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
def merge_lora_weights(self):
|
||||
"""
|
||||
Merge LoRA weights into the base model.
|
||||
"""
|
||||
if "text" in self.peft_target_model:
|
||||
self.get_vlm_model().text_model = self.get_vlm_model().text_model.merge_and_unload()
|
||||
else:
|
||||
self.vlm = self.vlm.merge_and_unload()
|
||||
|
||||
def get_vlm_model(
|
||||
self,
|
||||
):
|
||||
if hasattr(self.vlm.model, "model"): # When using peft
|
||||
return self.vlm.model.model
|
||||
else:
|
||||
return self.vlm.model
|
||||
|
||||
def set_requires_grad(self):
|
||||
if self.freeze_vision_encoder:
|
||||
self.get_vlm_model().vision_model.eval()
|
||||
for params in self.get_vlm_model().vision_model.parameters():
|
||||
params.requires_grad = False
|
||||
if self.train_expert_only:
|
||||
self.vlm.eval()
|
||||
for params in self.vlm.parameters():
|
||||
params.requires_grad = False
|
||||
else:
|
||||
# To avoid unused params issue with distributed training
|
||||
last_layers = [self.num_vlm_layers - 1]
|
||||
if (
|
||||
self.num_vlm_layers != self.num_expert_layers
|
||||
and self.num_vlm_layers % self.num_expert_layers == 0
|
||||
):
|
||||
last_layers.append(self.num_vlm_layers - 2)
|
||||
frozen_layers = [
|
||||
"lm_head",
|
||||
"text_model.model.norm.weight",
|
||||
]
|
||||
for layer in last_layers:
|
||||
frozen_layers.append(f"text_model.model.layers.{layer}.")
|
||||
|
||||
for name, params in self.vlm.named_parameters():
|
||||
if any([k in name for k in frozen_layers]):
|
||||
params.requires_grad = False
|
||||
# To avoid unused params issue with distributed training
|
||||
for name, params in self.lm_expert.named_parameters():
|
||||
if any(
|
||||
[
|
||||
k in name
|
||||
for k in [
|
||||
"lm_head",
|
||||
]
|
||||
]
|
||||
):
|
||||
params.requires_grad = False
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
|
||||
if self.freeze_vision_encoder:
|
||||
self.get_vlm_model().vision_model.eval()
|
||||
|
||||
if self.train_expert_only:
|
||||
self.vlm.eval()
|
||||
|
||||
# def to_bfloat16_like_physical_intelligence(self):
|
||||
# self.vlm = self.vlm.to(dtype=torch.bfloat16)
|
||||
|
||||
# params_to_change_dtype = [
|
||||
# "language_model.model.layers",
|
||||
# "gemma_expert.model.layers",
|
||||
# "vision_tower",
|
||||
# "multi_modal",
|
||||
# ]
|
||||
# for name, param in self.named_parameters():
|
||||
# if any(selector in name for selector in params_to_change_dtype):
|
||||
# param.data = param.data.to(dtype=torch.bfloat16)
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
patch_attention_mask = None
|
||||
# # FIXME(mshukor): probably not needed as we don't have padded images here
|
||||
# pixel_values = image.unsqueeze(1)
|
||||
# batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
# pixel_values = pixel_values
|
||||
# pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
|
||||
|
||||
# # Remove padding images - padding images are full 0.
|
||||
# nb_values_per_image = pixel_values.shape[1:].numel()
|
||||
# real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
|
||||
|
||||
# if not any(real_images_inds):
|
||||
# # no images, leave one empty image.
|
||||
# real_images_inds[0] = True
|
||||
|
||||
# pixel_values = pixel_values[real_images_inds].contiguous()
|
||||
|
||||
# # Handle the vision attention mask
|
||||
|
||||
# pixel_attention_mask = torch.ones(
|
||||
# size=[pixel_values.shape[i] for i in (0, 2, 3)],
|
||||
# dtype=torch.bool,
|
||||
# device=pixel_values.device,
|
||||
# )
|
||||
|
||||
# patch_size = self.vlm.config.vision_config.patch_size
|
||||
# patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
|
||||
# patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
|
||||
# patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
|
||||
# FIXME(mshukor): add special image tokens specific to smolvlm
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = (
|
||||
self.get_vlm_model()
|
||||
.vision_model(
|
||||
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
.last_hidden_state
|
||||
)
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.get_vlm_model().connector(image_hidden_states)
|
||||
return image_hidden_states
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.get_vlm_model().text_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward_attn_layer(
|
||||
self,
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache: bool = True,
|
||||
fill_kv_cache: bool = True,
|
||||
past_key_values=None,
|
||||
) -> list[torch.Tensor]:
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = model_layers[i][layer_idx]
|
||||
if hidden_states is None or layer is None:
|
||||
continue
|
||||
|
||||
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
# hidden_states = hidden_states * normalizer
|
||||
hidden_states = layer.input_layernorm(hidden_states)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
|
||||
# FIXME(mshukor): self attention always when having only the prefix
|
||||
# B,L,H,D with L sequence length, H number of heads, D head dim
|
||||
# concatenate on the number of embeddings/tokens
|
||||
query_states = torch.cat(query_states, dim=1)
|
||||
key_states = torch.cat(key_states, dim=1)
|
||||
value_states = torch.cat(value_states, dim=1)
|
||||
# FIXME(mshukor): seq should be B, H, L, D ?
|
||||
seq_len = query_states.shape[1]
|
||||
if seq_len < position_ids.shape[1]:
|
||||
_position_ids = position_ids[:, :seq_len]
|
||||
_attention_mask = attention_mask[:, :seq_len, :seq_len]
|
||||
else:
|
||||
_position_ids = position_ids
|
||||
_attention_mask = attention_mask
|
||||
|
||||
if self.self_attn_only_actions:
|
||||
attention_mask_ = _attention_mask.clone()
|
||||
position_ids_ = _position_ids.clone()
|
||||
if inputs_embeds[1] is not None:
|
||||
suffix_len = inputs_embeds[1].shape[1]
|
||||
attention_mask_[:, -suffix_len:, :-suffix_len] = False
|
||||
position_ids_[:, -suffix_len:] = (
|
||||
_position_ids[:, -suffix_len:] - _position_ids[:, -suffix_len][:, None]
|
||||
)
|
||||
else:
|
||||
attention_mask_ = _attention_mask
|
||||
position_ids_ = _position_ids
|
||||
|
||||
query_states = apply_rope(
|
||||
query_states, position_ids_
|
||||
) # FIXME(mshukor): this assumes we have always the vlm features?
|
||||
key_states = apply_rope(key_states, position_ids_)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = {}
|
||||
|
||||
if use_cache:
|
||||
if fill_kv_cache:
|
||||
past_key_values[layer_idx] = {
|
||||
"key_states": key_states,
|
||||
"value_states": value_states,
|
||||
}
|
||||
else:
|
||||
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||
value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1)
|
||||
|
||||
attention_interface = self.get_attention_interface()
|
||||
|
||||
att_output = attention_interface(
|
||||
attention_mask_, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
# att_output = att_output.to(dtype=models[i].dtype)
|
||||
|
||||
return [att_output], past_key_values
|
||||
|
||||
def forward_cross_attn_layer(
|
||||
self,
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache: bool = True,
|
||||
fill_kv_cache: bool = True,
|
||||
past_key_values=None,
|
||||
) -> list[torch.Tensor]:
|
||||
attention_interface = self.get_attention_interface()
|
||||
|
||||
att_outputs = []
|
||||
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), (
|
||||
f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
|
||||
)
|
||||
|
||||
if len(inputs_embeds) == 2 and not past_key_values:
|
||||
# Prefix attention
|
||||
seq_len = inputs_embeds[0].shape[1]
|
||||
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
|
||||
prefix_attention_mask = attention_mask[:, :seq_len, :seq_len]
|
||||
|
||||
layer = model_layers[0][layer_idx]
|
||||
|
||||
hidden_states = layer.input_layernorm(inputs_embeds[0])
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
# B,L,H,D with L sequence length, H number of heads, D head dim
|
||||
query_states = apply_rope(query_state, position_id)
|
||||
key_states = apply_rope(key_state, position_id)
|
||||
|
||||
att_output = attention_interface(
|
||||
prefix_attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
att_outputs.append(att_output)
|
||||
else:
|
||||
expert_position_id = position_ids
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = {}
|
||||
|
||||
if use_cache:
|
||||
if fill_kv_cache:
|
||||
past_key_values[layer_idx] = {
|
||||
"key_states": key_states,
|
||||
"value_states": value_states,
|
||||
}
|
||||
else:
|
||||
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = past_key_values[layer_idx]["key_states"]
|
||||
value_states = past_key_values[layer_idx]["value_states"]
|
||||
|
||||
# Expert
|
||||
expert_layer = model_layers[1][layer_idx]
|
||||
if expert_layer is not None:
|
||||
expert_hidden_states = expert_layer.input_layernorm(inputs_embeds[1])
|
||||
|
||||
expert_input_shape = expert_hidden_states.shape[:-1]
|
||||
expert_hidden_shape = (*expert_input_shape, -1, expert_layer.self_attn.head_dim)
|
||||
|
||||
expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype)
|
||||
expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape)
|
||||
|
||||
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(
|
||||
*key_states.shape[:2], -1
|
||||
)
|
||||
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(
|
||||
*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim
|
||||
) # k_proj should have same dim as kv
|
||||
|
||||
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(
|
||||
*value_states.shape[:2], -1
|
||||
)
|
||||
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(
|
||||
*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim
|
||||
)
|
||||
|
||||
expert_position_id = (
|
||||
expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values
|
||||
) # start from 0
|
||||
expert_attention_mask = attention_mask[
|
||||
:, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] :
|
||||
] # take into account kv
|
||||
|
||||
expert_query_states = apply_rope(expert_query_state, expert_position_id)
|
||||
# expert_key_states = apply_rope(expert_key_state, expert_position_id)
|
||||
|
||||
att_output = attention_interface(
|
||||
expert_attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
expert_query_states,
|
||||
expert_key_states,
|
||||
expert_value_states,
|
||||
)
|
||||
att_outputs.append(att_output)
|
||||
else:
|
||||
att_outputs.append(None)
|
||||
|
||||
# att_output = att_output.to(dtype=models[i].dtype)
|
||||
return att_outputs, past_key_values
|
||||
|
||||
def get_model_layers(self, models: list) -> list: # FIXME(mshukor): is this efficient?
|
||||
vlm_layers = []
|
||||
expert_layers = []
|
||||
multiple_of = self.num_vlm_layers // self.num_expert_layers
|
||||
for i in range(self.num_vlm_layers):
|
||||
if multiple_of > 0 and i > 0 and i % multiple_of != 0:
|
||||
expert_layer = None
|
||||
else:
|
||||
expert_layer_index = i // multiple_of if multiple_of > 0 else i
|
||||
expert_layer = models[1].layers[expert_layer_index]
|
||||
vlm_layers.append(models[0].layers[i])
|
||||
expert_layers.append(expert_layer)
|
||||
return [vlm_layers, expert_layers]
|
||||
|
||||
# TODO: break down this huge forward into modules or functions
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | Cache | None = None,
|
||||
inputs_embeds: list[torch.FloatTensor] = None,
|
||||
use_cache: bool | None = None,
|
||||
fill_kv_cache: bool | None = None,
|
||||
):
|
||||
models = [self.get_vlm_model().text_model, self.lm_expert]
|
||||
model_layers = self.get_model_layers(models)
|
||||
for hidden_states in inputs_embeds:
|
||||
# TODO this is very inefficient
|
||||
# dtype is always the same, batch size too (if > 1 len)
|
||||
# device could be trickier in multi gpu edge cases but that's it
|
||||
if hidden_states is None:
|
||||
continue
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# # Pad prefix embds so that prefix_embs + prefix_embs len are multiple of 128, pad left or right depending on the gen or train
|
||||
if self.attention_implementation == "flex":
|
||||
if (
|
||||
inputs_embeds[0] is not None
|
||||
and inputs_embeds[1] is not None
|
||||
and attention_mask.shape[-1] == attention_mask.shape[-2]
|
||||
and past_key_values is None
|
||||
): # Now only during training
|
||||
seq_len = inputs_embeds[0].shape[1] + inputs_embeds[1].shape[1]
|
||||
padded_seq_len = _round_up_to_multiple(
|
||||
seq_len, 128
|
||||
) # FIXME(mshukor): more efficient to have a fixed seq len?
|
||||
b_mask, q_len, kv_len = attention_mask.shape # The shape of your mask
|
||||
pad = padded_seq_len - q_len
|
||||
attention_mask = F.pad(attention_mask, (0, pad, 0, pad), value=True)
|
||||
inputs_embeds[0] = F.pad(inputs_embeds[0], (0, 0, 0, pad), value=0.0)
|
||||
position_ids = F.pad(position_ids, (0, pad), value=0)
|
||||
|
||||
# RMSNorm
|
||||
num_layers = self.num_vlm_layers
|
||||
head_dim = self.vlm.config.text_config.head_dim
|
||||
for layer_idx in range(num_layers):
|
||||
if (
|
||||
fill_kv_cache
|
||||
or "cross" not in self.attention_mode
|
||||
or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0)
|
||||
):
|
||||
att_outputs, past_key_values = self.forward_attn_layer(
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache=use_cache,
|
||||
fill_kv_cache=fill_kv_cache,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
else:
|
||||
att_outputs, past_key_values = self.forward_cross_attn_layer(
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache=use_cache,
|
||||
fill_kv_cache=fill_kv_cache,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
# query_states = []
|
||||
# key_states = []
|
||||
# value_states = []
|
||||
# for i, hidden_states in enumerate(inputs_embeds):
|
||||
# if hidden_states is None:
|
||||
# continue
|
||||
# layer = models[i].layers[layer_idx]
|
||||
# # normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
|
||||
# # hidden_states = hidden_states * normalizer
|
||||
# hidden_states = layer.input_layernorm(hidden_states)
|
||||
|
||||
# input_shape = hidden_states.shape[:-1]
|
||||
# hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
# hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
|
||||
# query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
# key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
# value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
# query_states.append(query_state)
|
||||
# key_states.append(key_state)
|
||||
# value_states.append(value_state)
|
||||
|
||||
# # FIXME(mshukor): self attention always when having only the prefix
|
||||
# # B,L,H,D with L sequence length, H number of heads, D head dim
|
||||
# # concatenate on the number of embeddings/tokens
|
||||
# query_states = torch.cat(query_states, dim=1)
|
||||
# key_states = torch.cat(key_states, dim=1)
|
||||
# value_states = torch.cat(value_states, dim=1)
|
||||
# # FIXME(mshukor): seq should be B, H, L, D ?
|
||||
# query_states = apply_rope(query_states, position_ids)
|
||||
# key_states = apply_rope(key_states, position_ids)
|
||||
|
||||
# if use_cache and past_key_values is None:
|
||||
# past_key_values = {}
|
||||
|
||||
# if use_cache:
|
||||
# if fill_kv_cache:
|
||||
# past_key_values[layer_idx] = {
|
||||
# "key_states": key_states,
|
||||
# "value_states": value_states,
|
||||
# }
|
||||
# else:
|
||||
# # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
||||
# # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# # the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# # in `transformers`. (molbap)
|
||||
# key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||
# value_states = torch.cat(
|
||||
# [past_key_values[layer_idx]["value_states"], value_states], dim=1
|
||||
# )
|
||||
|
||||
# attention_interface = self.get_attention_interface()
|
||||
# att_output = attention_interface(
|
||||
# attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
# )
|
||||
|
||||
# att_output = att_output.to(dtype=models[i].dtype)
|
||||
|
||||
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
|
||||
outputs_embeds = []
|
||||
start = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
# layer = models[i].layers[layer_idx]
|
||||
layer = model_layers[i][layer_idx]
|
||||
att_output = (
|
||||
att_outputs[i] if i < len(att_outputs) else att_outputs[0]
|
||||
) # in case of self_attn
|
||||
if hidden_states is not None:
|
||||
if layer is None:
|
||||
outputs_embeds.append(hidden_states)
|
||||
continue
|
||||
end = start + hidden_states.shape[1]
|
||||
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
att_out = att_output[:, start:end]
|
||||
out_emb = layer.self_attn.o_proj(att_out)
|
||||
|
||||
# TODO: first dropout (by default 0.0)
|
||||
# first residual
|
||||
out_emb += hidden_states
|
||||
after_first_residual = out_emb.clone()
|
||||
|
||||
out_emb = layer.post_attention_layernorm(out_emb)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
|
||||
# TODO: second dropout (by default 0.0)
|
||||
|
||||
# second residual
|
||||
out_emb += after_first_residual
|
||||
|
||||
outputs_embeds.append(out_emb)
|
||||
|
||||
start = end if len(att_outputs) == 1 else 0
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
|
||||
inputs_embeds = outputs_embeds
|
||||
|
||||
# final norm
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
if hidden_states is not None:
|
||||
out_emb = models[i].norm(hidden_states)
|
||||
outputs_embeds.append(out_emb)
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
return outputs_embeds, past_key_values
|
||||
|
||||
def get_attention_interface(self):
|
||||
if self.attention_implementation == "fa2":
|
||||
attention_interface = self.flash_attention_forward
|
||||
elif self.attention_implementation == "flex":
|
||||
attention_interface = partial(
|
||||
flex_attention_forward,
|
||||
num_att_heads=self.num_attention_heads,
|
||||
num_key_value_heads=self.num_key_value_heads,
|
||||
)
|
||||
else:
|
||||
attention_interface = self.eager_attention_forward
|
||||
return attention_interface
|
||||
|
||||
def flash_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
):
|
||||
raise NotImplementedError("FA2 is not implemented (yet)")
|
||||
|
||||
def eager_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
):
|
||||
num_att_heads = self.num_attention_heads
|
||||
num_key_value_heads = self.num_key_value_heads
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
# query_states: batch_size, sequence_length, num_att_head, head_dim
|
||||
# key_states: batch_size, sequence_length, num_key_value_head, head_dim
|
||||
# value_states: batch_size, sequence_length, num_key_value_head, head_dim
|
||||
sequence_length = key_states.shape[1]
|
||||
|
||||
key_states = key_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
# Attention here is upcasted to float32 to match the original eager implementation.
|
||||
|
||||
query_states = query_states.to(dtype=torch.float32)
|
||||
key_states = key_states.to(dtype=torch.float32)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
|
||||
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
att_weights *= head_dim**-0.5
|
||||
|
||||
att_weights = att_weights.to(dtype=torch.float32)
|
||||
big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py
|
||||
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
||||
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
||||
probs = probs.to(dtype=value_states.dtype)
|
||||
|
||||
# probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
|
||||
# value_states: batch_size, sequence_length, num_att_heads, head_dim
|
||||
|
||||
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
|
||||
|
||||
att_output = att_output.permute(0, 2, 1, 3)
|
||||
# we use -1 because sequence length can change
|
||||
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
||||
|
||||
return att_output
|
||||
@@ -1,955 +1,3 @@
|
||||
# #!/usr/bin/env python
|
||||
|
||||
# # Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
# #
|
||||
# # Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# # you may not use this file except in compliance with the License.
|
||||
# # You may obtain a copy of the License at
|
||||
# #
|
||||
# # http://www.apache.org/licenses/LICENSE-2.0
|
||||
# #
|
||||
# # Unless required by applicable law or agreed to in writing, software
|
||||
# # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# # See the License for the specific language governing permissions and
|
||||
# # limitations under the License.
|
||||
|
||||
# """
|
||||
# SmolVLA:
|
||||
|
||||
# [Paper](https://huggingface.co/papers/2506.01844)
|
||||
|
||||
# Designed by Hugging Face.
|
||||
|
||||
# Install smolvla extra dependencies:
|
||||
# ```bash
|
||||
# pip install -e ".[smolvla]"
|
||||
# ```
|
||||
|
||||
# Example of finetuning the smolvla pretrained model (`smolvla_base`):
|
||||
# ```bash
|
||||
# lerobot-train \
|
||||
# --policy.path=lerobot/smolvla_base \
|
||||
# --dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
# --batch_size=64 \
|
||||
# --steps=200000
|
||||
# ```
|
||||
|
||||
# Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM,
|
||||
# and an action expert.
|
||||
# ```bash
|
||||
# lerobot-train \
|
||||
# --policy.type=smolvla \
|
||||
# --dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
# --batch_size=64 \
|
||||
# --steps=200000
|
||||
# ```
|
||||
|
||||
# Example of using the smolvla pretrained model outside LeRobot training framework:
|
||||
# ```python
|
||||
# policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
||||
# ```
|
||||
|
||||
# """
|
||||
|
||||
# import math
|
||||
# import os
|
||||
# import re
|
||||
# from collections import deque
|
||||
|
||||
# import safetensors
|
||||
# import torch
|
||||
# import torch.nn.functional as F # noqa: N812
|
||||
# from torch import Tensor, nn
|
||||
# from transformers import AutoProcessor
|
||||
|
||||
# from lerobot.constants import ACTION
|
||||
# from lerobot.policies.normalize import (
|
||||
# Normalize,
|
||||
# Unnormalize,
|
||||
# )
|
||||
# from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
# from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
# from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
# from lerobot.policies.utils import (
|
||||
# populate_queues,
|
||||
# )
|
||||
# from lerobot.utils.utils import get_safe_dtype
|
||||
# OBS_STATE = 'state'
|
||||
# ACTION = 'actions'
|
||||
# # Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
|
||||
# _VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
|
||||
|
||||
|
||||
# def canonicalise(k: str) -> str:
|
||||
# """
|
||||
# Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
|
||||
# normalisation-buffer key.
|
||||
# """
|
||||
# return _VARIANT_RE.sub(".buffer_", k)
|
||||
|
||||
|
||||
# def standardise_state_dict(
|
||||
# checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
|
||||
# ) -> tuple[dict[str, torch.Tensor], list[str]]:
|
||||
# """
|
||||
# • Re-keys `checkpoint ` so that every entry matches the *reference* key set.
|
||||
# • If several variant keys collapse to the same canonical name we keep the
|
||||
# first one and log the collision.
|
||||
# • Returns the new dict + a list of entries that could not be matched.
|
||||
# """
|
||||
# out, collisions, unmatched = {}, {}, []
|
||||
|
||||
# for k, v in checkpoint.items():
|
||||
# canon = canonicalise(k)
|
||||
# if canon in ref_keys:
|
||||
# if canon in out: # duplicate after collapsing
|
||||
# collisions.setdefault(canon, []).append(k)
|
||||
# else:
|
||||
# out[canon] = v
|
||||
# else:
|
||||
# unmatched.append(k)
|
||||
|
||||
# if verbose:
|
||||
# for canon, variants in collisions.items():
|
||||
# print(f"[standardise_state_dict] '{canon}' ← {variants}")
|
||||
# if unmatched:
|
||||
# print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys")
|
||||
|
||||
# out.update({k: checkpoint[k] for k in unmatched})
|
||||
# return out, unmatched
|
||||
|
||||
|
||||
# def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
|
||||
# """
|
||||
# Renames keys in a checkpoint dictionary based on the given rename string.
|
||||
|
||||
# Args:
|
||||
# checkpoint (dict): The checkpoint dictionary.
|
||||
# rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2".
|
||||
|
||||
# Returns:
|
||||
# dict: The modified checkpoint with renamed keys.
|
||||
# """
|
||||
|
||||
# rename_dict = dict(pair.split("//") for pair in rename_str.split(","))
|
||||
|
||||
# new_checkpoint = {}
|
||||
# for k, v in checkpoint.items():
|
||||
# for old_key, new_key in rename_dict.items():
|
||||
# if old_key in k:
|
||||
# k = k.replace(old_key, new_key)
|
||||
# new_checkpoint[k] = v
|
||||
# return new_checkpoint
|
||||
|
||||
|
||||
# def load_smolvla(
|
||||
# model: torch.nn.Module,
|
||||
# filename: str | os.PathLike,
|
||||
# *,
|
||||
# device: str = "cpu",
|
||||
# checkpoint_keys_mapping: str = "",
|
||||
# ) -> torch.nn.Module:
|
||||
# state_dict = safetensors.torch.load_file(filename, device=device)
|
||||
|
||||
# # Optional user-supplied renames (e.g. "model._orig_mod.//model.")
|
||||
# if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
|
||||
# state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)
|
||||
|
||||
# state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
|
||||
|
||||
# # HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
|
||||
# norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
|
||||
# state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}
|
||||
|
||||
# missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# if not all(key.startswith(norm_keys) for key in missing) or unexpected:
|
||||
# raise RuntimeError(
|
||||
# "SmolVLA %d missing / %d unexpected keys",
|
||||
# len(missing),
|
||||
# len(unexpected),
|
||||
# )
|
||||
|
||||
# return model
|
||||
|
||||
|
||||
# def create_sinusoidal_pos_embedding(
|
||||
# time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
# ) -> Tensor:
|
||||
# """Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
# if dimension % 2 != 0:
|
||||
# raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
# if time.ndim != 1:
|
||||
# raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
|
||||
# dtype = get_safe_dtype(torch.float64, device.type)
|
||||
# fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
# period = min_period * (max_period / min_period) ** fraction
|
||||
|
||||
# # Compute the outer product
|
||||
# scaling_factor = 1.0 / period * 2 * math.pi
|
||||
# sin_input = scaling_factor[None, :] * time[:, None]
|
||||
# pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
# return pos_emb
|
||||
|
||||
|
||||
# def make_att_2d_masks(pad_masks, att_masks):
|
||||
# """Copied from big_vision.
|
||||
|
||||
# Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
||||
# smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
||||
# setup several types of attention, for example:
|
||||
|
||||
# [[1 1 1 1 1 1]]: pure causal attention.
|
||||
|
||||
# [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
||||
# themselves and the last 3 tokens have a causal attention. The first
|
||||
# entry could also be a 1 without changing behaviour.
|
||||
|
||||
# [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
||||
# block can attend all previous blocks and all tokens on the same block.
|
||||
|
||||
# Args:
|
||||
# input_mask: bool[B, N] true if its part of the input, false if padding.
|
||||
# mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
||||
# it and 0 where it shares the same attention mask as the previous token.
|
||||
# """
|
||||
# if att_masks.ndim != 2:
|
||||
# raise ValueError(att_masks.ndim)
|
||||
# if pad_masks.ndim != 2:
|
||||
# raise ValueError(pad_masks.ndim)
|
||||
|
||||
# cumsum = torch.cumsum(att_masks, dim=1)
|
||||
# att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
# pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||
# att_2d_masks = att_2d_masks & pad_2d_masks
|
||||
# return att_2d_masks
|
||||
|
||||
|
||||
# def resize_with_pad(img, width, height, pad_value=-1):
|
||||
# # assume no-op when width height fits already
|
||||
# if img.ndim != 4:
|
||||
# raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
||||
|
||||
# cur_height, cur_width = img.shape[2:]
|
||||
|
||||
# ratio = max(cur_width / width, cur_height / height)
|
||||
# resized_height = int(cur_height / ratio)
|
||||
# resized_width = int(cur_width / ratio)
|
||||
# resized_img = F.interpolate(
|
||||
# img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||
# )
|
||||
|
||||
# pad_height = max(0, int(height - resized_height))
|
||||
# pad_width = max(0, int(width - resized_width))
|
||||
|
||||
# # pad on left and top of image
|
||||
# padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||
# return padded_img
|
||||
|
||||
|
||||
# def pad_vector(vector, new_dim):
|
||||
# """Can be (batch_size x sequence_length x features_dimension)
|
||||
# or (batch_size x features_dimension)
|
||||
# """
|
||||
# if vector.shape[-1] == new_dim:
|
||||
# return vector
|
||||
# shape = list(vector.shape)
|
||||
# current_dim = shape[-1]
|
||||
# shape[-1] = new_dim
|
||||
# new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
|
||||
# new_vector[..., :current_dim] = vector
|
||||
# return new_vector
|
||||
|
||||
|
||||
# def normalize(x, min_val, max_val):
|
||||
# return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
# def unnormalize(x, min_val, max_val):
|
||||
# return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
# def safe_arcsin(value):
|
||||
# # This ensures that the input stays within
|
||||
# # [−1,1] to avoid invalid values for arcsin
|
||||
# return torch.arcsin(torch.clamp(value, -1.0, 1.0))
|
||||
|
||||
|
||||
# def aloha_gripper_to_angular(value):
|
||||
# # Aloha transforms the gripper positions into a linear space. The following code
|
||||
# # reverses this transformation to be consistent with smolvla which is pretrained in
|
||||
# # angular space.
|
||||
# #
|
||||
# # These values are coming from the Aloha code:
|
||||
# # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
# value = unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# # This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
# def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
# value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
# return safe_arcsin(value)
|
||||
|
||||
# # The constants are taken from the Interbotix code.
|
||||
# value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# # Normalize to [0, 1].
|
||||
# # The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
# return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
# def aloha_gripper_from_angular(value):
|
||||
# # Convert from the gripper position used by smolvla to the gripper position that is used by Aloha.
|
||||
# # Note that the units are still angular but the range is different.
|
||||
|
||||
# # The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
# value = unnormalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
# # These values are coming from the Aloha code:
|
||||
# # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
# return normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
# def aloha_gripper_from_angular_inv(value):
|
||||
# # Directly inverts the gripper_from_angular function.
|
||||
# value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
# return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
# class SmolVLAPolicy(PreTrainedPolicy):
|
||||
# """Wrapper class around VLAFlowMatching model to train and run inference within LeRobot."""
|
||||
|
||||
# config_class = SmolVLAConfig
|
||||
# name = "smolvla"
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# config: SmolVLAConfig,
|
||||
# dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
# ):
|
||||
# """
|
||||
# Args:
|
||||
# config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
# the configuration class is used.
|
||||
# dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
# that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
# """
|
||||
|
||||
# super().__init__(config)
|
||||
# config.validate_features()
|
||||
# self.config = config
|
||||
# self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
# self.normalize_targets = Normalize(
|
||||
# config.output_features, config.normalization_mapping, dataset_stats
|
||||
# )
|
||||
# self.unnormalize_outputs = Unnormalize(
|
||||
# config.output_features, config.normalization_mapping, dataset_stats
|
||||
# )
|
||||
|
||||
# self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
|
||||
# self.model = VLAFlowMatching(config)
|
||||
# self.reset()
|
||||
|
||||
# def reset(self):
|
||||
# """This should be called whenever the environment is reset."""
|
||||
# self._queues = {
|
||||
# ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
# }
|
||||
|
||||
# # HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues
|
||||
# @classmethod
|
||||
# def _load_as_safetensor(
|
||||
# cls,
|
||||
# model: "SmolVLAPolicy",
|
||||
# model_file: str,
|
||||
# map_location: str,
|
||||
# strict: bool,
|
||||
# ):
|
||||
# safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
||||
# return load_smolvla(
|
||||
# model,
|
||||
# model_file,
|
||||
# device=map_location,
|
||||
# checkpoint_keys_mapping="model._orig_mod.//model.",
|
||||
# )
|
||||
|
||||
# def get_optim_params(self) -> dict:
|
||||
# return self.parameters()
|
||||
|
||||
# def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
# # TODO: Check if this for loop is needed.
|
||||
# # Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch
|
||||
# # In the case of offline inference, we have the action in the batch
|
||||
# # that why without the k != ACTION check, it will raise an error because we are trying to stack
|
||||
# # on an empty container.
|
||||
# for k in batch:
|
||||
# if k in self._queues and k != ACTION:
|
||||
# batch[k] = torch.stack(list(self._queues[k]), dim=1)
|
||||
|
||||
# images, img_masks = self.prepare_images(batch)
|
||||
# state = self.prepare_state(batch)
|
||||
# lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
|
||||
# actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
|
||||
|
||||
# # Unpad actions
|
||||
# original_action_dim = self.config.action_feature.shape[0]
|
||||
# actions = actions[:, :, :original_action_dim]
|
||||
|
||||
# actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
|
||||
# if self.config.adapt_to_pi_aloha:
|
||||
# actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
# return actions
|
||||
|
||||
# def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# if self.config.adapt_to_pi_aloha:
|
||||
# batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
# batch = self.normalize_inputs(batch)
|
||||
|
||||
# return batch
|
||||
|
||||
# @torch.no_grad()
|
||||
# def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
# self.eval()
|
||||
|
||||
# batch = self._prepare_batch(batch)
|
||||
# self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
# actions = self._get_action_chunk(batch, noise)
|
||||
# return actions
|
||||
|
||||
# @torch.no_grad()
|
||||
# def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
# """Select a single action given environment observations.
|
||||
|
||||
# This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
# environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
# queue is empty.
|
||||
# """
|
||||
# self.eval()
|
||||
# batch = self._prepare_batch(batch)
|
||||
# self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
# # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# # querying the policy.
|
||||
# if len(self._queues[ACTION]) == 0:
|
||||
# actions = self._get_action_chunk(batch, noise)
|
||||
|
||||
# # `self.predict_action_chunk` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# # effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
# self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||
|
||||
# return self._queues[ACTION].popleft()
|
||||
|
||||
# def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
|
||||
# """Do a full training forward pass to compute the loss"""
|
||||
# if self.config.adapt_to_pi_aloha:
|
||||
# batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
# batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
# batch = self.normalize_inputs(batch)
|
||||
# batch = self.normalize_targets(batch)
|
||||
# images, img_masks = self.prepare_images(batch)
|
||||
# state = self.prepare_state(batch)
|
||||
# lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
# actions = self.prepare_action(batch)
|
||||
# actions_is_pad = batch.get("actions_id_pad")
|
||||
# loss_dict = {}
|
||||
# losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
# loss_dict["losses_after_forward"] = losses.clone()
|
||||
|
||||
# if actions_is_pad is not None:
|
||||
# in_episode_bound = ~actions_is_pad
|
||||
# losses = losses * in_episode_bound.unsqueeze(-1)
|
||||
# loss_dict["losses_after_in_ep_bound"] = losses.clone()
|
||||
|
||||
# # Remove padding
|
||||
# losses = losses[:, :, : self.config.max_action_dim]
|
||||
# loss_dict["losses_after_rm_padding"] = losses.clone()
|
||||
|
||||
# # For backward pass
|
||||
# loss = losses.mean()
|
||||
# # For backward pass
|
||||
# loss_dict["loss"] = loss.item()
|
||||
# return loss, loss_dict
|
||||
|
||||
# def prepare_images(self, batch):
|
||||
# """Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
||||
# convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
|
||||
# """
|
||||
# images = []
|
||||
# img_masks = []
|
||||
# present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
# missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
||||
|
||||
# if len(present_img_keys) == 0:
|
||||
# raise ValueError(
|
||||
# f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
||||
# )
|
||||
# # Preprocess image features present in the batch
|
||||
# for key in present_img_keys:
|
||||
# img = batch[key][:, -1, :, :, :] if batch[key].ndim == 5 else batch[key]
|
||||
# if self.config.resize_imgs_with_padding is not None:
|
||||
# img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
||||
|
||||
# # Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
# img = img * 2.0 - 1.0
|
||||
|
||||
# bsize = img.shape[0]
|
||||
# device = img.device
|
||||
# if f"{key}_padding_mask" in batch:
|
||||
# mask = batch[f"{key}_padding_mask"].bool()
|
||||
# else:
|
||||
# mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
# images.append(img)
|
||||
# img_masks.append(mask)
|
||||
|
||||
# # Create image features not present in the batch
|
||||
# # as fully 0 padded images.
|
||||
# for num_empty_cameras in range(len(missing_img_keys)):
|
||||
# if num_empty_cameras >= self.config.empty_cameras:
|
||||
# break
|
||||
# img = torch.ones_like(img) * -1
|
||||
# mask = torch.zeros_like(mask)
|
||||
# images.append(img)
|
||||
# img_masks.append(mask)
|
||||
# return images, img_masks
|
||||
|
||||
# def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
# """Tokenize the text input"""
|
||||
# device = batch[OBS_STATE].device
|
||||
# tasks = batch["task"]
|
||||
# if isinstance(tasks, str):
|
||||
# tasks = [tasks]
|
||||
|
||||
# if len(tasks) == 1:
|
||||
# tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
|
||||
|
||||
# tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
|
||||
# tokenized_prompt = self.language_tokenizer.__call__(
|
||||
# tasks,
|
||||
# padding=self.config.pad_language_to,
|
||||
# padding_side="right",
|
||||
# max_length=self.config.tokenizer_max_length,
|
||||
# return_tensors="pt",
|
||||
# )
|
||||
# lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
# lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
|
||||
# return lang_tokens, lang_masks
|
||||
|
||||
# def _pi_aloha_decode_state(self, state):
|
||||
# # Flip the joints.
|
||||
# for motor_idx in [1, 2, 8, 9]:
|
||||
# state[:, motor_idx] *= -1
|
||||
# # Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
# for motor_idx in [6, 13]:
|
||||
# state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
|
||||
# return state
|
||||
|
||||
# def _pi_aloha_encode_actions(self, actions):
|
||||
# # Flip the joints.
|
||||
# for motor_idx in [1, 2, 8, 9]:
|
||||
# actions[:, :, motor_idx] *= -1
|
||||
# # Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
# for motor_idx in [6, 13]:
|
||||
# actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
||||
# return actions
|
||||
|
||||
# def _pi_aloha_encode_actions_inv(self, actions):
|
||||
# # Flip the joints again.
|
||||
# for motor_idx in [1, 2, 8, 9]:
|
||||
# actions[:, :, motor_idx] *= -1
|
||||
# # Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
# for motor_idx in [6, 13]:
|
||||
# actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||
# return actions
|
||||
|
||||
# def prepare_state(self, batch):
|
||||
# """Pad state"""
|
||||
# state = batch[OBS_STATE][:, -1, :] if batch[OBS_STATE].ndim > 2 else batch[OBS_STATE]
|
||||
# state = pad_vector(state, self.config.max_state_dim)
|
||||
# return state
|
||||
|
||||
# def prepare_action(self, batch):
|
||||
# """Pad action"""
|
||||
# actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||
# return actions
|
||||
|
||||
|
||||
# def pad_tensor(tensor, max_len, pad_value=0):
|
||||
# """
|
||||
# Efficiently pads a tensor along sequence dimension to match max_len.
|
||||
|
||||
# Args:
|
||||
# tensor (torch.Tensor): Shape (B, L, ...) or (B, L).
|
||||
# max_len (int): Fixed sequence length.
|
||||
# pad_value (int/float): Value for padding.
|
||||
|
||||
# Returns:
|
||||
# torch.Tensor: Shape (B, max_len, ...) or (B, max_len).
|
||||
# """
|
||||
# b, d = tensor.shape[:2]
|
||||
|
||||
# # Create a padded tensor of max_len and copy the existing values
|
||||
# padded_tensor = torch.full(
|
||||
# (b, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device
|
||||
# )
|
||||
# padded_tensor[:, :d] = tensor # Efficient in-place copy
|
||||
|
||||
# return padded_tensor
|
||||
|
||||
|
||||
# class VLAFlowMatching(nn.Module):
|
||||
# """
|
||||
# SmolVLA
|
||||
|
||||
# [Paper]()
|
||||
|
||||
# Designed by Hugging Face.
|
||||
# ┌──────────────────────────────┐
|
||||
# │ actions │
|
||||
# │ ▲ │
|
||||
# │ ┌─────────┐ ┌─|────┐ │
|
||||
# │ | │────► │ │ │
|
||||
# │ | │ kv │ │ │
|
||||
# │ | │────► │Action│ │
|
||||
# │ | VLM │cache │Expert│ |
|
||||
# │ │ │────► | │ │
|
||||
# │ │ │ │ │ │
|
||||
# │ └▲──▲───▲─┘ └───▲──┘ |
|
||||
# │ │ | | │ |
|
||||
# │ | | | noise │
|
||||
# │ │ │ state │
|
||||
# │ │ language tokens │
|
||||
# │ image(s) │
|
||||
# └──────────────────────────────┘
|
||||
# """
|
||||
|
||||
# def __init__(self, config: SmolVLAConfig):
|
||||
# super().__init__()
|
||||
# self.config = config
|
||||
|
||||
# self.vlm_with_expert = SmolVLMWithExpertModel(
|
||||
# model_id=self.config.vlm_model_name,
|
||||
# freeze_vision_encoder=self.config.freeze_vision_encoder,
|
||||
# train_expert_only=self.config.train_expert_only,
|
||||
# load_vlm_weights=self.config.load_vlm_weights,
|
||||
# attention_mode=self.config.attention_mode,
|
||||
# num_expert_layers=self.config.num_expert_layers,
|
||||
# num_vlm_layers=self.config.num_vlm_layers,
|
||||
# self_attn_every_n_layers=self.config.self_attn_every_n_layers,
|
||||
# expert_width_multiplier=self.config.expert_width_multiplier,
|
||||
# )
|
||||
# self.state_proj = nn.Linear(
|
||||
# self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
|
||||
# )
|
||||
# self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size)
|
||||
# self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim)
|
||||
|
||||
# self.action_time_mlp_in = nn.Linear(
|
||||
# self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size
|
||||
# )
|
||||
# self.action_time_mlp_out = nn.Linear(
|
||||
# self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size
|
||||
# )
|
||||
|
||||
# self.set_requires_grad()
|
||||
# self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id
|
||||
# self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id
|
||||
# self.global_image_start_token = torch.tensor(
|
||||
# [self.fake_image_token, self.global_image_token], dtype=torch.long
|
||||
# )
|
||||
|
||||
# self.add_image_special_tokens = self.config.add_image_special_tokens
|
||||
# self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
|
||||
# self.prefix_length = self.config.prefix_length
|
||||
|
||||
# def set_requires_grad(self):
|
||||
# for params in self.state_proj.parameters():
|
||||
# params.requires_grad = self.config.train_state_proj
|
||||
|
||||
# def sample_noise(self, shape, device):
|
||||
# noise = torch.normal(
|
||||
# mean=0.0,
|
||||
# std=1.0,
|
||||
# size=shape,
|
||||
# dtype=torch.float32,
|
||||
# device=device,
|
||||
# )
|
||||
# return noise
|
||||
|
||||
# def sample_time(self, bsize, device):
|
||||
# beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
|
||||
# time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
|
||||
# time = time_beta * 0.999 + 0.001
|
||||
# return time
|
||||
|
||||
# def embed_prefix(
|
||||
# self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None
|
||||
# ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# """Embed images with SigLIP and language tokens with embedding layer to prepare
|
||||
# for SmolVLM transformer processing.
|
||||
# """
|
||||
# embs = []
|
||||
# pad_masks = []
|
||||
# att_masks = []
|
||||
# for _img_idx, (
|
||||
# img,
|
||||
# img_mask,
|
||||
# ) in enumerate(zip(images, img_masks, strict=False)):
|
||||
# if self.add_image_special_tokens:
|
||||
# image_start_token = (
|
||||
# self.vlm_with_expert.embed_language_tokens(
|
||||
# self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device)
|
||||
# )
|
||||
# .unsqueeze(0)
|
||||
# .expand(img.shape[0], -1, -1)
|
||||
# )
|
||||
# image_start_mask = torch.ones_like(
|
||||
# image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device
|
||||
# )
|
||||
# att_masks += [0] * (image_start_mask.shape[-1])
|
||||
# embs.append(image_start_token)
|
||||
# pad_masks.append(image_start_mask)
|
||||
|
||||
# img_emb = self.vlm_with_expert.embed_image(img)
|
||||
# img_emb = img_emb
|
||||
|
||||
# # Normalize image embeddings
|
||||
# img_emb_dim = img_emb.shape[-1]
|
||||
# img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
|
||||
|
||||
# bsize, num_img_embs = img_emb.shape[:2]
|
||||
# img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
||||
|
||||
# embs.append(img_emb)
|
||||
# pad_masks.append(img_mask)
|
||||
|
||||
# att_masks += [0] * (num_img_embs)
|
||||
# if self.add_image_special_tokens:
|
||||
# image_end_token = (
|
||||
# self.vlm_with_expert.embed_language_tokens(
|
||||
# self.image_end_token.to(device=self.vlm_with_expert.vlm.device)
|
||||
# )
|
||||
# .unsqueeze(0)
|
||||
# .expand(img.shape[0], -1, -1)
|
||||
# )
|
||||
# image_end_mask = torch.ones_like(
|
||||
# image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device
|
||||
# )
|
||||
# embs.append(image_end_token)
|
||||
# pad_masks.append(image_end_mask)
|
||||
# att_masks += [0] * (image_end_mask.shape[1])
|
||||
# lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens)
|
||||
# # Normalize language embeddings
|
||||
# lang_emb_dim = lang_emb.shape[-1]
|
||||
# lang_emb = lang_emb * math.sqrt(lang_emb_dim)
|
||||
|
||||
# embs.append(lang_emb)
|
||||
# pad_masks.append(lang_masks)
|
||||
|
||||
# num_lang_embs = lang_emb.shape[1]
|
||||
# att_masks += [0] * num_lang_embs
|
||||
|
||||
# state_emb = self.state_proj(state)
|
||||
# state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
|
||||
# embs.append(state_emb)
|
||||
# bsize = state_emb.shape[0]
|
||||
# device = state_emb.device
|
||||
|
||||
# states_seq_len = state_emb.shape[1]
|
||||
# state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=device)
|
||||
# pad_masks.append(state_mask)
|
||||
|
||||
# # Set attention masks so that image and language inputs do not attend to state or actions
|
||||
# att_masks += [1] * (states_seq_len)
|
||||
# embs = torch.cat(embs, dim=1)
|
||||
# pad_masks = torch.cat(pad_masks, dim=1)
|
||||
# att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||
# att_masks = att_masks[None, :]
|
||||
|
||||
# seq_len = pad_masks.shape[1]
|
||||
# if seq_len < self.prefix_length:
|
||||
# embs = pad_tensor(embs, self.prefix_length, pad_value=0)
|
||||
# pad_masks = pad_tensor(pad_masks, self.prefix_length, pad_value=0)
|
||||
# att_masks = pad_tensor(att_masks, self.prefix_length, pad_value=0)
|
||||
|
||||
# att_masks = att_masks.expand(bsize, -1)
|
||||
|
||||
# return embs, pad_masks, att_masks
|
||||
|
||||
# def embed_suffix(self, noisy_actions, timestep):
|
||||
# """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
# embs = []
|
||||
# pad_masks = []
|
||||
# att_masks = []
|
||||
|
||||
# # Fuse timestep + action information using an MLP
|
||||
# action_emb = self.action_in_proj(noisy_actions)
|
||||
# device = action_emb.device
|
||||
# bsize = action_emb.shape[0]
|
||||
# dtype = action_emb.dtype
|
||||
# # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
# time_emb = create_sinusoidal_pos_embedding(
|
||||
# timestep,
|
||||
# self.vlm_with_expert.expert_hidden_size,
|
||||
# self.config.min_period,
|
||||
# self.config.max_period,
|
||||
# device=device,
|
||||
# )
|
||||
# time_emb = time_emb.type(dtype=dtype)
|
||||
|
||||
# time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
# action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
# action_time_emb = self.action_time_mlp_in(action_time_emb)
|
||||
# action_time_emb = F.silu(action_time_emb) # swish == silu
|
||||
# action_time_emb = self.action_time_mlp_out(action_time_emb)
|
||||
|
||||
# # Add to input tokens
|
||||
# embs.append(action_time_emb)
|
||||
|
||||
# bsize, action_time_dim = action_time_emb.shape[:2]
|
||||
# action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
|
||||
# pad_masks.append(action_time_mask)
|
||||
|
||||
# # Set attention masks so that image, language and state inputs do not attend to action tokens
|
||||
# att_masks += [1] * self.config.chunk_size
|
||||
# embs = torch.cat(embs, dim=1)
|
||||
# pad_masks = torch.cat(pad_masks, dim=1)
|
||||
# att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
||||
# att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
# # added by jade
|
||||
# seq_len = pad_masks.shape[1]
|
||||
# if seq_len < self.config.chunk_size:
|
||||
# embs = pad_tensor(embs, self.config.chunk_size, pad_value=0)
|
||||
# pad_masks = pad_tensor(pad_masks, self.config.chunk_size, pad_value=0)
|
||||
# att_masks = pad_tensor(att_masks, self.config.chunk_size, pad_value=0)
|
||||
# return embs, pad_masks, att_masks
|
||||
|
||||
# def forward(
|
||||
# self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
||||
# ) -> Tensor:
|
||||
# """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||
# #added by jade
|
||||
# if actions.ndim == 2:
|
||||
# actions = actions[:, None, :].expand(-1, self.config.chunk_size, -1)
|
||||
# if noise is None:
|
||||
# noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
# if time is None:
|
||||
# time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
# time_expanded = time[:, None, None]
|
||||
# x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
# u_t = noise - actions
|
||||
# prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
# images, img_masks, lang_tokens, lang_masks, state=state
|
||||
# )
|
||||
# suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, time)
|
||||
|
||||
# pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
# att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
|
||||
# att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
# position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
# (_, suffix_out), _ = self.vlm_with_expert.forward(
|
||||
# attention_mask=att_2d_masks,
|
||||
# position_ids=position_ids,
|
||||
# past_key_values=None,
|
||||
# inputs_embeds=[prefix_embs, suffix_embs],
|
||||
# use_cache=False,
|
||||
# fill_kv_cache=False,
|
||||
# )
|
||||
# # suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
# suffix_out = suffix_out[:, -self.config.chunk_size:, :]
|
||||
# # Original openpi code, upcast attention output
|
||||
# suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
# v_t = self.action_out_proj(suffix_out)
|
||||
# losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
# return losses
|
||||
|
||||
# def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
|
||||
# """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
# bsize = state.shape[0]
|
||||
# device = state.device
|
||||
|
||||
# if noise is None:
|
||||
# actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim)
|
||||
# noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
# prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
# images, img_masks, lang_tokens, lang_masks, state=state
|
||||
# )
|
||||
# prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
# prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
# # Compute image and language key value cache
|
||||
# _, past_key_values = self.vlm_with_expert.forward(
|
||||
# attention_mask=prefix_att_2d_masks,
|
||||
# position_ids=prefix_position_ids,
|
||||
# past_key_values=None,
|
||||
# inputs_embeds=[prefix_embs, None],
|
||||
# use_cache=self.config.use_cache,
|
||||
# fill_kv_cache=True,
|
||||
# )
|
||||
# dt = -1.0 / self.config.num_steps
|
||||
# dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
# x_t = noise
|
||||
# time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
# while time >= -dt / 2:
|
||||
# expanded_time = time.expand(bsize)
|
||||
# v_t = self.denoise_step(
|
||||
# prefix_pad_masks,
|
||||
# past_key_values,
|
||||
# x_t,
|
||||
# expanded_time,
|
||||
# )
|
||||
# # Euler step
|
||||
# x_t += dt * v_t
|
||||
# time += dt
|
||||
# return x_t
|
||||
|
||||
# def denoise_step(
|
||||
# self,
|
||||
# prefix_pad_masks,
|
||||
# past_key_values,
|
||||
# x_t,
|
||||
# timestep,
|
||||
# ):
|
||||
# """Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||
# suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, timestep)
|
||||
|
||||
# suffix_len = suffix_pad_masks.shape[1]
|
||||
# batch_size = prefix_pad_masks.shape[0]
|
||||
# prefix_len = prefix_pad_masks.shape[1]
|
||||
# prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
||||
|
||||
# suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
|
||||
# full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
||||
# prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
||||
# position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
# outputs_embeds, _ = self.vlm_with_expert.forward(
|
||||
# attention_mask=full_att_2d_masks,
|
||||
# position_ids=position_ids,
|
||||
# past_key_values=past_key_values,
|
||||
# inputs_embeds=[None, suffix_embs],
|
||||
# use_cache=self.config.use_cache,
|
||||
# fill_kv_cache=False,
|
||||
# )
|
||||
# suffix_out = outputs_embeds[1]
|
||||
# suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
# suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
# v_t = self.action_out_proj(suffix_out)
|
||||
# return v_t
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
@@ -1028,7 +76,6 @@ from lerobot.policies.utils import (
|
||||
)
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
|
||||
# OBS_STATE = 'state'
|
||||
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
|
||||
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
|
||||
|
||||
@@ -1348,7 +395,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
|
||||
# Unpad actions
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
original_action_dim = 7
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
@@ -1892,4 +938,4 @@ class VLAFlowMatching(nn.Module):
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
v_t = self.action_out_proj(suffix_out)
|
||||
return v_t
|
||||
return v_t
|
||||
@@ -1 +0,0 @@
|
||||
c
|
||||
@@ -132,10 +132,6 @@ def rollout(
|
||||
|
||||
# Reset the policy and environments.
|
||||
policy.reset()
|
||||
# added by jade
|
||||
# for k in list(policy.config.input_features.keys()):
|
||||
# if k.startswith("observation.image"):
|
||||
# policy.config.input_features["observation.images." + k.split("observation.", 1)[1]] = policy.config.input_features.pop(k)
|
||||
observation, info = env.reset(seed=seeds)
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
@@ -171,26 +167,6 @@ def rollout(
|
||||
# Infer "task" from attributes of environments.
|
||||
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
||||
observation = add_envs_task(env, observation)
|
||||
# breakpoint()
|
||||
# observation = {
|
||||
# k.replace("observation.images.", "observation.") if k.startswith("observation.images.") else k: v
|
||||
# for k, v in observation.items()
|
||||
# # }
|
||||
# if "observation.image" in observation:
|
||||
# observation["image"] = observation.pop("observation.image").to(
|
||||
# device, non_blocking=device.type == "cuda"
|
||||
# )
|
||||
|
||||
# if "observation.image2" in observation:
|
||||
# observation["wrist_image"] = observation.pop("observation.image2").to(
|
||||
# device, non_blocking=device.type == "cuda"
|
||||
# )
|
||||
|
||||
# if "observation.state" in observation:
|
||||
# observation["state"] = observation.pop("observation.state").to(
|
||||
# device, non_blocking=device.type == "cuda"
|
||||
# )
|
||||
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
# Convert to CPU / numpy.
|
||||
@@ -550,15 +526,12 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
|
||||
logging.info("Making environment.")
|
||||
envs = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
breakpoint()
|
||||
|
||||
logging.info("Making policy.")
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
breakpoint()
|
||||
# policy, _ = load_smolvla(cfg.policy, "physical-intelligence/libero", policy)
|
||||
# rename "image" -> "observation.image"
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
|
||||
@@ -1,345 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from termcolor import colored
|
||||
from torch.amp import GradScaler
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.utils import cycle
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.scripts.eval import eval_policy, eval_policy_multitask
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_state,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
from lerobot.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
has_method,
|
||||
init_logging,
|
||||
)
|
||||
from lerobot.utils.wandb_utils import WandBLogger
|
||||
|
||||
|
||||
def update_policy(
|
||||
train_metrics: MetricsTracker,
|
||||
policy: PreTrainedPolicy,
|
||||
batch: Any,
|
||||
optimizer: Optimizer,
|
||||
grad_clip_norm: float,
|
||||
grad_scaler: GradScaler,
|
||||
lr_scheduler=None,
|
||||
use_amp: bool = False,
|
||||
lock=None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
start_time = time.perf_counter()
|
||||
device = get_device_from_parameters(policy)
|
||||
policy.train()
|
||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||
loss, output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
grad_scaler.scale(loss).backward()
|
||||
|
||||
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
|
||||
grad_scaler.unscale_(optimizer)
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(),
|
||||
grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
|
||||
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
||||
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
||||
with lock if lock is not None else nullcontext():
|
||||
grad_scaler.step(optimizer)
|
||||
# Updates the scale for next iteration.
|
||||
grad_scaler.update()
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Step through pytorch scheduler at every batch instead of epoch
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
|
||||
if has_method(policy, "update"):
|
||||
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
|
||||
policy.update()
|
||||
|
||||
train_metrics.loss = loss.item()
|
||||
train_metrics.grad_norm = grad_norm.item()
|
||||
train_metrics.lr = optimizer.param_groups[0]["lr"]
|
||||
train_metrics.update_s = time.perf_counter() - start_time
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
|
||||
"""Recreate normalization layers with dataset stats if missing (Adil's workaround)."""
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
if not hasattr(dataset_meta, "stats") or not dataset_meta.stats:
|
||||
print("⚠️ Dataset has no stats, skipping normalization injection.")
|
||||
return
|
||||
|
||||
stats = {}
|
||||
for key, stat_dict in dataset_meta.stats.items():
|
||||
stats[key] = {
|
||||
stat_type: torch.as_tensor(stat_array) if isinstance(stat_array, np.ndarray) else stat_array
|
||||
for stat_type, stat_array in stat_dict.items()
|
||||
}
|
||||
|
||||
normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats)
|
||||
normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats)
|
||||
unnormalize_outputs = Unnormalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, stats
|
||||
)
|
||||
|
||||
policy.normalize_inputs = normalize_inputs
|
||||
policy.normalize_targets = normalize_targets
|
||||
policy.unnormalize_outputs = unnormalize_outputs
|
||||
|
||||
print("✅ Normalization layers injected with dataset stats.")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
cfg.validate()
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
|
||||
if cfg.wandb.enable and cfg.wandb.project:
|
||||
wandb_logger = WandBLogger(cfg)
|
||||
else:
|
||||
wandb_logger = None
|
||||
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
||||
|
||||
if cfg.seed is not None:
|
||||
set_seed(cfg.seed)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||
eval_env = None
|
||||
if cfg.eval_freq > 0 and cfg.env is not None:
|
||||
logging.info("Creating env")
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
logging.info("Creating policy")
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
)
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
if cfg.resume:
|
||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
if cfg.env is not None:
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
||||
logging.info(f"{dataset.num_episodes=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.episode_data_index,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=cfg.num_workers,
|
||||
batch_size=cfg.batch_size,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=False,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
train_metrics = {
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||
"lr": AverageMeter("lr", ":0.1e"),
|
||||
"update_s": AverageMeter("updt_s", ":.3f"),
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
||||
}
|
||||
|
||||
train_tracker = MetricsTracker(
|
||||
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
|
||||
)
|
||||
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
train_tracker, output_dict = update_policy(
|
||||
train_tracker,
|
||||
policy,
|
||||
batch,
|
||||
optimizer,
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.policy.use_amp,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
step += 1
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
||||
|
||||
if is_log_step:
|
||||
logging.info(train_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
|
||||
if cfg.env and is_eval_step:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||
):
|
||||
if cfg.env.multitask_eval:
|
||||
eval_info = eval_policy_multitask(
|
||||
eval_env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||
)
|
||||
aggregated = eval_info["overall"]["aggregated"]
|
||||
# Print per-suite stats, log?
|
||||
for task_group, task_group_info in eval_info.items():
|
||||
if task_group == "overall":
|
||||
continue # Skip the overall stats since we already printed it
|
||||
print(f"\nAggregated Metrics for {task_group}:")
|
||||
print(task_group_info["aggregated"])
|
||||
else:
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
aggregated = eval_info["aggregated"]
|
||||
|
||||
eval_metrics = {
|
||||
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
||||
"pc_success": AverageMeter("success", ":.1f"),
|
||||
"eval_s": AverageMeter("eval_s", ":.3f"),
|
||||
}
|
||||
eval_tracker = MetricsTracker(
|
||||
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
|
||||
)
|
||||
eval_tracker.eval_s = aggregated.pop("eval_s")
|
||||
eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
|
||||
eval_tracker.pc_success = aggregated.pop("pc_success")
|
||||
logging.info(eval_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
||||
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
||||
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
|
||||
if eval_env:
|
||||
if cfg.env.multitask_eval:
|
||||
for _task_group, envs_dict in eval_env.items():
|
||||
for _idx, env in envs_dict.items():
|
||||
env.close()
|
||||
else:
|
||||
eval_env.close()
|
||||
logging.info("End of training")
|
||||
|
||||
if cfg.policy.push_to_hub:
|
||||
policy.push_model_to_hub(cfg)
|
||||
|
||||
|
||||
def main():
|
||||
init_logging()
|
||||
train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,366 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import time
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import set_seed as accelerate_set_seed
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_state,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
from lerobot.utils.utils import (
|
||||
format_big_number,
|
||||
has_method,
|
||||
init_logging,
|
||||
)
|
||||
|
||||
|
||||
def update_policy(
|
||||
train_metrics: MetricsTracker,
|
||||
policy: PreTrainedPolicy,
|
||||
batch: Any,
|
||||
optimizer: Optimizer,
|
||||
grad_clip_norm: float,
|
||||
accelerator: Accelerator,
|
||||
lr_scheduler=None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
start_time = time.perf_counter()
|
||||
policy.train()
|
||||
|
||||
# Use accelerator's autocast context if mixed precision is enabled
|
||||
with accelerator.autocast():
|
||||
loss, output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
|
||||
# Use accelerator for backward pass
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Gradient clipping - accelerator handles unscaling automatically
|
||||
if accelerator.sync_gradients and grad_clip_norm > 0:
|
||||
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
||||
else:
|
||||
grad_norm = torch.tensor(0.0)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step() if lr_scheduler is not None else None
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Update policy-specific buffers if needed
|
||||
if has_method(policy, "update"):
|
||||
policy.update()
|
||||
|
||||
# Gather metrics across all processes
|
||||
loss_value = accelerator.gather(loss.detach()).mean().item()
|
||||
grad_norm_value = accelerator.gather(grad_norm).mean().item()
|
||||
|
||||
train_metrics.loss = loss_value
|
||||
train_metrics.grad_norm = grad_norm_value
|
||||
train_metrics.lr = optimizer.param_groups[0]["lr"]
|
||||
train_metrics.update_s = time.perf_counter() - start_time
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
cfg.validate()
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
|
||||
# Initialize accelerator
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
# added by jade 2 lines
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
|
||||
accelerator = Accelerator(..., kwargs_handlers=[ddp_kwargs])
|
||||
|
||||
from lerobot.utils.wandb_utils import cfg_to_group, get_wandb_run_id_from_filesystem
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
accelerator = Accelerator(
|
||||
mixed_precision="bf16" if cfg.policy.use_amp else "no",
|
||||
gradient_accumulation_steps=cfg.policy.gradient_accumulation_steps,
|
||||
log_with="wandb" if cfg.wandb.enable else None,
|
||||
kwargs_handlers=[ddp_kwargs],
|
||||
project_dir=cfg.output_dir,
|
||||
)
|
||||
|
||||
accelerator.init_trackers(
|
||||
project_name=cfg.wandb.project,
|
||||
init_kwargs={
|
||||
"wandb": {
|
||||
"entity": cfg.wandb.entity,
|
||||
"name": cfg.job_name,
|
||||
"notes": cfg.wandb.notes,
|
||||
"tags": cfg_to_group(cfg, return_list=True),
|
||||
"dir": cfg.output_dir,
|
||||
"config": cfg.to_dict(),
|
||||
"save_code": False,
|
||||
"job_type": "train_eval",
|
||||
"mode": cfg.wandb.mode if cfg.wandb.mode in ["online", "offline", "disabled"] else "online",
|
||||
"resume": "must" if cfg.resume else None,
|
||||
"id": cfg.wandb.run_id
|
||||
if cfg.wandb.run_id
|
||||
else (get_wandb_run_id_from_filesystem(cfg.output_dir) if cfg.resume else None),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Set seed for reproducibility
|
||||
if cfg.seed is not None:
|
||||
accelerate_set_seed(cfg.seed)
|
||||
|
||||
# Setup device - accelerator handles device placement
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# Create dataset
|
||||
if accelerator.is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
print("c")
|
||||
# Create evaluation environment (only on main process)
|
||||
eval_env = None
|
||||
if cfg.eval_freq > 0 and cfg.env is not None and accelerator.is_main_process:
|
||||
logging.info("Creating env")
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
# Create policy
|
||||
if accelerator.is_main_process:
|
||||
logging.info("Creating policy")
|
||||
|
||||
# Use accelerator's device instead of cfg.policy.device
|
||||
with accelerator.main_process_first():
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
)
|
||||
|
||||
# Create optimizer and scheduler
|
||||
if accelerator.is_main_process:
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
||||
step = 0 # number of policy updates
|
||||
|
||||
if cfg.resume:
|
||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||
|
||||
# Prepare dataloader
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.episode_data_index,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=cfg.num_workers,
|
||||
batch_size=cfg.batch_size,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
pin_memory=True,
|
||||
drop_last=True, # Important for distributed training
|
||||
)
|
||||
|
||||
# Prepare for distributed training
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
policy, optimizer, dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# Log training info (only on main process)
|
||||
if accelerator.is_main_process:
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
if cfg.env is not None:
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
||||
logging.info(f"{dataset.num_episodes=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
logging.info(f"Number of processes: {accelerator.num_processes}")
|
||||
logging.info(f"Device: {accelerator.device}")
|
||||
logging.info(f"Mixed precision: {accelerator.mixed_precision}")
|
||||
|
||||
# Create metrics trackers
|
||||
train_metrics = {
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||
"lr": AverageMeter("lr", ":0.1e"),
|
||||
"update_s": AverageMeter("updt_s", ":.3f"),
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
||||
}
|
||||
|
||||
train_tracker = MetricsTracker(
|
||||
cfg.batch_size * accelerator.num_processes, # Account for all processes
|
||||
dataset.num_frames,
|
||||
dataset.num_episodes,
|
||||
train_metrics,
|
||||
initial_step=step,
|
||||
)
|
||||
|
||||
# Training loop
|
||||
policy.train()
|
||||
if accelerator.is_main_process:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
|
||||
# Create iterator from dataloader
|
||||
dl_iter = iter(dataloader)
|
||||
|
||||
for current_step in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
# Get next batch, cycling through dataloader if needed
|
||||
try:
|
||||
batch = next(dl_iter)
|
||||
print("data laoder batch keys: ", batch.keys())
|
||||
breakpoint()
|
||||
except StopIteration:
|
||||
dl_iter = iter(dataloader)
|
||||
batch = next(dl_iter)
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
# Update policy
|
||||
train_tracker, output_dict = update_policy(
|
||||
train_tracker,
|
||||
policy,
|
||||
batch,
|
||||
optimizer,
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
accelerator,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
|
||||
# Increment step counter
|
||||
step += 1
|
||||
train_tracker.step()
|
||||
|
||||
# Determine if we should log, save, or evaluate
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
||||
|
||||
# Logging (only on main process)
|
||||
if is_log_step and accelerator.is_main_process:
|
||||
logging.info(train_tracker)
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
for k, v in wandb_log_dict.items():
|
||||
accelerator.log({f"{'train'}/{k}": v}, step=step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
# Checkpointing (only on main process)
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
# ✅ all processes wait here
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if accelerator.is_main_process:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
|
||||
unwrapped_policy = accelerator.unwrap_model(policy)
|
||||
save_checkpoint(checkpoint_dir, step, cfg, unwrapped_policy, optimizer, lr_scheduler)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
|
||||
# ✅ all processes sync again after saving
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# if wandb_logger:
|
||||
# wandb_logger.log_policy(checkpoint_dir)
|
||||
|
||||
# Evaluation (only on main process)
|
||||
if cfg.env and is_eval_step and accelerator.is_main_process:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
|
||||
# Unwrap model for evaluation
|
||||
unwrapped_policy = accelerator.unwrap_model(policy)
|
||||
unwrapped_policy.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
unwrapped_policy,
|
||||
cfg.eval.n_episodes,
|
||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
|
||||
eval_metrics = {
|
||||
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
||||
"pc_success": AverageMeter("success", ":.1f"),
|
||||
"eval_s": AverageMeter("eval_s", ":.3f"),
|
||||
}
|
||||
eval_tracker = MetricsTracker(
|
||||
cfg.batch_size * accelerator.num_processes,
|
||||
dataset.num_frames,
|
||||
dataset.num_episodes,
|
||||
eval_metrics,
|
||||
initial_step=step,
|
||||
)
|
||||
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
|
||||
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
|
||||
eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success")
|
||||
logging.info(eval_tracker)
|
||||
|
||||
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
||||
for k, v in wandb_log_dict.items():
|
||||
accelerator.log({f"{'eval'}/{k}": v}, step=step)
|
||||
|
||||
# Set back to training mode
|
||||
policy.train()
|
||||
|
||||
# Wait for all processes to finish
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Cleanup
|
||||
if eval_env and accelerator.is_main_process:
|
||||
eval_env.close()
|
||||
|
||||
if accelerator.is_main_process:
|
||||
logging.info("End of training")
|
||||
accelerator.end_training() # added by jade
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
train()
|
||||
Reference in New Issue
Block a user