merge all modules files into the main modeling file

This commit is contained in:
Bryson Jones
2025-12-10 13:44:30 -08:00
parent b92dc82ddd
commit 3b2a4f548c
6 changed files with 700 additions and 1003 deletions

View File

@@ -1 +1,21 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_multi_task_dit import MultiTaskDiTConfig
from .modeling_multi_task_dit import MultiTaskDiTPolicy
from .processor_multi_task_dit import make_multi_task_dit_pre_post_processors
__all__ = ["MultiTaskDiTConfig", "MultiTaskDiTPolicy", "make_multi_task_dit_pre_post_processors"]

View File

@@ -25,18 +25,654 @@ References:
- https://brysonkjones.substack.com/p/dissecting-multitask-diffusion-transformer-policy
"""
import math
from abc import ABC, abstractmethod
from collections import deque
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.policies.multi_task_dit.modules.objectives import DiffusionObjective, FlowMatchingObjective
from lerobot.policies.multi_task_dit.modules.observation_encoder import ObservationEncoder
from lerobot.policies.multi_task_dit.modules.transformer import DiffusionTransformer
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_IMAGES
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
# -- Observation Encoders --
class CLIPVisionEncoder(nn.Module):
"""CLIP vision encoder using the CLS token for global image representation."""
def __init__(self, model_name: str):
super().__init__()
self.model_name = model_name
self.model = CLIPVisionModel.from_pretrained(self.model_name)
self.num_non_spatial_tokens = 1
self.embed_dim = self.model.config.hidden_size
def forward(self, x: Tensor) -> Tensor:
"""Encode RGB image to CLS token."""
outputs = self.model(pixel_values=x, output_hidden_states=False)
cls_token = outputs.last_hidden_state[:, 0]
b, embed_dim = cls_token.shape
return cls_token.reshape(b, embed_dim, 1, 1)
def get_output_shape(self) -> tuple:
return (self.embed_dim, 1, 1)
class CLIPTextEncoder(nn.Module):
"""CLIP text encoder with frozen weights and a learnable projection layer."""
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
super().__init__()
self.model_name = model_name
self.projection_dim = projection_dim
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
self.text_encoder = CLIPTextModel.from_pretrained(model_name)
for param in self.text_encoder.parameters():
param.requires_grad = False
self.text_embed_dim = self.text_encoder.config.hidden_size
self.projection = nn.Linear(self.text_embed_dim, projection_dim)
def forward(self, text: str | list[str]) -> Tensor:
"""Encode text to feature vectors."""
if isinstance(text, str):
text = [text]
text_inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
text_inputs = {k: v.to(next(self.parameters()).device) for k, v in text_inputs.items()}
with torch.no_grad():
outputs = self.text_encoder(**text_inputs)
clip_features = outputs.pooler_output
return self.projection(clip_features)
class ObservationEncoder(nn.Module):
"""Handles all observation processing for the conditioning vector."""
def __init__(self, config):
super().__init__()
self.config = config
self._setup_preprocessing(config)
if config.image_features:
self.num_cameras = len(config.image_features)
self.camera_names = list(config.image_features.keys())
if config.use_separate_encoder_per_camera:
self.vision_encoders = nn.ModuleList(
[CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names]
)
self.vision_encoder = None
else:
self.vision_encoder = CLIPVisionEncoder(model_name=config.vision_encoder_name)
self.vision_encoders = None
else:
self.vision_encoder = None
self.vision_encoders = None
self.camera_names = []
self.num_cameras = 0
if hasattr(config, "robot_state_feature") and config.robot_state_feature:
self.robot_state_dim = config.robot_state_feature.shape[0]
else:
self.robot_state_dim = 0
if hasattr(config, "env_state_feature") and config.env_state_feature:
self.env_state_dim = config.env_state_feature.shape[0]
else:
self.env_state_dim = 0
self.text_dim = config.hidden_dim
self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
self._setup_vector_output()
def _apply_preprocessing(self, images: Tensor) -> Tensor:
if self.do_resize:
images = self.resize(images)
if self.do_crop:
images = self.maybe_random_crop(images) if self.training else self.center_crop(images)
return images
def _setup_preprocessing(self, config):
if config.image_resize_shape is not None:
self.do_resize = True
self.resize = torchvision.transforms.Resize(
size=config.image_resize_shape,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True,
)
else:
self.do_resize = False
if config.image_crop_shape is not None:
self.do_crop = True
self.center_crop = torchvision.transforms.CenterCrop(config.image_crop_shape)
if config.image_crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.image_crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
def _setup_vector_output(self):
total_dim = 0
if self.vision_encoder is not None or self.vision_encoders is not None:
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders))
feature_map_shape = encoder_to_check.get_output_shape()
c, h, w = feature_map_shape
spatial_feature_dim = c * h * w
total_dim += spatial_feature_dim * self.num_cameras
total_dim += self.robot_state_dim
total_dim += self.env_state_dim
total_dim += self.text_dim
self.conditioning_dim = total_dim * self.config.n_obs_steps
def encode(self, batch: dict) -> Tensor:
"""Encode observations to vector format."""
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
conditioning_feats = []
conditioning_feats.append(batch[OBS_STATE])
if self.vision_encoder is not None or self.vision_encoders is not None:
images = batch[OBS_IMAGES]
if len(images.shape) == 5:
images = images.unsqueeze(1)
if self.config.use_separate_encoder_per_camera:
camera_features = []
for cam_idx in range(self.num_cameras):
cam_images = images[:, :, cam_idx]
cam_images_flat = einops.rearrange(cam_images, "b s c h w -> (b s) c h w")
cam_images_flat = self._apply_preprocessing(cam_images_flat)
cam_features = self.vision_encoders[cam_idx](cam_images_flat)
cam_visual_features = cam_features.flatten(start_dim=1)
cam_features_reshaped = einops.rearrange(
cam_visual_features, "(b s) f -> b s f", b=batch_size, s=n_obs_steps
)
camera_features.append(cam_features_reshaped)
img_features = torch.cat(camera_features, dim=-1)
conditioning_feats.append(img_features)
else:
images_flat = einops.rearrange(images, "b s n c h w -> (b s n) c h w")
images_flat = self._apply_preprocessing(images_flat)
visual_features = self.vision_encoder(images_flat).flatten(start_dim=1)
img_features = einops.rearrange(
visual_features, "(b s n) f -> b s (n f)", b=batch_size, s=n_obs_steps, n=self.num_cameras
)
conditioning_feats.append(img_features)
if self.env_state_dim > 0 and OBS_ENV_STATE in batch:
conditioning_feats.append(batch[OBS_ENV_STATE])
if self.text_encoder is not None and "task" in batch:
text_features = self.text_encoder(batch["task"])
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1)
conditioning_feats.append(text_features)
combined_features = torch.cat(conditioning_feats, dim=-1)
return combined_features.flatten(start_dim=1)
# -- Transformer Components --
def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
"""Modulate input with shift and scale for AdaLN-Zero."""
return x * (1 + scale) + shift
class SinusoidalPosEmb(nn.Module):
"""Sinusoidal positional embeddings for timesteps."""
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class RotaryPositionalEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE) for transformers."""
def __init__(self, head_dim: int, max_seq_len: int = 512, base: float = 10000.0):
super().__init__()
assert head_dim % 2 == 0, "head_dim must be even for RoPE"
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.base = base
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._precompute_cache(max_seq_len)
def _precompute_cache(self, seq_len: int):
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("_sin_cached", emb.sin()[None, None, :, :], persistent=False)
def _rotate_half(self, x: Tensor) -> Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def forward(self, q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]:
seq_len = q.shape[2]
if seq_len > self.max_seq_len:
raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}.")
cos = self._cos_cached[:, :, :seq_len, :].to(q.dtype)
sin = self._sin_cached[:, :, :seq_len, :].to(q.dtype)
q_rotated = (q * cos) + (self._rotate_half(q) * sin)
k_rotated = (k * cos) + (self._rotate_half(k) * sin)
return q_rotated, k_rotated
class RoPEAttention(nn.Module):
"""Multi-head self-attention with Rotary Position Embedding (RoPE)."""
def __init__(
self,
hidden_size: int,
num_heads: int,
dropout: float = 0.0,
max_seq_len: int = 512,
rope_base: float = 10000.0,
):
super().__init__()
assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim**-0.5
self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.rope = RotaryPositionalEmbedding(head_dim=self.head_dim, max_seq_len=max_seq_len, base=rope_base)
def forward(self, x: Tensor) -> Tensor:
B, T, _ = x.shape # noqa: N806
qkv = self.qkv_proj(x)
qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q, k = self.rope(q, k)
attn_out = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.dropout.p if isinstance(self.dropout, nn.Dropout) and self.training else 0.0,
)
attn_out = attn_out.transpose(1, 2).reshape(B, T, self.hidden_size)
return self.out_proj(attn_out)
class TransformerBlock(nn.Module):
"""DiT-style transformer block with AdaLN-Zero."""
def __init__(
self,
hidden_size: int = 128,
num_heads: int = 4,
num_features: int = 128,
dropout: float = 0.0,
use_rope: bool = False,
max_seq_len: int = 512,
rope_base: float = 10000.0,
):
super().__init__()
self.use_rope = use_rope
if use_rope:
self.attn = RoPEAttention(
hidden_size=hidden_size,
num_heads=num_heads,
dropout=dropout,
max_seq_len=max_seq_len,
rope_base=rope_base,
)
else:
self.multihead_attn = nn.MultiheadAttention(
hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout
)
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.GELU(approximate="tanh"),
nn.Linear(hidden_size * 4, hidden_size),
)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(num_features, 6 * hidden_size, bias=True))
def forward(self, x: Tensor, features: Tensor) -> Tensor:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
features
).chunk(6, dim=1)
attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1))
if self.use_rope:
attn_out = self.attn(attn_input)
else:
attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input)
x = x + gate_msa.unsqueeze(1) * attn_out
mlp_input = modulate(self.norm2(x), shift_mlp.unsqueeze(1), scale_mlp.unsqueeze(1))
mlp_out = self.mlp(mlp_input)
x = x + gate_mlp.unsqueeze(1) * mlp_out
return x
class DiffusionTransformer(nn.Module):
"""Transformer-based diffusion noise prediction model."""
def __init__(self, config, conditioning_dim: int):
super().__init__()
self.config = config
self.conditioning_dim = conditioning_dim
self.action_dim = config.action_feature.shape[0]
self.horizon = config.horizon
self.hidden_size = config.hidden_dim
self.num_layers = config.num_layers
self.num_heads = config.num_heads
self.dropout = config.dropout
self.use_rope = config.use_rope
self.timestep_embed_dim = config.timestep_embed_dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(self.timestep_embed_dim),
nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim),
nn.GELU(),
nn.Linear(2 * self.timestep_embed_dim, self.timestep_embed_dim),
nn.GELU(),
)
self.cond_dim = self.timestep_embed_dim + conditioning_dim
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
if config.use_positional_encoding:
self.pos_embedding = nn.Parameter(
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
)
else:
self.pos_embedding = None
self.transformer_blocks = nn.ModuleList(
[
TransformerBlock(
hidden_size=self.hidden_size,
num_heads=self.num_heads,
num_features=self.cond_dim,
dropout=self.dropout,
use_rope=self.use_rope,
max_seq_len=self.horizon,
rope_base=config.rope_base,
)
for _ in range(self.num_layers)
]
)
self.output_proj = nn.Linear(self.hidden_size, self.action_dim)
self._initialize_weights()
def _initialize_weights(self):
for block in self.transformer_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
def forward(self, x: Tensor, timestep: Tensor, conditioning_vec: Tensor) -> Tensor:
_, seq_len, _ = x.shape
timestep_features = self.time_mlp(timestep)
cond_features = torch.cat([timestep_features, conditioning_vec], dim=-1)
hidden_seq = self.input_proj(x)
if self.pos_embedding is not None:
hidden_seq = hidden_seq + self.pos_embedding[:, :seq_len, :]
for block in self.transformer_blocks:
hidden_seq = block(hidden_seq, cond_features)
return self.output_proj(hidden_seq)
# -- Objectives --
class BaseObjective(ABC):
"""Base class for objectives used in Multi-Task DiT policy."""
def __init__(self, config, action_dim: int, horizon: int):
self.config = config
self.action_dim = action_dim
self.horizon = horizon
@abstractmethod
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
pass
@abstractmethod
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
pass
class DiffusionObjective(BaseObjective):
"""Standard diffusion (DDPM/DDIM) objective implementation."""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__(config, action_dim, horizon)
self.do_mask_loss_for_padding = do_mask_loss_for_padding
scheduler_kwargs = {
"num_train_timesteps": config.num_train_timesteps,
"beta_start": config.beta_start,
"beta_end": config.beta_end,
"beta_schedule": config.beta_schedule,
"clip_sample": config.clip_sample,
"clip_sample_range": config.clip_sample_range,
"prediction_type": config.prediction_type,
}
if config.noise_scheduler_type == "DDPM":
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
elif config.noise_scheduler_type == "DDIM":
self.noise_scheduler = DDIMScheduler(**scheduler_kwargs)
else:
raise ValueError(f"Unsupported noise scheduler type {config.noise_scheduler_type}")
self.num_inference_steps = (
config.num_inference_steps
if config.num_inference_steps is not None
else self.noise_scheduler.config.num_train_timesteps
)
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
clean_actions = batch["action"]
noise = torch.randn_like(clean_actions)
timesteps = torch.randint(
low=0,
high=self.noise_scheduler.config.num_train_timesteps,
size=(clean_actions.shape[0],),
device=clean_actions.device,
).long()
noisy_actions = self.noise_scheduler.add_noise(clean_actions, noise, timesteps)
prediction_type = self.noise_scheduler.config.prediction_type
if prediction_type == "epsilon":
target = noise
elif prediction_type == "sample":
target = clean_actions
else:
raise ValueError(f"Unsupported prediction type: {prediction_type}")
predicted = model(noisy_actions, timesteps, conditioning_vec=conditioning_vec)
loss = F.mse_loss(predicted, target, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_actions = ~batch["action_is_pad"]
loss = loss * valid_actions.unsqueeze(-1)
return loss.mean()
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
sample = torch.randn(
size=(batch_size, self.horizon, self.action_dim),
dtype=dtype,
device=device,
)
self.noise_scheduler.set_timesteps(self.num_inference_steps)
for t in self.noise_scheduler.timesteps:
model_output = model(
sample,
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
conditioning_vec=conditioning_vec,
)
sample = self.noise_scheduler.step(model_output, t, sample).prev_sample
return sample
class FlowMatchingObjective(BaseObjective):
"""Flow matching objective: trains a model to predict velocity fields."""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__(config, action_dim, horizon)
self.do_mask_loss_for_padding = do_mask_loss_for_padding
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor:
if self.config.timestep_sampling_strategy == "uniform":
return torch.rand(batch_size, device=device)
elif self.config.timestep_sampling_strategy == "beta":
beta_dist = torch.distributions.Beta(
self.config.timestep_sampling_alpha, self.config.timestep_sampling_beta
)
u = beta_dist.sample((batch_size,)).to(device)
return self.config.timestep_sampling_s * (1.0 - u)
else:
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}")
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
data = batch["action"]
batch_size = data.shape[0]
device = data.device
noise = torch.randn_like(data)
t = self._sample_timesteps(batch_size, device)
t_expanded = t.view(-1, 1, 1)
x_t = t_expanded * data + (1 - (1 - self.config.sigma_min) * t_expanded) * noise
target_velocity = data - (1 - self.config.sigma_min) * noise
predicted_velocity = model(x_t, t, conditioning_vec=conditioning_vec)
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_mask = ~batch["action_is_pad"]
loss = loss * valid_mask.unsqueeze(-1)
return loss.mean()
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
x = torch.randn((batch_size, self.horizon, self.action_dim), dtype=dtype, device=device)
num_steps = self.config.num_integration_steps
time_grid = torch.linspace(0, 1, num_steps + 1, device=device)
if self.config.integration_method == "euler":
x = self._euler_integrate(model, x, time_grid, conditioning_vec)
elif self.config.integration_method == "rk4":
x = self._rk4_integrate(model, x, time_grid, conditioning_vec)
else:
raise ValueError(f"Unknown integration method: {self.config.integration_method}")
return x
def _euler_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor:
x = x_init
for i in range(len(time_grid) - 1):
t_scalar = time_grid[i].item()
dt = (time_grid[i + 1] - time_grid[i]).item()
t_batch = torch.full((x.shape[0],), t_scalar, dtype=x.dtype, device=x.device)
with torch.no_grad():
velocity = model(x, t_batch, conditioning_vec=conditioning_vec)
x = x + dt * velocity
return x
def _rk4_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor:
x = x_init
def dynamics(x_val: Tensor, t_scalar: float) -> Tensor:
t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device)
with torch.no_grad():
return model(x_val, t_batch, conditioning_vec=conditioning_vec)
for i in range(len(time_grid) - 1):
t = time_grid[i].item()
dt = (time_grid[i + 1] - time_grid[i]).item()
k1 = dynamics(x, t)
k2 = dynamics(x + dt * k1 / 2, t + dt / 2)
k3 = dynamics(x + dt * k2 / 2, t + dt / 2)
k4 = dynamics(x + dt * k3, t + dt)
x = x + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
return x
# -- Policy --
class MultiTaskDiTPolicy(PreTrainedPolicy):
@@ -77,7 +713,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
self.reset()
def get_optim_params(self) -> list:
"""Returns parameter groups with different learning rates for vision vs non-vision parameters."""
"""Returns parameter groups with different learning rates for vision vs non-vision parameters"""
non_vision_params = []
vision_encoder_params = []
@@ -110,7 +746,7 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
return actions[:, start_idx:end_idx]
def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`."""
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
@@ -122,13 +758,47 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
# Always include task queue for text conditioning
self._queues["task"] = deque(maxlen=self.config.n_obs_steps)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""Run the batch through the model and compute the loss for training or validation."""
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations"""
self.eval()
original_batch_keys = set(batch.keys())
new_batch = {}
for k in self._queues:
if k in original_batch_keys:
if self._queues[k] and isinstance(self._queues[k][-1][0], str):
new_batch[k] = self._queues[k][-1]
else:
queue_values = list(self._queues[k])
new_batch[k] = torch.stack(queue_values, dim=1)
batch = new_batch
actions = self._generate_actions(batch)
return actions
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations"""
if ACTION in batch:
batch.pop(ACTION)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch = dict(batch)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
self._queues = populate_queues(self._queues, batch)
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1))
return self._queues[ACTION].popleft()
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""Run the batch through the model and compute the loss for training"""
if self.config.image_features:
batch = dict(batch)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
n_obs_steps = batch["observation.state"].shape[1]
@@ -140,43 +810,3 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)
return loss, None
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()
original_batch_keys = set(batch.keys())
new_batch = {}
for k in self._queues:
if k in original_batch_keys:
if self._queues[k] and isinstance(self._queues[k][-1][0], str):
# for task description which is a list of strings
new_batch[k] = self._queues[k][-1]
else:
queue_values = list(self._queues[k])
new_batch[k] = torch.stack(queue_values, dim=1)
batch = new_batch
actions = self._generate_actions(batch)
return actions
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
This method manages caching of observations and actions by generating an action chunk
and returning actions from the cache until it's depleted.
"""
if ACTION in batch:
batch.pop(ACTION)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
self._queues = populate_queues(self._queues, batch)
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1))
return self._queues[ACTION].popleft()

View File

@@ -1,248 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Objective implementations for Multi-Task DiT policy.
- DiffusionObjective: Standard DDPM/DDIM diffusion
- FlowMatchingObjective: Flow matching with ODE integration
"""
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor
class BaseObjective(ABC):
"""Base class for objectives used in Multi-Task DiT policy."""
def __init__(self, config, action_dim: int, horizon: int):
self.config = config
self.action_dim = action_dim
self.horizon = horizon
@abstractmethod
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
"""Compute training loss."""
pass
@abstractmethod
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
"""Generate action samples conditioned on observations."""
pass
class DiffusionObjective(BaseObjective):
"""Standard diffusion (DDPM/DDIM) objective implementation."""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__(config, action_dim, horizon)
self.do_mask_loss_for_padding = do_mask_loss_for_padding
# Build noise scheduler
scheduler_kwargs = {
"num_train_timesteps": config.num_train_timesteps,
"beta_start": config.beta_start,
"beta_end": config.beta_end,
"beta_schedule": config.beta_schedule,
"clip_sample": config.clip_sample,
"clip_sample_range": config.clip_sample_range,
"prediction_type": config.prediction_type,
}
if config.noise_scheduler_type == "DDPM":
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
elif config.noise_scheduler_type == "DDIM":
self.noise_scheduler = DDIMScheduler(**scheduler_kwargs)
else:
raise ValueError(f"Unsupported noise scheduler type {config.noise_scheduler_type}")
# Inference steps default to training steps if not provided
self.num_inference_steps = (
config.num_inference_steps
if config.num_inference_steps is not None
else self.noise_scheduler.config.num_train_timesteps
)
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
clean_actions = batch["action"]
noise = torch.randn_like(clean_actions)
timesteps = torch.randint(
low=0,
high=self.noise_scheduler.config.num_train_timesteps,
size=(clean_actions.shape[0],),
device=clean_actions.device,
).long()
noisy_actions = self.noise_scheduler.add_noise(clean_actions, noise, timesteps)
# Target depends on prediction type
prediction_type = self.noise_scheduler.config.prediction_type
if prediction_type == "epsilon":
target = noise
elif prediction_type == "sample":
target = clean_actions
else:
raise ValueError(f"Unsupported prediction type: {prediction_type}")
predicted = model(noisy_actions, timesteps, conditioning_vec=conditioning_vec)
loss = F.mse_loss(predicted, target, reduction="none")
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_actions = ~batch["action_is_pad"] # (B, T)
loss = loss * valid_actions.unsqueeze(-1)
return loss.mean()
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
sample = torch.randn(
size=(batch_size, self.horizon, self.action_dim),
dtype=dtype,
device=device,
)
self.noise_scheduler.set_timesteps(self.num_inference_steps)
for t in self.noise_scheduler.timesteps:
model_output = model(
sample,
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
conditioning_vec=conditioning_vec,
)
sample = self.noise_scheduler.step(model_output, t, sample).prev_sample
return sample
class FlowMatchingObjective(BaseObjective):
"""Flow matching objective: trains a model to predict velocity fields."""
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
super().__init__(config, action_dim, horizon)
self.do_mask_loss_for_padding = do_mask_loss_for_padding
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor:
"""Sample timesteps according to configured strategy."""
if self.config.timestep_sampling_strategy == "uniform":
return torch.rand(batch_size, device=device)
elif self.config.timestep_sampling_strategy == "beta":
# Sample u ~ Beta(α, β) then transform: t = s(1-u)
# This emphasizes t near 0 (high noise) when α > β
beta_dist = torch.distributions.Beta(
self.config.timestep_sampling_alpha, self.config.timestep_sampling_beta
)
u = beta_dist.sample((batch_size,)).to(device)
return self.config.timestep_sampling_s * (1.0 - u)
else:
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}")
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
"""Compute flow matching training loss."""
data = batch["action"] # Clean action sequences (B, T, D)
batch_size = data.shape[0]
device = data.device
noise = torch.randn_like(data)
t = self._sample_timesteps(batch_size, device)
t_expanded = t.view(-1, 1, 1) # (B, 1, 1) for broadcasting
x_t = t_expanded * data + (1 - (1 - self.config.sigma_min) * t_expanded) * noise
# The velocity we want the model to learn: v = data - (1-σ)·noise
target_velocity = data - (1 - self.config.sigma_min) * noise
predicted_velocity = model(x_t, t, conditioning_vec=conditioning_vec)
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
# Optionally mask padded actions
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
valid_mask = ~batch["action_is_pad"] # (B, T)
loss = loss * valid_mask.unsqueeze(-1) # (B, T, D)
return loss.mean()
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
"""Generate actions by integrating the learned velocity field via ODE."""
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
# Start from random noise at t=0
x = torch.randn((batch_size, self.horizon, self.action_dim), dtype=dtype, device=device)
# Time grid from 0 to 1
num_steps = self.config.num_integration_steps
time_grid = torch.linspace(0, 1, num_steps + 1, device=device)
# Integrate ODE using chosen method
if self.config.integration_method == "euler":
x = self._euler_integrate(model, x, time_grid, conditioning_vec)
elif self.config.integration_method == "rk4":
x = self._rk4_integrate(model, x, time_grid, conditioning_vec)
else:
raise ValueError(f"Unknown integration method: {self.config.integration_method}")
return x
def _euler_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor:
"""Euler integration: x_{n+1} = x_n + dt * v_θ(x_n, t_n)"""
x = x_init
for i in range(len(time_grid) - 1):
t_scalar = time_grid[i].item()
dt = (time_grid[i + 1] - time_grid[i]).item()
# Create time tensor for batch
t_batch = torch.full((x.shape[0],), t_scalar, dtype=x.dtype, device=x.device)
# Get velocity at current point
with torch.no_grad():
velocity = model(x, t_batch, conditioning_vec=conditioning_vec)
# Euler step
x = x + dt * velocity
return x
def _rk4_integrate(
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
) -> Tensor:
"""4th-order Runge-Kutta integration."""
x = x_init
def dynamics(x_val: Tensor, t_scalar: float) -> Tensor:
t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device)
with torch.no_grad():
return model(x_val, t_batch, conditioning_vec=conditioning_vec)
for i in range(len(time_grid) - 1):
t = time_grid[i].item()
dt = (time_grid[i + 1] - time_grid[i]).item()
# RK4 stages
k1 = dynamics(x, t)
k2 = dynamics(x + dt * k1 / 2, t + dt / 2)
k3 = dynamics(x + dt * k2 / 2, t + dt / 2)
k4 = dynamics(x + dt * k3, t + dt)
# Weighted combination
x = x + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
return x

View File

@@ -1,291 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Observation encoding for Multi-Task DiT policy.
Handles vision encoding, text encoding, robot state, and environment state.
"""
import einops
import torch
import torch.nn as nn
import torchvision
from torch import Tensor
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
class CLIPVisionEncoder(nn.Module):
"""CLIP vision encoder using the CLS token for global image representation."""
def __init__(self, model_name: str):
super().__init__()
self.model_name = model_name
# Load CLIP vision model from transformers
self.model = CLIPVisionModel.from_pretrained(self.model_name)
# CLIP models have 1 CLS token
self.num_non_spatial_tokens = 1
# Get embed_dim from model config
self.embed_dim = self.model.config.hidden_size
def forward(self, x: Tensor) -> Tensor:
"""Encode RGB image to CLS token.
Preprocessing (resize, crop) is handled by ObservationEncoder
"""
# Extract features using CLIPVisionModel
# Input: (B, C, H, W) - already preprocessed
outputs = self.model(pixel_values=x, output_hidden_states=False)
# Extract CLS token from last_hidden_state (first token)
# last_hidden_state shape: (B, sequence_length, hidden_size)
cls_token = outputs.last_hidden_state[:, 0] # (B, embed_dim)
b, embed_dim = cls_token.shape
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
cls_features = cls_token.reshape(b, embed_dim, 1, 1)
return cls_features
def get_output_shape(self) -> tuple:
return (self.embed_dim, 1, 1)
class CLIPTextEncoder(nn.Module):
"""CLIP text encoder with frozen weights and a learnable projection layer."""
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
super().__init__()
self.model_name = model_name
self.projection_dim = projection_dim
# Load CLIP text encoder and tokenizer
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
self.text_encoder = CLIPTextModel.from_pretrained(model_name)
# Freeze all CLIP text encoder parameters
for param in self.text_encoder.parameters():
param.requires_grad = False
self.text_embed_dim = self.text_encoder.config.hidden_size
# Learnable projection layer (always present, only trainable component)
self.projection = nn.Linear(self.text_embed_dim, projection_dim)
def forward(self, text: str | list[str]) -> Tensor:
"""Encode text to feature vectors.
Args:
text: Single string or list of strings
Returns:
Text features of shape (B, projection_dim)
"""
# handle single string input
if isinstance(text, str):
text = [text]
text_inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
text_inputs = {k: v.to(next(self.parameters()).device) for k, v in text_inputs.items()}
# encode text through CLIP (frozen)
with torch.no_grad():
outputs = self.text_encoder(**text_inputs)
# Extract pooled output (EOS token embedding)
clip_features = outputs.pooler_output # (B, text_embed_dim)
# project to desired dimension (trainable)
projected_features = self.projection(clip_features) # (B, projection_dim)
return projected_features
class ObservationEncoder(nn.Module):
"""Handles all observation processing for the conditioning vector."""
def __init__(self, config):
super().__init__()
self.config = config
self._setup_preprocessing(config)
if config.image_features:
self.num_cameras = len(config.image_features)
self.camera_names = list(config.image_features.keys())
if config.use_separate_encoder_per_camera:
self.vision_encoders = nn.ModuleList(
[CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names]
)
self.vision_encoder = None
else:
self.vision_encoder = CLIPVisionEncoder(model_name=config.vision_encoder_name)
self.vision_encoders = None
else:
self.vision_encoder = None
self.vision_encoders = None
self.camera_names = []
self.num_cameras = 0
if hasattr(config, "robot_state_feature") and config.robot_state_feature:
self.robot_state_dim = config.robot_state_feature.shape[0]
else:
self.robot_state_dim = 0
if hasattr(config, "env_state_feature") and config.env_state_feature:
self.env_state_dim = config.env_state_feature.shape[0]
else:
self.env_state_dim = 0
self.text_dim = config.hidden_dim
self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
self._setup_vector_output()
def _apply_preprocessing(self, images: Tensor) -> Tensor:
"""Apply preprocessing transforms to images."""
if self.do_resize:
images = self.resize(images)
if self.do_crop:
images = self.maybe_random_crop(images) if self.training else self.center_crop(images)
return images
def _setup_preprocessing(self, config):
"""Setup image preprocessing transforms."""
if config.image_resize_shape is not None:
self.do_resize = True
self.resize = torchvision.transforms.Resize(
size=config.image_resize_shape,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True,
)
else:
self.do_resize = False
if config.image_crop_shape is not None:
self.do_crop = True
self.center_crop = torchvision.transforms.CenterCrop(config.image_crop_shape)
if config.image_crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.image_crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
def _setup_vector_output(self):
total_dim = 0
# Vision features - get CLS token feature dimension
if self.vision_encoder is not None or self.vision_encoders is not None:
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders))
# Get output shape from encoder (deterministic for CLS tokens)
feature_map_shape = encoder_to_check.get_output_shape()
c, h, w = feature_map_shape
spatial_feature_dim = c * h * w # For CLS token: embed_dim * 1 * 1 = embed_dim
total_dim += spatial_feature_dim * self.num_cameras
# State features
total_dim += self.robot_state_dim
total_dim += self.env_state_dim
# Text features
total_dim += self.text_dim
# Account for temporal stacking
self.conditioning_dim = total_dim * self.config.n_obs_steps
def encode(self, batch: dict) -> Tensor:
"""Encode observations to vector format."""
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
conditioning_feats = []
conditioning_feats.append(batch[OBS_STATE])
if self.vision_encoder is not None or self.vision_encoders is not None:
images = batch[OBS_IMAGES] # (B, n_obs_steps, num_cameras, C, H, W)
# Handle case when n_obs=1 and time dimension might be squeezed
if len(images.shape) == 5:
# Shape is (B, N, C, H, W) - add time dimension
images = images.unsqueeze(1) # (B, 1, N, C, H, W)
if self.config.use_separate_encoder_per_camera:
# Process each camera with its own encoder
camera_features = []
for cam_idx in range(self.num_cameras):
# Extract images for this camera: (B, n_obs_steps, C, H, W)
cam_images = images[:, :, cam_idx]
# Rearrange to: (B*n_obs_steps, C, H, W)
cam_images_flat = einops.rearrange(cam_images, "b s c h w -> (b s) c h w")
# Apply preprocessing
cam_images_flat = self._apply_preprocessing(cam_images_flat)
# Process with camera-specific encoder (direct index access)
cam_features = self.vision_encoders[cam_idx](cam_images_flat)
# Apply spatial vectorization (flatten CLS token features)
cam_visual_features = cam_features.flatten(start_dim=1)
# Reshape back: (B*n_obs_steps, feature_dim) → (B, n_obs_steps, feature_dim)
cam_features_reshaped = einops.rearrange(
cam_visual_features, "(b s) f -> b s f", b=batch_size, s=n_obs_steps
)
camera_features.append(cam_features_reshaped)
# Concatenate features from all cameras: (B, n_obs_steps, total_feature_dim)
img_features = torch.cat(camera_features, dim=-1)
conditioning_feats.append(img_features)
else:
# Shared encoder for all cameras
# Rearrange to: (B*n_obs_steps*num_cameras, C, H, W)
images_flat = einops.rearrange(images, "b s n c h w -> (b s n) c h w")
images_flat = self._apply_preprocessing(images_flat)
visual_features = self.vision_encoder(images_flat).flatten(start_dim=1)
# Reshape back and concatenate camera features
# (B*n_obs_steps*num_cameras, feature_dim) → (B, n_obs_steps, num_cameras*feature_dim)
img_features = einops.rearrange(
visual_features, "(b s n) f -> b s (n f)", b=batch_size, s=n_obs_steps, n=self.num_cameras
)
conditioning_feats.append(img_features)
if self.env_state_dim > 0 and OBS_ENV_STATE in batch:
conditioning_feats.append(batch[OBS_ENV_STATE])
if self.text_encoder is not None and "task" in batch:
text_features = self.text_encoder(batch["task"]) # (B, text_dim)
# Expand across temporal dimension to match other features
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1) # (B, T, text_dim)
conditioning_feats.append(text_features)
combined_features = torch.cat(conditioning_feats, dim=-1) # (B, n_obs_steps, total_feature_dim)
return combined_features.flatten(start_dim=1) # (B, n_obs_steps * total_feature_dim)

View File

@@ -1,414 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer backbone for noise prediction in Multi-Task DiT policy.
Adapted from DiT (Diffusion Transformer: https://github.com/facebookresearch/DiT) for 1D trajectory data.
"""
import math
import torch
from torch import Tensor, nn
def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
"""Modulate input with shift and scale for AdaLN-Zero.
Args:
x: Input tensor
shift: Shift parameter
scale: Scale parameter
Returns:
Modulated tensor: x * (1 + scale) + shift
"""
return x * (1 + scale) + shift
class SinusoidalPosEmb(nn.Module):
"""Sinusoidal positional embeddings for timesteps.
Identical to the reference implementation - generates smooth embeddings
for diffusion timestep values.
"""
def __init__(self, dim: int):
"""
Args:
dim: Embedding dimension
"""
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: (B,) tensor of timestep values
Returns:
(B, dim) positional embeddings
"""
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class RotaryPositionalEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE) for transformers.
RoPE encodes position information by rotating query and key vectors,
which naturally captures relative positions through the dot product.
Applied at every attention layer rather than once at input.
To do this, we need to reimplement the attention mechanism to apply RoPE
to Q and K before computing the attention scores, so we cannot use the
the built-in MultiheadAttention module.
Original RoPE Paper: https://arxiv.org/abs/2104.09864 (RoFormer)
"""
def __init__(self, head_dim: int, max_seq_len: int = 512, base: float = 10000.0):
super().__init__()
assert head_dim % 2 == 0, "head_dim must be even for RoPE"
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.base = base
# Precompute inverse frequencies: theta_i = 1 / (base^(2i/d))
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._precompute_cache(max_seq_len)
def _precompute_cache(self, seq_len: int):
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("_sin_cached", emb.sin()[None, None, :, :], persistent=False)
def _rotate_half(self, x: Tensor) -> Tensor:
"""Rotate half the hidden dims of the input.
For x = [x1, x2], returns [-x2, x1]
"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def forward(self, q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]:
"""Apply rotary embeddings to query and key tensors."""
seq_len = q.shape[2]
if seq_len > self.max_seq_len:
raise ValueError(
f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}. "
f"Increase max_seq_len in RoPE config."
)
# Slice precomputed cache to actual sequence length
cos = self._cos_cached[:, :, :seq_len, :].to(q.dtype)
sin = self._sin_cached[:, :, :seq_len, :].to(q.dtype)
# Apply rotation: q_rot = q * cos + rotate_half(q) * sin
q_rotated = (q * cos) + (self._rotate_half(q) * sin)
k_rotated = (k * cos) + (self._rotate_half(k) * sin)
return q_rotated, k_rotated
class RoPEAttention(nn.Module):
"""Multi-head self-attention with Rotary Position Embedding (RoPE).
Custom attention implementation that applies RoPE to Q and K before
computing attention scores. This allows position information to be
encoded at every attention layer.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
dropout: float = 0.0,
max_seq_len: int = 512,
rope_base: float = 10000.0,
):
"""
Args:
hidden_size: Total hidden dimension
num_heads: Number of attention heads
dropout: Attention dropout rate
max_seq_len: Maximum sequence length for RoPE cache
rope_base: Base for RoPE frequency computation
"""
super().__init__()
assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = self.head_dim**-0.5
self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.rope = RotaryPositionalEmbedding(head_dim=self.head_dim, max_seq_len=max_seq_len, base=rope_base)
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: (B, T, hidden_size) input sequence
Returns:
(B, T, hidden_size) attention output
"""
B, T, _ = x.shape # noqa: N806
# Compute Q, K, V
qkv = self.qkv_proj(x) # (B, T, 3 * hidden_size)
qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, num_heads, T, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2] # Each: (B, num_heads, T, head_dim)
# Apply RoPE to Q and K
q, k = self.rope(q, k)
# Scaled dot-product attention
# Using PyTorch's efficient attention when available
attn_out = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.dropout.p if isinstance(self.dropout, nn.Dropout) and self.training else 0.0,
) # (B, num_heads, T, head_dim)
# Reshape and project output
attn_out = attn_out.transpose(1, 2).reshape(B, T, self.hidden_size) # (B, T, hidden_size)
output = self.out_proj(attn_out)
return output
class TransformerBlock(nn.Module):
"""DiT-style transformer block with AdaLN-Zero.
Official DiT implementation with 6-parameter adaptive layer normalization:
- shift_msa, scale_msa, gate_msa: for attention block
- shift_mlp, scale_mlp, gate_mlp: for MLP block
Supports both standard attention and RoPE attention.
Reference: https://github.com/facebookresearch/DiT
"""
def __init__(
self,
hidden_size: int = 128,
num_heads: int = 4,
num_features: int = 128,
dropout: float = 0.0,
use_rope: bool = False,
max_seq_len: int = 512,
rope_base: float = 10000.0,
):
"""
Args:
hidden_size: Hidden dimension of transformer
num_heads: Number of attention heads
num_features: Size of conditioning features
dropout: Dropout rate
use_rope: Whether to use Rotary Position Embedding
max_seq_len: Maximum sequence length (for RoPE cache)
rope_base: Base frequency for RoPE
"""
super().__init__()
self.use_rope = use_rope
if use_rope:
self.attn = RoPEAttention(
hidden_size=hidden_size,
num_heads=num_heads,
dropout=dropout,
max_seq_len=max_seq_len,
rope_base=rope_base,
)
else:
self.multihead_attn = nn.MultiheadAttention(
hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout
)
# Layer normalizations (no learnable affine parameters, all adaptation via conditioning)
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# Feed-forward network (MLP)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.GELU(approximate="tanh"),
nn.Linear(hidden_size * 4, hidden_size),
)
# AdaLN-Zero modulation: produces 6 parameters (shift, scale, gate for attn and mlp)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(num_features, 6 * hidden_size, bias=True))
def forward(self, x: Tensor, features: Tensor) -> Tensor:
"""
Args:
x: (B, T, hidden_size) input sequence
features: (B, num_features) conditioning features
Returns:
(B, T, hidden_size) processed sequence
"""
# Generate 6 modulation parameters from conditioning
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
features
).chunk(6, dim=1)
# Attention block: norm → modulate → attn → gate × output → residual
# modulate requires unsqueeze(1) to add sequence dimension for broadcasting
attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1))
if self.use_rope:
attn_out = self.attn(attn_input)
else:
attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input)
x = x + gate_msa.unsqueeze(1) * attn_out
# MLP block: norm → modulate → mlp → gate × output → residual
mlp_input = modulate(self.norm2(x), shift_mlp.unsqueeze(1), scale_mlp.unsqueeze(1))
mlp_out = self.mlp(mlp_input)
x = x + gate_mlp.unsqueeze(1) * mlp_out
return x
class DiffusionTransformer(nn.Module):
"""Transformer-based diffusion noise prediction model."""
def __init__(self, config, conditioning_dim: int):
"""Initialize transformer for noise prediction.
Args:
config: MultiTaskDiTConfig with transformer parameters
conditioning_dim: Dimension of concatenated observation features
"""
super().__init__()
self.config = config
self.conditioning_dim = conditioning_dim
self.action_dim = config.action_feature.shape[0]
self.horizon = config.horizon
self.hidden_size = config.hidden_dim
self.num_layers = config.num_layers
self.num_heads = config.num_heads
self.dropout = config.dropout
self.use_rope = config.use_rope
self.timestep_embed_dim = config.timestep_embed_dim
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(self.timestep_embed_dim),
nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim),
nn.GELU(),
nn.Linear(2 * self.timestep_embed_dim, self.timestep_embed_dim),
nn.GELU(),
)
self.cond_dim = self.timestep_embed_dim + conditioning_dim
# Project action dimensions to hidden size
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
if config.use_positional_encoding:
# Learnable positional embeddings for sequence positions (absolute encoding)
self.pos_embedding = nn.Parameter(
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
)
else:
self.pos_embedding = None
self.transformer_blocks = nn.ModuleList(
[
TransformerBlock(
hidden_size=self.hidden_size,
num_heads=self.num_heads,
num_features=self.cond_dim,
dropout=self.dropout,
use_rope=self.use_rope,
max_seq_len=self.horizon,
rope_base=config.rope_base,
)
for _ in range(self.num_layers)
]
)
# Project back to action dimensions
self.output_proj = nn.Linear(self.hidden_size, self.action_dim)
# Zero-initialize adaLN_modulation layers for AdaLN-Zero
self._initialize_weights()
def _initialize_weights(self):
"""Zero-initialize final linear layer of adaLN_modulation for training stability."""
for block in self.transformer_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
def forward(self, x: Tensor, timestep: Tensor, conditioning_vec: Tensor) -> Tensor:
"""Predict noise to remove from noisy actions.
Args:
x: (B, T, action_dim) noisy action sequences
timestep: (B,) diffusion timesteps
conditioning_vec: (B, conditioning_dim) observation features (required)
Returns:
(B, T, action_dim) predicted noise
"""
_, seq_len, _ = x.shape
timestep_features = self.time_mlp(timestep) # (B, timestep_embed_dim)
# conditioning_vec is now required
cond_features = torch.cat([timestep_features, conditioning_vec], dim=-1) # (B, cond_dim)
# Project action sequence to hidden dimension
hidden_seq = self.input_proj(x) # (B, T, hidden_size)
if self.pos_embedding is not None:
# Add learned positional embeddings
hidden_seq = hidden_seq + self.pos_embedding[:, :seq_len, :] # (B, T, hidden_size)
# Pass through transformer layers with conditioning
for block in self.transformer_blocks:
hidden_seq = block(hidden_seq, cond_features) # (B, T, hidden_size)
# Project back to action space
output = self.output_proj(hidden_seq) # (B, T, action_dim)
return output