diff --git a/src/lerobot/policies/vla_jepa/__init__.py b/src/lerobot/policies/vla_jepa/__init__.py new file mode 100644 index 000000000..be6a59ba0 --- /dev/null +++ b/src/lerobot/policies/vla_jepa/__init__.py @@ -0,0 +1,10 @@ +from .configuration_vla_jepa import VLAJEPAConfig +from .modeling_vla_jepa import VLAJEPAPolicy +from .processor_vla_jepa import VLAJEPANewLineProcessor, make_vla_jepa_pre_post_processors + +__all__ = [ + "VLAJEPAConfig", + "VLAJEPAPolicy", + "VLAJEPANewLineProcessor", + "make_vla_jepa_pre_post_processors", +] diff --git a/src/lerobot/policies/vla_jepa/action_head.py b/src/lerobot/policies/vla_jepa/action_head.py new file mode 100644 index 000000000..2ff34e071 --- /dev/null +++ b/src/lerobot/policies/vla_jepa/action_head.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from diffusers import ConfigMixin, ModelMixin +from diffusers.configuration_utils import register_to_config +from diffusers.models.attention import Attention, FeedForward +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from torch import nn +from torch.distributions import Beta + +from .configuration_vla_jepa import VLAJEPAConfig + + +def swish(x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(x) + + +class SinusoidalPositionalEncoding(nn.Module): + def __init__(self, embedding_dim: int): + super().__init__() + self.embedding_dim = embedding_dim + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + timesteps = timesteps.float() + batch_size, seq_len = timesteps.shape + half_dim = self.embedding_dim // 2 + exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device) + exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1)) + freqs = timesteps.unsqueeze(-1) * exponent.exp() + return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1) + + +class ActionEncoder(nn.Module): + def __init__(self, action_dim: int, hidden_size: int): + super().__init__() + self.w1 = nn.Linear(action_dim, hidden_size) + self.w2 = nn.Linear(hidden_size * 2, hidden_size) + self.w3 = nn.Linear(hidden_size, hidden_size) + self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) + + def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, _ = actions.shape + if timesteps.ndim != 1 or timesteps.shape[0] != batch_size: + raise ValueError("timesteps must have shape [batch_size].") + timesteps = timesteps.unsqueeze(1).expand(-1, seq_len) + action_emb = self.w1(actions) + time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype) + return self.w3(swish(self.w2(torch.cat([action_emb, time_emb], dim=-1)))) + + +class TimestepEncoder(nn.Module): + def __init__(self, embedding_dim: int): + super().__init__() + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype) + return self.timestep_embedder(projected) + + +class AdaLayerNorm(nn.Module): + def __init__(self, embedding_dim: int): + super().__init__() + self.linear = nn.Linear(embedding_dim, embedding_dim * 2) + self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False) + self.silu = nn.SiLU() + + def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: + scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1) + return self.norm(x) * (1 + scale[:, None]) + shift[:, None] + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float, + cross_attention_dim: int, + ) -> None: + super().__init__() + self.norm1 = AdaLayerNorm(dim) + self.attn = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=True, + cross_attention_dim=cross_attention_dim, + out_bias=True, + ) + self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False) + self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + ) -> torch.Tensor: + attn_input = self.norm1(hidden_states, temb) + hidden_states = hidden_states + self.attn(attn_input, encoder_hidden_states=encoder_hidden_states) + hidden_states = hidden_states + self.ff(self.norm2(hidden_states)) + return hidden_states + + +class DiT(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + output_dim: int, + num_layers: int, + dropout: float, + cross_attention_dim: int, + ) -> None: + super().__init__() + self.inner_dim = num_attention_heads * attention_head_dim + self.timestep_encoder = TimestepEncoder(self.inner_dim) + self.blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + ) + for _ in range(num_layers) + ] + ) + self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False) + self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2) + self.proj_out_2 = nn.Linear(self.inner_dim, output_dim) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.Tensor, + ) -> torch.Tensor: + temb = self.timestep_encoder(timestep) + x = hidden_states + for block in self.blocks: + x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb) + shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1) + x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None] + return self.proj_out_2(x) + + +@dataclass +class ActionModelPreset: + hidden_size: int + attention_head_dim: int + num_attention_heads: int + + +DIT_PRESETS = { + "DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12), + "DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32), +} + + +class VLAJEPAActionHead(nn.Module): + def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None: + super().__init__() + preset = DIT_PRESETS[config.action_model_type] + self.config = config + self.input_embedding_dim = preset.hidden_size + self.action_horizon = config.future_action_window_size + 1 + self.num_inference_timesteps = config.num_inference_timesteps + + self.model = DiT( + num_attention_heads=config.action_num_heads or preset.num_attention_heads, + attention_head_dim=config.action_attention_head_dim or preset.attention_head_dim, + output_dim=config.action_hidden_size, + num_layers=config.action_num_layers, + dropout=config.action_dropout, + cross_attention_dim=cross_attention_dim, + ) + self.action_encoder = ActionEncoder(config.action_dim, config.action_hidden_size) + self.action_decoder = nn.Sequential( + nn.Linear(config.action_hidden_size, config.action_hidden_size), + nn.GELU(), + nn.Linear(config.action_hidden_size, config.action_dim), + ) + self.state_encoder = ( + nn.Sequential( + nn.Linear(config.state_dim, config.action_hidden_size), + nn.GELU(), + nn.Linear(config.action_hidden_size, config.action_hidden_size), + ) + if config.state_dim > 0 + else None + ) + self.future_tokens = nn.Embedding(config.num_action_tokens_per_timestep, config.action_hidden_size) + self.position_embedding = nn.Embedding(config.chunk_size + config.num_action_tokens_per_timestep + 4, config.action_hidden_size) + self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta) + + def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype) + return (self.config.action_noise_s - sample) / self.config.action_noise_s + + def _build_inputs( + self, + conditioning_tokens: torch.Tensor, + actions: torch.Tensor, + state: torch.Tensor | None, + timesteps: torch.Tensor, + ) -> torch.Tensor: + action_features = self.action_encoder(actions, timesteps) + pos_ids = torch.arange(action_features.shape[1], device=actions.device) + action_features = action_features + self.position_embedding(pos_ids)[None] + + future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1) + seq = [future_tokens, action_features] + if state is not None and self.state_encoder is not None: + if state.ndim == 2: + state = state.unsqueeze(1) + seq.insert(0, self.state_encoder(state)) + return torch.cat(seq, dim=1) + + def forward( + self, + conditioning_tokens: torch.Tensor, + actions: torch.Tensor, + state: torch.Tensor | None = None, + ) -> torch.Tensor: + noise = torch.randn_like(actions) + t = self.sample_time(actions.shape[0], actions.device, actions.dtype) + noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions + velocity = actions - noise + t_discretized = (t * self.config.action_num_timestep_buckets).long() + + hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized) + pred = self.model( + hidden_states=hidden_states, + encoder_hidden_states=conditioning_tokens, + timestep=t_discretized, + ) + pred_actions = self.action_decoder(pred[:, -actions.shape[1] :]) + return F.mse_loss(pred_actions, velocity, reduction="mean") + + @torch.no_grad() + def predict_action( + self, + conditioning_tokens: torch.Tensor, + state: torch.Tensor | None = None, + ) -> torch.Tensor: + batch_size = conditioning_tokens.shape[0] + actions = torch.randn( + batch_size, + self.action_horizon, + self.config.action_dim, + dtype=conditioning_tokens.dtype, + device=conditioning_tokens.device, + ) + dt = 1.0 / max(self.num_inference_timesteps, 1) + for step in range(self.num_inference_timesteps): + t_cont = step / float(max(self.num_inference_timesteps, 1)) + t_value = int(t_cont * self.config.action_num_timestep_buckets) + timesteps = torch.full((batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long) + hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps) + pred = self.model( + hidden_states=hidden_states, + encoder_hidden_states=conditioning_tokens, + timestep=timesteps, + ) + pred_velocity = self.action_decoder(pred[:, -self.action_horizon :]) + actions = actions + dt * pred_velocity + return actions diff --git a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py index e69de29bb..62d4f065d 100644 --- a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig + + +@PreTrainedConfig.register_subclass("vla_jepa") +@dataclass +class VLAJEPAConfig(PreTrainedConfig): + n_obs_steps: int = 1 + chunk_size: int = 16 + n_action_steps: int = 16 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, + "ACTION": NormalizationMode.MEAN_STD, + } + ) + + qwen_model_name: str = "Qwen/Qwen3-VL-4B-Instruct" + jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256" + + tokenizer_padding_side: str = "left" + prompt_template: str = "{instruction}\n\nPredict {actions} and condition future prediction with {e_actions}." + special_action_token: str = "<|action_{}|>" + embodied_action_token: str = "<|embodied_action|>" + + action_dim: int = 7 + state_dim: int = 7 + future_action_window_size: int = 15 + past_action_window_size: int = 0 + num_action_tokens_per_timestep: int = 4 + num_embodied_action_tokens_per_instruction: int = 8 + num_inference_timesteps: int = 10 + + action_hidden_size: int = 1024 + action_model_type: str = "DiT-B" + action_num_layers: int = 12 + action_num_heads: int = 12 + action_attention_head_dim: int = 64 + action_dropout: float = 0.1 + action_num_timestep_buckets: int = 1000 + action_noise_beta_alpha: float = 1.5 + action_noise_beta_beta: float = 1.0 + action_noise_s: float = 0.999 + + num_video_frames: int = 4 + predictor_depth: int = 6 + predictor_num_heads: int = 8 + predictor_mlp_ratio: float = 4.0 + predictor_dropout: float = 0.0 + world_model_loss_weight: float = 0.1 + enable_world_model: bool = True + + resize_images_to: tuple[int, int] | None = None + torch_dtype: str = "bfloat16" + + optimizer_lr: float = 1e-4 + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 1e-10 + optimizer_grad_clip_norm: float = 10.0 + scheduler_warmup_steps: int = 1_000 + scheduler_decay_steps: int = 30_000 + scheduler_decay_lr: float = 2.5e-6 + + def __post_init__(self) -> None: + super().__post_init__() + if self.n_action_steps > self.chunk_size: + raise ValueError("`n_action_steps` must be <= `chunk_size`.") + if self.future_action_window_size + 1 > self.chunk_size: + raise ValueError("`chunk_size` must cover the predicted action horizon.") + if self.num_video_frames < 2: + raise ValueError("`num_video_frames` must be >= 2 for JEPA prediction.") + + def validate_features(self) -> None: + if not self.image_features: + raise ValueError("VLAJEPA requires at least one visual input feature.") + if self.action_feature is None: + raise ValueError("VLAJEPA requires an action output feature.") + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, + ) + + def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig: + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> list[int]: + return [0] + + @property + def action_delta_indices(self) -> list[int]: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py new file mode 100644 index 000000000..429ddf96f --- /dev/null +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +from collections import deque +from pathlib import Path + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from transformers import AutoModel, AutoVideoProcessor + +from lerobot.policies.pretrained import PreTrainedPolicy, T +from lerobot.policies.utils import populate_queues +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE + +from .action_head import VLAJEPAActionHead +from .configuration_vla_jepa import VLAJEPAConfig +from .qwen_interface import Qwen3VLInterface +from .world_model import ActionConditionedVideoPredictor + + +class VLAJEPAModel(nn.Module): + def __init__(self, config: VLAJEPAConfig) -> None: + super().__init__() + self.config = config + self.qwen = Qwen3VLInterface(config) + self.action_tokens, self.action_token_ids, self.embodied_action_token_id = self.qwen.expand_tokenizer() + self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size) + + self.video_encoder = AutoModel.from_pretrained( + config.jepa_encoder_name, + torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype), + ) + self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name) + self.video_predictor = ActionConditionedVideoPredictor( + embed_dim=self.video_encoder.config.hidden_size, + action_embed_dim=self.qwen.model.config.hidden_size, + predictor_embed_dim=self.video_encoder.config.hidden_size, + depth=config.predictor_depth, + num_heads=config.predictor_num_heads, + mlp_ratio=config.predictor_mlp_ratio, + num_action_tokens_per_step=config.num_action_tokens_per_timestep, + ) + self.replace_prompt = "".join( + token * self.config.num_action_tokens_per_timestep + for token in self.action_tokens[: self.config.num_video_frames - 1] + ) + self.embodied_replace_prompt = self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction + + def _collect_images(self, batch: dict[str, Tensor]) -> list[list]: + sample_key = self.config.image_features[0] + batch_size = batch[sample_key].shape[0] + images = [[] for _ in range(batch_size)] + for key in self.config.image_features: + tensor = batch[key] + if tensor.ndim == 5: + tensor = tensor[:, -1] + for idx in range(batch_size): + images[idx].append(self.qwen.tensor_to_pil(tensor[idx])) + return images + + def _collect_videos(self, batch: dict[str, Tensor]) -> torch.Tensor: + first_key = self.config.image_features[0] + source = batch[first_key] + if source.ndim == 4: + source = source.unsqueeze(1).repeat(1, self.config.num_video_frames, 1, 1, 1) + elif source.ndim == 5 and source.shape[1] < self.config.num_video_frames: + pad = source[:, -1:].repeat(1, self.config.num_video_frames - source.shape[1], 1, 1, 1) + source = torch.cat([source, pad], dim=1) + elif source.ndim == 5: + source = source[:, -self.config.num_video_frames :] + else: + raise ValueError(f"Unsupported image tensor shape for JEPA: {tuple(source.shape)}") + return source + + def _get_tasks(self, batch: dict[str, Tensor | list[str] | str]) -> list[str]: + tasks = batch.get("task") + if tasks is None: + return ["Execute the robot action."] * next(iter(batch.values())).shape[0] + if isinstance(tasks, str): + return [tasks] + return list(tasks) + + def _extract_qwen_conditioning(self, batch: dict[str, Tensor]) -> tuple[torch.Tensor, torch.Tensor]: + images = self._collect_images(batch) + tasks = self._get_tasks(batch) + qwen_inputs = self.qwen.build_inputs( + images=images, + instructions=tasks, + action_prompt=self.replace_prompt, + embodied_prompt=self.embodied_replace_prompt, + ) + outputs = self.qwen.model( + **qwen_inputs, + output_hidden_states=True, + output_attentions=False, + return_dict=True, + ) + hidden = outputs.hidden_states[-1] + action_mask = torch.isin( + qwen_inputs["input_ids"], + torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device), + ) + action_indices = action_mask.nonzero(as_tuple=True) + action_tokens = hidden[action_indices[0], action_indices[1], :].view(hidden.shape[0], -1, hidden.shape[-1]) + + embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id + embodied_indices = embodied_mask.nonzero(as_tuple=True) + embodied_tokens = hidden[embodied_indices[0], embodied_indices[1], :].view(hidden.shape[0], -1, hidden.shape[-1]) + return action_tokens, embodied_tokens + + def _prepare_state(self, batch: dict[str, Tensor], device: torch.device, dtype: torch.dtype) -> torch.Tensor | None: + if OBS_STATE not in batch: + return None + state = batch[OBS_STATE] + if state.ndim > 2: + state = state[:, -1, :] + return state.to(device=device, dtype=dtype) + + def _prepare_action_targets(self, batch: dict[str, Tensor], device: torch.device, dtype: torch.dtype) -> torch.Tensor: + actions = batch[ACTION] + if actions.ndim == 2: + actions = actions.unsqueeze(1) + horizon = self.config.future_action_window_size + 1 + if actions.shape[1] < horizon: + pad = actions[:, -1:].repeat(1, horizon - actions.shape[1], 1) + actions = torch.cat([actions, pad], dim=1) + return actions[:, -horizon:].to(device=device, dtype=dtype) + + def _encode_video(self, video_tensor: torch.Tensor) -> torch.Tensor: + processed = [] + for sample in video_tensor: + processed_sample = self.video_processor(videos=sample, return_tensors="pt")["pixel_values_videos"] + processed.append(processed_sample) + pixel_values = torch.cat(processed, dim=0).to(self.video_encoder.device) + return self.video_encoder.get_vision_features(pixel_values_videos=pixel_values) + + def _compute_world_model_loss(self, batch: dict[str, Tensor], action_tokens: torch.Tensor) -> torch.Tensor | None: + if not self.config.enable_world_model: + return None + video_tensor = self._collect_videos(batch) + video_features = self._encode_video(video_tensor) + batch_size = video_tensor.shape[0] + num_frames = video_tensor.shape[1] + tokens_per_frame = video_features.shape[1] // num_frames + video_features = video_features.view(batch_size, num_frames, tokens_per_frame, -1) + input_states = video_features[:, :-1] + gt_states = video_features[:, 1:] + + expected_tokens = (num_frames - 1) * self.config.num_action_tokens_per_timestep + if action_tokens.shape[1] < expected_tokens: + pad = action_tokens[:, -1:].repeat(1, expected_tokens - action_tokens.shape[1], 1) + action_tokens = torch.cat([action_tokens, pad], dim=1) + action_tokens = action_tokens[:, :expected_tokens] + action_tokens = action_tokens.view(batch_size, num_frames - 1, self.config.num_action_tokens_per_timestep, -1) + pred_states = self.video_predictor(input_states, action_tokens) + return F.l1_loss(pred_states, gt_states, reduction="mean") + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]: + action_tokens, embodied_tokens = self._extract_qwen_conditioning(batch) + state = self._prepare_state(batch, embodied_tokens.device, embodied_tokens.dtype) + target_actions = self._prepare_action_targets(batch, embodied_tokens.device, embodied_tokens.dtype) + action_loss = self.action_model(embodied_tokens, target_actions, state) + + wm_loss = self._compute_world_model_loss(batch, action_tokens) + total_loss = action_loss + logs = {"action_loss": action_loss.detach()} + if wm_loss is not None: + total_loss = total_loss + self.config.world_model_loss_weight * wm_loss + logs["wm_loss"] = wm_loss.detach() + logs["loss"] = total_loss.detach() + return total_loss, logs + + @torch.no_grad() + def predict_action(self, batch: dict[str, Tensor]) -> Tensor: + _, embodied_tokens = self._extract_qwen_conditioning(batch) + state = self._prepare_state(batch, embodied_tokens.device, embodied_tokens.dtype) + return self.action_model.predict_action(embodied_tokens, state) + + +class VLAJEPAPolicy(PreTrainedPolicy): + config_class = VLAJEPAConfig + name = "vla_jepa" + + def __init__(self, config: VLAJEPAConfig, **kwargs) -> None: + super().__init__(config) + config.validate_features() + self.model = VLAJEPAModel(config) + self.reset() + + def reset(self) -> None: + self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)} + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: + loss, logs = self.model(batch) + return loss, {key: value.item() for key, value in logs.items()} + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002 + self.eval() + self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) + return self.model.predict_action(batch) + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002 + self.eval() + self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) + if len(self._queues[ACTION]) == 0: + actions = self.model.predict_action(batch) + self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps]) + return self._queues[ACTION].popleft() + + @classmethod + def from_pretrained( + cls: type[T], + pretrained_name_or_path: str | Path, + **kwargs, + ): + return super().from_pretrained(pretrained_name_or_path, **kwargs) diff --git a/src/lerobot/policies/vla_jepa/processor_vla_jepa.py b/src/lerobot/policies/vla_jepa/processor_vla_jepa.py new file mode 100644 index 000000000..acd6ea2b6 --- /dev/null +++ b/src/lerobot/policies/vla_jepa/processor_vla_jepa.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import Any + +import torch + +from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + ComplementaryDataProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME + + +def make_vla_jepa_pre_post_processors( + config: VLAJEPAConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + features = {**config.input_features, **config.output_features} + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + VLAJEPANewLineProcessor(), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features=features, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + ] + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + DeviceProcessorStep(device="cpu"), + ] + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) + + +@ProcessorStepRegistry.register(name="vla_jepa_new_line_processor") +class VLAJEPANewLineProcessor(ComplementaryDataProcessorStep): + def complementary_data(self, complementary_data): + if "task" not in complementary_data: + return complementary_data + + task = complementary_data["task"] + if task is None: + return complementary_data + + new_complementary_data = dict(complementary_data) + if isinstance(task, str): + if not task.endswith("\n"): + new_complementary_data["task"] = f"{task}\n" + elif isinstance(task, list) and all(isinstance(t, str) for t in task): + new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task] + return new_complementary_data + + def transform_features(self, features): + return features diff --git a/src/lerobot/policies/vla_jepa/qwen_interface.py b/src/lerobot/policies/vla_jepa/qwen_interface.py new file mode 100644 index 000000000..1e1e7a895 --- /dev/null +++ b/src/lerobot/policies/vla_jepa/qwen_interface.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from typing import Sequence + +import numpy as np +import torch +from PIL import Image +from transformers import AutoProcessor, Qwen3VLForConditionalGeneration + +from .configuration_vla_jepa import VLAJEPAConfig + + +class Qwen3VLInterface(torch.nn.Module): + def __init__(self, config: VLAJEPAConfig) -> None: + super().__init__() + self.config = config + self.model = Qwen3VLForConditionalGeneration.from_pretrained( + config.qwen_model_name, + torch_dtype=self._get_torch_dtype(config.torch_dtype), + ) + self.processor = AutoProcessor.from_pretrained(config.qwen_model_name) + self.processor.tokenizer.padding_side = config.tokenizer_padding_side + self.model.config.hidden_size = self.model.config.text_config.hidden_size + + @staticmethod + def _get_torch_dtype(dtype_name: str) -> torch.dtype: + if dtype_name == "float32": + return torch.float32 + if dtype_name == "float16": + return torch.float16 + return torch.bfloat16 + + def expand_tokenizer(self) -> tuple[list[str], list[int], int]: + max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep + tokenizer = self.processor.tokenizer + action_tokens = [] + action_token_ids = [] + for idx in range(max_action_tokens): + token = self.config.special_action_token.format(idx) + action_tokens.append(token) + if token not in tokenizer.get_vocab(): + tokenizer.add_tokens([token], special_tokens=True) + action_token_ids.append(tokenizer.convert_tokens_to_ids(token)) + + embodied_action_token = self.config.embodied_action_token + if embodied_action_token not in tokenizer.get_vocab(): + tokenizer.add_tokens([embodied_action_token], special_tokens=True) + embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token) + + if self.model.get_input_embeddings().weight.size(0) < len(tokenizer): + self.model.resize_token_embeddings(len(tokenizer)) + return action_tokens, action_token_ids, embodied_action_token_id + + def build_inputs( + self, + images: Sequence[Sequence[Image.Image]], + instructions: Sequence[str], + action_prompt: str, + embodied_prompt: str, + ) -> dict[str, torch.Tensor]: + messages = [] + for sample_images, instruction in zip(images, instructions, strict=True): + prompt = self.config.prompt_template.format( + instruction=instruction, + actions=action_prompt, + e_actions=embodied_prompt, + ) + content = [{"type": "image", "image": img} for img in sample_images] + content.append({"type": "text", "text": prompt}) + messages.append([{"role": "user", "content": content}]) + + batch_inputs = self.processor.apply_chat_template( + messages, + tokenize=True, + padding=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + ) + return batch_inputs.to(self.model.device) + + @staticmethod + def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image: + image = image_tensor.detach().cpu() + if image.ndim == 3 and image.shape[0] in (1, 3): + image = image.permute(1, 2, 0) + image = image.float() + if image.max() <= 1.0: + image = image * 255.0 + image = image.clamp(0, 255).to(torch.uint8).numpy() + if image.shape[-1] == 1: + image = np.repeat(image, 3, axis=-1) + return Image.fromarray(image) diff --git a/src/lerobot/policies/vla_jepa/world_model.py b/src/lerobot/policies/vla_jepa/world_model.py new file mode 100644 index 000000000..4e32706eb --- /dev/null +++ b/src/lerobot/policies/vla_jepa/world_model.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import torch +from torch import nn + + +def build_block_causal_attention_mask(num_steps: int, tokens_per_step: int, cond_tokens: int) -> torch.Tensor: + total_tokens = num_steps * (tokens_per_step + cond_tokens) + mask = torch.full((total_tokens, total_tokens), float("-inf")) + for current_step in range(num_steps): + row_start = current_step * (tokens_per_step + cond_tokens) + row_end = row_start + tokens_per_step + cond_tokens + allowed_end = row_end + mask[row_start:row_end, :allowed_end] = 0 + return mask + + +class ActionConditionedVideoPredictor(nn.Module): + def __init__( + self, + embed_dim: int, + action_embed_dim: int, + predictor_embed_dim: int, + depth: int, + num_heads: int, + mlp_ratio: float, + num_action_tokens_per_step: int, + ) -> None: + super().__init__() + self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim) + self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim) + encoder_layer = nn.TransformerEncoderLayer( + d_model=predictor_embed_dim, + nhead=num_heads, + dim_feedforward=int(predictor_embed_dim * mlp_ratio), + dropout=0.0, + activation="gelu", + batch_first=True, + ) + self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth) + self.norm = nn.LayerNorm(predictor_embed_dim) + self.proj = nn.Linear(predictor_embed_dim, embed_dim) + self.num_action_tokens_per_step = num_action_tokens_per_step + + def forward(self, frame_tokens: torch.Tensor, action_tokens: torch.Tensor) -> torch.Tensor: + batch_size, num_steps, tokens_per_frame, _ = frame_tokens.shape + _, action_steps, _, _ = action_tokens.shape + if action_steps != num_steps: + raise ValueError(f"Expected {num_steps} action steps, got {action_steps}.") + + frame_tokens = self.predictor_embed(frame_tokens) + action_tokens = self.action_encoder(action_tokens) + fused_steps = [] + for step in range(num_steps): + fused_steps.append(torch.cat([action_tokens[:, step], frame_tokens[:, step]], dim=1)) + fused = torch.cat(fused_steps, dim=1) + + attn_mask = build_block_causal_attention_mask( + num_steps=num_steps, + tokens_per_step=tokens_per_frame, + cond_tokens=self.num_action_tokens_per_step, + ).to(device=fused.device, dtype=fused.dtype) + encoded = self.encoder(fused, mask=attn_mask) + encoded = encoded.view(batch_size, num_steps, self.num_action_tokens_per_step + tokens_per_frame, -1) + predicted_frame_tokens = encoded[:, :, self.num_action_tokens_per_step :, :] + return self.proj(self.norm(predicted_frame_tokens))