feat(policies): add VLA-JEPA

This commit is contained in:
ginwind
2026-04-14 04:03:34 +00:00
committed by Maximellerbach
parent d5944c410c
commit 0e18bdaf7a
7 changed files with 865 additions and 0 deletions

View File

@@ -0,0 +1,10 @@
from .configuration_vla_jepa import VLAJEPAConfig
from .modeling_vla_jepa import VLAJEPAPolicy
from .processor_vla_jepa import VLAJEPANewLineProcessor, make_vla_jepa_pre_post_processors
__all__ = [
"VLAJEPAConfig",
"VLAJEPAPolicy",
"VLAJEPANewLineProcessor",
"make_vla_jepa_pre_post_processors",
]

View File

@@ -0,0 +1,280 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from torch import nn
from torch.distributions import Beta
from .configuration_vla_jepa import VLAJEPAConfig
def swish(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
timesteps = timesteps.float()
batch_size, seq_len = timesteps.shape
half_dim = self.embedding_dim // 2
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device)
exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1))
freqs = timesteps.unsqueeze(-1) * exponent.exp()
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1)
class ActionEncoder(nn.Module):
def __init__(self, action_dim: int, hidden_size: int):
super().__init__()
self.w1 = nn.Linear(action_dim, hidden_size)
self.w2 = nn.Linear(hidden_size * 2, hidden_size)
self.w3 = nn.Linear(hidden_size, hidden_size)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = actions.shape
if timesteps.ndim != 1 or timesteps.shape[0] != batch_size:
raise ValueError("timesteps must have shape [batch_size].")
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
action_emb = self.w1(actions)
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
return self.w3(swish(self.w2(torch.cat([action_emb, time_emb], dim=-1))))
class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype)
return self.timestep_embedder(projected)
class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False)
self.silu = nn.SiLU()
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1)
return self.norm(x) * (1 + scale[:, None]) + shift[:, None]
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout: float,
cross_attention_dim: int,
) -> None:
super().__init__()
self.norm1 = AdaLayerNorm(dim)
self.attn = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=True,
cross_attention_dim=cross_attention_dim,
out_bias=True,
)
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
) -> torch.Tensor:
attn_input = self.norm1(hidden_states, temb)
hidden_states = hidden_states + self.attn(attn_input, encoder_hidden_states=encoder_hidden_states)
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
return hidden_states
class DiT(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
output_dim: int,
num_layers: int,
dropout: float,
cross_attention_dim: int,
) -> None:
super().__init__()
self.inner_dim = num_attention_heads * attention_head_dim
self.timestep_encoder = TimestepEncoder(self.inner_dim)
self.blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
)
for _ in range(num_layers)
]
)
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2)
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.Tensor,
) -> torch.Tensor:
temb = self.timestep_encoder(timestep)
x = hidden_states
for block in self.blocks:
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
return self.proj_out_2(x)
@dataclass
class ActionModelPreset:
hidden_size: int
attention_head_dim: int
num_attention_heads: int
DIT_PRESETS = {
"DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12),
"DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32),
}
class VLAJEPAActionHead(nn.Module):
def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None:
super().__init__()
preset = DIT_PRESETS[config.action_model_type]
self.config = config
self.input_embedding_dim = preset.hidden_size
self.action_horizon = config.future_action_window_size + 1
self.num_inference_timesteps = config.num_inference_timesteps
self.model = DiT(
num_attention_heads=config.action_num_heads or preset.num_attention_heads,
attention_head_dim=config.action_attention_head_dim or preset.attention_head_dim,
output_dim=config.action_hidden_size,
num_layers=config.action_num_layers,
dropout=config.action_dropout,
cross_attention_dim=cross_attention_dim,
)
self.action_encoder = ActionEncoder(config.action_dim, config.action_hidden_size)
self.action_decoder = nn.Sequential(
nn.Linear(config.action_hidden_size, config.action_hidden_size),
nn.GELU(),
nn.Linear(config.action_hidden_size, config.action_dim),
)
self.state_encoder = (
nn.Sequential(
nn.Linear(config.state_dim, config.action_hidden_size),
nn.GELU(),
nn.Linear(config.action_hidden_size, config.action_hidden_size),
)
if config.state_dim > 0
else None
)
self.future_tokens = nn.Embedding(config.num_action_tokens_per_timestep, config.action_hidden_size)
self.position_embedding = nn.Embedding(config.chunk_size + config.num_action_tokens_per_timestep + 4, config.action_hidden_size)
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype)
return (self.config.action_noise_s - sample) / self.config.action_noise_s
def _build_inputs(
self,
conditioning_tokens: torch.Tensor,
actions: torch.Tensor,
state: torch.Tensor | None,
timesteps: torch.Tensor,
) -> torch.Tensor:
action_features = self.action_encoder(actions, timesteps)
pos_ids = torch.arange(action_features.shape[1], device=actions.device)
action_features = action_features + self.position_embedding(pos_ids)[None]
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1)
seq = [future_tokens, action_features]
if state is not None and self.state_encoder is not None:
if state.ndim == 2:
state = state.unsqueeze(1)
seq.insert(0, self.state_encoder(state))
return torch.cat(seq, dim=1)
def forward(
self,
conditioning_tokens: torch.Tensor,
actions: torch.Tensor,
state: torch.Tensor | None = None,
) -> torch.Tensor:
noise = torch.randn_like(actions)
t = self.sample_time(actions.shape[0], actions.device, actions.dtype)
noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions
velocity = actions - noise
t_discretized = (t * self.config.action_num_timestep_buckets).long()
hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized)
pred = self.model(
hidden_states=hidden_states,
encoder_hidden_states=conditioning_tokens,
timestep=t_discretized,
)
pred_actions = self.action_decoder(pred[:, -actions.shape[1] :])
return F.mse_loss(pred_actions, velocity, reduction="mean")
@torch.no_grad()
def predict_action(
self,
conditioning_tokens: torch.Tensor,
state: torch.Tensor | None = None,
) -> torch.Tensor:
batch_size = conditioning_tokens.shape[0]
actions = torch.randn(
batch_size,
self.action_horizon,
self.config.action_dim,
dtype=conditioning_tokens.dtype,
device=conditioning_tokens.device,
)
dt = 1.0 / max(self.num_inference_timesteps, 1)
for step in range(self.num_inference_timesteps):
t_cont = step / float(max(self.num_inference_timesteps, 1))
t_value = int(t_cont * self.config.action_num_timestep_buckets)
timesteps = torch.full((batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long)
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
pred = self.model(
hidden_states=hidden_states,
encoder_hidden_states=conditioning_tokens,
timestep=timesteps,
)
pred_velocity = self.action_decoder(pred[:, -self.action_horizon :])
actions = actions + dt * pred_velocity
return actions

View File

@@ -0,0 +1,115 @@
from __future__ import annotations
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("vla_jepa")
@dataclass
class VLAJEPAConfig(PreTrainedConfig):
n_obs_steps: int = 1
chunk_size: int = 16
n_action_steps: int = 16
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
qwen_model_name: str = "Qwen/Qwen3-VL-4B-Instruct"
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
tokenizer_padding_side: str = "left"
prompt_template: str = "{instruction}\n\nPredict {actions} and condition future prediction with {e_actions}."
special_action_token: str = "<|action_{}|>"
embodied_action_token: str = "<|embodied_action|>"
action_dim: int = 7
state_dim: int = 7
future_action_window_size: int = 15
past_action_window_size: int = 0
num_action_tokens_per_timestep: int = 4
num_embodied_action_tokens_per_instruction: int = 8
num_inference_timesteps: int = 10
action_hidden_size: int = 1024
action_model_type: str = "DiT-B"
action_num_layers: int = 12
action_num_heads: int = 12
action_attention_head_dim: int = 64
action_dropout: float = 0.1
action_num_timestep_buckets: int = 1000
action_noise_beta_alpha: float = 1.5
action_noise_beta_beta: float = 1.0
action_noise_s: float = 0.999
num_video_frames: int = 4
predictor_depth: int = 6
predictor_num_heads: int = 8
predictor_mlp_ratio: float = 4.0
predictor_dropout: float = 0.0
world_model_loss_weight: float = 0.1
enable_world_model: bool = True
resize_images_to: tuple[int, int] | None = None
torch_dtype: str = "bfloat16"
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_grad_clip_norm: float = 10.0
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
def __post_init__(self) -> None:
super().__post_init__()
if self.n_action_steps > self.chunk_size:
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
if self.future_action_window_size + 1 > self.chunk_size:
raise ValueError("`chunk_size` must cover the predicted action horizon.")
if self.num_video_frames < 2:
raise ValueError("`num_video_frames` must be >= 2 for JEPA prediction.")
def validate_features(self) -> None:
if not self.image_features:
raise ValueError("VLAJEPA requires at least one visual input feature.")
if self.action_feature is None:
raise ValueError("VLAJEPA requires an action output feature.")
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list[int]:
return [0]
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -0,0 +1,218 @@
from __future__ import annotations
from collections import deque
from pathlib import Path
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from transformers import AutoModel, AutoVideoProcessor
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from .action_head import VLAJEPAActionHead
from .configuration_vla_jepa import VLAJEPAConfig
from .qwen_interface import Qwen3VLInterface
from .world_model import ActionConditionedVideoPredictor
class VLAJEPAModel(nn.Module):
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
self.config = config
self.qwen = Qwen3VLInterface(config)
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = self.qwen.expand_tokenizer()
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
self.video_encoder = AutoModel.from_pretrained(
config.jepa_encoder_name,
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
)
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
self.video_predictor = ActionConditionedVideoPredictor(
embed_dim=self.video_encoder.config.hidden_size,
action_embed_dim=self.qwen.model.config.hidden_size,
predictor_embed_dim=self.video_encoder.config.hidden_size,
depth=config.predictor_depth,
num_heads=config.predictor_num_heads,
mlp_ratio=config.predictor_mlp_ratio,
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
)
self.replace_prompt = "".join(
token * self.config.num_action_tokens_per_timestep
for token in self.action_tokens[: self.config.num_video_frames - 1]
)
self.embodied_replace_prompt = self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
def _collect_images(self, batch: dict[str, Tensor]) -> list[list]:
sample_key = self.config.image_features[0]
batch_size = batch[sample_key].shape[0]
images = [[] for _ in range(batch_size)]
for key in self.config.image_features:
tensor = batch[key]
if tensor.ndim == 5:
tensor = tensor[:, -1]
for idx in range(batch_size):
images[idx].append(self.qwen.tensor_to_pil(tensor[idx]))
return images
def _collect_videos(self, batch: dict[str, Tensor]) -> torch.Tensor:
first_key = self.config.image_features[0]
source = batch[first_key]
if source.ndim == 4:
source = source.unsqueeze(1).repeat(1, self.config.num_video_frames, 1, 1, 1)
elif source.ndim == 5 and source.shape[1] < self.config.num_video_frames:
pad = source[:, -1:].repeat(1, self.config.num_video_frames - source.shape[1], 1, 1, 1)
source = torch.cat([source, pad], dim=1)
elif source.ndim == 5:
source = source[:, -self.config.num_video_frames :]
else:
raise ValueError(f"Unsupported image tensor shape for JEPA: {tuple(source.shape)}")
return source
def _get_tasks(self, batch: dict[str, Tensor | list[str] | str]) -> list[str]:
tasks = batch.get("task")
if tasks is None:
return ["Execute the robot action."] * next(iter(batch.values())).shape[0]
if isinstance(tasks, str):
return [tasks]
return list(tasks)
def _extract_qwen_conditioning(self, batch: dict[str, Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
images = self._collect_images(batch)
tasks = self._get_tasks(batch)
qwen_inputs = self.qwen.build_inputs(
images=images,
instructions=tasks,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
outputs = self.qwen.model(
**qwen_inputs,
output_hidden_states=True,
output_attentions=False,
return_dict=True,
)
hidden = outputs.hidden_states[-1]
action_mask = torch.isin(
qwen_inputs["input_ids"],
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
)
action_indices = action_mask.nonzero(as_tuple=True)
action_tokens = hidden[action_indices[0], action_indices[1], :].view(hidden.shape[0], -1, hidden.shape[-1])
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
embodied_tokens = hidden[embodied_indices[0], embodied_indices[1], :].view(hidden.shape[0], -1, hidden.shape[-1])
return action_tokens, embodied_tokens
def _prepare_state(self, batch: dict[str, Tensor], device: torch.device, dtype: torch.dtype) -> torch.Tensor | None:
if OBS_STATE not in batch:
return None
state = batch[OBS_STATE]
if state.ndim > 2:
state = state[:, -1, :]
return state.to(device=device, dtype=dtype)
def _prepare_action_targets(self, batch: dict[str, Tensor], device: torch.device, dtype: torch.dtype) -> torch.Tensor:
actions = batch[ACTION]
if actions.ndim == 2:
actions = actions.unsqueeze(1)
horizon = self.config.future_action_window_size + 1
if actions.shape[1] < horizon:
pad = actions[:, -1:].repeat(1, horizon - actions.shape[1], 1)
actions = torch.cat([actions, pad], dim=1)
return actions[:, -horizon:].to(device=device, dtype=dtype)
def _encode_video(self, video_tensor: torch.Tensor) -> torch.Tensor:
processed = []
for sample in video_tensor:
processed_sample = self.video_processor(videos=sample, return_tensors="pt")["pixel_values_videos"]
processed.append(processed_sample)
pixel_values = torch.cat(processed, dim=0).to(self.video_encoder.device)
return self.video_encoder.get_vision_features(pixel_values_videos=pixel_values)
def _compute_world_model_loss(self, batch: dict[str, Tensor], action_tokens: torch.Tensor) -> torch.Tensor | None:
if not self.config.enable_world_model:
return None
video_tensor = self._collect_videos(batch)
video_features = self._encode_video(video_tensor)
batch_size = video_tensor.shape[0]
num_frames = video_tensor.shape[1]
tokens_per_frame = video_features.shape[1] // num_frames
video_features = video_features.view(batch_size, num_frames, tokens_per_frame, -1)
input_states = video_features[:, :-1]
gt_states = video_features[:, 1:]
expected_tokens = (num_frames - 1) * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_tokens:
pad = action_tokens[:, -1:].repeat(1, expected_tokens - action_tokens.shape[1], 1)
action_tokens = torch.cat([action_tokens, pad], dim=1)
action_tokens = action_tokens[:, :expected_tokens]
action_tokens = action_tokens.view(batch_size, num_frames - 1, self.config.num_action_tokens_per_timestep, -1)
pred_states = self.video_predictor(input_states, action_tokens)
return F.l1_loss(pred_states, gt_states, reduction="mean")
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
action_tokens, embodied_tokens = self._extract_qwen_conditioning(batch)
state = self._prepare_state(batch, embodied_tokens.device, embodied_tokens.dtype)
target_actions = self._prepare_action_targets(batch, embodied_tokens.device, embodied_tokens.dtype)
action_loss = self.action_model(embodied_tokens, target_actions, state)
wm_loss = self._compute_world_model_loss(batch, action_tokens)
total_loss = action_loss
logs = {"action_loss": action_loss.detach()}
if wm_loss is not None:
total_loss = total_loss + self.config.world_model_loss_weight * wm_loss
logs["wm_loss"] = wm_loss.detach()
logs["loss"] = total_loss.detach()
return total_loss, logs
@torch.no_grad()
def predict_action(self, batch: dict[str, Tensor]) -> Tensor:
_, embodied_tokens = self._extract_qwen_conditioning(batch)
state = self._prepare_state(batch, embodied_tokens.device, embodied_tokens.dtype)
return self.action_model.predict_action(embodied_tokens, state)
class VLAJEPAPolicy(PreTrainedPolicy):
config_class = VLAJEPAConfig
name = "vla_jepa"
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
super().__init__(config)
config.validate_features()
self.model = VLAJEPAModel(config)
self.reset()
def reset(self) -> None:
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
loss, logs = self.model(batch)
return loss, {key: value.item() for key, value in logs.items()}
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
return self.model.predict_action(batch)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
if len(self._queues[ACTION]) == 0:
actions = self.model.predict_action(batch)
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()
@classmethod
def from_pretrained(
cls: type[T],
pretrained_name_or_path: str | Path,
**kwargs,
):
return super().from_pretrained(pretrained_name_or_path, **kwargs)

View File

@@ -0,0 +1,83 @@
from __future__ import annotations
from typing import Any
import torch
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_vla_jepa_pre_post_processors(
config: VLAJEPAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
features = {**config.input_features, **config.output_features}
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
VLAJEPANewLineProcessor(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features=features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
@ProcessorStepRegistry.register(name="vla_jepa_new_line_processor")
class VLAJEPANewLineProcessor(ComplementaryDataProcessorStep):
def complementary_data(self, complementary_data):
if "task" not in complementary_data:
return complementary_data
task = complementary_data["task"]
if task is None:
return complementary_data
new_complementary_data = dict(complementary_data)
if isinstance(task, str):
if not task.endswith("\n"):
new_complementary_data["task"] = f"{task}\n"
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
return new_complementary_data
def transform_features(self, features):
return features

View File

@@ -0,0 +1,93 @@
from __future__ import annotations
from typing import Sequence
import numpy as np
import torch
from PIL import Image
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
from .configuration_vla_jepa import VLAJEPAConfig
class Qwen3VLInterface(torch.nn.Module):
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
self.config = config
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
config.qwen_model_name,
torch_dtype=self._get_torch_dtype(config.torch_dtype),
)
self.processor = AutoProcessor.from_pretrained(config.qwen_model_name)
self.processor.tokenizer.padding_side = config.tokenizer_padding_side
self.model.config.hidden_size = self.model.config.text_config.hidden_size
@staticmethod
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
if dtype_name == "float32":
return torch.float32
if dtype_name == "float16":
return torch.float16
return torch.bfloat16
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
tokenizer = self.processor.tokenizer
action_tokens = []
action_token_ids = []
for idx in range(max_action_tokens):
token = self.config.special_action_token.format(idx)
action_tokens.append(token)
if token not in tokenizer.get_vocab():
tokenizer.add_tokens([token], special_tokens=True)
action_token_ids.append(tokenizer.convert_tokens_to_ids(token))
embodied_action_token = self.config.embodied_action_token
if embodied_action_token not in tokenizer.get_vocab():
tokenizer.add_tokens([embodied_action_token], special_tokens=True)
embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token)
if self.model.get_input_embeddings().weight.size(0) < len(tokenizer):
self.model.resize_token_embeddings(len(tokenizer))
return action_tokens, action_token_ids, embodied_action_token_id
def build_inputs(
self,
images: Sequence[Sequence[Image.Image]],
instructions: Sequence[str],
action_prompt: str,
embodied_prompt: str,
) -> dict[str, torch.Tensor]:
messages = []
for sample_images, instruction in zip(images, instructions, strict=True):
prompt = self.config.prompt_template.format(
instruction=instruction,
actions=action_prompt,
e_actions=embodied_prompt,
)
content = [{"type": "image", "image": img} for img in sample_images]
content.append({"type": "text", "text": prompt})
messages.append([{"role": "user", "content": content}])
batch_inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
padding=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
return batch_inputs.to(self.model.device)
@staticmethod
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
image = image_tensor.detach().cpu()
if image.ndim == 3 and image.shape[0] in (1, 3):
image = image.permute(1, 2, 0)
image = image.float()
if image.max() <= 1.0:
image = image * 255.0
image = image.clamp(0, 255).to(torch.uint8).numpy()
if image.shape[-1] == 1:
image = np.repeat(image, 3, axis=-1)
return Image.fromarray(image)

View File

@@ -0,0 +1,66 @@
from __future__ import annotations
import torch
from torch import nn
def build_block_causal_attention_mask(num_steps: int, tokens_per_step: int, cond_tokens: int) -> torch.Tensor:
total_tokens = num_steps * (tokens_per_step + cond_tokens)
mask = torch.full((total_tokens, total_tokens), float("-inf"))
for current_step in range(num_steps):
row_start = current_step * (tokens_per_step + cond_tokens)
row_end = row_start + tokens_per_step + cond_tokens
allowed_end = row_end
mask[row_start:row_end, :allowed_end] = 0
return mask
class ActionConditionedVideoPredictor(nn.Module):
def __init__(
self,
embed_dim: int,
action_embed_dim: int,
predictor_embed_dim: int,
depth: int,
num_heads: int,
mlp_ratio: float,
num_action_tokens_per_step: int,
) -> None:
super().__init__()
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim)
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=predictor_embed_dim,
nhead=num_heads,
dim_feedforward=int(predictor_embed_dim * mlp_ratio),
dropout=0.0,
activation="gelu",
batch_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
self.norm = nn.LayerNorm(predictor_embed_dim)
self.proj = nn.Linear(predictor_embed_dim, embed_dim)
self.num_action_tokens_per_step = num_action_tokens_per_step
def forward(self, frame_tokens: torch.Tensor, action_tokens: torch.Tensor) -> torch.Tensor:
batch_size, num_steps, tokens_per_frame, _ = frame_tokens.shape
_, action_steps, _, _ = action_tokens.shape
if action_steps != num_steps:
raise ValueError(f"Expected {num_steps} action steps, got {action_steps}.")
frame_tokens = self.predictor_embed(frame_tokens)
action_tokens = self.action_encoder(action_tokens)
fused_steps = []
for step in range(num_steps):
fused_steps.append(torch.cat([action_tokens[:, step], frame_tokens[:, step]], dim=1))
fused = torch.cat(fused_steps, dim=1)
attn_mask = build_block_causal_attention_mask(
num_steps=num_steps,
tokens_per_step=tokens_per_frame,
cond_tokens=self.num_action_tokens_per_step,
).to(device=fused.device, dtype=fused.dtype)
encoded = self.encoder(fused, mask=attn_mask)
encoded = encoded.view(batch_size, num_steps, self.num_action_tokens_per_step + tokens_per_frame, -1)
predicted_frame_tokens = encoded[:, :, self.num_action_tokens_per_step :, :]
return self.proj(self.norm(predicted_frame_tokens))