mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
feat(policies): add VLA-JEPA
This commit is contained in:
10
src/lerobot/policies/vla_jepa/__init__.py
Normal file
10
src/lerobot/policies/vla_jepa/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
from .modeling_vla_jepa import VLAJEPAPolicy
|
||||
from .processor_vla_jepa import VLAJEPANewLineProcessor, make_vla_jepa_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"VLAJEPAConfig",
|
||||
"VLAJEPAPolicy",
|
||||
"VLAJEPANewLineProcessor",
|
||||
"make_vla_jepa_pre_post_processors",
|
||||
]
|
||||
280
src/lerobot/policies/vla_jepa/action_head.py
Normal file
280
src/lerobot/policies/vla_jepa/action_head.py
Normal file
@@ -0,0 +1,280 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
|
||||
def swish(x: torch.Tensor) -> torch.Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
timesteps = timesteps.float()
|
||||
batch_size, seq_len = timesteps.shape
|
||||
half_dim = self.embedding_dim // 2
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device)
|
||||
exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1))
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
||||
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1)
|
||||
|
||||
|
||||
class ActionEncoder(nn.Module):
|
||||
def __init__(self, action_dim: int, hidden_size: int):
|
||||
super().__init__()
|
||||
self.w1 = nn.Linear(action_dim, hidden_size)
|
||||
self.w2 = nn.Linear(hidden_size * 2, hidden_size)
|
||||
self.w3 = nn.Linear(hidden_size, hidden_size)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, seq_len, _ = actions.shape
|
||||
if timesteps.ndim != 1 or timesteps.shape[0] != batch_size:
|
||||
raise ValueError("timesteps must have shape [batch_size].")
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
|
||||
action_emb = self.w1(actions)
|
||||
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
|
||||
return self.w3(swish(self.w2(torch.cat([action_emb, time_emb], dim=-1))))
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype)
|
||||
return self.timestep_embedder(projected)
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
||||
self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False)
|
||||
self.silu = nn.SiLU()
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
|
||||
scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1)
|
||||
return self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout: float,
|
||||
cross_attention_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = AdaLayerNorm(dim)
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=True,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
out_bias=True,
|
||||
)
|
||||
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
attn_input = self.norm1(hidden_states, temb)
|
||||
hidden_states = hidden_states + self.attn(attn_input, encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DiT(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
dropout: float,
|
||||
cross_attention_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.timestep_encoder = TimestepEncoder(self.inner_dim)
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
temb = self.timestep_encoder(timestep)
|
||||
x = hidden_states
|
||||
for block in self.blocks:
|
||||
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
|
||||
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
|
||||
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
return self.proj_out_2(x)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionModelPreset:
|
||||
hidden_size: int
|
||||
attention_head_dim: int
|
||||
num_attention_heads: int
|
||||
|
||||
|
||||
DIT_PRESETS = {
|
||||
"DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12),
|
||||
"DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32),
|
||||
}
|
||||
|
||||
|
||||
class VLAJEPAActionHead(nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None:
|
||||
super().__init__()
|
||||
preset = DIT_PRESETS[config.action_model_type]
|
||||
self.config = config
|
||||
self.input_embedding_dim = preset.hidden_size
|
||||
self.action_horizon = config.future_action_window_size + 1
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
self.model = DiT(
|
||||
num_attention_heads=config.action_num_heads or preset.num_attention_heads,
|
||||
attention_head_dim=config.action_attention_head_dim or preset.attention_head_dim,
|
||||
output_dim=config.action_hidden_size,
|
||||
num_layers=config.action_num_layers,
|
||||
dropout=config.action_dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
self.action_encoder = ActionEncoder(config.action_dim, config.action_hidden_size)
|
||||
self.action_decoder = nn.Sequential(
|
||||
nn.Linear(config.action_hidden_size, config.action_hidden_size),
|
||||
nn.GELU(),
|
||||
nn.Linear(config.action_hidden_size, config.action_dim),
|
||||
)
|
||||
self.state_encoder = (
|
||||
nn.Sequential(
|
||||
nn.Linear(config.state_dim, config.action_hidden_size),
|
||||
nn.GELU(),
|
||||
nn.Linear(config.action_hidden_size, config.action_hidden_size),
|
||||
)
|
||||
if config.state_dim > 0
|
||||
else None
|
||||
)
|
||||
self.future_tokens = nn.Embedding(config.num_action_tokens_per_timestep, config.action_hidden_size)
|
||||
self.position_embedding = nn.Embedding(config.chunk_size + config.num_action_tokens_per_timestep + 4, config.action_hidden_size)
|
||||
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
|
||||
|
||||
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype)
|
||||
return (self.config.action_noise_s - sample) / self.config.action_noise_s
|
||||
|
||||
def _build_inputs(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
state: torch.Tensor | None,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
action_features = self.action_encoder(actions, timesteps)
|
||||
pos_ids = torch.arange(action_features.shape[1], device=actions.device)
|
||||
action_features = action_features + self.position_embedding(pos_ids)[None]
|
||||
|
||||
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1)
|
||||
seq = [future_tokens, action_features]
|
||||
if state is not None and self.state_encoder is not None:
|
||||
if state.ndim == 2:
|
||||
state = state.unsqueeze(1)
|
||||
seq.insert(0, self.state_encoder(state))
|
||||
return torch.cat(seq, dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
noise = torch.randn_like(actions)
|
||||
t = self.sample_time(actions.shape[0], actions.device, actions.dtype)
|
||||
noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions
|
||||
velocity = actions - noise
|
||||
t_discretized = (t * self.config.action_num_timestep_buckets).long()
|
||||
|
||||
hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized)
|
||||
pred = self.model(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=conditioning_tokens,
|
||||
timestep=t_discretized,
|
||||
)
|
||||
pred_actions = self.action_decoder(pred[:, -actions.shape[1] :])
|
||||
return F.mse_loss(pred_actions, velocity, reduction="mean")
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(
|
||||
self,
|
||||
conditioning_tokens: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = conditioning_tokens.shape[0]
|
||||
actions = torch.randn(
|
||||
batch_size,
|
||||
self.action_horizon,
|
||||
self.config.action_dim,
|
||||
dtype=conditioning_tokens.dtype,
|
||||
device=conditioning_tokens.device,
|
||||
)
|
||||
dt = 1.0 / max(self.num_inference_timesteps, 1)
|
||||
for step in range(self.num_inference_timesteps):
|
||||
t_cont = step / float(max(self.num_inference_timesteps, 1))
|
||||
t_value = int(t_cont * self.config.action_num_timestep_buckets)
|
||||
timesteps = torch.full((batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long)
|
||||
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
|
||||
pred = self.model(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=conditioning_tokens,
|
||||
timestep=timesteps,
|
||||
)
|
||||
pred_velocity = self.action_decoder(pred[:, -self.action_horizon :])
|
||||
actions = actions + dt * pred_velocity
|
||||
return actions
|
||||
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("vla_jepa")
|
||||
@dataclass
|
||||
class VLAJEPAConfig(PreTrainedConfig):
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 16
|
||||
n_action_steps: int = 16
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
qwen_model_name: str = "Qwen/Qwen3-VL-4B-Instruct"
|
||||
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
|
||||
|
||||
tokenizer_padding_side: str = "left"
|
||||
prompt_template: str = "{instruction}\n\nPredict {actions} and condition future prediction with {e_actions}."
|
||||
special_action_token: str = "<|action_{}|>"
|
||||
embodied_action_token: str = "<|embodied_action|>"
|
||||
|
||||
action_dim: int = 7
|
||||
state_dim: int = 7
|
||||
future_action_window_size: int = 15
|
||||
past_action_window_size: int = 0
|
||||
num_action_tokens_per_timestep: int = 4
|
||||
num_embodied_action_tokens_per_instruction: int = 8
|
||||
num_inference_timesteps: int = 10
|
||||
|
||||
action_hidden_size: int = 1024
|
||||
action_model_type: str = "DiT-B"
|
||||
action_num_layers: int = 12
|
||||
action_num_heads: int = 12
|
||||
action_attention_head_dim: int = 64
|
||||
action_dropout: float = 0.1
|
||||
action_num_timestep_buckets: int = 1000
|
||||
action_noise_beta_alpha: float = 1.5
|
||||
action_noise_beta_beta: float = 1.0
|
||||
action_noise_s: float = 0.999
|
||||
|
||||
num_video_frames: int = 4
|
||||
predictor_depth: int = 6
|
||||
predictor_num_heads: int = 8
|
||||
predictor_mlp_ratio: float = 4.0
|
||||
predictor_dropout: float = 0.0
|
||||
world_model_loss_weight: float = 0.1
|
||||
enable_world_model: bool = True
|
||||
|
||||
resize_images_to: tuple[int, int] | None = None
|
||||
torch_dtype: str = "bfloat16"
|
||||
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
optimizer_grad_clip_norm: float = 10.0
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
|
||||
if self.future_action_window_size + 1 > self.chunk_size:
|
||||
raise ValueError("`chunk_size` must cover the predicted action horizon.")
|
||||
if self.num_video_frames < 2:
|
||||
raise ValueError("`num_video_frames` must be >= 2 for JEPA prediction.")
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features:
|
||||
raise ValueError("VLAJEPA requires at least one visual input feature.")
|
||||
if self.action_feature is None:
|
||||
raise ValueError("VLAJEPA requires an action output feature.")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
return [0]
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
218
src/lerobot/policies/vla_jepa/modeling_vla_jepa.py
Normal file
218
src/lerobot/policies/vla_jepa/modeling_vla_jepa.py
Normal file
@@ -0,0 +1,218 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoModel, AutoVideoProcessor
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from .action_head import VLAJEPAActionHead
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
from .qwen_interface import Qwen3VLInterface
|
||||
from .world_model import ActionConditionedVideoPredictor
|
||||
|
||||
|
||||
class VLAJEPAModel(nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.qwen = Qwen3VLInterface(config)
|
||||
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = self.qwen.expand_tokenizer()
|
||||
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
|
||||
|
||||
self.video_encoder = AutoModel.from_pretrained(
|
||||
config.jepa_encoder_name,
|
||||
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
|
||||
)
|
||||
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
|
||||
self.video_predictor = ActionConditionedVideoPredictor(
|
||||
embed_dim=self.video_encoder.config.hidden_size,
|
||||
action_embed_dim=self.qwen.model.config.hidden_size,
|
||||
predictor_embed_dim=self.video_encoder.config.hidden_size,
|
||||
depth=config.predictor_depth,
|
||||
num_heads=config.predictor_num_heads,
|
||||
mlp_ratio=config.predictor_mlp_ratio,
|
||||
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
|
||||
)
|
||||
self.replace_prompt = "".join(
|
||||
token * self.config.num_action_tokens_per_timestep
|
||||
for token in self.action_tokens[: self.config.num_video_frames - 1]
|
||||
)
|
||||
self.embodied_replace_prompt = self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
|
||||
|
||||
def _collect_images(self, batch: dict[str, Tensor]) -> list[list]:
|
||||
sample_key = self.config.image_features[0]
|
||||
batch_size = batch[sample_key].shape[0]
|
||||
images = [[] for _ in range(batch_size)]
|
||||
for key in self.config.image_features:
|
||||
tensor = batch[key]
|
||||
if tensor.ndim == 5:
|
||||
tensor = tensor[:, -1]
|
||||
for idx in range(batch_size):
|
||||
images[idx].append(self.qwen.tensor_to_pil(tensor[idx]))
|
||||
return images
|
||||
|
||||
def _collect_videos(self, batch: dict[str, Tensor]) -> torch.Tensor:
|
||||
first_key = self.config.image_features[0]
|
||||
source = batch[first_key]
|
||||
if source.ndim == 4:
|
||||
source = source.unsqueeze(1).repeat(1, self.config.num_video_frames, 1, 1, 1)
|
||||
elif source.ndim == 5 and source.shape[1] < self.config.num_video_frames:
|
||||
pad = source[:, -1:].repeat(1, self.config.num_video_frames - source.shape[1], 1, 1, 1)
|
||||
source = torch.cat([source, pad], dim=1)
|
||||
elif source.ndim == 5:
|
||||
source = source[:, -self.config.num_video_frames :]
|
||||
else:
|
||||
raise ValueError(f"Unsupported image tensor shape for JEPA: {tuple(source.shape)}")
|
||||
return source
|
||||
|
||||
def _get_tasks(self, batch: dict[str, Tensor | list[str] | str]) -> list[str]:
|
||||
tasks = batch.get("task")
|
||||
if tasks is None:
|
||||
return ["Execute the robot action."] * next(iter(batch.values())).shape[0]
|
||||
if isinstance(tasks, str):
|
||||
return [tasks]
|
||||
return list(tasks)
|
||||
|
||||
def _extract_qwen_conditioning(self, batch: dict[str, Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
images = self._collect_images(batch)
|
||||
tasks = self._get_tasks(batch)
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=images,
|
||||
instructions=tasks,
|
||||
action_prompt=self.replace_prompt,
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
outputs = self.qwen.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=True,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
)
|
||||
hidden = outputs.hidden_states[-1]
|
||||
action_mask = torch.isin(
|
||||
qwen_inputs["input_ids"],
|
||||
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
|
||||
)
|
||||
action_indices = action_mask.nonzero(as_tuple=True)
|
||||
action_tokens = hidden[action_indices[0], action_indices[1], :].view(hidden.shape[0], -1, hidden.shape[-1])
|
||||
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
embodied_tokens = hidden[embodied_indices[0], embodied_indices[1], :].view(hidden.shape[0], -1, hidden.shape[-1])
|
||||
return action_tokens, embodied_tokens
|
||||
|
||||
def _prepare_state(self, batch: dict[str, Tensor], device: torch.device, dtype: torch.dtype) -> torch.Tensor | None:
|
||||
if OBS_STATE not in batch:
|
||||
return None
|
||||
state = batch[OBS_STATE]
|
||||
if state.ndim > 2:
|
||||
state = state[:, -1, :]
|
||||
return state.to(device=device, dtype=dtype)
|
||||
|
||||
def _prepare_action_targets(self, batch: dict[str, Tensor], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
actions = batch[ACTION]
|
||||
if actions.ndim == 2:
|
||||
actions = actions.unsqueeze(1)
|
||||
horizon = self.config.future_action_window_size + 1
|
||||
if actions.shape[1] < horizon:
|
||||
pad = actions[:, -1:].repeat(1, horizon - actions.shape[1], 1)
|
||||
actions = torch.cat([actions, pad], dim=1)
|
||||
return actions[:, -horizon:].to(device=device, dtype=dtype)
|
||||
|
||||
def _encode_video(self, video_tensor: torch.Tensor) -> torch.Tensor:
|
||||
processed = []
|
||||
for sample in video_tensor:
|
||||
processed_sample = self.video_processor(videos=sample, return_tensors="pt")["pixel_values_videos"]
|
||||
processed.append(processed_sample)
|
||||
pixel_values = torch.cat(processed, dim=0).to(self.video_encoder.device)
|
||||
return self.video_encoder.get_vision_features(pixel_values_videos=pixel_values)
|
||||
|
||||
def _compute_world_model_loss(self, batch: dict[str, Tensor], action_tokens: torch.Tensor) -> torch.Tensor | None:
|
||||
if not self.config.enable_world_model:
|
||||
return None
|
||||
video_tensor = self._collect_videos(batch)
|
||||
video_features = self._encode_video(video_tensor)
|
||||
batch_size = video_tensor.shape[0]
|
||||
num_frames = video_tensor.shape[1]
|
||||
tokens_per_frame = video_features.shape[1] // num_frames
|
||||
video_features = video_features.view(batch_size, num_frames, tokens_per_frame, -1)
|
||||
input_states = video_features[:, :-1]
|
||||
gt_states = video_features[:, 1:]
|
||||
|
||||
expected_tokens = (num_frames - 1) * self.config.num_action_tokens_per_timestep
|
||||
if action_tokens.shape[1] < expected_tokens:
|
||||
pad = action_tokens[:, -1:].repeat(1, expected_tokens - action_tokens.shape[1], 1)
|
||||
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||
action_tokens = action_tokens[:, :expected_tokens]
|
||||
action_tokens = action_tokens.view(batch_size, num_frames - 1, self.config.num_action_tokens_per_timestep, -1)
|
||||
pred_states = self.video_predictor(input_states, action_tokens)
|
||||
return F.l1_loss(pred_states, gt_states, reduction="mean")
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
action_tokens, embodied_tokens = self._extract_qwen_conditioning(batch)
|
||||
state = self._prepare_state(batch, embodied_tokens.device, embodied_tokens.dtype)
|
||||
target_actions = self._prepare_action_targets(batch, embodied_tokens.device, embodied_tokens.dtype)
|
||||
action_loss = self.action_model(embodied_tokens, target_actions, state)
|
||||
|
||||
wm_loss = self._compute_world_model_loss(batch, action_tokens)
|
||||
total_loss = action_loss
|
||||
logs = {"action_loss": action_loss.detach()}
|
||||
if wm_loss is not None:
|
||||
total_loss = total_loss + self.config.world_model_loss_weight * wm_loss
|
||||
logs["wm_loss"] = wm_loss.detach()
|
||||
logs["loss"] = total_loss.detach()
|
||||
return total_loss, logs
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
_, embodied_tokens = self._extract_qwen_conditioning(batch)
|
||||
state = self._prepare_state(batch, embodied_tokens.device, embodied_tokens.dtype)
|
||||
return self.action_model.predict_action(embodied_tokens, state)
|
||||
|
||||
|
||||
class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
config_class = VLAJEPAConfig
|
||||
name = "vla_jepa"
|
||||
|
||||
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.model = VLAJEPAModel(config)
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
loss, logs = self.model(batch)
|
||||
return loss, {key: value.item() for key, value in logs.items()}
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
return self.model.predict_action(batch)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.model.predict_action(batch)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
**kwargs,
|
||||
):
|
||||
return super().from_pretrained(pretrained_name_or_path, **kwargs)
|
||||
83
src/lerobot/policies/vla_jepa/processor_vla_jepa.py
Normal file
83
src/lerobot/policies/vla_jepa/processor_vla_jepa.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_vla_jepa_pre_post_processors(
|
||||
config: VLAJEPAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
features = {**config.input_features, **config.output_features}
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
VLAJEPANewLineProcessor(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_new_line_processor")
|
||||
class VLAJEPANewLineProcessor(ComplementaryDataProcessorStep):
|
||||
def complementary_data(self, complementary_data):
|
||||
if "task" not in complementary_data:
|
||||
return complementary_data
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
return complementary_data
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
if isinstance(task, str):
|
||||
if not task.endswith("\n"):
|
||||
new_complementary_data["task"] = f"{task}\n"
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
93
src/lerobot/policies/vla_jepa/qwen_interface.py
Normal file
93
src/lerobot/policies/vla_jepa/qwen_interface.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
||||
|
||||
from .configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
|
||||
class Qwen3VLInterface(torch.nn.Module):
|
||||
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||
config.qwen_model_name,
|
||||
torch_dtype=self._get_torch_dtype(config.torch_dtype),
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(config.qwen_model_name)
|
||||
self.processor.tokenizer.padding_side = config.tokenizer_padding_side
|
||||
self.model.config.hidden_size = self.model.config.text_config.hidden_size
|
||||
|
||||
@staticmethod
|
||||
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
|
||||
if dtype_name == "float32":
|
||||
return torch.float32
|
||||
if dtype_name == "float16":
|
||||
return torch.float16
|
||||
return torch.bfloat16
|
||||
|
||||
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
||||
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
|
||||
tokenizer = self.processor.tokenizer
|
||||
action_tokens = []
|
||||
action_token_ids = []
|
||||
for idx in range(max_action_tokens):
|
||||
token = self.config.special_action_token.format(idx)
|
||||
action_tokens.append(token)
|
||||
if token not in tokenizer.get_vocab():
|
||||
tokenizer.add_tokens([token], special_tokens=True)
|
||||
action_token_ids.append(tokenizer.convert_tokens_to_ids(token))
|
||||
|
||||
embodied_action_token = self.config.embodied_action_token
|
||||
if embodied_action_token not in tokenizer.get_vocab():
|
||||
tokenizer.add_tokens([embodied_action_token], special_tokens=True)
|
||||
embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token)
|
||||
|
||||
if self.model.get_input_embeddings().weight.size(0) < len(tokenizer):
|
||||
self.model.resize_token_embeddings(len(tokenizer))
|
||||
return action_tokens, action_token_ids, embodied_action_token_id
|
||||
|
||||
def build_inputs(
|
||||
self,
|
||||
images: Sequence[Sequence[Image.Image]],
|
||||
instructions: Sequence[str],
|
||||
action_prompt: str,
|
||||
embodied_prompt: str,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
messages = []
|
||||
for sample_images, instruction in zip(images, instructions, strict=True):
|
||||
prompt = self.config.prompt_template.format(
|
||||
instruction=instruction,
|
||||
actions=action_prompt,
|
||||
e_actions=embodied_prompt,
|
||||
)
|
||||
content = [{"type": "image", "image": img} for img in sample_images]
|
||||
content.append({"type": "text", "text": prompt})
|
||||
messages.append([{"role": "user", "content": content}])
|
||||
|
||||
batch_inputs = self.processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
padding=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
return batch_inputs.to(self.model.device)
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
|
||||
image = image_tensor.detach().cpu()
|
||||
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||
image = image.permute(1, 2, 0)
|
||||
image = image.float()
|
||||
if image.max() <= 1.0:
|
||||
image = image * 255.0
|
||||
image = image.clamp(0, 255).to(torch.uint8).numpy()
|
||||
if image.shape[-1] == 1:
|
||||
image = np.repeat(image, 3, axis=-1)
|
||||
return Image.fromarray(image)
|
||||
66
src/lerobot/policies/vla_jepa/world_model.py
Normal file
66
src/lerobot/policies/vla_jepa/world_model.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def build_block_causal_attention_mask(num_steps: int, tokens_per_step: int, cond_tokens: int) -> torch.Tensor:
|
||||
total_tokens = num_steps * (tokens_per_step + cond_tokens)
|
||||
mask = torch.full((total_tokens, total_tokens), float("-inf"))
|
||||
for current_step in range(num_steps):
|
||||
row_start = current_step * (tokens_per_step + cond_tokens)
|
||||
row_end = row_start + tokens_per_step + cond_tokens
|
||||
allowed_end = row_end
|
||||
mask[row_start:row_end, :allowed_end] = 0
|
||||
return mask
|
||||
|
||||
|
||||
class ActionConditionedVideoPredictor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
action_embed_dim: int,
|
||||
predictor_embed_dim: int,
|
||||
depth: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
num_action_tokens_per_step: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim)
|
||||
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim)
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=predictor_embed_dim,
|
||||
nhead=num_heads,
|
||||
dim_feedforward=int(predictor_embed_dim * mlp_ratio),
|
||||
dropout=0.0,
|
||||
activation="gelu",
|
||||
batch_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
|
||||
self.norm = nn.LayerNorm(predictor_embed_dim)
|
||||
self.proj = nn.Linear(predictor_embed_dim, embed_dim)
|
||||
self.num_action_tokens_per_step = num_action_tokens_per_step
|
||||
|
||||
def forward(self, frame_tokens: torch.Tensor, action_tokens: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_steps, tokens_per_frame, _ = frame_tokens.shape
|
||||
_, action_steps, _, _ = action_tokens.shape
|
||||
if action_steps != num_steps:
|
||||
raise ValueError(f"Expected {num_steps} action steps, got {action_steps}.")
|
||||
|
||||
frame_tokens = self.predictor_embed(frame_tokens)
|
||||
action_tokens = self.action_encoder(action_tokens)
|
||||
fused_steps = []
|
||||
for step in range(num_steps):
|
||||
fused_steps.append(torch.cat([action_tokens[:, step], frame_tokens[:, step]], dim=1))
|
||||
fused = torch.cat(fused_steps, dim=1)
|
||||
|
||||
attn_mask = build_block_causal_attention_mask(
|
||||
num_steps=num_steps,
|
||||
tokens_per_step=tokens_per_frame,
|
||||
cond_tokens=self.num_action_tokens_per_step,
|
||||
).to(device=fused.device, dtype=fused.dtype)
|
||||
encoded = self.encoder(fused, mask=attn_mask)
|
||||
encoded = encoded.view(batch_size, num_steps, self.num_action_tokens_per_step + tokens_per_frame, -1)
|
||||
predicted_frame_tokens = encoded[:, :, self.num_action_tokens_per_step :, :]
|
||||
return self.proj(self.norm(predicted_frame_tokens))
|
||||
Reference in New Issue
Block a user