diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index ce36fad5d..0d4e36172 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -63,6 +63,8 @@ title: π₀.₅ (Pi05) - local: molmoact2 title: MolmoAct2 + - local: vla_jepa + title: VLA-JEPA - local: eo1 title: EO-1 - local: groot diff --git a/docs/source/policy_vla_jepa_README.md b/docs/source/policy_vla_jepa_README.md new file mode 100644 index 000000000..70cdbd6b5 --- /dev/null +++ b/docs/source/policy_vla_jepa_README.md @@ -0,0 +1,39 @@ +# VLA-JEPA + +This repository contains the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head. + +Converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA). + +--- + +## Architecture Overview + +| Component | Module | Role | +| ----------------------- | --------------------------------- | ------------------------------------------------------- | +| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens | +| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk | +| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) | + +At inference time only the Qwen backbone and action head are used; the world model is not needed. + +--- + +## Citation + +```bibtex +@misc{sun2026vlajepaenhancingvisionlanguageactionmodel, + title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model}, + author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen}, + year = {2026}, + eprint = {2602.10098}, + archivePrefix = {arXiv}, + primaryClass = {cs.RO}, + url = {https://arxiv.org/abs/2602.10098}, +} +``` + +--- + +## License + +Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**. diff --git a/docs/source/vla_jepa.mdx b/docs/source/vla_jepa.mdx new file mode 100644 index 000000000..ad37b2349 --- /dev/null +++ b/docs/source/vla_jepa.mdx @@ -0,0 +1,235 @@ +# VLA-JEPA + +This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head. + +--- + +## Architecture Overview + +VLA-JEPA has three main components: + +| Component | Module | Role | +| ----------------------- | --------------------------------- | ------------------------------------------------------- | +| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens | +| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk | +| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) | + +### Data flow + +**Training:** + +1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens. +2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens. +3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching. +4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss. + +**Inference:** +Only Qwen + the action head are used. The world model is not needed at inference time. + +### Action head details + +Available presets via `action_model_type`: + +| Preset | Hidden dim | Heads | Head dim | +| ------- | ---------- | ----- | -------- | +| `DiT-B` | 768 | 12 | 64 | +| `DiT-L` | 1536 | 32 | 48 | + +### World model details + +The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes: + +- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim` +- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim` + +It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints). + +--- + +## Pretrained Checkpoints + +Three checkpoints are available directly inside the LeRobot org here: [`lerobot/VLA-JEPA`](https://huggingface.co/collections/lerobot/vla-jepa), converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA): + +| Checkpoint | Dataset | Cameras | World model | Action dim | +| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- | +| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 | +| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 | +| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 (view duplicated ×2) | Enabled | 7 | + +All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone. + +--- + +## Configuration + +Key parameters in `VLAJEPAConfig`: + +| Parameter | Default | Description | +| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `chunk_size` | 7 | Number of actions predicted per inference call | +| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning | +| `num_video_frames` | 8 | Video clip length fed to the world model | +| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor | +| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss | +| `num_inference_timesteps` | 4 | Euler integration steps for action denoising | +| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head | +| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) | +| `gripper_dim` | 6 | Index of the gripper dimension in the action vector (e.g. 6 for a 7-DoF arm with gripper as the last joint) | +| `gripper_threshold` | 0.5 | Threshold used by `pre_snap_gripper_action` and `binarize_gripper_action` to binarize the gripper dimension | +| `pre_snap_gripper_action` | `True` | Snap the gripper dim to {0, 1} before unnormalization. Set to `False` for robots without a binary gripper | +| `binarize_gripper_action` | `True` | Binarize the gripper dim to {-1, 1} after unnormalization. Set to `False` for robots without a binary gripper | + +--- + +## Training + +Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints. + +### Full training from scratch + +```bash +lerobot-train \ + policy.type=vla_jepa \ + policy.repo_id=your_org/your_repo \ + dataset.repo_id=your_org/your_dataset +``` + +### Fine-tuning from a pretrained checkpoint + +```bash +lerobot-train \ + --policy.path=lerobot/VLA-JEPA-Pretrain \ + --policy.repo_id=your_org/your_repo \ + --dataset.repo_id=your_org/your_dataset +``` + +If you want to freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`: + +```bash +lerobot-train \ + --policy.path=lerobot/VLA-JEPA-Pretrain \ + --policy.repo_id=your_org/your_repo \ + --policy.freeze_qwen=true \ + --dataset.repo_id=your_org/your_dataset +``` + +### Fine-tuning on a different embodiment + +When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error. + +The layers that depend on `action_dim` and `state_dim` are: + +| Layer | Key prefix | +| ----------------------------------------- | ----------------------------------- | +| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` | +| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` | +| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` | + +```bash +lerobot-train \ + --policy.path=lerobot/VLA-JEPA-Pretrain \ + --policy.repo_id=your_org/your_repo \ + --policy.freeze_qwen=true \ + --policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \ + --dataset.repo_id=your_org/your_dataset +``` + +If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list. + +### Reproducing the LIBERO results + +**Training on LIBERO:** +starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset. +Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256. + +```bash +lerobot-train \ + --policy.path=lerobot/VLA-JEPA-Pretrain \ + --policy.repo_id=your_org/your_repo \ + --dataset.repo_id=HuggingFaceVLA/libero \ + --steps=30000 +``` + +**Evaluating the pretrained LIBERO-10 checkpoint:** + +```bash +lerobot-eval \ + --policy.path=lerobot/VLA-JEPA-LIBERO \ + --env.type=libero \ + --env.task=libero_spatial,libero_object,libero_goal,libero_10 \ + --eval.n_episodes=10 \ + --eval.batch_size=5 +``` + +To evaluate a subset of tasks only: + +```bash +lerobot-eval \ + --policy.path=lerobot/VLA-JEPA-LIBERO \ + --env.type=libero \ + --env.task=libero_10 \ + --env.task_ids='[0,1,2]' \ + --eval.n_episodes=10 \ + --eval.batch_size=5 +``` + +**Expected results:** + +| Suite | Episodes | Successes | Success Rate | +| -------------- | -------- | --------- | ------------ | +| libero_spatial | 100 | 93 | **95.0%** | +| libero_object | 100 | 100 | **100.0%** | +| libero_goal | 100 | 98 | **98.0%** | +| libero_10 | 100 | 96 | **93.0%** | +| **Overall** | **400** | **387** | **96.5%** | + +--- + +## Fine-tuning on datasets with a different number of cameras + +The pretrained world model predictor was trained with `embed_dim = jepa_tubelet_size × 1024` (default `jepa_tubelet_size=2`). + +**Default behaviour — view padding / trimming (no action required)** + +When fine-tuning from `VLA-JEPA-Pretrain` the model automatically adjusts the number of views fed to the world model to match `jepa_tubelet_size`: + +- **Single-view datasets (e.g. BridgeV2):** the single-view latent is duplicated to produce a two-view world-model input, preserving the JEPA self-supervised signal without any weight mismatch. +- **>2-view datasets (e.g. DROID with 3 views):** all views are passed to the Qwen backbone (for richer context), but only the first `jepa_tubelet_size` views (one wrist + one third-person, following the configured view order) are used for the world model. + +**Option 1 — Disable the world model** + +Set `enable_world_model=False` to skip the JEPA loss entirely. Only the Qwen backbone and action head are loaded and trained. This is sufficient for good action performance. + +```bash +lerobot-train \ + --policy.path=lerobot/VLA-JEPA-Pretrain \ + --policy.enable_world_model=false \ + --policy.repo_id=your_org/your_repo \ + --dataset.repo_id=your_org/single_camera_dataset +``` + +**Option 2 — Reinitialize the predictor input projection** + +If you want to change `jepa_tubelet_size` to a value other than 2, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint. + +--- + +## Citation + +```bibtex +@misc{sun2026vlajepaenhancingvisionlanguageactionmodel, + title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model}, + author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen}, + year = {2026}, + eprint = {2602.10098}, + archivePrefix = {arXiv}, + primaryClass = {cs.RO}, + url = {https://arxiv.org/abs/2602.10098}, +} +``` + +--- + +## License + +Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**. diff --git a/pyproject.toml b/pyproject.toml index ef7a36873..2b4c22f12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -217,6 +217,7 @@ topreward = ["lerobot[transformers-dep]"] xvla = ["lerobot[transformers-dep]"] eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"] hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] +vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"] # Features async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"] @@ -282,6 +283,7 @@ all = [ # "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn "lerobot[xvla]", "lerobot[hilserl]", + "lerobot[vla_jepa]", "lerobot[async]", "lerobot[dev]", "lerobot[test]", diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 05fda05d8..a42b38ba4 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -57,6 +57,7 @@ from .pretrained import PreTrainedPolicy from .smolvla.configuration_smolvla import SmolVLAConfig from .tdmpc.configuration_tdmpc import TDMPCConfig from .utils import validate_visual_features_consistency +from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig from .vqbet.configuration_vqbet import VQBeTConfig from .wall_x.configuration_wall_x import WallXConfig from .xvla.configuration_xvla import XVLAConfig @@ -157,6 +158,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from .molmoact2.modeling_molmoact2 import MolmoAct2Policy return MolmoAct2Policy + elif name == "vla_jepa": + from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy + + return VLAJEPAPolicy else: try: return _get_policy_cls_from_policy_name(name=name) @@ -211,6 +216,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return EO1Config(**kwargs) elif policy_type == "molmoact2": return MolmoAct2Config(**kwargs) + elif policy_type == "vla_jepa": + return VLAJEPAConfig(**kwargs) else: try: config_cls = PreTrainedConfig.get_choice_class(policy_type) @@ -415,6 +422,7 @@ def make_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, EO1Config): from .eo1.processor_eo1 import make_eo1_pre_post_processors @@ -432,6 +440,14 @@ def make_pre_post_processors( dataset_meta=kwargs.get("dataset_meta"), ) + elif isinstance(policy_cfg, VLAJEPAConfig): + from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors + + processors = make_vla_jepa_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + else: try: processors = _make_processors_from_policy_config( diff --git a/src/lerobot/policies/vla_jepa/README.md b/src/lerobot/policies/vla_jepa/README.md new file mode 120000 index 000000000..2aaa91cca --- /dev/null +++ b/src/lerobot/policies/vla_jepa/README.md @@ -0,0 +1 @@ +../../../../docs/source/policy_vla_jepa_README.md \ No newline at end of file diff --git a/src/lerobot/policies/vla_jepa/__init__.py b/src/lerobot/policies/vla_jepa/__init__.py new file mode 100644 index 000000000..453a4c9a4 --- /dev/null +++ b/src/lerobot/policies/vla_jepa/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_vla_jepa import VLAJEPAConfig +from .modeling_vla_jepa import VLAJEPAPolicy +from .processor_vla_jepa import make_vla_jepa_pre_post_processors + +__all__ = [ + "VLAJEPAConfig", + "VLAJEPAPolicy", + "make_vla_jepa_pre_post_processors", +] diff --git a/src/lerobot/policies/vla_jepa/action_head.py b/src/lerobot/policies/vla_jepa/action_head.py new file mode 100644 index 000000000..d62953abf --- /dev/null +++ b/src/lerobot/policies/vla_jepa/action_head.py @@ -0,0 +1,337 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections import OrderedDict +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch +import torch.nn.functional as F # noqa: N812 +from torch import nn +from torch.distributions import Beta + +from lerobot.utils.import_utils import _diffusers_available, require_package + +if TYPE_CHECKING or _diffusers_available: + from diffusers import ConfigMixin, ModelMixin + from diffusers.configuration_utils import register_to_config + from diffusers.models.attention import Attention, FeedForward + from diffusers.models.embeddings import TimestepEmbedding, Timesteps +else: + + class ModelMixin: # type: ignore[no-redef] + pass + + class ConfigMixin: # type: ignore[no-redef] + pass + + register_to_config = lambda f: f # noqa: E731 + Attention = FeedForward = TimestepEmbedding = Timesteps = None + +from .configuration_vla_jepa import VLAJEPAConfig + + +class SinusoidalPositionalEncoding(nn.Module): + def __init__(self, embedding_dim: int): + super().__init__() + self.embedding_dim = embedding_dim + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + timesteps = timesteps.float() + batch_size, seq_len = timesteps.shape + half_dim = self.embedding_dim // 2 + exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device) + exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1)) + freqs = timesteps.unsqueeze(-1) * exponent.exp() + return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1) + + +class ActionEncoder(nn.Module): + def __init__(self, action_dim: int, hidden_size: int): + super().__init__() + self.layer1 = nn.Linear(action_dim, hidden_size) + self.layer2 = nn.Linear(hidden_size * 2, hidden_size) + self.layer3 = nn.Linear(hidden_size, hidden_size) + self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) + + def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, _ = actions.shape + if timesteps.ndim != 1 or timesteps.shape[0] != batch_size: + raise ValueError("timesteps must have shape [batch_size].") + timesteps = timesteps.unsqueeze(1).expand(-1, seq_len) + action_emb = self.layer1(actions) + time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype) + return self.layer3(F.silu(self.layer2(torch.cat([action_emb, time_emb], dim=-1)))) + + +class TimestepEncoder(nn.Module): + def __init__(self, embedding_dim: int): + super().__init__() + require_package("diffusers", extra="vla_jepa") + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype) + return self.timestep_embedder(projected) + + +class AdaLayerNorm(nn.Module): + def __init__(self, embedding_dim: int): + super().__init__() + self.linear = nn.Linear(embedding_dim, embedding_dim * 2) + self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False) + self.silu = nn.SiLU() + + def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: + scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1) + return self.norm(x) * (1 + scale[:, None]) + shift[:, None] + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float, + cross_attention_dim: int, + is_cross_attention: bool = True, + ) -> None: + super().__init__() + self.is_cross_attention = is_cross_attention + self.norm1 = AdaLayerNorm(dim) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=True, + cross_attention_dim=cross_attention_dim, + out_bias=True, + ) + self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False) + self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None, + temb: torch.Tensor, + ) -> torch.Tensor: + attn_input = self.norm1(hidden_states, temb) + attention_context = encoder_hidden_states if self.is_cross_attention else None + hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context) + hidden_states = hidden_states + self.ff(self.norm2(hidden_states)) + return hidden_states + + +class DiT(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + output_dim: int, + num_layers: int, + dropout: float, + cross_attention_dim: int, + ) -> None: + super().__init__() + self.inner_dim = num_attention_heads * attention_head_dim + self.timestep_encoder = TimestepEncoder(self.inner_dim) + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim, + is_cross_attention=layer_idx % 2 == 0, + ) + for layer_idx in range(num_layers) + ] + ) + self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False) + self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2) + self.proj_out_2 = nn.Linear(self.inner_dim, output_dim) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.Tensor, + ) -> torch.Tensor: + temb = self.timestep_encoder(timestep) + x = hidden_states + for block in self.transformer_blocks: + x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb) + shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1) + x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None] + return self.proj_out_2(x) + + +@dataclass +class ActionModelPreset: + hidden_size: int + attention_head_dim: int + num_attention_heads: int + + +DIT_PRESETS = { + "DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12), + "DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32), + "DiT-test": ActionModelPreset(hidden_size=16, attention_head_dim=8, num_attention_heads=2), +} + + +class VLAJEPAActionHead(nn.Module): + def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None: + super().__init__() + preset = DIT_PRESETS[config.action_model_type] + self.config = config + num_heads = config.action_num_heads or preset.num_attention_heads + head_dim = config.action_attention_head_dim or preset.attention_head_dim + inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768 + + self.input_embedding_dim = inner_dim + self.action_horizon = config.chunk_size + self.num_inference_timesteps = config.num_inference_timesteps + + hidden_size = config.action_hidden_size + self.model = DiT( + num_attention_heads=num_heads, + attention_head_dim=head_dim, + output_dim=hidden_size, + num_layers=config.action_num_layers, + dropout=config.action_dropout, + cross_attention_dim=cross_attention_dim, + ) + self.action_encoder = ActionEncoder(config.action_dim, inner_dim) + self.action_decoder = nn.Sequential( + OrderedDict( + [ + ("layer1", nn.Linear(hidden_size, hidden_size)), + ("relu", nn.ReLU()), + ("layer2", nn.Linear(hidden_size, config.action_dim)), + ] + ) + ) + self.state_encoder = ( + nn.Sequential( + OrderedDict( + [ + ("layer1", nn.Linear(config.state_dim, hidden_size)), + ("relu", nn.ReLU()), + ("layer2", nn.Linear(hidden_size, inner_dim)), + ] + ) + ) + if config.state_dim > 0 + else None + ) + self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim) + self.position_embedding = nn.Embedding( + max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4), + inner_dim, + ) + self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta) + + def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype) + return (self.config.action_noise_s - sample) / self.config.action_noise_s + + def _build_inputs( + self, + conditioning_tokens: torch.Tensor, + actions: torch.Tensor, + state: torch.Tensor | None, + timesteps: torch.Tensor, + ) -> torch.Tensor: + action_features = self.action_encoder(actions, timesteps) + pos_ids = torch.arange(action_features.shape[1], device=actions.device) + action_features = action_features + self.position_embedding(pos_ids)[None] + + future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1) + seq = [future_tokens, action_features] + if state is not None and self.state_encoder is not None: + if state.ndim == 2: + state = state.unsqueeze(1) + seq.insert(0, self.state_encoder(state)) + return torch.cat(seq, dim=1) + + def forward( + self, + conditioning_tokens: torch.Tensor, + actions: torch.Tensor, + state: torch.Tensor | None = None, + action_is_pad: torch.Tensor | None = None, + ) -> torch.Tensor: + noise = torch.randn_like(actions) + t = self.sample_time(actions.shape[0], actions.device, actions.dtype) + noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions + velocity = actions - noise + t_discretized = (t * self.config.action_num_timestep_buckets).long() + + hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized) + pred = self.model( + hidden_states=hidden_states, + encoder_hidden_states=conditioning_tokens, + timestep=t_discretized, + ) + pred_actions = self.action_decoder(pred[:, -actions.shape[1] :]) + + if action_is_pad is None: + action_is_pad = torch.zeros(actions.shape[:2], dtype=torch.bool, device=actions.device) + + loss = F.mse_loss(pred_actions, velocity, reduction="none") # [B, T, action_dim] + valid_mask = ~action_is_pad.unsqueeze(-1) # [B, T, 1] + num_valid = valid_mask.sum() * loss.shape[-1] + return (loss * valid_mask).sum() / num_valid.clamp_min(1) + + @torch.no_grad() + def predict_action( + self, + conditioning_tokens: torch.Tensor, + state: torch.Tensor | None = None, + ) -> torch.Tensor: + batch_size = conditioning_tokens.shape[0] + actions = torch.randn( + batch_size, + self.action_horizon, + self.config.action_dim, + dtype=conditioning_tokens.dtype, + device=conditioning_tokens.device, + ) + dt = 1.0 / max(self.num_inference_timesteps, 1) + for step in range(self.num_inference_timesteps): + t_cont = step / float(max(self.num_inference_timesteps, 1)) + t_value = int(t_cont * self.config.action_num_timestep_buckets) + timesteps = torch.full( + (batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long + ) + hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps) + pred = self.model( + hidden_states=hidden_states, + encoder_hidden_states=conditioning_tokens, + timestep=timesteps, + ) + pred_velocity = self.action_decoder(pred[:, -self.action_horizon :]) + actions = actions + dt * pred_velocity + return actions diff --git a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py new file mode 100644 index 000000000..8a30ee374 --- /dev/null +++ b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py @@ -0,0 +1,154 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig + + +@PreTrainedConfig.register_subclass("vla_jepa") +@dataclass +class VLAJEPAConfig(PreTrainedConfig): + n_obs_steps: int = 1 + chunk_size: int = 7 + n_action_steps: int = 7 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, + "ACTION": NormalizationMode.MIN_MAX, + } + ) + + qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct" + jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256" + freeze_qwen: bool = False + enable_world_model: bool = True + # Enables cross-embodiment transfer: when fine-tuning a pretrained model on a robot with a + # different action or state dimensionality, the input/output projection layers must be + # re-initialised from scratch while the rest of the network keeps its pretrained weights. + # List the key prefixes that are allowed to have shape mismatches; anything else raises an error. + # e.g. ["model.action_model.action_encoder", "model.action_model.state_encoder"] + reinit_modules: list[str] | None = None + + tokenizer_padding_side: str = "left" + prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}." + special_action_token: str = "<|action_{}|>" + embodied_action_token: str = "<|embodied_action|>" + + action_dim: int = 7 + state_dim: int = 8 + + num_action_tokens_per_timestep: int = 8 + num_embodied_action_tokens_per_instruction: int = 32 + num_inference_timesteps: int = 4 + + action_hidden_size: int = 1024 + action_model_type: str = "DiT-B" + action_num_layers: int = 16 + action_num_heads: int | None = None + action_attention_head_dim: int | None = None + action_dropout: float = 0.2 + action_num_timestep_buckets: int = 1000 + action_noise_beta_alpha: float = 1.5 + action_noise_beta_beta: float = 1.0 + action_noise_s: float = 0.999 + num_target_vision_tokens: int = 32 + action_max_seq_len: int = 1024 + + # total video frames loaded per sample + num_video_frames: int = 8 + predictor_depth: int = 12 + predictor_num_heads: int = 8 + predictor_mlp_ratio: float = 4.0 + predictor_dropout: float = 0.0 + world_model_loss_weight: float = 0.1 + jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256) + repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style) + + resize_images_to: tuple[int, int] | None = None + binarize_gripper_action: bool = True + pre_snap_gripper_action: bool = True + clip_normalized_actions: bool = True + gripper_dim: int = 6 + gripper_threshold: float = 0.5 + torch_dtype: str = "bfloat16" + + optimizer_lr: float = 1e-4 + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 1e-10 + optimizer_grad_clip_norm: float = 10.0 + scheduler_warmup_steps: int = 1_000 + scheduler_decay_steps: int = 30_000 + scheduler_decay_lr: float = 2.5e-6 + + def __post_init__(self) -> None: + super().__post_init__() + if self.freeze_qwen and self.enable_world_model: + # freezing qwen backbone makes world model training irrelevant since no grad flows + self.enable_world_model = False + if self.n_action_steps > self.chunk_size: + raise ValueError("`n_action_steps` must be <= `chunk_size`.") + if self.num_video_frames < 2 * self.jepa_tubelet_size: + raise ValueError( + f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` " + f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position." + ) + + def validate_features(self) -> None: + if not self.image_features: + raise ValueError("VLAJEPA requires at least one visual input feature.") + if self.action_feature is None: + raise ValueError("VLAJEPA requires an action output feature.") + self.action_dim = self.action_feature.shape[0] + if self.robot_state_feature is not None: + self.state_dim = self.robot_state_feature.shape[0] + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, + ) + + def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig: + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> list[int]: + # load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1] + # matches original repo's observation_indices=list(range(video_horizon)) + return list(range(self.num_video_frames)) + + @property + def action_delta_indices(self) -> list[int]: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py new file mode 100644 index 000000000..45d83e652 --- /dev/null +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -0,0 +1,629 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from collections import deque +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np +import torch +import torch.nn.functional as F # noqa: N812 +from PIL import Image +from torch import Tensor, nn + +from lerobot.policies.pretrained import PreTrainedPolicy, T +from lerobot.policies.utils import populate_queues +from lerobot.utils.constants import ACTION, OBS_STATE +from lerobot.utils.import_utils import _transformers_available, require_package + +if TYPE_CHECKING or _transformers_available: + from transformers import AutoModel, AutoVideoProcessor +else: + AutoModel = None + AutoVideoProcessor = None + +from .action_head import VLAJEPAActionHead +from .configuration_vla_jepa import VLAJEPAConfig +from .qwen_interface import Qwen3VLInterface +from .world_model import ActionConditionedVideoPredictor + +# ============================================================================ +# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation +# ============================================================================ + + +class VLAJEPAModel(nn.Module): + """ + Native VLA-JEPA model following the original starVLA VLA_JEPA.py. + + Components: + - Qwen3-VL: vision-language backbone for fused embeddings + - DiT-B: flow-matching action head for future action prediction + - V-JEPA: world model for video frame prediction + + Input: List[dict] native format (same as original starVLA) + - "image": List[PIL.Image] (multi-view images) + - "video": np.ndarray [V, T, H, W, 3] + - "lang": str (task instruction) + - "action": np.ndarray [T, action_dim] (optional, training only) + - "state": np.ndarray [1, state_dim] (optional) + """ + + def __init__(self, config: VLAJEPAConfig) -> None: + super().__init__() + require_package("transformers", extra="vla_jepa") + self.config = config + + # Vision-language backbone + self.qwen = Qwen3VLInterface(config) + + # Tokenizer expansion for special action tokens + self.action_tokens, self.action_token_ids, self.embodied_action_token_id = ( + self.qwen.expand_tokenizer() + ) + + # Action head (flow-matching DiT) + self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size) + + # JEPA world model components + if config.enable_world_model: + self.video_encoder = AutoModel.from_pretrained( + config.jepa_encoder_name, + torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype), + ) + self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name) + num_views = config.jepa_tubelet_size + tubelet_size = self.video_encoder.config.tubelet_size + image_size = getattr(self.video_encoder.config, "image_size", None) + if image_size is None: + first_image_shape = next(iter(config.image_features.values())).shape + image_size = first_image_shape[-1] + self.video_predictor = ActionConditionedVideoPredictor( + num_frames=config.num_video_frames // tubelet_size, + img_size=(image_size, image_size), + patch_size=16, + tubelet_size=1, + embed_dim=self.video_encoder.config.hidden_size * num_views, + action_embed_dim=self.qwen.model.config.hidden_size, + predictor_embed_dim=self.video_encoder.config.hidden_size, + depth=config.predictor_depth, + num_heads=config.predictor_num_heads, + mlp_ratio=config.predictor_mlp_ratio, + num_action_tokens_per_step=config.num_action_tokens_per_timestep, + ) + else: + self.video_encoder = None + self.video_processor = None + self.video_predictor = None + + if config.freeze_qwen: + self.qwen.requires_grad_(False) + + # Build prompt placeholders. + # Use the encoder's actual tubelet_size when available (world model enabled), + # otherwise fall back to config. + _tubelet_size = ( + self.video_encoder.config.tubelet_size + if config.enable_world_model + else self.config.jepa_tubelet_size + ) + num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1 + self.replace_prompt = "".join( + token * self.config.num_action_tokens_per_timestep + for token in self.action_tokens[:num_action_prompt_steps] + ) + self.embodied_replace_prompt = ( + self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction + ) + + def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor: + """Return the last decoder hidden state before the final RMSNorm. + + The model was trained with the output of the last transformer block BEFORE + the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from + `output_hidden_states=True` is post-norm (tied to `last_hidden_state` via + `@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers + the correct pre-RMSNorm state, matching the training-time representation. + """ + captured: list[torch.Tensor] = [] + + def _hook(module, input, output): + h = output[0] if isinstance(output, tuple) else output + captured.append(h) + + last_layer = self.qwen.model.model.language_model.layers[-1] + handle = last_layer.register_forward_hook(_hook) + try: + self.qwen.model( + **qwen_inputs, + output_hidden_states=False, + output_attentions=False, + return_dict=True, + ) + finally: + handle.remove() + + return captured[0] # [B, seq_len, H] + + # ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ---- + + def forward(self, examples: list[dict]) -> dict[str, Tensor]: + """ + Native forward pass following original starVLA VLA_JEPA.forward. + + Args: + examples: List of per-sample dicts with keys: + "image" : List[PIL.Image] — multi-view images + "video" : np.ndarray [V, T, H, W, 3] + "lang" : str — task instruction + "action" : np.ndarray [T, action_dim] (optional) + "state" : np.ndarray [1, state_dim] (optional) + + Returns: + dict with "action_loss" and "wm_loss" keys (scalar Tensors). + """ + # Unpack native format (same pattern as original VLA_JEPA.py) + batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]] + batch_videos = [ex["video"] for ex in examples] # List[np.ndarray] + instructions = [ex["lang"] for ex in examples] # List[str] + has_action = "action" in examples[0] and examples[0]["action"] is not None + actions = [ex["action"] for ex in examples] if has_action else None + has_state = "state" in examples[0] and examples[0]["state"] is not None + state = [ex["state"] for ex in examples] if has_state else None + action_is_pad = ( + [ex["action_is_pad"] for ex in examples] + if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None + else None + ) + + # Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W] + batch_videos = np.stack(batch_videos) + batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W] + + # Adjust number of views for the world model: + # - fewer views than expected: duplicate the first view to fill up + # - more views than expected: keep only the first num_views_world_model views + num_views_world_model = self.config.jepa_tubelet_size + if batch_videos.shape[1] < num_views_world_model: + num_missing_views = num_views_world_model - batch_videos.shape[1] + first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1) + batch_videos = np.concatenate([batch_videos, first_view], axis=1) + elif batch_videos.shape[1] > num_views_world_model: + batch_videos = batch_videos[:, :num_views_world_model] + + # ---- Step 1: QwenVL encode (same as original) ---- + qwen_inputs = self.qwen.build_inputs( + images=batch_images, + instructions=instructions, + action_prompt=self.replace_prompt, + embodied_prompt=self.embodied_replace_prompt, + ) + + # Locate embodied-action tokens (always needed for action head) + embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id + embodied_indices = embodied_mask.nonzero(as_tuple=True) + + # Locate action tokens (only needed for world model predictor) + if self.config.enable_world_model: + action_mask = torch.isin( + qwen_inputs["input_ids"], + torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device), + ) + action_indices = action_mask.nonzero(as_tuple=True) + + device_type = next(self.parameters()).device.type + + with torch.autocast(device_type=device_type, dtype=torch.bfloat16): + last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H] + b, _, h = last_hidden.shape + + if self.config.enable_world_model: + action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h) + + embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h) + + # ---- Step 2+3: JEPA Encoder + Predictor ---- + device_wm = last_hidden.device + if not self.config.enable_world_model: + wm_loss = torch.tensor(0.0, device=device_wm) + else: + b, v, t_frames, c, h_img, w_img = batch_videos.shape + batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img) + + video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[ + "pixel_values_videos" + ].to(self.video_encoder.device) # [B*V, T, C, H, W] + + with torch.no_grad(): + video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels) + # Merge views: [B*V, ...] -> [B, ..., V*embed_dim] + video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2) + + tubelet_size = self.video_encoder.config.tubelet_size + device_wm = video_embeddings.device + # num_video_frames raw frames → t_enc_total temporal positions after tubelet compression + t_enc_total = self.config.num_video_frames // tubelet_size + + if t_enc_total < 2: + wm_loss = torch.tensor(0.0, device=device_wm) + else: + # Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232): + # input_states: positions 0..T-2, gt_states: positions 1..T-1 + t_enc_ctx = t_enc_total - 1 + tokens_per_frame = video_embeddings.shape[1] // t_enc_total + + input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :] + gt_states = video_embeddings[:, tokens_per_frame:, :] + + expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep + if action_tokens.shape[1] < expected_actions: + pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1) + action_tokens = torch.cat([action_tokens, pad], dim=1) + + predicted_states = self.video_predictor( + input_states.float(), + action_tokens[:, :expected_actions].float(), + ) + + wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean") + + if not has_action: + return {"wm_loss": wm_loss} + + # ---- Step 4: Action Head ---- + with torch.autocast(device_type=device_type, dtype=torch.float32): + actions_tensor = torch.tensor( + np.array(actions), device=last_hidden.device, dtype=torch.float32 + ) # [B, T_full, action_dim] + action_horizon = self.config.chunk_size + actions_target = actions_tensor[:, -action_horizon:, :] + + state_tensor = None + if state is not None: + state_tensor = torch.tensor( + np.array(state), device=last_hidden.device, dtype=last_hidden.dtype + ) # [B, 1, state_dim] + + repeated_diffusion_steps = self.config.repeated_diffusion_steps + actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1) + embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1) + if state_tensor is not None: + state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1) + + action_is_pad_rep = None + if action_is_pad is not None: + pad_tensor = torch.stack( + [ + p.to(actions_target.device) + if isinstance(p, Tensor) + else torch.tensor(p, device=actions_target.device) + for p in action_is_pad + ] + ) # [B, T_full] + pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon] + action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon] + + action_loss = self.action_model( + embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep + ) + + return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight} + + # ---- Native predict_action (follows original VLA_JEPA.predict_action) ---- + + @torch.no_grad() + def predict_action( + self, + batch_images: list[list[Image.Image]], + instructions: list[str], + state: np.ndarray | None = None, + ) -> np.ndarray: + """ + Native action prediction following original VLA_JEPA.predict_action. + + Args: + batch_images: List of samples; each is List[PIL.Image] (multi-view). + instructions: Task instructions, one per sample. + state: Optional [B, state_dim] numpy array. + + Returns: + np.ndarray [B, action_horizon, action_dim] — predicted actions. + """ + if self.config.resize_images_to is not None: + height, width = self.config.resize_images_to + resampling = getattr(Image, "Resampling", Image).BOX + batch_images = [ + [image.resize((width, height), resample=resampling) for image in sample_images] + for sample_images in batch_images + ] + + qwen_inputs = self.qwen.build_inputs( + images=batch_images, + instructions=instructions, + action_prompt=self.replace_prompt, + embodied_prompt=self.embodied_replace_prompt, + ) + + embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id + embodied_indices = embodied_mask.nonzero(as_tuple=True) + + device_type = next(self.parameters()).device.type + + with torch.autocast(device_type=device_type, dtype=torch.bfloat16): + last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H] + b, _, h = last_hidden.shape + embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h) + + state_tensor = None + if state is not None: + state_tensor = torch.from_numpy(np.array(state)).to( + device=last_hidden.device, dtype=last_hidden.dtype + ) + + pred_actions = self.action_model.predict_action( + embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None + ) # [B, action_horizon, action_dim] + + return pred_actions.detach().cpu().numpy() + + +# ============================================================================ +# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format +# ============================================================================ + + +class VLAJEPAPolicy(PreTrainedPolicy): + """ + LeRobot adapter for VLA-JEPA. + + Converts LeRobot's standard batch format (dict[str, Tensor]) to the native + VLA-JEPA format (List[dict]), calls the native model, and converts outputs + back to LeRobot format. + """ + + config_class = VLAJEPAConfig + name = "vla_jepa" + + def __init__(self, config: VLAJEPAConfig, **kwargs) -> None: + super().__init__(config) + config.validate_features() + if dataset_meta := kwargs.get("dataset_meta"): + # cfg.input_features keeps the pretrained model's feature keys (needed for rename_map + # compatibility), so validate_features() may have read stale dims from a pretrained + # config. Override state_dim/action_dim from the actual dataset being used. + ds_features = dataset_meta.features + if OBS_STATE in ds_features: + config.state_dim = ds_features[OBS_STATE]["shape"][0] + if ACTION in ds_features: + config.action_dim = ds_features[ACTION]["shape"][0] + + self.model = VLAJEPAModel(config) + self.reset() + + def reset(self) -> None: + self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)} + + # ---- Format Conversion: LeRobot → Native ---- + + def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]: + """ + Convert LeRobot batch format to native VLA-JEPA examples format. + + LeRobot format: + batch = { + "observation.images.": Tensor [B, C, H, W] or [B, T, C, H, W], + "observation.state": Tensor [B, state_dim] or [B, T, state_dim], + "action": Tensor [B, chunk_size, action_dim], (training only) + "task": str | List[str], (optional instruction) + } + + Native format (List[dict]): + { + "image": List[PIL.Image], # multi-view images per sample + "video": np.ndarray [V, T, H, W, 3], + "lang": str, # task instruction + "action": np.ndarray [T, action_dim], # optional + "state": np.ndarray [1, state_dim], # optional + } + """ + # Determine batch size from the first image feature + image_keys = list(self.config.image_features.keys()) + if not image_keys: + raise ValueError("VLAJEPA requires at least one image feature.") + first_key = image_keys[0] + first_tensor = batch[first_key] + batch_size = first_tensor.shape[0] + + # ---- Collect images per sample ---- + # images_per_sample[b][v] = PIL.Image for view v + images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)] + for key in image_keys: + tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W] + if tensor.ndim == 5: + # observation_delta_indices = [0, 1, ..., num_video_frames-1] + # index 0 is the current observation (delta=0) + tensor = tensor[:, 0] + for b in range(batch_size): + images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b])) + + # ---- Collect videos per sample ---- + # Build video arrays: for each sample, stack views as [V, T, H, W, 3] + # Check whether any image feature has a time dimension + video_source = None + for k in image_keys: + if k in batch: + video_source = batch[k] # Use first available for shape inspection + break + + if video_source is None: + raise ValueError("No image data found in batch for video construction.") + + videos_per_sample = [] + for b in range(batch_size): + sample_views = [] + for k in image_keys: + t = batch[k][b] # [C, H, W] or [T, C, H, W] + if t.ndim == 3: + t = t.unsqueeze(0) # [1, C, H, W] + # Convert to [T, H, W, 3] numpy + t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy() + # Clamp to [0, 255] + if t_np.max() <= 1.0: + t_np = t_np * 255.0 + t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8) + sample_views.append(t_np) + # Stack views: [V, T, H, W, 3] + videos_per_sample.append(np.stack(sample_views, axis=0)) + + # ---- Collect instructions ---- + tasks = batch.get("task") + if tasks is None: + instructions = ["Execute the robot action."] * batch_size + elif isinstance(tasks, str): + instructions = [tasks] * batch_size + else: + instructions = list(tasks) + + # ---- Collect actions (training only) ---- + actions_list = None + action_is_pad_list = None + actions_tensor = batch.get(ACTION) + if actions_tensor is not None: + if actions_tensor.ndim == 2: + actions_tensor = actions_tensor.unsqueeze(1) + actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)] + action_is_pad_tensor = batch.get("action_is_pad") + if action_is_pad_tensor is not None: + action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)] + + # ---- Collect state ---- + state_list = None + state_tensor = batch.get(OBS_STATE) + if state_tensor is not None: + if state_tensor.ndim > 2: + state_tensor = state_tensor[:, -1, :] + if state_tensor.ndim == 2: + state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim] + state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)] + + # ---- Assemble native examples ---- + examples = [] + for b in range(batch_size): + example = { + "image": images_per_sample[b], + "video": videos_per_sample[b], + "lang": instructions[b], + } + if actions_list is not None: + example["action"] = actions_list[b] + if action_is_pad_list is not None: + example["action_is_pad"] = action_is_pad_list[b] + if state_list is not None: + example["state"] = state_list[b] + examples.append(example) + + return examples + + # ---- LeRobot Policy Interface ---- + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: + """LeRobot train forward: convert → native forward → aggregate losses.""" + examples = self._prepare_model_inputs(batch) + native_output = self.model.forward(examples) + + ref = next(iter(native_output.values())) + zero = torch.zeros((), device=ref.device, dtype=ref.dtype) + total_loss = native_output.get("action_loss", zero) + native_output.get("wm_loss", zero) + logs = {k: v.detach().item() for k, v in native_output.items()} + logs["loss"] = total_loss.detach().item() + return total_loss, logs + + def get_optim_params(self) -> dict: + return self.model.parameters() + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + """LeRobot inference: convert → native predict → return as Tensor.""" + self.eval() + self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) + + examples = self._prepare_model_inputs(batch) + batch_images = [ex["image"] for ex in examples] + instructions = [ex["lang"] for ex in examples] + + state_np = None + if "state" in examples[0] and examples[0]["state"] is not None: + state_np = np.stack([ex["state"] for ex in examples]) + + actions_np = self.model.predict_action(batch_images, instructions, state_np) + return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32) + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + """LeRobot select_action with action queue caching.""" + self.eval() + self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) + if len(self._queues[ACTION]) == 0: + actions = self.predict_action_chunk(batch) + self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps]) + return self._queues[ACTION].popleft() + + @classmethod + def from_pretrained( + cls: type[T], + pretrained_name_or_path: str | Path, + **kwargs, + ): + return super().from_pretrained(pretrained_name_or_path, **kwargs) + + @classmethod + def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: + reinit_prefixes = model.config.reinit_modules + if not reinit_prefixes: + return super()._load_as_safetensor(model, model_file, map_location, strict) + + from safetensors.torch import load_file + + state_dict = load_file(model_file, device=map_location) + current = model.state_dict() + + reinitialized: list[str] = [] + filtered: dict = {} + for key, value in state_dict.items(): + if key in current and value.shape != current[key].shape: + if not any(key.startswith(p) for p in reinit_prefixes): + raise ValueError( + f"Shape mismatch for '{key}' (checkpoint {tuple(value.shape)} vs model " + f"{tuple(current[key].shape)}) and its prefix is not in `reinit_modules`." + ) + reinitialized.append( + f"{key}: checkpoint {tuple(value.shape)} → model {tuple(current[key].shape)}" + ) + else: + filtered[key] = value + + if reinitialized: + logging.warning( + f"reinit_modules: skipping {len(reinitialized)} tensor(s) with mismatched shapes " + f"(randomly re-initialised):\n " + "\n ".join(reinitialized) + ) + + from lerobot.policies.utils import log_model_loading_keys + + missing_keys, unexpected_keys = model.load_state_dict(filtered, strict=False) + log_model_loading_keys(missing_keys, unexpected_keys) + return model diff --git a/src/lerobot/policies/vla_jepa/processor_vla_jepa.py b/src/lerobot/policies/vla_jepa/processor_vla_jepa.py new file mode 100644 index 000000000..b59cc0e90 --- /dev/null +++ b/src/lerobot/policies/vla_jepa/processor_vla_jepa.py @@ -0,0 +1,155 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + +import torch + +from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + EnvTransition, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStep, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + TransitionKey, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME + + +@ProcessorStepRegistry.register(name="vla_jepa_clip_actions") +class ClipActionsProcessorStep(ProcessorStep): + """Clips action tensor to [-1, 1] before unnormalization.""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + action = transition.get(TransitionKey.ACTION) + if action is not None: + transition = dict(transition) + transition[TransitionKey.ACTION] = action.clamp(-1.0, 1.0) + return transition + + def transform_features(self, features): + return features + + +@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper") +class PreSnapGripperProcessorStep(ProcessorStep): + """Snaps a gripper dimension to {0, 1} BEFORE unnormalization. + + Mirrors the original starVLA LIBERO eval: + normalized[:, gripper_dim] = np.where(normalized[:, gripper_dim] < threshold, 0, 1) + This ensures the unnormalizer receives an exact binary value, which is + required when the model was trained with gripper in identity (mask=False) + space where 0=open and 1=close. + """ + + def __init__(self, gripper_dim: int = 6, threshold: float = 0.5): + self.gripper_dim = gripper_dim + self.threshold = threshold + + def __call__(self, transition: EnvTransition) -> EnvTransition: + action = transition.get(TransitionKey.ACTION) + if action is not None and action.shape[-1] > self.gripper_dim: + transition = dict(transition) + a = action.clone() + a[..., self.gripper_dim] = (a[..., self.gripper_dim] >= self.threshold).float() + transition[TransitionKey.ACTION] = a + return transition + + def transform_features(self, features): + return features + + +@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper") +class BinarizeGripperProcessorStep(ProcessorStep): + """Binarizes a gripper dimension after unnormalization. + + Maps continuous value to {-1, 1}: > threshold → -1, <= threshold → 1 (matches starVLA convention). + Only applied when action has more dimensions than gripper_dim. + """ + + def __init__(self, gripper_dim: int = 6, threshold: float = 0.5): + self.gripper_dim = gripper_dim + self.threshold = threshold + + def __call__(self, transition: EnvTransition) -> EnvTransition: + action = transition.get(TransitionKey.ACTION) + if action is not None and action.shape[-1] > self.gripper_dim: + transition = dict(transition) + a = action.clone() + a[..., self.gripper_dim] = 1.0 - 2.0 * (a[..., self.gripper_dim] > self.threshold).float() + transition[TransitionKey.ACTION] = a + return transition + + def transform_features(self, features): + return features + + +def make_vla_jepa_pre_post_processors( + config: VLAJEPAConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + features = {**config.input_features, **config.output_features} + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features=features, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + ] + output_steps: list[ProcessorStep] = [] + if config.clip_normalized_actions: + output_steps.append(ClipActionsProcessorStep()) + if config.pre_snap_gripper_action: + output_steps.append( + PreSnapGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold) + ) + output_steps.append( + UnnormalizerProcessorStep( + features=features, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ) + ) + if config.binarize_gripper_action: + output_steps.append( + BinarizeGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold) + ) + output_steps.append(DeviceProcessorStep(device="cpu")) + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/vla_jepa/qwen_interface.py b/src/lerobot/policies/vla_jepa/qwen_interface.py new file mode 100644 index 000000000..24f530efc --- /dev/null +++ b/src/lerobot/policies/vla_jepa/qwen_interface.py @@ -0,0 +1,117 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +import numpy as np +import torch +from PIL import Image + +from lerobot.utils.import_utils import _transformers_available + +if TYPE_CHECKING or _transformers_available: + from transformers import AutoProcessor, Qwen3VLForConditionalGeneration +else: + AutoProcessor = None + Qwen3VLForConditionalGeneration = None + +from .configuration_vla_jepa import VLAJEPAConfig + + +class Qwen3VLInterface(torch.nn.Module): + def __init__(self, config: VLAJEPAConfig) -> None: + super().__init__() + self.config = config + self.model = Qwen3VLForConditionalGeneration.from_pretrained( + config.qwen_model_name, + torch_dtype=self._get_torch_dtype(config.torch_dtype), + ) + self.processor = AutoProcessor.from_pretrained(config.qwen_model_name) + self.processor.tokenizer.padding_side = config.tokenizer_padding_side + self.model.config.hidden_size = self.model.config.text_config.hidden_size + + @staticmethod + def _get_torch_dtype(dtype_name: str) -> torch.dtype: + if dtype_name == "float32": + return torch.float32 + if dtype_name == "float16": + return torch.float16 + return torch.bfloat16 + + def expand_tokenizer(self) -> tuple[list[str], list[int], int]: + # starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4, + # independent of vj2 num_action_tokens_per_timestep. Keeping this count + # is required for Qwen embedding/lm_head checkpoint shapes to match. + max_action_tokens = self.config.chunk_size * 4 + tokenizer = self.processor.tokenizer + action_tokens = [] + action_token_ids = [] + for idx in range(max_action_tokens): + token = self.config.special_action_token.format(idx) + action_tokens.append(token) + if token not in tokenizer.get_vocab(): + tokenizer.add_tokens([token], special_tokens=True) + action_token_ids.append(tokenizer.convert_tokens_to_ids(token)) + + embodied_action_token = self.config.embodied_action_token + if embodied_action_token not in tokenizer.get_vocab(): + tokenizer.add_tokens([embodied_action_token], special_tokens=True) + embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token) + + if self.model.get_input_embeddings().weight.size(0) < len(tokenizer): + self.model.resize_token_embeddings(len(tokenizer)) + return action_tokens, action_token_ids, embodied_action_token_id + + def build_inputs( + self, + images: Sequence[Sequence[Image.Image]], + instructions: Sequence[str], + action_prompt: str, + embodied_prompt: str, + ) -> dict[str, torch.Tensor]: + messages = [] + for sample_images, instruction in zip(images, instructions, strict=True): + prompt = self.config.prompt_template.format( + instruction=instruction, + actions=action_prompt, + e_actions=embodied_prompt, + ) + content = [{"type": "image", "image": img} for img in sample_images] + content.append({"type": "text", "text": prompt}) + messages.append([{"role": "user", "content": content}]) + + batch_inputs = self.processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + processor_kwargs={"padding": True, "return_tensors": "pt"}, + ) + return batch_inputs.to(self.model.device) + + @staticmethod + def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image: + image = image_tensor.detach().cpu() + if image.ndim == 3 and image.shape[0] in (1, 3): + image = image.permute(1, 2, 0) + image = image.float() + if image.max() <= 1.0: + image = image * 255.0 + image = image.clamp(0, 255).round().to(torch.uint8).numpy() + if image.shape[-1] == 1: + image = np.repeat(image, 3, axis=-1) + return Image.fromarray(image) diff --git a/src/lerobot/policies/vla_jepa/world_model.py b/src/lerobot/policies/vla_jepa/world_model.py new file mode 100644 index 000000000..87f78448c --- /dev/null +++ b/src/lerobot/policies/vla_jepa/world_model.py @@ -0,0 +1,418 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn.functional as F # noqa: N812 +from torch import nn + + +def build_action_block_causal_attention_mask( + num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1 +) -> torch.Tensor: + tokens_per_frame = add_tokens + grid_height * grid_width + num_tokens = num_frames * tokens_per_frame + mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool) + mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool) + local_window_time = num_frames + + for current_frame in range(num_frames): + first_context_frame = max(0, current_frame - local_window_time + 1) + for context_frame in range(first_context_frame, current_frame + 1): + row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame) + col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame) + mask[row, col] = mask_block + return mask + + +def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: + _, _, _, dim = x.size() + if dim % 2 != 0: + raise ValueError("Embedding dimension must be even for rotary position encoding.") + + omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device) + omega /= dim / 2.0 + omega = 1.0 / 10000**omega + freqs = torch.einsum("..., f -> ... f", pos, omega) + emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2) + emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2) + + y = x.unflatten(-1, (-1, 2)) + y1, y2 = y.unbind(dim=-1) + y = torch.stack((-y2, y1), dim=-1).flatten(-2) + return x * emb_cos + y * emb_sin + + +class DropPath(nn.Module): + def __init__(self, drop_prob: float = 0.0) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() + return x.div(keep_prob) * random_tensor + + +class MLP(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: int | None = None, + out_features: int | None = None, + act_layer: type[nn.Module] = nn.GELU, + drop: float = 0.0, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class ACRoPEAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: float | None = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + use_sdpa: bool = True, + is_causal: bool = False, + grid_size: int = 16, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = qk_scale or self.head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop_prob = proj_drop + self.proj_drop = nn.Dropout(proj_drop) + self.use_sdpa = use_sdpa + self.d_dim = int(2 * ((self.head_dim // 3) // 2)) + self.h_dim = int(2 * ((self.head_dim // 3) // 2)) + self.w_dim = int(2 * ((self.head_dim // 3) // 2)) + self.grid_size = grid_size + self.is_causal = is_causal + + @staticmethod + def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor: + return ids // int(height * width) + + def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor: + frame_ids = self._get_frame_pos(ids, height, width) + ids = ids - int(height * width) * frame_ids + return ids // width + + def separate_positions( + self, ids: torch.Tensor, height: int, width: int + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + frame_ids = self._get_frame_pos(ids, height, width) + height_ids = self._get_height_pos(ids, height, width) + width_ids = ids - int(height * width) * frame_ids - width * height_ids + return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor | None = None, + attn_mask: torch.Tensor | None = None, + num_frames: int | None = None, + grid_height: int | None = None, + grid_width: int | None = None, + action_tokens: int = 0, + ) -> torch.Tensor: + batch_size, num_tokens, channels = x.size() + if num_frames is None or grid_height is None or grid_width is None: + raise ValueError("num_frames, grid_height and grid_width are required.") + + if mask is not None: + mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1) + d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width) + else: + mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device) + d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width) + + h_mask *= self.grid_size / grid_height + w_mask *= self.grid_size / grid_width + + if action_tokens > 0: + x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels) + action_q, action_k, action_v = [], [], [] + for idx in range(action_tokens): + action_token = x[:, :, idx : idx + 1, :].flatten(1, 2) + qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + qd = rotate_queries_or_keys( + q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device) + ) + kd = rotate_queries_or_keys( + k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device) + ) + qr = q[..., self.d_dim :] + kr = k[..., self.d_dim :] + action_q.append( + torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1) + ) + action_k.append( + torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1) + ) + action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1)) + + action_q = torch.cat(action_q, dim=3).flatten(2, 3) + action_k = torch.cat(action_k, dim=3).flatten(2, 3) + action_v = torch.cat(action_v, dim=3).flatten(2, 3) + x = x[:, :, action_tokens:, :].flatten(1, 2) + + qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + offset = 0 + qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask) + kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask) + offset += self.d_dim + qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask) + kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask) + offset += self.h_dim + qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask) + kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask) + offset += self.w_dim + + if offset < self.head_dim: + q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1) + k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1) + else: + q = torch.cat([qd, qh, qw], dim=-1) + k = torch.cat([kd, kh, kw], dim=-1) + + if action_tokens > 0: + + def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor: + frame_tokens = frame_tokens.view( + batch_size, self.num_heads, num_frames, grid_height * grid_width, -1 + ) + action_token_values = action_token_values.view( + batch_size, self.num_heads, num_frames, action_tokens, -1 + ) + return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3) + + q = merge(q, action_q) + k = merge(k, action_k) + v = merge(v, action_v) + + if attn_mask is not None or self.use_sdpa: + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask + ) + else: + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels) + x = self.proj(x) + return self.proj_drop(x) + + +class ACBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: float | None = None, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + norm_layer: type[nn.Module] = nn.LayerNorm, + use_sdpa: bool = True, + is_causal: bool = False, + grid_size: int = 16, + use_rope: bool = True, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + if not use_rope: + raise ValueError("JEVLA1 world predictor uses AC RoPE attention.") + self.attn = ACRoPEAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + use_sdpa=use_sdpa, + is_causal=is_causal, + grid_size=grid_size, + proj_drop=drop, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = MLP( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=nn.GELU, + drop=drop, + ) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor | None = None, + num_frames: int | None = None, + grid_height: int | None = None, + grid_width: int | None = None, + action_tokens: int = 0, + ) -> torch.Tensor: + y = self.norm1(x) + y = self.attn( + y, + mask=None, + attn_mask=attn_mask, + num_frames=num_frames, + grid_height=grid_height, + grid_width=grid_width, + action_tokens=action_tokens, + ) + x = x + self.drop_path(y) + y = self.norm2(x) + return x + self.drop_path(self.mlp(y)) + + +class ActionConditionedVideoPredictor(nn.Module): + """JEVLA1-compatible action-conditioned V-JEPA predictor.""" + + def __init__( + self, + num_frames: int, + img_size: tuple[int, int], + patch_size: int, + tubelet_size: int, + embed_dim: int, + action_embed_dim: int, + predictor_embed_dim: int, + depth: int, + num_heads: int, + mlp_ratio: float, + num_action_tokens_per_step: int, + use_extrinsics: bool = False, + ) -> None: + super().__init__() + self.is_frame_causal = True + self.use_extrinsics = use_extrinsics + self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True) + self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True) + self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True) + self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True) + + self.img_height, self.img_width = img_size + self.patch_size = patch_size + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.grid_height = self.img_height // self.patch_size + self.grid_width = self.img_width // self.patch_size + + self.predictor_blocks = nn.ModuleList( + [ + ACBlock( + dim=predictor_embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6), + grid_size=self.grid_height, + use_rope=True, + ) + for _ in range(depth) + ] + ) + self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6) + self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True) + self.num_action_tokens_per_step = num_action_tokens_per_step + + @property + def norm(self) -> nn.LayerNorm: + return self.predictor_norm + + @property + def proj(self) -> nn.Linear: + return self.predictor_proj + + def forward( + self, + frame_tokens: torch.Tensor, + action_tokens: torch.Tensor, + extrinsics: torch.Tensor | None = None, + ) -> torch.Tensor: + # starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D]. + x = self.predictor_embed(frame_tokens) + batch_size, num_context_tokens, hidden_dim = x.size() + num_frames = num_context_tokens // (self.grid_height * self.grid_width) + + actions = self.action_encoder(action_tokens) + actions = actions.view(batch_size, num_frames, -1, hidden_dim) + cond_tokens = actions.shape[2] + + x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim) + if self.use_extrinsics: + if extrinsics is None: + raise ValueError("extrinsics are required when use_extrinsics=True.") + cond_tokens += 1 + extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2) + x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2) + else: + x = torch.cat([actions, x], dim=2).flatten(1, 2) + + attn_mask = build_action_block_causal_attention_mask( + num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens + ) + attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True) + + for block in self.predictor_blocks: + x = block( + x, + attn_mask=attn_mask, + num_frames=num_frames, + grid_height=self.grid_height, + grid_width=self.grid_width, + action_tokens=cond_tokens, + ) + + x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim) + x = x[:, :, cond_tokens:, :].flatten(1, 2) + x = self.predictor_norm(x) + return self.predictor_proj(x) diff --git a/tests/policies/vla_jepa/conftest.py b/tests/policies/vla_jepa/conftest.py new file mode 100644 index 000000000..5301b5bc7 --- /dev/null +++ b/tests/policies/vla_jepa/conftest.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python +"""Shared fixtures and helpers for VLA-JEPA tests.""" + +from __future__ import annotations + +from types import SimpleNamespace + +import numpy as np +import pytest +import torch +from PIL import Image +from torch import Tensor, nn + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE + +# --------------------------------------------------------------------------- +# Shared constants +# --------------------------------------------------------------------------- + +BATCH_SIZE = 2 +ACTION_DIM = 3 +STATE_DIM = 4 +IMAGE_SIZE = 8 +ACTION_HORIZON = 4 +N_ACTION_STEPS = 2 +NUM_VIDEO_FRAMES = 3 +QWEN_HIDDEN_SIZE = 16 # hidden size produced by _FakeQwenBackbone + +EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM) +EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def set_seed_all(seed: int) -> None: + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def make_config( + action_dim: int = ACTION_DIM, + state_dim: int = STATE_DIM, + action_horizon: int = ACTION_HORIZON, + num_video_frames: int = NUM_VIDEO_FRAMES, +) -> VLAJEPAConfig: + config = VLAJEPAConfig( + input_features={ + f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)), + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)), + }, + output_features={ + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)), + }, + device="cpu", + chunk_size=action_horizon, + n_action_steps=min(N_ACTION_STEPS, action_horizon), + action_dim=action_dim, + state_dim=state_dim, + num_video_frames=num_video_frames, + num_action_tokens_per_timestep=2, + num_embodied_action_tokens_per_instruction=3, + num_inference_timesteps=2, + action_hidden_size=QWEN_HIDDEN_SIZE, + action_model_type="DiT-test", + action_num_layers=1, + predictor_depth=1, + predictor_num_heads=2, + predictor_mlp_ratio=2.0, + jepa_tubelet_size=1, + ) + config.validate_features() + return config + + +def make_train_batch( + batch_size: int = BATCH_SIZE, + action_dim: int = ACTION_DIM, + state_dim: int = STATE_DIM, + action_horizon: int = ACTION_HORIZON, + num_video_frames: int = NUM_VIDEO_FRAMES, +) -> dict[str, Tensor | list[str]]: + return { + f"{OBS_IMAGES}.laptop": torch.rand(batch_size, num_video_frames, 3, IMAGE_SIZE, IMAGE_SIZE), + OBS_STATE: torch.randn(batch_size, 1, state_dim), + ACTION: torch.randn(batch_size, action_horizon, action_dim), + "task": ["pick up the cube"] * batch_size, + } + + +def make_inference_batch( + batch_size: int = BATCH_SIZE, + state_dim: int = STATE_DIM, +) -> dict[str, Tensor | list[str]]: + return { + f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE), + OBS_STATE: torch.randn(batch_size, state_dim), + "task": ["pick up the cube"] * batch_size, + } + + +# --------------------------------------------------------------------------- +# Fake external models (replace Qwen3-VL and V-JEPA at test time) +# --------------------------------------------------------------------------- + + +class _FakeLanguageLayer(nn.Module): + """Leaf module whose forward hook is captured by _qwen_last_decoder_hidden.""" + + def __init__(self, hidden_size: int) -> None: + super().__init__() + self._hidden_size = hidden_size + + def forward(self, hidden: Tensor, **_: object) -> tuple[Tensor, ...]: + return (hidden,) + + +class _FakeLanguageModel(nn.Module): + def __init__(self, hidden_size: int) -> None: + super().__init__() + self._hidden_size = hidden_size + self.layers = nn.ModuleList([_FakeLanguageLayer(hidden_size)]) + + def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace: + batch_size, seq_len = input_ids.shape + hidden = torch.zeros(batch_size, seq_len, self._hidden_size, device=input_ids.device) + self.layers[-1](hidden) + return SimpleNamespace() + + +class _FakeQwenInnerModel(nn.Module): + """Mimics the `.model.model` level that _qwen_last_decoder_hidden walks into.""" + + def __init__(self, hidden_size: int) -> None: + super().__init__() + self.language_model = _FakeLanguageModel(hidden_size) + + def forward(self, input_ids: Tensor, **kwargs: object) -> SimpleNamespace: + return self.language_model(input_ids) + + +class _FakeQwenBackbone(nn.Module): + def __init__(self, hidden_size: int) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(1)) + self.config = SimpleNamespace( + hidden_size=hidden_size, + text_config=SimpleNamespace(hidden_size=hidden_size), + ) + self.model = _FakeQwenInnerModel(hidden_size) + + @property + def device(self) -> torch.device: + return self.weight.device + + def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace: + batch_size, seq_len = input_ids.shape + hidden_size = self.config.hidden_size + values = torch.arange( + batch_size * seq_len * hidden_size, + device=input_ids.device, + dtype=torch.float32, + ).view(batch_size, seq_len, hidden_size) + hidden = values / values.numel() + self.weight + self.model(input_ids) # call through so the forward hook on layers[-1] fires + return SimpleNamespace(hidden_states=[hidden]) + + +class _FakeQwenInterface(nn.Module): + def __init__(self, config: VLAJEPAConfig) -> None: + super().__init__() + self.config = config + self.model = _FakeQwenBackbone(hidden_size=QWEN_HIDDEN_SIZE) + + @staticmethod + def _get_torch_dtype(dtype_name: str) -> torch.dtype: + return torch.float32 if dtype_name == "float32" else torch.bfloat16 + + def expand_tokenizer(self) -> tuple[list[str], list[int], int]: + max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep + action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)] + action_token_ids = list(range(1000, 1000 + max_action_tokens)) + return action_tokens, action_token_ids, 2000 + + def build_inputs( + self, + images: list[list[Image.Image]], + instructions: list[str], + action_prompt: str, + embodied_prompt: str, + ) -> dict[str, Tensor]: + batch_size = len(images) + del images, instructions, action_prompt, embodied_prompt + action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep + token_ids = ( + [10] + + list(range(1000, 1000 + action_count)) + + [2000] * self.config.num_embodied_action_tokens_per_instruction + + [11] + ) + return { + "input_ids": torch.tensor( + [token_ids] * batch_size, + device=self.model.device, + dtype=torch.long, + ) + } + + @staticmethod + def tensor_to_pil(image_tensor: Tensor) -> Image.Image: + image = image_tensor.detach().cpu() + if image.ndim == 3 and image.shape[0] in (1, 3): + image = image.permute(1, 2, 0) + image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy() + return Image.fromarray(image) + + +class _FakeVideoEncoder(nn.Module): + def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(1)) + # image_size must be >= patch_size (16) so the predictor grid is non-zero. + # Setting image_size=16 gives a 1x1 grid (1 patch per frame). + self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size, image_size=16) + + @property + def device(self) -> torch.device: + return self.weight.device + + def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor: + batch_size, num_frames = pixel_values_videos.shape[:2] + hidden_size = self.config.hidden_size + frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False) + return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size) + + +class _FakeVideoProcessor: + def __call__(self, videos, return_tensors: str) -> dict[str, Tensor]: + assert return_tensors == "pt" + if isinstance(videos, list): + pixel_values = torch.stack([torch.as_tensor(v) for v in videos]) + else: + pixel_values = torch.as_tensor(videos).unsqueeze(0) + return {"pixel_values_videos": pixel_values} + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None: + from lerobot.policies.vla_jepa import modeling_vla_jepa + + monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface) + monkeypatch.setattr( + modeling_vla_jepa.AutoModel, + "from_pretrained", + lambda *args, **kwargs: _FakeVideoEncoder(), + ) + monkeypatch.setattr( + modeling_vla_jepa.AutoVideoProcessor, + "from_pretrained", + lambda *args, **kwargs: _FakeVideoProcessor(), + ) diff --git a/tests/policies/vla_jepa/test_action_head.py b/tests/policies/vla_jepa/test_action_head.py new file mode 100644 index 000000000..5acff6371 --- /dev/null +++ b/tests/policies/vla_jepa/test_action_head.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python + +from __future__ import annotations + +import pytest +import torch + +pytest.importorskip("diffusers") + +from conftest import ( + ACTION_DIM, + ACTION_HORIZON, + BATCH_SIZE, + QWEN_HIDDEN_SIZE, + STATE_DIM, + make_config, + set_seed_all, +) # noqa: E402 + +from lerobot.policies.vla_jepa.action_head import ( # noqa: E402 + VLAJEPAActionHead, +) + +# --------------------------------------------------------------------------- +# VLAJEPAActionHead +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "action_dim,state_dim,action_horizon", + [ + (3, 4, 4), # default test dims + (7, 0, 16), # no proprioceptive state, production-like action space + (6, 8, 8), # medium dims + ], +) +def test_action_head_sample_time_range(action_dim: int, state_dim: int, action_horizon: int) -> None: + config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon) + head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE) + t = head.sample_time(batch_size=200, device=torch.device("cpu"), dtype=torch.float32) + assert t.shape == (200,) + assert torch.isfinite(t).all() + + +@pytest.mark.parametrize( + "action_dim,state_dim,action_horizon", + [ + (3, 4, 4), + (7, 0, 16), + (6, 8, 8), + ], +) +def test_action_head_build_inputs_shape(action_dim: int, state_dim: int, action_horizon: int) -> None: + config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon) + head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE) + conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE) + actions = torch.randn(2, action_horizon, action_dim) + timesteps = torch.randint(0, 100, (2,)) + + state = torch.randn(2, state_dim) if state_dim > 0 else None + out_with = head._build_inputs(conditioning, actions, state, timesteps) + out_none = head._build_inputs(conditioning, actions, None, timesteps) + + assert out_with.ndim == 3 and out_none.ndim == 3 + if state_dim > 0: + assert out_with.shape[1] > out_none.shape[1] + assert torch.isfinite(out_with).all() and torch.isfinite(out_none).all() + + +@pytest.mark.parametrize( + "action_dim,state_dim,action_horizon", + [ + (3, 4, 4), + (7, 0, 16), + (6, 8, 8), + ], +) +def test_action_head_forward_loss_valid(action_dim: int, state_dim: int, action_horizon: int) -> None: + set_seed_all(42) + config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon) + head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE) + conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE) + actions = torch.randn(2, action_horizon, action_dim) + state = torch.randn(2, state_dim) if state_dim > 0 else None + loss = head.forward(conditioning, actions, state) + assert loss.shape == () + assert torch.isfinite(loss) and loss > 0 + + +def test_action_head_forward_gradient_flows() -> None: + set_seed_all(42) + config = make_config() + head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE) + conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE) + actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM) + state = torch.randn(BATCH_SIZE, STATE_DIM) + loss = head.forward(conditioning, actions, state) + loss.backward() + assert any(p.grad is not None for p in head.parameters() if p.requires_grad) + + +@torch.no_grad() +@pytest.mark.parametrize( + "action_dim,state_dim,action_horizon", + [ + (3, 4, 4), + (7, 0, 16), + (6, 8, 8), + ], +) +def test_action_head_predict_action_shape(action_dim: int, state_dim: int, action_horizon: int) -> None: + set_seed_all(42) + config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon) + head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE) + conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE) + state = torch.randn(2, state_dim) if state_dim > 0 else None + pred = head.predict_action(conditioning, state) + assert tuple(pred.shape) == (2, action_horizon, action_dim) + assert torch.isfinite(pred).all() + + +# --------------------------------------------------------------------------- +# action_is_pad masking +# --------------------------------------------------------------------------- + + +def test_action_head_loss_fully_padded_is_zero() -> None: + """Loss is 0 when every timestep is padded (exercises the clamp_min guard).""" + set_seed_all(42) + config = make_config() + head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE) + conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE) + actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM) + state = torch.randn(BATCH_SIZE, STATE_DIM) + + action_is_pad = torch.ones(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool) + loss = head.forward(conditioning, actions, state, action_is_pad) + assert loss.item() == 0.0 + + +def test_action_head_loss_none_matches_no_padding() -> None: + """action_is_pad=None is equivalent to an all-False (no padding) mask.""" + set_seed_all(42) + config = make_config() + head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE) + conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE) + actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM) + state = torch.randn(BATCH_SIZE, STATE_DIM) + + set_seed_all(0) + loss_none = head.forward(conditioning, actions, state, action_is_pad=None) + + set_seed_all(0) + no_pad = torch.zeros(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool) + loss_zeros = head.forward(conditioning, actions, state, action_is_pad=no_pad) + + assert torch.isclose(loss_none, loss_zeros) diff --git a/tests/policies/vla_jepa/test_configuration.py b/tests/policies/vla_jepa/test_configuration.py new file mode 100644 index 000000000..2eda08ad3 --- /dev/null +++ b/tests/policies/vla_jepa/test_configuration.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python + +from __future__ import annotations + +import pytest +from conftest import ACTION_DIM, ACTION_HORIZON, IMAGE_SIZE, NUM_VIDEO_FRAMES, STATE_DIM, make_config + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE + + +def test_delta_indices() -> None: + config = make_config() + assert config.observation_delta_indices == list(range(NUM_VIDEO_FRAMES)) + assert config.action_delta_indices == list(range(ACTION_HORIZON)) + + +def test_n_action_steps_exceeds_chunk_size_raises() -> None: + with pytest.raises(ValueError, match="n_action_steps"): + VLAJEPAConfig(chunk_size=4, n_action_steps=8) + + +def test_too_few_video_frames_raises() -> None: + with pytest.raises(ValueError, match="video_horizon"): + VLAJEPAConfig( + chunk_size=16, + n_action_steps=16, + num_video_frames=2, + jepa_tubelet_size=2, # needs >= 4 frames (2 for current, 2 for future) to have a window of size > 0 + ) + + +def test_validate_features_no_image_raises() -> None: + config = VLAJEPAConfig( + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}, + ) + with pytest.raises(ValueError, match="at least one visual input feature"): + config.validate_features() + + +def test_validate_features_no_action_raises() -> None: + config = VLAJEPAConfig( + input_features={ + f"{OBS_IMAGES}.cam": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)), + }, + output_features={}, + ) + with pytest.raises(ValueError, match="action output feature"): + config.validate_features() + + +def test_validate_features_sets_action_dim_from_feature() -> None: + config = make_config(action_dim=6, state_dim=10) + assert config.action_dim == 6 + assert config.state_dim == 10 diff --git a/tests/policies/vla_jepa/test_vla_jepa.py b/tests/policies/vla_jepa/test_vla_jepa.py new file mode 100644 index 000000000..70194dd59 --- /dev/null +++ b/tests/policies/vla_jepa/test_vla_jepa.py @@ -0,0 +1,598 @@ +#!/usr/bin/env python + +from __future__ import annotations + +import os +from copy import deepcopy + +import numpy as np +import pytest +import torch +from torch import Tensor + +pytest.importorskip("transformers") +pytest.importorskip("diffusers") + +pytestmark = pytest.mark.filterwarnings( + "ignore:In CPU autocast, but the target dtype is not supported:UserWarning" +) + +from conftest import ( # noqa: E402 + ACTION_DIM, + ACTION_HORIZON, + BATCH_SIZE, + EXPECTED_ACTION_CHUNK_SHAPE, + EXPECTED_SELECT_ACTION_SHAPE, + IMAGE_SIZE, + N_ACTION_STEPS, + QWEN_HIDDEN_SIZE, + STATE_DIM, + make_config, + make_inference_batch, + make_train_batch, + set_seed_all, +) + +from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig # noqa: E402 +from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy # noqa: E402 +from lerobot.utils.constants import ACTION # noqa: E402 + +PRETRAINED_REPO_ID = "ginwind/VLA-JEPA" +PRETRAINED_SUBFOLDER = "LIBERO" + +# extended hub tests load the full converted safetensors checkpoints (~5 GB) and are +# skipped by default. Set VLA_JEPA_EXTENDED=1 to opt in. +_VLA_JEPA_EXTENDED = os.environ.get("VLA_JEPA_EXTENDED", "0") != "0" +extended_test = pytest.mark.skipif(not _VLA_JEPA_EXTENDED, reason="Set VLA_JEPA_EXTENDED=1 to run hub tests") + + +# --------------------------------------------------------------------------- +# Core training / inference tests +# --------------------------------------------------------------------------- + + +def test_training_forward_pass(patch_vla_jepa_external_models: None) -> None: + set_seed_all(42) + policy = VLAJEPAPolicy(make_config()) + policy.train() + + batch = make_train_batch() + batch_before = deepcopy(batch) + + loss, logs = policy.forward(batch) + + assert loss.shape == () + assert torch.isfinite(loss) + assert set(logs) == {"action_loss", "wm_loss", "loss"} + assert logs["action_loss"] > 0 + assert logs["wm_loss"] >= 0 + + loss.backward() + assert any(p.grad is not None for p in policy.model.action_model.parameters() if p.requires_grad) + # Batch must not be mutated. + assert set(batch) == set(batch_before) + for key, value in batch.items(): + if isinstance(value, Tensor): + assert torch.equal(value, batch_before[key]) + else: + assert value == batch_before[key] + + +@pytest.mark.parametrize("batch_size", [1, 2, 4]) +def test_training_forward_various_batch_sizes(patch_vla_jepa_external_models: None, batch_size: int) -> None: + set_seed_all(42) + policy = VLAJEPAPolicy(make_config()) + policy.train() + loss, logs = policy.forward(make_train_batch(batch_size=batch_size)) + assert torch.isfinite(loss) and loss > 0 + assert set(logs) == {"action_loss", "wm_loss", "loss"} + + +@pytest.mark.parametrize( + "action_dim,state_dim,action_horizon", + [ + (3, 4, 4), + (7, 0, 16), + (6, 8, 8), + ], +) +def test_training_forward_various_dims( + patch_vla_jepa_external_models: None, + action_dim: int, + state_dim: int, + action_horizon: int, +) -> None: + set_seed_all(42) + config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon) + policy = VLAJEPAPolicy(config) + policy.train() + batch = make_train_batch(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon) + loss, _ = policy.forward(batch) + assert torch.isfinite(loss) and loss > 0 + + +@torch.no_grad() +def test_action_generation_shape(patch_vla_jepa_external_models: None) -> None: + set_seed_all(42) + policy = VLAJEPAPolicy(make_config()) + policy.eval() + batch = make_inference_batch() + + chunk = policy.predict_action_chunk(batch) + assert tuple(chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE + assert chunk.device.type == "cpu" + assert torch.isfinite(chunk).all() + + a1 = policy.select_action(batch) + a2 = policy.select_action(batch) + assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE + assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE + assert torch.isfinite(a1).all() and torch.isfinite(a2).all() + + +@torch.no_grad() +@pytest.mark.parametrize("action_dim,state_dim", [(3, 4), (7, 0), (6, 8)]) +def test_action_generation_various_dims( + patch_vla_jepa_external_models: None, action_dim: int, state_dim: int +) -> None: + set_seed_all(42) + config = make_config(action_dim=action_dim, state_dim=state_dim) + policy = VLAJEPAPolicy(config) + policy.eval() + batch = make_inference_batch(state_dim=state_dim) + chunk = policy.predict_action_chunk(batch) + assert chunk.shape[-1] == action_dim + assert torch.isfinite(chunk).all() + + +@torch.no_grad() +def test_inference_reproducibility(patch_vla_jepa_external_models: None) -> None: + set_seed_all(42) + policy = VLAJEPAPolicy(make_config()) + policy.eval() + batch = make_inference_batch() + + set_seed_all(123) + actions_1 = policy.predict_action_chunk(batch) + set_seed_all(123) + actions_2 = policy.predict_action_chunk(batch) + + assert tuple(actions_1.shape) == EXPECTED_ACTION_CHUNK_SHAPE + assert torch.allclose(actions_1, actions_2, atol=1e-6) + + +@torch.no_grad() +def test_predict_action_chunk_always_finite(patch_vla_jepa_external_models: None) -> None: + policy = VLAJEPAPolicy(make_config()) + policy.eval() + for seed in [0, 42, 123]: + set_seed_all(seed) + chunk = policy.predict_action_chunk(make_inference_batch()) + assert torch.isfinite(chunk).all(), f"non-finite actions with seed={seed}" + + +# --------------------------------------------------------------------------- +# Action queue behaviour +# --------------------------------------------------------------------------- + + +@torch.no_grad() +def test_select_action_queue_drains_before_refill(patch_vla_jepa_external_models: None) -> None: + set_seed_all(42) + policy = VLAJEPAPolicy(make_config()) + policy.eval() + batch = make_inference_batch() + + # First call fills the queue (n_action_steps items) and pops one. + a1 = policy.select_action(batch) + assert len(policy._queues[ACTION]) == N_ACTION_STEPS - 1 + + # Second call pops from the existing queue without calling predict_action_chunk. + a2 = policy.select_action(batch) + assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE + assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE + + +@torch.no_grad() +def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None: + set_seed_all(42) + policy = VLAJEPAPolicy(make_config()) + policy.eval() + policy.select_action(make_inference_batch()) + assert len(policy._queues[ACTION]) > 0 + + policy.reset() + assert len(policy._queues[ACTION]) == 0 + + +# --------------------------------------------------------------------------- +# Format conversion +# --------------------------------------------------------------------------- + + +def test_prepare_model_inputs_training_format(patch_vla_jepa_external_models: None) -> None: + from PIL import Image + + policy = VLAJEPAPolicy(make_config()) + examples = policy._prepare_model_inputs(make_train_batch()) + + assert len(examples) == BATCH_SIZE + for ex in examples: + assert set(ex) >= {"image", "video", "lang", "action", "state"} + assert len(ex["image"]) == 1 and isinstance(ex["image"][0], Image.Image) + assert ex["video"].ndim == 5 and ex["video"].dtype == np.uint8 # [V,T,H,W,C] + assert ex["action"].shape == (ACTION_HORIZON, ACTION_DIM) + assert ex["state"].shape == (1, STATE_DIM) + + +def test_prepare_model_inputs_inference_omits_action(patch_vla_jepa_external_models: None) -> None: + policy = VLAJEPAPolicy(make_config()) + for ex in policy._prepare_model_inputs(make_inference_batch()): + assert "action" not in ex + assert "image" in ex and "video" in ex and "lang" in ex + + +def test_prepare_model_inputs_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None: + policy = VLAJEPAPolicy(make_config()) + batch = make_inference_batch() + del batch["task"] + examples = policy._prepare_model_inputs(batch) + assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples) + + +def test_prepare_model_inputs_string_task_broadcast(patch_vla_jepa_external_models: None) -> None: + policy = VLAJEPAPolicy(make_config()) + batch = make_inference_batch() + batch["task"] = "open the drawer" + assert all(ex["lang"] == "open the drawer" for ex in policy._prepare_model_inputs(batch)) + + +def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: None) -> None: + from lerobot.utils.constants import OBS_STATE + + policy = VLAJEPAPolicy(make_config()) + batch = make_inference_batch() + del batch[OBS_STATE] + assert all("state" not in ex for ex in policy._prepare_model_inputs(batch)) + + +# --------------------------------------------------------------------------- +# Pretrained checkpoint +# Hub tests (opt-in: VLA_JEPA_EXTENDED=1) +# --------------------------------------------------------------------------- + + +def _make_hub_train_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict: + """Build a training batch whose keys/shapes match a hub-loaded policy config.""" + cfg = policy.config + batch: dict = {"task": ["pick up the cube"] * batch_size} + for key, feat in cfg.image_features.items(): + h, w = feat.shape[-2], feat.shape[-1] + batch[key] = torch.rand(batch_size, cfg.num_video_frames, 3, h, w) + if cfg.robot_state_feature is not None: + batch["observation.state"] = torch.randn(batch_size, 1, cfg.robot_state_feature.shape[0]) + batch[ACTION] = torch.randn(batch_size, cfg.chunk_size, cfg.action_dim) + return batch + + +def _make_hub_inference_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict: + """Build an inference batch whose keys/shapes match a hub-loaded policy config.""" + cfg = policy.config + batch: dict = {"task": ["pick up the cube"] * batch_size} + for key, feat in cfg.image_features.items(): + h, w = feat.shape[-2], feat.shape[-1] + batch[key] = torch.rand(batch_size, 3, h, w) + if cfg.robot_state_feature is not None: + batch["observation.state"] = torch.randn(batch_size, cfg.robot_state_feature.shape[0]) + return batch + + +_CP_ROOT = "lerobot" + +# Each tuple: (repo_id, enable_world_model) +_HUB_VARIANTS = [ + (f"{_CP_ROOT}/VLA-JEPA-LIBERO", True), + (f"{_CP_ROOT}/VLA-JEPA-Pretrain", True), + (f"{_CP_ROOT}/VLA-JEPA-SimplerEnv", False), +] + + +@extended_test +@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS) +def test_hub_checkpoint_loads(repo_id: str, enable_world_model: bool) -> None: + """Policy loads from the converted safetensors checkpoint on the Hub.""" + policy = VLAJEPAPolicy.from_pretrained(repo_id) + assert policy.config.enable_world_model == enable_world_model + assert sum(p.numel() for p in policy.parameters()) > 0 + + +@extended_test +@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS) +def test_hub_checkpoint_forward_pass(repo_id: str, enable_world_model: bool) -> None: + """Policy loaded from hub produces finite losses with a correctly-shaped batch.""" + policy = VLAJEPAPolicy.from_pretrained(repo_id) + policy.train() + + batch = _make_hub_train_batch(policy) + loss, logs = policy.forward(batch) + assert torch.isfinite(loss) + assert "action_loss" in logs + if enable_world_model: + assert "wm_loss" in logs + + +@extended_test +def test_hub_freeze_qwen_disables_world_model() -> None: + """freeze_qwen=True (via cli_overrides) freezes qwen and disables the world model.""" + policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO", cli_overrides=["freeze_qwen=true"]) + assert not policy.config.enable_world_model + assert policy.model.video_predictor is None + qwen_params = list(policy.model.qwen.parameters()) + assert all(not p.requires_grad for p in qwen_params) + assert any(p.requires_grad for p in policy.model.action_model.parameters()) + + +@extended_test +def test_hub_disable_world_model_loads_simpler_env() -> None: + """SimplerEnv checkpoint (world model disabled) loads cleanly and runs inference.""" + policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv") + assert not policy.config.enable_world_model + assert policy.model.video_predictor is None + assert policy.model.video_encoder is None + + +@extended_test +def test_hub_libero_inference_shape() -> None: + """select_action returns the expected shape using the LIBERO hub checkpoint.""" + policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO") + policy.eval() + batch = _make_hub_inference_batch(policy) + action = policy.select_action(batch) + assert action.shape[-1] == policy.config.action_dim + + +# --------------------------------------------------------------------------- +# Postprocessor unnormalization tests +# +# These tests verify that the postprocessor pipeline (clip → unnorm → binarize) +# correctly applies MIN_MAX unnormalization after predict_action_chunk. +# --------------------------------------------------------------------------- + + +def _make_dataset_stats(action_dim: int = ACTION_DIM) -> dict: + """Returns sample dataset_stats with a simple [i, i+10] range per action dim.""" + from lerobot.utils.constants import ACTION + + return { + ACTION: { + "min": torch.tensor([float(i) for i in range(action_dim)], dtype=torch.float32), + "max": torch.tensor([float(i) + 10.0 for i in range(action_dim)], dtype=torch.float32), + } + } + + +@torch.no_grad() +def test_postprocessor_unnormalizes_actions(patch_vla_jepa_external_models: None) -> None: + """UnnormalizerProcessorStep with MIN_MAX produces the correct inverse of MIN_MAX normalization.""" + from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + from lerobot.processor import UnnormalizerProcessorStep + from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + from lerobot.utils.constants import ACTION + + dataset_stats = _make_dataset_stats() + + rng = np.random.default_rng(7) + actions_np = rng.uniform(-1.0, 1.0, (2, ACTION_HORIZON, ACTION_DIM)).astype(np.float32) + + a_min = dataset_stats[ACTION]["min"].numpy() + a_max = dataset_stats[ACTION]["max"].numpy() + expected = (actions_np + 1.0) / 2.0 * (a_max - a_min) + a_min + + features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))} + unnorm_step = UnnormalizerProcessorStep( + features=features, + norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX}, + stats=dataset_stats, + ) + + actions_tensor = torch.from_numpy(actions_np) + transition = policy_action_to_transition(actions_tensor) + result = transition_to_policy_action(unnorm_step(transition)).numpy() + + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6) + + +@torch.no_grad() +def test_postprocessor_clip_clamps_before_unnorm(patch_vla_jepa_external_models: None) -> None: + """ClipActionsProcessorStep clamps to [-1, 1] before unnormalization.""" + from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + from lerobot.policies.vla_jepa.processor_vla_jepa import ClipActionsProcessorStep + from lerobot.processor import UnnormalizerProcessorStep + from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + from lerobot.utils.constants import ACTION + + dataset_stats = _make_dataset_stats() + a_min = dataset_stats[ACTION]["min"].numpy() + a_max = dataset_stats[ACTION]["max"].numpy() + + # Deliberately out-of-range inputs + actions_np = np.array([[[2.0] * ACTION_DIM, [-3.0] * ACTION_DIM]], dtype=np.float32) + clipped = np.clip(actions_np, -1.0, 1.0) + expected = (clipped + 1.0) / 2.0 * (a_max - a_min) + a_min + + features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))} + clip_step = ClipActionsProcessorStep() + unnorm_step = UnnormalizerProcessorStep( + features=features, + norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX}, + stats=dataset_stats, + ) + + transition = policy_action_to_transition(torch.from_numpy(actions_np)) + transition = clip_step(transition) + result = transition_to_policy_action(unnorm_step(transition)).numpy() + + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6) + + +@torch.no_grad() +def test_postprocessor_applied_after_predict_action_chunk( + patch_vla_jepa_external_models: None, monkeypatch: pytest.MonkeyPatch +) -> None: + """predict_action_chunk returns raw actions; the postprocessor applies unnormalization. + + Verifies the split: predict_action_chunk returns normalized actions, and calling the + postprocessor on them produces the correctly unnormalized result. + """ + from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors + + raw_actions = np.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=np.float32) + + cfg = make_config() + cfg.clip_normalized_actions = False + cfg.binarize_gripper_action = False + policy = VLAJEPAPolicy(cfg) + policy.eval() + monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.copy()) + + dataset_stats = _make_dataset_stats() + _, postprocessor = make_vla_jepa_pre_post_processors(cfg, dataset_stats) + + batch = make_inference_batch() + chunk = policy.predict_action_chunk(batch) + + # predict_action_chunk returns raw (normalized) actions + assert torch.allclose(chunk, torch.zeros_like(chunk), atol=1e-6), ( + "predict_action_chunk should return raw actions without unnormalization applied." + ) + + # Postprocessor applies unnormalization: 0 → (0+1)/2 * (max-min) + min = 5 + i + unnormed = postprocessor(chunk) + from lerobot.utils.constants import ACTION + + a_min = dataset_stats[ACTION]["min"].numpy() + a_max = dataset_stats[ACTION]["max"].numpy() + expected_first = 0.5 * (0.0 + 1.0) * (a_max[0] - a_min[0]) + a_min[0] + assert unnormed[0, 0, 0].item() == pytest.approx(expected_first, abs=1e-5) + + +# --------------------------------------------------------------------------- +# World-model view adjustment (padding / trimming) tests +# --------------------------------------------------------------------------- + + +_MULTIVIEW_NUM_FRAMES = 4 # must be >= 2 * jepa_tubelet_size (=2) for world-model tests + + +def _make_multiview_config(num_views: int, jepa_tubelet_size: int = 2) -> VLAJEPAConfig: + from lerobot.configs.types import FeatureType, PolicyFeature + from lerobot.utils.constants import OBS_IMAGES, OBS_STATE + + config = VLAJEPAConfig( + input_features={ + **{ + f"{OBS_IMAGES}.cam{i}": PolicyFeature( + type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE) + ) + for i in range(num_views) + }, + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)), + }, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}, + device="cpu", + chunk_size=ACTION_HORIZON, + n_action_steps=N_ACTION_STEPS, + action_dim=ACTION_DIM, + state_dim=STATE_DIM, + num_video_frames=_MULTIVIEW_NUM_FRAMES, + num_action_tokens_per_timestep=2, + num_embodied_action_tokens_per_instruction=3, + num_inference_timesteps=2, + action_hidden_size=QWEN_HIDDEN_SIZE, + action_model_type="DiT-test", + action_num_layers=1, + predictor_depth=1, + predictor_num_heads=2, + predictor_mlp_ratio=2.0, + jepa_tubelet_size=jepa_tubelet_size, + ) + config.validate_features() + return config + + +def _make_multiview_train_batch(num_views: int, batch_size: int = BATCH_SIZE) -> dict: + from lerobot.utils.constants import OBS_IMAGES, OBS_STATE + + batch = { + f"{OBS_IMAGES}.cam{i}": torch.rand(batch_size, _MULTIVIEW_NUM_FRAMES, 3, IMAGE_SIZE, IMAGE_SIZE) + for i in range(num_views) + } + batch[OBS_STATE] = torch.randn(batch_size, 1, STATE_DIM) + batch[ACTION] = torch.randn(batch_size, ACTION_HORIZON, ACTION_DIM) + batch["task"] = ["pick up the cube"] * batch_size + return batch + + +@pytest.mark.parametrize( + "num_views", + [ + 1, # fewer views than jepa_tubelet_size → first view duplicated + 2, # exact match → unchanged + 3, # more views than jepa_tubelet_size → trimmed to first two + ], +) +def test_training_forward_world_model_view_adjustment( + patch_vla_jepa_external_models: None, + num_views: int, +) -> None: + """World-model view padding/trimming must not break the training forward pass.""" + set_seed_all(42) + policy = VLAJEPAPolicy(_make_multiview_config(num_views=num_views, jepa_tubelet_size=2)) + policy.train() + loss, logs = policy.forward(_make_multiview_train_batch(num_views=num_views)) + assert torch.isfinite(loss) + assert logs["wm_loss"] >= 0 + + +def test_single_view_is_duplicated_for_world_model(patch_vla_jepa_external_models: None) -> None: + """With one dataset view and jepa_tubelet_size=2, the view must be duplicated before encoding.""" + set_seed_all(42) + policy = VLAJEPAPolicy(_make_multiview_config(num_views=1, jepa_tubelet_size=2)) + policy.train() + + captured_videos: list = [] + original_processor = policy.model.video_processor + + class _CapturingProcessor: + def __call__(self, videos: list, return_tensors: str) -> dict: + captured_videos.extend(videos) + return original_processor(videos=videos, return_tensors=return_tensors) + + policy.model.video_processor = _CapturingProcessor() + policy.forward(_make_multiview_train_batch(num_views=1)) + + # reshape is batch-major: (b0v0, b0v1, b1v0, b1v1, …) + assert len(captured_videos) == BATCH_SIZE * 2 + for i in range(BATCH_SIZE): + np.testing.assert_array_equal(captured_videos[2 * i], captured_videos[2 * i + 1]) + + +def test_excess_views_trimmed_for_world_model(patch_vla_jepa_external_models: None) -> None: + """With three dataset views and jepa_tubelet_size=2, only the first two views reach the encoder.""" + set_seed_all(42) + policy = VLAJEPAPolicy(_make_multiview_config(num_views=3, jepa_tubelet_size=2)) + policy.train() + + captured_videos: list = [] + original_processor = policy.model.video_processor + + class _CapturingProcessor: + def __call__(self, videos: list, return_tensors: str) -> dict: + captured_videos.extend(videos) + return original_processor(videos=videos, return_tensors=return_tensors) + + policy.model.video_processor = _CapturingProcessor() + policy.forward(_make_multiview_train_batch(num_views=3)) + + # Only B*2 items must reach the encoder, not B*3. + assert len(captured_videos) == BATCH_SIZE * 2 diff --git a/tests/policies/vla_jepa/test_world_model.py b/tests/policies/vla_jepa/test_world_model.py new file mode 100644 index 000000000..555b2cd11 --- /dev/null +++ b/tests/policies/vla_jepa/test_world_model.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python + +from __future__ import annotations + +import pytest +import torch + +from lerobot.policies.vla_jepa.world_model import ( + ActionConditionedVideoPredictor, +) + +_ACTION_EMBED_DIM = 8 + + +def _make_predictor( + embed_dim: int = 8, + action_embed_dim: int = _ACTION_EMBED_DIM, + predictor_embed_dim: int = 24, + num_action_tokens: int = 2, + tokens_per_frame: int = 1, +) -> ActionConditionedVideoPredictor: + return ActionConditionedVideoPredictor( + num_frames=1, + img_size=(1, tokens_per_frame), + patch_size=1, + tubelet_size=1, + embed_dim=embed_dim, + action_embed_dim=action_embed_dim, + predictor_embed_dim=predictor_embed_dim, + depth=1, + num_heads=2, + mlp_ratio=2.0, + num_action_tokens_per_step=num_action_tokens, + ) + + +@pytest.mark.parametrize( + "batch,num_steps,tokens_per_frame,embed_dim", + [ + (1, 2, 1, 8), + (2, 3, 4, 8), + (4, 5, 2, 16), + ], +) +def test_predictor_output_shape(batch: int, num_steps: int, tokens_per_frame: int, embed_dim: int) -> None: + predictor = _make_predictor( + embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM, tokens_per_frame=tokens_per_frame + ) + frame_tokens = torch.randn(batch, num_steps * tokens_per_frame, embed_dim) + action_tokens = torch.randn(batch, num_steps * 2, _ACTION_EMBED_DIM) + out = predictor(frame_tokens, action_tokens) + assert tuple(out.shape) == (batch, num_steps * tokens_per_frame, embed_dim) + assert torch.isfinite(out).all() + + +def test_predictor_step_mismatch_raises() -> None: + predictor = _make_predictor(tokens_per_frame=4) + frame_tokens = torch.randn(2, 3 * 4, 8) # 3 steps, 4 tokens each + with pytest.raises(RuntimeError): + predictor(frame_tokens, torch.randn(2, 2 * 2, 8)) # 2 steps → mismatch diff --git a/uv.lock b/uv.lock index fbcdf1a83..6acacab56 100644 --- a/uv.lock +++ b/uv.lock @@ -3052,6 +3052,11 @@ video-benchmark = [ viz = [ { name = "rerun-sdk" }, ] +vla-jepa = [ + { name = "diffusers" }, + { name = "qwen-vl-utils" }, + { name = "transformers" }, +] wallx = [ { name = "peft" }, { name = "qwen-vl-utils" }, @@ -3120,6 +3125,7 @@ requires-dist = [ { name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'diffusion'" }, { name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'groot'" }, { name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'multi-task-dit'" }, + { name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'vla-jepa'" }, { name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" }, @@ -3171,6 +3177,7 @@ requires-dist = [ { name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'eo1'" }, { name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'robometer'" }, { name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'sarm'" }, + { name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'vla-jepa'" }, { name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'wallx'" }, { name = "lerobot", extras = ["reachy2"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["rebot"], marker = "extra == 'all'" }, @@ -3200,12 +3207,14 @@ requires-dist = [ { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'sarm'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'smolvla'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'topreward'" }, + { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'vla-jepa'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'wallx'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'xvla'" }, { name = "lerobot", extras = ["video-benchmark"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["viz"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["viz"], marker = "extra == 'core-scripts'" }, { name = "lerobot", extras = ["viz"], marker = "extra == 'dataset-viz'" }, + { name = "lerobot", extras = ["vla-jepa"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["wallx"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["xvla"], marker = "extra == 'all'" }, { name = "matplotlib", marker = "extra == 'matplotlib-dep'", specifier = ">=3.10.3,<4.0.0" }, @@ -3267,7 +3276,7 @@ requires-dist = [ { name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" }, { name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" }, ] -provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"] +provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"] [[package]] name = "librt"