Compare commits

...

44 Commits

Author SHA1 Message Date
Maximellerbach
ac9128b319 fixing simlink 2026-05-29 15:18:24 +02:00
Maximellerbach
61d82a5773 adding configuration gripper index and threshold 2026-05-29 14:11:51 +02:00
Maximellerbach
d8ed30d58c removing swish in favor of silu 2026-05-29 13:10:21 +02:00
Maximellerbach
79f6756505 cleanup 2026-05-29 12:53:40 +02:00
Maximellerbach
7a98c56ce4 removing useless pre-processor 2026-05-29 12:28:49 +02:00
Maximellerbach
9961eb6918 adding .mdx docs and shortening polivy_vla_jepa_README.md 2026-05-29 11:52:42 +02:00
Maximellerbach
e1852db71a adding licences 2026-05-29 11:51:58 +02:00
Maximellerbach
c2b29c8ae0 removing conversion script 2026-05-28 12:05:35 +02:00
Maximellerbach
58eac863aa fixing misconception about multiview / singleview handling 2026-05-28 12:05:35 +02:00
Maximellerbach
952e5146dc smol fix to avoid having default CPU device when training 2026-05-28 12:05:35 +02:00
Maximellerbach
37fda2a6fc adding instructions for different embodiement + fixing some tests 2026-05-28 12:05:35 +02:00
Maximellerbach
df7d5132d1 fix qwen norm layer output libero eval is now as expected 2026-05-28 12:05:35 +02:00
Maxime Ellerbach
8efa5cabe9 trying to close success rate gap 2026-05-28 12:05:35 +02:00
Maximellerbach
7e23859c55 fixing training and exal examples 2026-05-28 12:05:35 +02:00
Maximellerbach
a24f669deb adding guard for diffusers 2026-05-28 12:05:35 +02:00
Maximellerbach
b7727b8a6c adressing dtype zeros issue 2026-05-28 12:05:35 +02:00
Maxime Ellerbach
7db4414e6b fixing doc defaults args
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
2026-05-28 12:05:35 +02:00
Maximellerbach
7da594fda8 pre-commit cleanup 2026-05-28 12:05:35 +02:00
Maxime Ellerbach
01ce5d7af1 refactoring into using pre and post processor 2026-05-28 12:05:35 +02:00
Maxime Ellerbach
83ef59e020 lots of changes to make existing weights work, need to massively refactor the pre and post processing 2026-05-28 12:05:35 +02:00
Maximellerbach
997b713f14 removing missleading future_action_window_size to just use chunk_size 2026-05-28 12:05:35 +02:00
Maximellerbach
1e3d25f10e allow different state dim and action dim 2026-05-28 12:05:35 +02:00
Maximellerbach
8c0efb8295 trying out to re-init the action head to avoid pretraining dimension mismatch 2026-05-28 12:05:35 +02:00
Maximellerbach
9ef0fd5433 make default params more aligned with paper and pretrained models
- adding possibility of freezing qwen backbone and world model
- added tests for weight loading
2026-05-28 12:05:35 +02:00
Maximellerbach
26d2ac48a8 add one-shot script to convert ginwind/VLA-JEPA checkpoints to safetensors (will remove once migrated) 2026-05-28 12:05:34 +02:00
Maximellerbach
0f29cd3167 add VLA-JEPA documentation
Covers architecture overview, pretrained checkpoints, config reference,
training/eval commands for LIBERO-10, and guidance on fine-tuning for
single-camera datasets.
2026-05-28 12:05:34 +02:00
Maximellerbach
64c9570547 update VLA-JEPA tests for arch changes and action_is_pad
- Switch conftest to use `action_model_type="DiT-test"` now that
  `action_num_heads` / `action_attention_head_dim` have been removed.
- Add action_head tests covering fully-padded loss (zero) and equivalence
  of action_is_pad=None vs all-zeros mask.
- Remove obsolete `test_native_to_lerobot_wm_only` test.
2026-05-28 12:05:34 +02:00
Maximellerbach
8b03d25fef propagate action_is_pad masking through VLA-JEPA policy pipeline
Pass the `action_is_pad` tensor from the batch through to the action head
so padded timesteps are excluded from the flow-matching loss.
2026-05-28 12:05:34 +02:00
Maximellerbach
5e37d97631 align VLA-JEPA architecture with original checkpoint
- Remove stale `action_num_heads` / `action_attention_head_dim` config fields;
  DiT head dimensions are now always derived from the preset (DiT-B/L/test).
- Add `num_target_vision_tokens` and `action_max_seq_len` config fields required
  by the action head's future-token embedding and positional embedding tables.
- Fix default `qwen_model_name` to 2B (matches all released checkpoints).
- Rename `ActionEncoder` attrs w1/w2/w3 → layer1/layer2/layer3 to match
  checkpoint key names; replace `nn.Sequential` decoder/state-encoder with
  `_MLP2` (layer1/layer2 naming).
- Fix `VLAJEPAActionHead` to size ActionEncoder and StateEncoder at `inner_dim`
  (DiT input width) rather than `action_hidden_size` (DiT output width).
- Rename `DiT.blocks` → `transformer_blocks` and `attn` → `attn1` to match
  checkpoint; add alternating cross/self attention (even blocks cross-attend to
  Qwen context, odd blocks self-attend).
- Add `DiT-test` preset for unit tests.
- Rewrite `ActionConditionedVideoPredictor` with explicit ViT-style blocks
  (`_PredictorBlock` with fused qkv) to match checkpoint structure; rename
  `encoder`/`norm`/`proj` → `predictor_blocks`/`predictor_norm`/`predictor_proj`.
2026-05-28 12:05:34 +02:00
Maximellerbach
a71e0d34ad adding more tests to ensure good coverage 2026-05-28 12:05:34 +02:00
Maximellerbach
d00b3e993a some more fixes to be closer to the original implem 2026-05-28 12:05:34 +02:00
Maxime Ellerbach
596c72bfc6 adjusting obs steps, tublets size to match original implementation 2026-05-28 12:05:34 +02:00
Maxime Ellerbach
ea535ad98d fixing wm_loss not propagating 2026-05-28 12:05:34 +02:00
Maxime Ellerbach
7368a0085a fix warnings with qwen processor kwargs 2026-05-28 12:05:34 +02:00
Maxime Ellerbach
7dba4f19a9 fixing action and state dim 2026-05-28 12:05:34 +02:00
Maximellerbach
90d398ea59 adding guards to avoid needing transformers and diffusers for type checking and basic tests 2026-05-28 12:05:34 +02:00
Maximellerbach
16a4643000 updating uv lock 2026-05-28 12:05:34 +02:00
Maximellerbach
6fa78ca8b4 adding deps to pyproject.toml 2026-05-28 11:59:32 +02:00
Maximellerbach
ec4bf4e47f linting 2026-05-28 11:59:32 +02:00
ginwind
da56489174 (feat)policies: add VLA-JEPA 2026-05-28 11:59:32 +02:00
ginwind
addb354296 support vla_jepa 2026-05-28 11:59:31 +02:00
ginwind
848ed3240e feat(policies): add VLA-JEPA 2026-05-28 11:58:03 +02:00
ginwind
f7bb1795e7 feat(policies): add VLA-JEPA 2026-05-28 11:58:03 +02:00
ginwind
0355902fba first commit 2026-05-28 11:58:03 +02:00
19 changed files with 3283 additions and 1 deletions

View File

@@ -61,6 +61,8 @@
title: π₀.₅ (Pi05)
- local: molmoact2
title: MolmoAct2
- local: vla_jepa
title: VLA-JEPA
- local: eo1
title: EO-1
- local: groot

View File

@@ -0,0 +1,39 @@
# VLA-JEPA
This repository contains the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
Converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA).
---
## Architecture Overview
| Component | Module | Role |
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
At inference time only the Qwen backbone and action head are used; the world model is not needed.
---
## Citation
```bibtex
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
year = {2026},
eprint = {2602.10098},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2602.10098},
}
```
---
## License
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.

235
docs/source/vla_jepa.mdx Normal file
View File

@@ -0,0 +1,235 @@
# VLA-JEPA
This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
---
## Architecture Overview
VLA-JEPA has three main components:
| Component | Module | Role |
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
### Data flow
**Training:**
1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens.
2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens.
3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching.
4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss.
**Inference:**
Only Qwen + the action head are used. The world model is not needed at inference time.
### Action head details
Available presets via `action_model_type`:
| Preset | Hidden dim | Heads | Head dim |
| ------- | ---------- | ----- | -------- |
| `DiT-B` | 768 | 12 | 64 |
| `DiT-L` | 1536 | 32 | 48 |
### World model details
The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes:
- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim`
- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim`
It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints).
---
## Pretrained Checkpoints
Three checkpoints are available directly inside the LeRobot org here: [`lerobot/VLA-JEPA`](https://huggingface.co/collections/lerobot/vla-jepa), converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA):
| Checkpoint | Dataset | Cameras | World model | Action dim |
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 (view duplicated ×2) | Enabled | 7 |
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
---
## Configuration
Key parameters in `VLAJEPAConfig`:
| Parameter | Default | Description |
| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `chunk_size` | 7 | Number of actions predicted per inference call |
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
| `num_video_frames` | 8 | Video clip length fed to the world model |
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) |
| `gripper_dim` | 6 | Index of the gripper dimension in the action vector (e.g. 6 for a 7-DoF arm with gripper as the last joint) |
| `gripper_threshold` | 0.5 | Threshold used by `pre_snap_gripper_action` and `binarize_gripper_action` to binarize the gripper dimension |
| `pre_snap_gripper_action` | `True` | Snap the gripper dim to {0, 1} before unnormalization. Set to `False` for robots without a binary gripper |
| `binarize_gripper_action` | `True` | Binarize the gripper dim to {-1, 1} after unnormalization. Set to `False` for robots without a binary gripper |
---
## Training
Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints.
### Full training from scratch
```bash
lerobot-train \
policy.type=vla_jepa \
policy.repo_id=your_org/your_repo \
dataset.repo_id=your_org/your_dataset
```
### Fine-tuning from a pretrained checkpoint
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=your_org/your_dataset
```
If you want to freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--policy.freeze_qwen=true \
--dataset.repo_id=your_org/your_dataset
```
### Fine-tuning on a different embodiment
When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error.
The layers that depend on `action_dim` and `state_dim` are:
| Layer | Key prefix |
| ----------------------------------------- | ----------------------------------- |
| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` |
| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` |
| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` |
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--policy.freeze_qwen=true \
--policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \
--dataset.repo_id=your_org/your_dataset
```
If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list.
### Reproducing the LIBERO results
**Training on LIBERO:**
starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset.
Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256.
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=HuggingFaceVLA/libero \
--steps=30000
```
**Evaluating the pretrained LIBERO-10 checkpoint:**
```bash
lerobot-eval \
--policy.path=lerobot/VLA-JEPA-LIBERO \
--env.type=libero \
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
--eval.n_episodes=10 \
--eval.batch_size=5
```
To evaluate a subset of tasks only:
```bash
lerobot-eval \
--policy.path=lerobot/VLA-JEPA-LIBERO \
--env.type=libero \
--env.task=libero_10 \
--env.task_ids='[0,1,2]' \
--eval.n_episodes=10 \
--eval.batch_size=5
```
**Expected results:**
| Suite | Episodes | Successes | Success Rate |
| -------------- | -------- | --------- | ------------ |
| libero_spatial | 100 | 93 | **95.0%** |
| libero_object | 100 | 100 | **100.0%** |
| libero_goal | 100 | 98 | **98.0%** |
| libero_10 | 100 | 96 | **93.0%** |
| **Overall** | **400** | **387** | **96.5%** |
---
## Fine-tuning on datasets with a different number of cameras
The pretrained world model predictor was trained with `embed_dim = jepa_tubelet_size × 1024` (default `jepa_tubelet_size=2`).
**Default behaviour — view padding / trimming (no action required)**
When fine-tuning from `VLA-JEPA-Pretrain` the model automatically adjusts the number of views fed to the world model to match `jepa_tubelet_size`:
- **Single-view datasets (e.g. BridgeV2):** the single-view latent is duplicated to produce a two-view world-model input, preserving the JEPA self-supervised signal without any weight mismatch.
- **>2-view datasets (e.g. DROID with 3 views):** all views are passed to the Qwen backbone (for richer context), but only the first `jepa_tubelet_size` views (one wrist + one third-person, following the configured view order) are used for the world model.
**Option 1 — Disable the world model**
Set `enable_world_model=False` to skip the JEPA loss entirely. Only the Qwen backbone and action head are loaded and trained. This is sufficient for good action performance.
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.enable_world_model=false \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=your_org/single_camera_dataset
```
**Option 2 — Reinitialize the predictor input projection**
If you want to change `jepa_tubelet_size` to a value other than 2, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
---
## Citation
```bibtex
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
year = {2026},
eprint = {2602.10098},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2602.10098},
}
```
---
## License
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.

View File

@@ -216,6 +216,7 @@ topreward = ["lerobot[transformers-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
@@ -281,6 +282,7 @@ all = [
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
"lerobot[hilserl]",
"lerobot[vla_jepa]",
"lerobot[async]",
"lerobot[dev]",
"lerobot[test]",

View File

@@ -57,6 +57,7 @@ from .pretrained import PreTrainedPolicy
from .smolvla.configuration_smolvla import SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig
from .utils import validate_visual_features_consistency
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from .vqbet.configuration_vqbet import VQBeTConfig
from .wall_x.configuration_wall_x import WallXConfig
from .xvla.configuration_xvla import XVLAConfig
@@ -157,6 +158,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .molmoact2.modeling_molmoact2 import MolmoAct2Policy
return MolmoAct2Policy
elif name == "vla_jepa":
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
return VLAJEPAPolicy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -211,6 +216,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return EO1Config(**kwargs)
elif policy_type == "molmoact2":
return MolmoAct2Config(**kwargs)
elif policy_type == "vla_jepa":
return VLAJEPAConfig(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -415,6 +422,7 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, EO1Config):
from .eo1.processor_eo1 import make_eo1_pre_post_processors
@@ -432,6 +440,14 @@ def make_pre_post_processors(
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(policy_cfg, VLAJEPAConfig):
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
processors = make_vla_jepa_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_policy_config(

View File

@@ -0,0 +1 @@
../../../../docs/source/policy_vla_jepa_README.md

View File

@@ -0,0 +1,23 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_vla_jepa import VLAJEPAConfig
from .modeling_vla_jepa import VLAJEPAPolicy
from .processor_vla_jepa import make_vla_jepa_pre_post_processors
__all__ = [
"VLAJEPAConfig",
"VLAJEPAPolicy",
"make_vla_jepa_pre_post_processors",
]

View File

@@ -0,0 +1,337 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import OrderedDict
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
from torch.distributions import Beta
from lerobot.utils.import_utils import _diffusers_available, require_package
if TYPE_CHECKING or _diffusers_available:
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
else:
class ModelMixin: # type: ignore[no-redef]
pass
class ConfigMixin: # type: ignore[no-redef]
pass
register_to_config = lambda f: f # noqa: E731
Attention = FeedForward = TimestepEmbedding = Timesteps = None
from .configuration_vla_jepa import VLAJEPAConfig
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
timesteps = timesteps.float()
batch_size, seq_len = timesteps.shape
half_dim = self.embedding_dim // 2
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device)
exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1))
freqs = timesteps.unsqueeze(-1) * exponent.exp()
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1)
class ActionEncoder(nn.Module):
def __init__(self, action_dim: int, hidden_size: int):
super().__init__()
self.layer1 = nn.Linear(action_dim, hidden_size)
self.layer2 = nn.Linear(hidden_size * 2, hidden_size)
self.layer3 = nn.Linear(hidden_size, hidden_size)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = actions.shape
if timesteps.ndim != 1 or timesteps.shape[0] != batch_size:
raise ValueError("timesteps must have shape [batch_size].")
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
action_emb = self.layer1(actions)
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
return self.layer3(F.silu(self.layer2(torch.cat([action_emb, time_emb], dim=-1))))
class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
require_package("diffusers", extra="vla_jepa")
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype)
return self.timestep_embedder(projected)
class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False)
self.silu = nn.SiLU()
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1)
return self.norm(x) * (1 + scale[:, None]) + shift[:, None]
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout: float,
cross_attention_dim: int,
is_cross_attention: bool = True,
) -> None:
super().__init__()
self.is_cross_attention = is_cross_attention
self.norm1 = AdaLayerNorm(dim)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=True,
cross_attention_dim=cross_attention_dim,
out_bias=True,
)
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None,
temb: torch.Tensor,
) -> torch.Tensor:
attn_input = self.norm1(hidden_states, temb)
attention_context = encoder_hidden_states if self.is_cross_attention else None
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context)
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
return hidden_states
class DiT(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
output_dim: int,
num_layers: int,
dropout: float,
cross_attention_dim: int,
) -> None:
super().__init__()
self.inner_dim = num_attention_heads * attention_head_dim
self.timestep_encoder = TimestepEncoder(self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim,
is_cross_attention=layer_idx % 2 == 0,
)
for layer_idx in range(num_layers)
]
)
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2)
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.Tensor,
) -> torch.Tensor:
temb = self.timestep_encoder(timestep)
x = hidden_states
for block in self.transformer_blocks:
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
return self.proj_out_2(x)
@dataclass
class ActionModelPreset:
hidden_size: int
attention_head_dim: int
num_attention_heads: int
DIT_PRESETS = {
"DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12),
"DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32),
"DiT-test": ActionModelPreset(hidden_size=16, attention_head_dim=8, num_attention_heads=2),
}
class VLAJEPAActionHead(nn.Module):
def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None:
super().__init__()
preset = DIT_PRESETS[config.action_model_type]
self.config = config
num_heads = config.action_num_heads or preset.num_attention_heads
head_dim = config.action_attention_head_dim or preset.attention_head_dim
inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768
self.input_embedding_dim = inner_dim
self.action_horizon = config.chunk_size
self.num_inference_timesteps = config.num_inference_timesteps
hidden_size = config.action_hidden_size
self.model = DiT(
num_attention_heads=num_heads,
attention_head_dim=head_dim,
output_dim=hidden_size,
num_layers=config.action_num_layers,
dropout=config.action_dropout,
cross_attention_dim=cross_attention_dim,
)
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
self.action_decoder = nn.Sequential(
OrderedDict(
[
("layer1", nn.Linear(hidden_size, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, config.action_dim)),
]
)
)
self.state_encoder = (
nn.Sequential(
OrderedDict(
[
("layer1", nn.Linear(config.state_dim, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, inner_dim)),
]
)
)
if config.state_dim > 0
else None
)
self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim)
self.position_embedding = nn.Embedding(
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
inner_dim,
)
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype)
return (self.config.action_noise_s - sample) / self.config.action_noise_s
def _build_inputs(
self,
conditioning_tokens: torch.Tensor,
actions: torch.Tensor,
state: torch.Tensor | None,
timesteps: torch.Tensor,
) -> torch.Tensor:
action_features = self.action_encoder(actions, timesteps)
pos_ids = torch.arange(action_features.shape[1], device=actions.device)
action_features = action_features + self.position_embedding(pos_ids)[None]
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1)
seq = [future_tokens, action_features]
if state is not None and self.state_encoder is not None:
if state.ndim == 2:
state = state.unsqueeze(1)
seq.insert(0, self.state_encoder(state))
return torch.cat(seq, dim=1)
def forward(
self,
conditioning_tokens: torch.Tensor,
actions: torch.Tensor,
state: torch.Tensor | None = None,
action_is_pad: torch.Tensor | None = None,
) -> torch.Tensor:
noise = torch.randn_like(actions)
t = self.sample_time(actions.shape[0], actions.device, actions.dtype)
noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions
velocity = actions - noise
t_discretized = (t * self.config.action_num_timestep_buckets).long()
hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized)
pred = self.model(
hidden_states=hidden_states,
encoder_hidden_states=conditioning_tokens,
timestep=t_discretized,
)
pred_actions = self.action_decoder(pred[:, -actions.shape[1] :])
if action_is_pad is None:
action_is_pad = torch.zeros(actions.shape[:2], dtype=torch.bool, device=actions.device)
loss = F.mse_loss(pred_actions, velocity, reduction="none") # [B, T, action_dim]
valid_mask = ~action_is_pad.unsqueeze(-1) # [B, T, 1]
num_valid = valid_mask.sum() * loss.shape[-1]
return (loss * valid_mask).sum() / num_valid.clamp_min(1)
@torch.no_grad()
def predict_action(
self,
conditioning_tokens: torch.Tensor,
state: torch.Tensor | None = None,
) -> torch.Tensor:
batch_size = conditioning_tokens.shape[0]
actions = torch.randn(
batch_size,
self.action_horizon,
self.config.action_dim,
dtype=conditioning_tokens.dtype,
device=conditioning_tokens.device,
)
dt = 1.0 / max(self.num_inference_timesteps, 1)
for step in range(self.num_inference_timesteps):
t_cont = step / float(max(self.num_inference_timesteps, 1))
t_value = int(t_cont * self.config.action_num_timestep_buckets)
timesteps = torch.full(
(batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long
)
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
pred = self.model(
hidden_states=hidden_states,
encoder_hidden_states=conditioning_tokens,
timestep=timesteps,
)
pred_velocity = self.action_decoder(pred[:, -self.action_horizon :])
actions = actions + dt * pred_velocity
return actions

View File

@@ -0,0 +1,154 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("vla_jepa")
@dataclass
class VLAJEPAConfig(PreTrainedConfig):
n_obs_steps: int = 1
chunk_size: int = 7
n_action_steps: int = 7
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MIN_MAX,
}
)
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
freeze_qwen: bool = False
enable_world_model: bool = True
# Enables cross-embodiment transfer: when fine-tuning a pretrained model on a robot with a
# different action or state dimensionality, the input/output projection layers must be
# re-initialised from scratch while the rest of the network keeps its pretrained weights.
# List the key prefixes that are allowed to have shape mismatches; anything else raises an error.
# e.g. ["model.action_model.action_encoder", "model.action_model.state_encoder"]
reinit_modules: list[str] | None = None
tokenizer_padding_side: str = "left"
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
special_action_token: str = "<|action_{}|>"
embodied_action_token: str = "<|embodied_action|>"
action_dim: int = 7
state_dim: int = 8
num_action_tokens_per_timestep: int = 8
num_embodied_action_tokens_per_instruction: int = 32
num_inference_timesteps: int = 4
action_hidden_size: int = 1024
action_model_type: str = "DiT-B"
action_num_layers: int = 16
action_num_heads: int | None = None
action_attention_head_dim: int | None = None
action_dropout: float = 0.2
action_num_timestep_buckets: int = 1000
action_noise_beta_alpha: float = 1.5
action_noise_beta_beta: float = 1.0
action_noise_s: float = 0.999
num_target_vision_tokens: int = 32
action_max_seq_len: int = 1024
# total video frames loaded per sample
num_video_frames: int = 8
predictor_depth: int = 12
predictor_num_heads: int = 8
predictor_mlp_ratio: float = 4.0
predictor_dropout: float = 0.0
world_model_loss_weight: float = 0.1
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
resize_images_to: tuple[int, int] | None = None
binarize_gripper_action: bool = True
pre_snap_gripper_action: bool = True
clip_normalized_actions: bool = True
gripper_dim: int = 6
gripper_threshold: float = 0.5
torch_dtype: str = "bfloat16"
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_grad_clip_norm: float = 10.0
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
def __post_init__(self) -> None:
super().__post_init__()
if self.freeze_qwen and self.enable_world_model:
# freezing qwen backbone makes world model training irrelevant since no grad flows
self.enable_world_model = False
if self.n_action_steps > self.chunk_size:
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
if self.num_video_frames < 2 * self.jepa_tubelet_size:
raise ValueError(
f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` "
f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position."
)
def validate_features(self) -> None:
if not self.image_features:
raise ValueError("VLAJEPA requires at least one visual input feature.")
if self.action_feature is None:
raise ValueError("VLAJEPA requires an action output feature.")
self.action_dim = self.action_feature.shape[0]
if self.robot_state_feature is not None:
self.state_dim = self.robot_state_feature.shape[0]
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list[int]:
# load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1]
# matches original repo's observation_indices=list(range(video_horizon))
return list(range(self.num_video_frames))
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -0,0 +1,629 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from PIL import Image
from torch import Tensor, nn
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoModel, AutoVideoProcessor
else:
AutoModel = None
AutoVideoProcessor = None
from .action_head import VLAJEPAActionHead
from .configuration_vla_jepa import VLAJEPAConfig
from .qwen_interface import Qwen3VLInterface
from .world_model import ActionConditionedVideoPredictor
# ============================================================================
# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation
# ============================================================================
class VLAJEPAModel(nn.Module):
"""
Native VLA-JEPA model following the original starVLA VLA_JEPA.py.
Components:
- Qwen3-VL: vision-language backbone for fused embeddings
- DiT-B: flow-matching action head for future action prediction
- V-JEPA: world model for video frame prediction
Input: List[dict] native format (same as original starVLA)
- "image": List[PIL.Image] (multi-view images)
- "video": np.ndarray [V, T, H, W, 3]
- "lang": str (task instruction)
- "action": np.ndarray [T, action_dim] (optional, training only)
- "state": np.ndarray [1, state_dim] (optional)
"""
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
require_package("transformers", extra="vla_jepa")
self.config = config
# Vision-language backbone
self.qwen = Qwen3VLInterface(config)
# Tokenizer expansion for special action tokens
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = (
self.qwen.expand_tokenizer()
)
# Action head (flow-matching DiT)
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
# JEPA world model components
if config.enable_world_model:
self.video_encoder = AutoModel.from_pretrained(
config.jepa_encoder_name,
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
)
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
num_views = config.jepa_tubelet_size
tubelet_size = self.video_encoder.config.tubelet_size
image_size = getattr(self.video_encoder.config, "image_size", None)
if image_size is None:
first_image_shape = next(iter(config.image_features.values())).shape
image_size = first_image_shape[-1]
self.video_predictor = ActionConditionedVideoPredictor(
num_frames=config.num_video_frames // tubelet_size,
img_size=(image_size, image_size),
patch_size=16,
tubelet_size=1,
embed_dim=self.video_encoder.config.hidden_size * num_views,
action_embed_dim=self.qwen.model.config.hidden_size,
predictor_embed_dim=self.video_encoder.config.hidden_size,
depth=config.predictor_depth,
num_heads=config.predictor_num_heads,
mlp_ratio=config.predictor_mlp_ratio,
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
)
else:
self.video_encoder = None
self.video_processor = None
self.video_predictor = None
if config.freeze_qwen:
self.qwen.requires_grad_(False)
# Build prompt placeholders.
# Use the encoder's actual tubelet_size when available (world model enabled),
# otherwise fall back to config.
_tubelet_size = (
self.video_encoder.config.tubelet_size
if config.enable_world_model
else self.config.jepa_tubelet_size
)
num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1
self.replace_prompt = "".join(
token * self.config.num_action_tokens_per_timestep
for token in self.action_tokens[:num_action_prompt_steps]
)
self.embodied_replace_prompt = (
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
)
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return the last decoder hidden state before the final RMSNorm.
The model was trained with the output of the last transformer block BEFORE
the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
`output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
`@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
the correct pre-RMSNorm state, matching the training-time representation.
"""
captured: list[torch.Tensor] = []
def _hook(module, input, output):
h = output[0] if isinstance(output, tuple) else output
captured.append(h)
last_layer = self.qwen.model.model.language_model.layers[-1]
handle = last_layer.register_forward_hook(_hook)
try:
self.qwen.model(
**qwen_inputs,
output_hidden_states=False,
output_attentions=False,
return_dict=True,
)
finally:
handle.remove()
return captured[0] # [B, seq_len, H]
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
"""
Native forward pass following original starVLA VLA_JEPA.forward.
Args:
examples: List of per-sample dicts with keys:
"image" : List[PIL.Image] — multi-view images
"video" : np.ndarray [V, T, H, W, 3]
"lang" : str — task instruction
"action" : np.ndarray [T, action_dim] (optional)
"state" : np.ndarray [1, state_dim] (optional)
Returns:
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
"""
# Unpack native format (same pattern as original VLA_JEPA.py)
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
instructions = [ex["lang"] for ex in examples] # List[str]
has_action = "action" in examples[0] and examples[0]["action"] is not None
actions = [ex["action"] for ex in examples] if has_action else None
has_state = "state" in examples[0] and examples[0]["state"] is not None
state = [ex["state"] for ex in examples] if has_state else None
action_is_pad = (
[ex["action_is_pad"] for ex in examples]
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
else None
)
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
batch_videos = np.stack(batch_videos)
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
# Adjust number of views for the world model:
# - fewer views than expected: duplicate the first view to fill up
# - more views than expected: keep only the first num_views_world_model views
num_views_world_model = self.config.jepa_tubelet_size
if batch_videos.shape[1] < num_views_world_model:
num_missing_views = num_views_world_model - batch_videos.shape[1]
first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1)
batch_videos = np.concatenate([batch_videos, first_view], axis=1)
elif batch_videos.shape[1] > num_views_world_model:
batch_videos = batch_videos[:, :num_views_world_model]
# ---- Step 1: QwenVL encode (same as original) ----
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
# Locate embodied-action tokens (always needed for action head)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
# Locate action tokens (only needed for world model predictor)
if self.config.enable_world_model:
action_mask = torch.isin(
qwen_inputs["input_ids"],
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
)
action_indices = action_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
if self.config.enable_world_model:
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
# ---- Step 2+3: JEPA Encoder + Predictor ----
device_wm = last_hidden.device
if not self.config.enable_world_model:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
b, v, t_frames, c, h_img, w_img = batch_videos.shape
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
"pixel_values_videos"
].to(self.video_encoder.device) # [B*V, T, C, H, W]
with torch.no_grad():
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
tubelet_size = self.video_encoder.config.tubelet_size
device_wm = video_embeddings.device
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
t_enc_total = self.config.num_video_frames // tubelet_size
if t_enc_total < 2:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
# input_states: positions 0..T-2, gt_states: positions 1..T-1
t_enc_ctx = t_enc_total - 1
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
gt_states = video_embeddings[:, tokens_per_frame:, :]
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_actions:
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
action_tokens = torch.cat([action_tokens, pad], dim=1)
predicted_states = self.video_predictor(
input_states.float(),
action_tokens[:, :expected_actions].float(),
)
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
if not has_action:
return {"wm_loss": wm_loss}
# ---- Step 4: Action Head ----
with torch.autocast(device_type=device_type, dtype=torch.float32):
actions_tensor = torch.tensor(
np.array(actions), device=last_hidden.device, dtype=torch.float32
) # [B, T_full, action_dim]
action_horizon = self.config.chunk_size
actions_target = actions_tensor[:, -action_horizon:, :]
state_tensor = None
if state is not None:
state_tensor = torch.tensor(
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
) # [B, 1, state_dim]
repeated_diffusion_steps = self.config.repeated_diffusion_steps
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
if state_tensor is not None:
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
action_is_pad_rep = None
if action_is_pad is not None:
pad_tensor = torch.stack(
[
p.to(actions_target.device)
if isinstance(p, Tensor)
else torch.tensor(p, device=actions_target.device)
for p in action_is_pad
]
) # [B, T_full]
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
action_loss = self.action_model(
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
)
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
@torch.no_grad()
def predict_action(
self,
batch_images: list[list[Image.Image]],
instructions: list[str],
state: np.ndarray | None = None,
) -> np.ndarray:
"""
Native action prediction following original VLA_JEPA.predict_action.
Args:
batch_images: List of samples; each is List[PIL.Image] (multi-view).
instructions: Task instructions, one per sample.
state: Optional [B, state_dim] numpy array.
Returns:
np.ndarray [B, action_horizon, action_dim] — predicted actions.
"""
if self.config.resize_images_to is not None:
height, width = self.config.resize_images_to
resampling = getattr(Image, "Resampling", Image).BOX
batch_images = [
[image.resize((width, height), resample=resampling) for image in sample_images]
for sample_images in batch_images
]
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
state_tensor = None
if state is not None:
state_tensor = torch.from_numpy(np.array(state)).to(
device=last_hidden.device, dtype=last_hidden.dtype
)
pred_actions = self.action_model.predict_action(
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
) # [B, action_horizon, action_dim]
return pred_actions.detach().cpu().numpy()
# ============================================================================
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
# ============================================================================
class VLAJEPAPolicy(PreTrainedPolicy):
"""
LeRobot adapter for VLA-JEPA.
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
back to LeRobot format.
"""
config_class = VLAJEPAConfig
name = "vla_jepa"
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
super().__init__(config)
config.validate_features()
if dataset_meta := kwargs.get("dataset_meta"):
# cfg.input_features keeps the pretrained model's feature keys (needed for rename_map
# compatibility), so validate_features() may have read stale dims from a pretrained
# config. Override state_dim/action_dim from the actual dataset being used.
ds_features = dataset_meta.features
if OBS_STATE in ds_features:
config.state_dim = ds_features[OBS_STATE]["shape"][0]
if ACTION in ds_features:
config.action_dim = ds_features[ACTION]["shape"][0]
self.model = VLAJEPAModel(config)
self.reset()
def reset(self) -> None:
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
# ---- Format Conversion: LeRobot → Native ----
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
"""
Convert LeRobot batch format to native VLA-JEPA examples format.
LeRobot format:
batch = {
"observation.images.<key>": Tensor [B, C, H, W] or [B, T, C, H, W],
"observation.state": Tensor [B, state_dim] or [B, T, state_dim],
"action": Tensor [B, chunk_size, action_dim], (training only)
"task": str | List[str], (optional instruction)
}
Native format (List[dict]):
{
"image": List[PIL.Image], # multi-view images per sample
"video": np.ndarray [V, T, H, W, 3],
"lang": str, # task instruction
"action": np.ndarray [T, action_dim], # optional
"state": np.ndarray [1, state_dim], # optional
}
"""
# Determine batch size from the first image feature
image_keys = list(self.config.image_features.keys())
if not image_keys:
raise ValueError("VLAJEPA requires at least one image feature.")
first_key = image_keys[0]
first_tensor = batch[first_key]
batch_size = first_tensor.shape[0]
# ---- Collect images per sample ----
# images_per_sample[b][v] = PIL.Image for view v
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
for key in image_keys:
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
if tensor.ndim == 5:
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
# index 0 is the current observation (delta=0)
tensor = tensor[:, 0]
for b in range(batch_size):
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
# ---- Collect videos per sample ----
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
# Check whether any image feature has a time dimension
video_source = None
for k in image_keys:
if k in batch:
video_source = batch[k] # Use first available for shape inspection
break
if video_source is None:
raise ValueError("No image data found in batch for video construction.")
videos_per_sample = []
for b in range(batch_size):
sample_views = []
for k in image_keys:
t = batch[k][b] # [C, H, W] or [T, C, H, W]
if t.ndim == 3:
t = t.unsqueeze(0) # [1, C, H, W]
# Convert to [T, H, W, 3] numpy
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
# Clamp to [0, 255]
if t_np.max() <= 1.0:
t_np = t_np * 255.0
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
sample_views.append(t_np)
# Stack views: [V, T, H, W, 3]
videos_per_sample.append(np.stack(sample_views, axis=0))
# ---- Collect instructions ----
tasks = batch.get("task")
if tasks is None:
instructions = ["Execute the robot action."] * batch_size
elif isinstance(tasks, str):
instructions = [tasks] * batch_size
else:
instructions = list(tasks)
# ---- Collect actions (training only) ----
actions_list = None
action_is_pad_list = None
actions_tensor = batch.get(ACTION)
if actions_tensor is not None:
if actions_tensor.ndim == 2:
actions_tensor = actions_tensor.unsqueeze(1)
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
action_is_pad_tensor = batch.get("action_is_pad")
if action_is_pad_tensor is not None:
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
# ---- Collect state ----
state_list = None
state_tensor = batch.get(OBS_STATE)
if state_tensor is not None:
if state_tensor.ndim > 2:
state_tensor = state_tensor[:, -1, :]
if state_tensor.ndim == 2:
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
# ---- Assemble native examples ----
examples = []
for b in range(batch_size):
example = {
"image": images_per_sample[b],
"video": videos_per_sample[b],
"lang": instructions[b],
}
if actions_list is not None:
example["action"] = actions_list[b]
if action_is_pad_list is not None:
example["action_is_pad"] = action_is_pad_list[b]
if state_list is not None:
example["state"] = state_list[b]
examples.append(example)
return examples
# ---- LeRobot Policy Interface ----
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""LeRobot train forward: convert → native forward → aggregate losses."""
examples = self._prepare_model_inputs(batch)
native_output = self.model.forward(examples)
ref = next(iter(native_output.values()))
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
total_loss = native_output.get("action_loss", zero) + native_output.get("wm_loss", zero)
logs = {k: v.detach().item() for k, v in native_output.items()}
logs["loss"] = total_loss.detach().item()
return total_loss, logs
def get_optim_params(self) -> dict:
return self.model.parameters()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot inference: convert → native predict → return as Tensor."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
examples = self._prepare_model_inputs(batch)
batch_images = [ex["image"] for ex in examples]
instructions = [ex["lang"] for ex in examples]
state_np = None
if "state" in examples[0] and examples[0]["state"] is not None:
state_np = np.stack([ex["state"] for ex in examples])
actions_np = self.model.predict_action(batch_images, instructions, state_np)
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot select_action with action queue caching."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()
@classmethod
def from_pretrained(
cls: type[T],
pretrained_name_or_path: str | Path,
**kwargs,
):
return super().from_pretrained(pretrained_name_or_path, **kwargs)
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
reinit_prefixes = model.config.reinit_modules
if not reinit_prefixes:
return super()._load_as_safetensor(model, model_file, map_location, strict)
from safetensors.torch import load_file
state_dict = load_file(model_file, device=map_location)
current = model.state_dict()
reinitialized: list[str] = []
filtered: dict = {}
for key, value in state_dict.items():
if key in current and value.shape != current[key].shape:
if not any(key.startswith(p) for p in reinit_prefixes):
raise ValueError(
f"Shape mismatch for '{key}' (checkpoint {tuple(value.shape)} vs model "
f"{tuple(current[key].shape)}) and its prefix is not in `reinit_modules`."
)
reinitialized.append(
f"{key}: checkpoint {tuple(value.shape)} → model {tuple(current[key].shape)}"
)
else:
filtered[key] = value
if reinitialized:
logging.warning(
f"reinit_modules: skipping {len(reinitialized)} tensor(s) with mismatched shapes "
f"(randomly re-initialised):\n " + "\n ".join(reinitialized)
)
from lerobot.policies.utils import log_model_loading_keys
missing_keys, unexpected_keys = model.load_state_dict(filtered, strict=False)
log_model_loading_keys(missing_keys, unexpected_keys)
return model

View File

@@ -0,0 +1,155 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
import torch
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
EnvTransition,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
@ProcessorStepRegistry.register(name="vla_jepa_clip_actions")
class ClipActionsProcessorStep(ProcessorStep):
"""Clips action tensor to [-1, 1] before unnormalization."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None:
transition = dict(transition)
transition[TransitionKey.ACTION] = action.clamp(-1.0, 1.0)
return transition
def transform_features(self, features):
return features
@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper")
class PreSnapGripperProcessorStep(ProcessorStep):
"""Snaps a gripper dimension to {0, 1} BEFORE unnormalization.
Mirrors the original starVLA LIBERO eval:
normalized[:, gripper_dim] = np.where(normalized[:, gripper_dim] < threshold, 0, 1)
This ensures the unnormalizer receives an exact binary value, which is
required when the model was trained with gripper in identity (mask=False)
space where 0=open and 1=close.
"""
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
self.gripper_dim = gripper_dim
self.threshold = threshold
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] > self.gripper_dim:
transition = dict(transition)
a = action.clone()
a[..., self.gripper_dim] = (a[..., self.gripper_dim] >= self.threshold).float()
transition[TransitionKey.ACTION] = a
return transition
def transform_features(self, features):
return features
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
class BinarizeGripperProcessorStep(ProcessorStep):
"""Binarizes a gripper dimension after unnormalization.
Maps continuous value to {-1, 1}: > threshold → -1, <= threshold → 1 (matches starVLA convention).
Only applied when action has more dimensions than gripper_dim.
"""
def __init__(self, gripper_dim: int = 6, threshold: float = 0.5):
self.gripper_dim = gripper_dim
self.threshold = threshold
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] > self.gripper_dim:
transition = dict(transition)
a = action.clone()
a[..., self.gripper_dim] = 1.0 - 2.0 * (a[..., self.gripper_dim] > self.threshold).float()
transition[TransitionKey.ACTION] = a
return transition
def transform_features(self, features):
return features
def make_vla_jepa_pre_post_processors(
config: VLAJEPAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
features = {**config.input_features, **config.output_features}
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features=features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps: list[ProcessorStep] = []
if config.clip_normalized_actions:
output_steps.append(ClipActionsProcessorStep())
if config.pre_snap_gripper_action:
output_steps.append(
PreSnapGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
)
output_steps.append(
UnnormalizerProcessorStep(
features=features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
)
)
if config.binarize_gripper_action:
output_steps.append(
BinarizeGripperProcessorStep(gripper_dim=config.gripper_dim, threshold=config.gripper_threshold)
)
output_steps.append(DeviceProcessorStep(device="cpu"))
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)

View File

@@ -0,0 +1,117 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
import numpy as np
import torch
from PIL import Image
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
else:
AutoProcessor = None
Qwen3VLForConditionalGeneration = None
from .configuration_vla_jepa import VLAJEPAConfig
class Qwen3VLInterface(torch.nn.Module):
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
self.config = config
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
config.qwen_model_name,
torch_dtype=self._get_torch_dtype(config.torch_dtype),
)
self.processor = AutoProcessor.from_pretrained(config.qwen_model_name)
self.processor.tokenizer.padding_side = config.tokenizer_padding_side
self.model.config.hidden_size = self.model.config.text_config.hidden_size
@staticmethod
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
if dtype_name == "float32":
return torch.float32
if dtype_name == "float16":
return torch.float16
return torch.bfloat16
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
# starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4,
# independent of vj2 num_action_tokens_per_timestep. Keeping this count
# is required for Qwen embedding/lm_head checkpoint shapes to match.
max_action_tokens = self.config.chunk_size * 4
tokenizer = self.processor.tokenizer
action_tokens = []
action_token_ids = []
for idx in range(max_action_tokens):
token = self.config.special_action_token.format(idx)
action_tokens.append(token)
if token not in tokenizer.get_vocab():
tokenizer.add_tokens([token], special_tokens=True)
action_token_ids.append(tokenizer.convert_tokens_to_ids(token))
embodied_action_token = self.config.embodied_action_token
if embodied_action_token not in tokenizer.get_vocab():
tokenizer.add_tokens([embodied_action_token], special_tokens=True)
embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token)
if self.model.get_input_embeddings().weight.size(0) < len(tokenizer):
self.model.resize_token_embeddings(len(tokenizer))
return action_tokens, action_token_ids, embodied_action_token_id
def build_inputs(
self,
images: Sequence[Sequence[Image.Image]],
instructions: Sequence[str],
action_prompt: str,
embodied_prompt: str,
) -> dict[str, torch.Tensor]:
messages = []
for sample_images, instruction in zip(images, instructions, strict=True):
prompt = self.config.prompt_template.format(
instruction=instruction,
actions=action_prompt,
e_actions=embodied_prompt,
)
content = [{"type": "image", "image": img} for img in sample_images]
content.append({"type": "text", "text": prompt})
messages.append([{"role": "user", "content": content}])
batch_inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
processor_kwargs={"padding": True, "return_tensors": "pt"},
)
return batch_inputs.to(self.model.device)
@staticmethod
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
image = image_tensor.detach().cpu()
if image.ndim == 3 and image.shape[0] in (1, 3):
image = image.permute(1, 2, 0)
image = image.float()
if image.max() <= 1.0:
image = image * 255.0
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
if image.shape[-1] == 1:
image = np.repeat(image, 3, axis=-1)
return Image.fromarray(image)

View File

@@ -0,0 +1,418 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
def build_action_block_causal_attention_mask(
num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1
) -> torch.Tensor:
tokens_per_frame = add_tokens + grid_height * grid_width
num_tokens = num_frames * tokens_per_frame
mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool)
mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool)
local_window_time = num_frames
for current_frame in range(num_frames):
first_context_frame = max(0, current_frame - local_window_time + 1)
for context_frame in range(first_context_frame, current_frame + 1):
row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame)
col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame)
mask[row, col] = mask_block
return mask
def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
_, _, _, dim = x.size()
if dim % 2 != 0:
raise ValueError("Embedding dimension must be even for rotary position encoding.")
omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device)
omega /= dim / 2.0
omega = 1.0 / 10000**omega
freqs = torch.einsum("..., f -> ... f", pos, omega)
emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2)
emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2)
y = x.unflatten(-1, (-1, 2))
y1, y2 = y.unbind(dim=-1)
y = torch.stack((-y2, y1), dim=-1).flatten(-2)
return x * emb_cos + y * emb_sin
class DropPath(nn.Module):
def __init__(self, drop_prob: float = 0.0) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
return x.div(keep_prob) * random_tensor
class MLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int | None = None,
out_features: int | None = None,
act_layer: type[nn.Module] = nn.GELU,
drop: float = 0.0,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ACRoPEAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_scale: float | None = None,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
use_sdpa: bool = True,
is_causal: bool = False,
grid_size: int = 16,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = qk_scale or self.head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop_prob = proj_drop
self.proj_drop = nn.Dropout(proj_drop)
self.use_sdpa = use_sdpa
self.d_dim = int(2 * ((self.head_dim // 3) // 2))
self.h_dim = int(2 * ((self.head_dim // 3) // 2))
self.w_dim = int(2 * ((self.head_dim // 3) // 2))
self.grid_size = grid_size
self.is_causal = is_causal
@staticmethod
def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
return ids // int(height * width)
def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
frame_ids = self._get_frame_pos(ids, height, width)
ids = ids - int(height * width) * frame_ids
return ids // width
def separate_positions(
self, ids: torch.Tensor, height: int, width: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
frame_ids = self._get_frame_pos(ids, height, width)
height_ids = self._get_height_pos(ids, height, width)
width_ids = ids - int(height * width) * frame_ids - width * height_ids
return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor | None = None,
attn_mask: torch.Tensor | None = None,
num_frames: int | None = None,
grid_height: int | None = None,
grid_width: int | None = None,
action_tokens: int = 0,
) -> torch.Tensor:
batch_size, num_tokens, channels = x.size()
if num_frames is None or grid_height is None or grid_width is None:
raise ValueError("num_frames, grid_height and grid_width are required.")
if mask is not None:
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
else:
mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device)
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
h_mask *= self.grid_size / grid_height
w_mask *= self.grid_size / grid_width
if action_tokens > 0:
x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels)
action_q, action_k, action_v = [], [], []
for idx in range(action_tokens):
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
qd = rotate_queries_or_keys(
q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
)
kd = rotate_queries_or_keys(
k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
)
qr = q[..., self.d_dim :]
kr = k[..., self.d_dim :]
action_q.append(
torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
)
action_k.append(
torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
)
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
action_k = torch.cat(action_k, dim=3).flatten(2, 3)
action_v = torch.cat(action_v, dim=3).flatten(2, 3)
x = x[:, :, action_tokens:, :].flatten(1, 2)
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
offset = 0
qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask)
kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask)
offset += self.d_dim
qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask)
kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask)
offset += self.h_dim
qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask)
kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask)
offset += self.w_dim
if offset < self.head_dim:
q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1)
k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1)
else:
q = torch.cat([qd, qh, qw], dim=-1)
k = torch.cat([kd, kh, kw], dim=-1)
if action_tokens > 0:
def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor:
frame_tokens = frame_tokens.view(
batch_size, self.num_heads, num_frames, grid_height * grid_width, -1
)
action_token_values = action_token_values.view(
batch_size, self.num_heads, num_frames, action_tokens, -1
)
return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3)
q = merge(q, action_q)
k = merge(k, action_k)
v = merge(v, action_v)
if attn_mask is not None or self.use_sdpa:
x = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
)
else:
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
x = self.proj(x)
return self.proj_drop(x)
class ACBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_scale: float | None = None,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
norm_layer: type[nn.Module] = nn.LayerNorm,
use_sdpa: bool = True,
is_causal: bool = False,
grid_size: int = 16,
use_rope: bool = True,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
if not use_rope:
raise ValueError("JEVLA1 world predictor uses AC RoPE attention.")
self.attn = ACRoPEAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
use_sdpa=use_sdpa,
is_causal=is_causal,
grid_size=grid_size,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = MLP(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=nn.GELU,
drop=drop,
)
def forward(
self,
x: torch.Tensor,
attn_mask: torch.Tensor | None = None,
num_frames: int | None = None,
grid_height: int | None = None,
grid_width: int | None = None,
action_tokens: int = 0,
) -> torch.Tensor:
y = self.norm1(x)
y = self.attn(
y,
mask=None,
attn_mask=attn_mask,
num_frames=num_frames,
grid_height=grid_height,
grid_width=grid_width,
action_tokens=action_tokens,
)
x = x + self.drop_path(y)
y = self.norm2(x)
return x + self.drop_path(self.mlp(y))
class ActionConditionedVideoPredictor(nn.Module):
"""JEVLA1-compatible action-conditioned V-JEPA predictor."""
def __init__(
self,
num_frames: int,
img_size: tuple[int, int],
patch_size: int,
tubelet_size: int,
embed_dim: int,
action_embed_dim: int,
predictor_embed_dim: int,
depth: int,
num_heads: int,
mlp_ratio: float,
num_action_tokens_per_step: int,
use_extrinsics: bool = False,
) -> None:
super().__init__()
self.is_frame_causal = True
self.use_extrinsics = use_extrinsics
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
self.img_height, self.img_width = img_size
self.patch_size = patch_size
self.num_frames = num_frames
self.tubelet_size = tubelet_size
self.grid_height = self.img_height // self.patch_size
self.grid_width = self.img_width // self.patch_size
self.predictor_blocks = nn.ModuleList(
[
ACBlock(
dim=predictor_embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6),
grid_size=self.grid_height,
use_rope=True,
)
for _ in range(depth)
]
)
self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6)
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
self.num_action_tokens_per_step = num_action_tokens_per_step
@property
def norm(self) -> nn.LayerNorm:
return self.predictor_norm
@property
def proj(self) -> nn.Linear:
return self.predictor_proj
def forward(
self,
frame_tokens: torch.Tensor,
action_tokens: torch.Tensor,
extrinsics: torch.Tensor | None = None,
) -> torch.Tensor:
# starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D].
x = self.predictor_embed(frame_tokens)
batch_size, num_context_tokens, hidden_dim = x.size()
num_frames = num_context_tokens // (self.grid_height * self.grid_width)
actions = self.action_encoder(action_tokens)
actions = actions.view(batch_size, num_frames, -1, hidden_dim)
cond_tokens = actions.shape[2]
x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim)
if self.use_extrinsics:
if extrinsics is None:
raise ValueError("extrinsics are required when use_extrinsics=True.")
cond_tokens += 1
extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2)
x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2)
else:
x = torch.cat([actions, x], dim=2).flatten(1, 2)
attn_mask = build_action_block_causal_attention_mask(
num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens
)
attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
for block in self.predictor_blocks:
x = block(
x,
attn_mask=attn_mask,
num_frames=num_frames,
grid_height=self.grid_height,
grid_width=self.grid_width,
action_tokens=cond_tokens,
)
x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim)
x = x[:, :, cond_tokens:, :].flatten(1, 2)
x = self.predictor_norm(x)
return self.predictor_proj(x)

View File

@@ -0,0 +1,273 @@
#!/usr/bin/env python
"""Shared fixtures and helpers for VLA-JEPA tests."""
from __future__ import annotations
from types import SimpleNamespace
import numpy as np
import pytest
import torch
from PIL import Image
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
# ---------------------------------------------------------------------------
# Shared constants
# ---------------------------------------------------------------------------
BATCH_SIZE = 2
ACTION_DIM = 3
STATE_DIM = 4
IMAGE_SIZE = 8
ACTION_HORIZON = 4
N_ACTION_STEPS = 2
NUM_VIDEO_FRAMES = 3
QWEN_HIDDEN_SIZE = 16 # hidden size produced by _FakeQwenBackbone
EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def set_seed_all(seed: int) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def make_config(
action_dim: int = ACTION_DIM,
state_dim: int = STATE_DIM,
action_horizon: int = ACTION_HORIZON,
num_video_frames: int = NUM_VIDEO_FRAMES,
) -> VLAJEPAConfig:
config = VLAJEPAConfig(
input_features={
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
},
output_features={
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
},
device="cpu",
chunk_size=action_horizon,
n_action_steps=min(N_ACTION_STEPS, action_horizon),
action_dim=action_dim,
state_dim=state_dim,
num_video_frames=num_video_frames,
num_action_tokens_per_timestep=2,
num_embodied_action_tokens_per_instruction=3,
num_inference_timesteps=2,
action_hidden_size=QWEN_HIDDEN_SIZE,
action_model_type="DiT-test",
action_num_layers=1,
predictor_depth=1,
predictor_num_heads=2,
predictor_mlp_ratio=2.0,
jepa_tubelet_size=1,
)
config.validate_features()
return config
def make_train_batch(
batch_size: int = BATCH_SIZE,
action_dim: int = ACTION_DIM,
state_dim: int = STATE_DIM,
action_horizon: int = ACTION_HORIZON,
num_video_frames: int = NUM_VIDEO_FRAMES,
) -> dict[str, Tensor | list[str]]:
return {
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, num_video_frames, 3, IMAGE_SIZE, IMAGE_SIZE),
OBS_STATE: torch.randn(batch_size, 1, state_dim),
ACTION: torch.randn(batch_size, action_horizon, action_dim),
"task": ["pick up the cube"] * batch_size,
}
def make_inference_batch(
batch_size: int = BATCH_SIZE,
state_dim: int = STATE_DIM,
) -> dict[str, Tensor | list[str]]:
return {
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE),
OBS_STATE: torch.randn(batch_size, state_dim),
"task": ["pick up the cube"] * batch_size,
}
# ---------------------------------------------------------------------------
# Fake external models (replace Qwen3-VL and V-JEPA at test time)
# ---------------------------------------------------------------------------
class _FakeLanguageLayer(nn.Module):
"""Leaf module whose forward hook is captured by _qwen_last_decoder_hidden."""
def __init__(self, hidden_size: int) -> None:
super().__init__()
self._hidden_size = hidden_size
def forward(self, hidden: Tensor, **_: object) -> tuple[Tensor, ...]:
return (hidden,)
class _FakeLanguageModel(nn.Module):
def __init__(self, hidden_size: int) -> None:
super().__init__()
self._hidden_size = hidden_size
self.layers = nn.ModuleList([_FakeLanguageLayer(hidden_size)])
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
batch_size, seq_len = input_ids.shape
hidden = torch.zeros(batch_size, seq_len, self._hidden_size, device=input_ids.device)
self.layers[-1](hidden)
return SimpleNamespace()
class _FakeQwenInnerModel(nn.Module):
"""Mimics the `.model.model` level that _qwen_last_decoder_hidden walks into."""
def __init__(self, hidden_size: int) -> None:
super().__init__()
self.language_model = _FakeLanguageModel(hidden_size)
def forward(self, input_ids: Tensor, **kwargs: object) -> SimpleNamespace:
return self.language_model(input_ids)
class _FakeQwenBackbone(nn.Module):
def __init__(self, hidden_size: int) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(1))
self.config = SimpleNamespace(
hidden_size=hidden_size,
text_config=SimpleNamespace(hidden_size=hidden_size),
)
self.model = _FakeQwenInnerModel(hidden_size)
@property
def device(self) -> torch.device:
return self.weight.device
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
batch_size, seq_len = input_ids.shape
hidden_size = self.config.hidden_size
values = torch.arange(
batch_size * seq_len * hidden_size,
device=input_ids.device,
dtype=torch.float32,
).view(batch_size, seq_len, hidden_size)
hidden = values / values.numel() + self.weight
self.model(input_ids) # call through so the forward hook on layers[-1] fires
return SimpleNamespace(hidden_states=[hidden])
class _FakeQwenInterface(nn.Module):
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
self.config = config
self.model = _FakeQwenBackbone(hidden_size=QWEN_HIDDEN_SIZE)
@staticmethod
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
return torch.float32 if dtype_name == "float32" else torch.bfloat16
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)]
action_token_ids = list(range(1000, 1000 + max_action_tokens))
return action_tokens, action_token_ids, 2000
def build_inputs(
self,
images: list[list[Image.Image]],
instructions: list[str],
action_prompt: str,
embodied_prompt: str,
) -> dict[str, Tensor]:
batch_size = len(images)
del images, instructions, action_prompt, embodied_prompt
action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep
token_ids = (
[10]
+ list(range(1000, 1000 + action_count))
+ [2000] * self.config.num_embodied_action_tokens_per_instruction
+ [11]
)
return {
"input_ids": torch.tensor(
[token_ids] * batch_size,
device=self.model.device,
dtype=torch.long,
)
}
@staticmethod
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
image = image_tensor.detach().cpu()
if image.ndim == 3 and image.shape[0] in (1, 3):
image = image.permute(1, 2, 0)
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
return Image.fromarray(image)
class _FakeVideoEncoder(nn.Module):
def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(1))
# image_size must be >= patch_size (16) so the predictor grid is non-zero.
# Setting image_size=16 gives a 1x1 grid (1 patch per frame).
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size, image_size=16)
@property
def device(self) -> torch.device:
return self.weight.device
def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor:
batch_size, num_frames = pixel_values_videos.shape[:2]
hidden_size = self.config.hidden_size
frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False)
return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size)
class _FakeVideoProcessor:
def __call__(self, videos, return_tensors: str) -> dict[str, Tensor]:
assert return_tensors == "pt"
if isinstance(videos, list):
pixel_values = torch.stack([torch.as_tensor(v) for v in videos])
else:
pixel_values = torch.as_tensor(videos).unsqueeze(0)
return {"pixel_values_videos": pixel_values}
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None:
from lerobot.policies.vla_jepa import modeling_vla_jepa
monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface)
monkeypatch.setattr(
modeling_vla_jepa.AutoModel,
"from_pretrained",
lambda *args, **kwargs: _FakeVideoEncoder(),
)
monkeypatch.setattr(
modeling_vla_jepa.AutoVideoProcessor,
"from_pretrained",
lambda *args, **kwargs: _FakeVideoProcessor(),
)

View File

@@ -0,0 +1,157 @@
#!/usr/bin/env python
from __future__ import annotations
import pytest
import torch
pytest.importorskip("diffusers")
from conftest import (
ACTION_DIM,
ACTION_HORIZON,
BATCH_SIZE,
QWEN_HIDDEN_SIZE,
STATE_DIM,
make_config,
set_seed_all,
) # noqa: E402
from lerobot.policies.vla_jepa.action_head import ( # noqa: E402
VLAJEPAActionHead,
)
# ---------------------------------------------------------------------------
# VLAJEPAActionHead
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4), # default test dims
(7, 0, 16), # no proprioceptive state, production-like action space
(6, 8, 8), # medium dims
],
)
def test_action_head_sample_time_range(action_dim: int, state_dim: int, action_horizon: int) -> None:
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
t = head.sample_time(batch_size=200, device=torch.device("cpu"), dtype=torch.float32)
assert t.shape == (200,)
assert torch.isfinite(t).all()
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_action_head_build_inputs_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(2, action_horizon, action_dim)
timesteps = torch.randint(0, 100, (2,))
state = torch.randn(2, state_dim) if state_dim > 0 else None
out_with = head._build_inputs(conditioning, actions, state, timesteps)
out_none = head._build_inputs(conditioning, actions, None, timesteps)
assert out_with.ndim == 3 and out_none.ndim == 3
if state_dim > 0:
assert out_with.shape[1] > out_none.shape[1]
assert torch.isfinite(out_with).all() and torch.isfinite(out_none).all()
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_action_head_forward_loss_valid(action_dim: int, state_dim: int, action_horizon: int) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(2, action_horizon, action_dim)
state = torch.randn(2, state_dim) if state_dim > 0 else None
loss = head.forward(conditioning, actions, state)
assert loss.shape == ()
assert torch.isfinite(loss) and loss > 0
def test_action_head_forward_gradient_flows() -> None:
set_seed_all(42)
config = make_config()
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
state = torch.randn(BATCH_SIZE, STATE_DIM)
loss = head.forward(conditioning, actions, state)
loss.backward()
assert any(p.grad is not None for p in head.parameters() if p.requires_grad)
@torch.no_grad()
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_action_head_predict_action_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
state = torch.randn(2, state_dim) if state_dim > 0 else None
pred = head.predict_action(conditioning, state)
assert tuple(pred.shape) == (2, action_horizon, action_dim)
assert torch.isfinite(pred).all()
# ---------------------------------------------------------------------------
# action_is_pad masking
# ---------------------------------------------------------------------------
def test_action_head_loss_fully_padded_is_zero() -> None:
"""Loss is 0 when every timestep is padded (exercises the clamp_min guard)."""
set_seed_all(42)
config = make_config()
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
state = torch.randn(BATCH_SIZE, STATE_DIM)
action_is_pad = torch.ones(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
loss = head.forward(conditioning, actions, state, action_is_pad)
assert loss.item() == 0.0
def test_action_head_loss_none_matches_no_padding() -> None:
"""action_is_pad=None is equivalent to an all-False (no padding) mask."""
set_seed_all(42)
config = make_config()
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
state = torch.randn(BATCH_SIZE, STATE_DIM)
set_seed_all(0)
loss_none = head.forward(conditioning, actions, state, action_is_pad=None)
set_seed_all(0)
no_pad = torch.zeros(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
loss_zeros = head.forward(conditioning, actions, state, action_is_pad=no_pad)
assert torch.isclose(loss_none, loss_zeros)

View File

@@ -0,0 +1,57 @@
#!/usr/bin/env python
from __future__ import annotations
import pytest
from conftest import ACTION_DIM, ACTION_HORIZON, IMAGE_SIZE, NUM_VIDEO_FRAMES, STATE_DIM, make_config
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
def test_delta_indices() -> None:
config = make_config()
assert config.observation_delta_indices == list(range(NUM_VIDEO_FRAMES))
assert config.action_delta_indices == list(range(ACTION_HORIZON))
def test_n_action_steps_exceeds_chunk_size_raises() -> None:
with pytest.raises(ValueError, match="n_action_steps"):
VLAJEPAConfig(chunk_size=4, n_action_steps=8)
def test_too_few_video_frames_raises() -> None:
with pytest.raises(ValueError, match="video_horizon"):
VLAJEPAConfig(
chunk_size=16,
n_action_steps=16,
num_video_frames=2,
jepa_tubelet_size=2, # needs >= 4 frames (2 for current, 2 for future) to have a window of size > 0
)
def test_validate_features_no_image_raises() -> None:
config = VLAJEPAConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
)
with pytest.raises(ValueError, match="at least one visual input feature"):
config.validate_features()
def test_validate_features_no_action_raises() -> None:
config = VLAJEPAConfig(
input_features={
f"{OBS_IMAGES}.cam": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
},
output_features={},
)
with pytest.raises(ValueError, match="action output feature"):
config.validate_features()
def test_validate_features_sets_action_dim_from_feature() -> None:
config = make_config(action_dim=6, state_dim=10)
assert config.action_dim == 6
assert config.state_dim == 10

View File

@@ -0,0 +1,598 @@
#!/usr/bin/env python
from __future__ import annotations
import os
from copy import deepcopy
import numpy as np
import pytest
import torch
from torch import Tensor
pytest.importorskip("transformers")
pytest.importorskip("diffusers")
pytestmark = pytest.mark.filterwarnings(
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
)
from conftest import ( # noqa: E402
ACTION_DIM,
ACTION_HORIZON,
BATCH_SIZE,
EXPECTED_ACTION_CHUNK_SHAPE,
EXPECTED_SELECT_ACTION_SHAPE,
IMAGE_SIZE,
N_ACTION_STEPS,
QWEN_HIDDEN_SIZE,
STATE_DIM,
make_config,
make_inference_batch,
make_train_batch,
set_seed_all,
)
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig # noqa: E402
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy # noqa: E402
from lerobot.utils.constants import ACTION # noqa: E402
PRETRAINED_REPO_ID = "ginwind/VLA-JEPA"
PRETRAINED_SUBFOLDER = "LIBERO"
# extended hub tests load the full converted safetensors checkpoints (~5 GB) and are
# skipped by default. Set VLA_JEPA_EXTENDED=1 to opt in.
_VLA_JEPA_EXTENDED = os.environ.get("VLA_JEPA_EXTENDED", "0") != "0"
extended_test = pytest.mark.skipif(not _VLA_JEPA_EXTENDED, reason="Set VLA_JEPA_EXTENDED=1 to run hub tests")
# ---------------------------------------------------------------------------
# Core training / inference tests
# ---------------------------------------------------------------------------
def test_training_forward_pass(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.train()
batch = make_train_batch()
batch_before = deepcopy(batch)
loss, logs = policy.forward(batch)
assert loss.shape == ()
assert torch.isfinite(loss)
assert set(logs) == {"action_loss", "wm_loss", "loss"}
assert logs["action_loss"] > 0
assert logs["wm_loss"] >= 0
loss.backward()
assert any(p.grad is not None for p in policy.model.action_model.parameters() if p.requires_grad)
# Batch must not be mutated.
assert set(batch) == set(batch_before)
for key, value in batch.items():
if isinstance(value, Tensor):
assert torch.equal(value, batch_before[key])
else:
assert value == batch_before[key]
@pytest.mark.parametrize("batch_size", [1, 2, 4])
def test_training_forward_various_batch_sizes(patch_vla_jepa_external_models: None, batch_size: int) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.train()
loss, logs = policy.forward(make_train_batch(batch_size=batch_size))
assert torch.isfinite(loss) and loss > 0
assert set(logs) == {"action_loss", "wm_loss", "loss"}
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_training_forward_various_dims(
patch_vla_jepa_external_models: None,
action_dim: int,
state_dim: int,
action_horizon: int,
) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
policy = VLAJEPAPolicy(config)
policy.train()
batch = make_train_batch(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
loss, _ = policy.forward(batch)
assert torch.isfinite(loss) and loss > 0
@torch.no_grad()
def test_action_generation_shape(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.eval()
batch = make_inference_batch()
chunk = policy.predict_action_chunk(batch)
assert tuple(chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE
assert chunk.device.type == "cpu"
assert torch.isfinite(chunk).all()
a1 = policy.select_action(batch)
a2 = policy.select_action(batch)
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
assert torch.isfinite(a1).all() and torch.isfinite(a2).all()
@torch.no_grad()
@pytest.mark.parametrize("action_dim,state_dim", [(3, 4), (7, 0), (6, 8)])
def test_action_generation_various_dims(
patch_vla_jepa_external_models: None, action_dim: int, state_dim: int
) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim)
policy = VLAJEPAPolicy(config)
policy.eval()
batch = make_inference_batch(state_dim=state_dim)
chunk = policy.predict_action_chunk(batch)
assert chunk.shape[-1] == action_dim
assert torch.isfinite(chunk).all()
@torch.no_grad()
def test_inference_reproducibility(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.eval()
batch = make_inference_batch()
set_seed_all(123)
actions_1 = policy.predict_action_chunk(batch)
set_seed_all(123)
actions_2 = policy.predict_action_chunk(batch)
assert tuple(actions_1.shape) == EXPECTED_ACTION_CHUNK_SHAPE
assert torch.allclose(actions_1, actions_2, atol=1e-6)
@torch.no_grad()
def test_predict_action_chunk_always_finite(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
policy.eval()
for seed in [0, 42, 123]:
set_seed_all(seed)
chunk = policy.predict_action_chunk(make_inference_batch())
assert torch.isfinite(chunk).all(), f"non-finite actions with seed={seed}"
# ---------------------------------------------------------------------------
# Action queue behaviour
# ---------------------------------------------------------------------------
@torch.no_grad()
def test_select_action_queue_drains_before_refill(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.eval()
batch = make_inference_batch()
# First call fills the queue (n_action_steps items) and pops one.
a1 = policy.select_action(batch)
assert len(policy._queues[ACTION]) == N_ACTION_STEPS - 1
# Second call pops from the existing queue without calling predict_action_chunk.
a2 = policy.select_action(batch)
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
@torch.no_grad()
def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.eval()
policy.select_action(make_inference_batch())
assert len(policy._queues[ACTION]) > 0
policy.reset()
assert len(policy._queues[ACTION]) == 0
# ---------------------------------------------------------------------------
# Format conversion
# ---------------------------------------------------------------------------
def test_prepare_model_inputs_training_format(patch_vla_jepa_external_models: None) -> None:
from PIL import Image
policy = VLAJEPAPolicy(make_config())
examples = policy._prepare_model_inputs(make_train_batch())
assert len(examples) == BATCH_SIZE
for ex in examples:
assert set(ex) >= {"image", "video", "lang", "action", "state"}
assert len(ex["image"]) == 1 and isinstance(ex["image"][0], Image.Image)
assert ex["video"].ndim == 5 and ex["video"].dtype == np.uint8 # [V,T,H,W,C]
assert ex["action"].shape == (ACTION_HORIZON, ACTION_DIM)
assert ex["state"].shape == (1, STATE_DIM)
def test_prepare_model_inputs_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
for ex in policy._prepare_model_inputs(make_inference_batch()):
assert "action" not in ex
assert "image" in ex and "video" in ex and "lang" in ex
def test_prepare_model_inputs_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
del batch["task"]
examples = policy._prepare_model_inputs(batch)
assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples)
def test_prepare_model_inputs_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
batch["task"] = "open the drawer"
assert all(ex["lang"] == "open the drawer" for ex in policy._prepare_model_inputs(batch))
def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
from lerobot.utils.constants import OBS_STATE
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
del batch[OBS_STATE]
assert all("state" not in ex for ex in policy._prepare_model_inputs(batch))
# ---------------------------------------------------------------------------
# Pretrained checkpoint
# Hub tests (opt-in: VLA_JEPA_EXTENDED=1)
# ---------------------------------------------------------------------------
def _make_hub_train_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
"""Build a training batch whose keys/shapes match a hub-loaded policy config."""
cfg = policy.config
batch: dict = {"task": ["pick up the cube"] * batch_size}
for key, feat in cfg.image_features.items():
h, w = feat.shape[-2], feat.shape[-1]
batch[key] = torch.rand(batch_size, cfg.num_video_frames, 3, h, w)
if cfg.robot_state_feature is not None:
batch["observation.state"] = torch.randn(batch_size, 1, cfg.robot_state_feature.shape[0])
batch[ACTION] = torch.randn(batch_size, cfg.chunk_size, cfg.action_dim)
return batch
def _make_hub_inference_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
"""Build an inference batch whose keys/shapes match a hub-loaded policy config."""
cfg = policy.config
batch: dict = {"task": ["pick up the cube"] * batch_size}
for key, feat in cfg.image_features.items():
h, w = feat.shape[-2], feat.shape[-1]
batch[key] = torch.rand(batch_size, 3, h, w)
if cfg.robot_state_feature is not None:
batch["observation.state"] = torch.randn(batch_size, cfg.robot_state_feature.shape[0])
return batch
_CP_ROOT = "lerobot"
# Each tuple: (repo_id, enable_world_model)
_HUB_VARIANTS = [
(f"{_CP_ROOT}/VLA-JEPA-LIBERO", True),
(f"{_CP_ROOT}/VLA-JEPA-Pretrain", True),
(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv", False),
]
@extended_test
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
def test_hub_checkpoint_loads(repo_id: str, enable_world_model: bool) -> None:
"""Policy loads from the converted safetensors checkpoint on the Hub."""
policy = VLAJEPAPolicy.from_pretrained(repo_id)
assert policy.config.enable_world_model == enable_world_model
assert sum(p.numel() for p in policy.parameters()) > 0
@extended_test
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
def test_hub_checkpoint_forward_pass(repo_id: str, enable_world_model: bool) -> None:
"""Policy loaded from hub produces finite losses with a correctly-shaped batch."""
policy = VLAJEPAPolicy.from_pretrained(repo_id)
policy.train()
batch = _make_hub_train_batch(policy)
loss, logs = policy.forward(batch)
assert torch.isfinite(loss)
assert "action_loss" in logs
if enable_world_model:
assert "wm_loss" in logs
@extended_test
def test_hub_freeze_qwen_disables_world_model() -> None:
"""freeze_qwen=True (via cli_overrides) freezes qwen and disables the world model."""
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO", cli_overrides=["freeze_qwen=true"])
assert not policy.config.enable_world_model
assert policy.model.video_predictor is None
qwen_params = list(policy.model.qwen.parameters())
assert all(not p.requires_grad for p in qwen_params)
assert any(p.requires_grad for p in policy.model.action_model.parameters())
@extended_test
def test_hub_disable_world_model_loads_simpler_env() -> None:
"""SimplerEnv checkpoint (world model disabled) loads cleanly and runs inference."""
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv")
assert not policy.config.enable_world_model
assert policy.model.video_predictor is None
assert policy.model.video_encoder is None
@extended_test
def test_hub_libero_inference_shape() -> None:
"""select_action returns the expected shape using the LIBERO hub checkpoint."""
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO")
policy.eval()
batch = _make_hub_inference_batch(policy)
action = policy.select_action(batch)
assert action.shape[-1] == policy.config.action_dim
# ---------------------------------------------------------------------------
# Postprocessor unnormalization tests
#
# These tests verify that the postprocessor pipeline (clip → unnorm → binarize)
# correctly applies MIN_MAX unnormalization after predict_action_chunk.
# ---------------------------------------------------------------------------
def _make_dataset_stats(action_dim: int = ACTION_DIM) -> dict:
"""Returns sample dataset_stats with a simple [i, i+10] range per action dim."""
from lerobot.utils.constants import ACTION
return {
ACTION: {
"min": torch.tensor([float(i) for i in range(action_dim)], dtype=torch.float32),
"max": torch.tensor([float(i) + 10.0 for i in range(action_dim)], dtype=torch.float32),
}
}
@torch.no_grad()
def test_postprocessor_unnormalizes_actions(patch_vla_jepa_external_models: None) -> None:
"""UnnormalizerProcessorStep with MIN_MAX produces the correct inverse of MIN_MAX normalization."""
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor import UnnormalizerProcessorStep
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import ACTION
dataset_stats = _make_dataset_stats()
rng = np.random.default_rng(7)
actions_np = rng.uniform(-1.0, 1.0, (2, ACTION_HORIZON, ACTION_DIM)).astype(np.float32)
a_min = dataset_stats[ACTION]["min"].numpy()
a_max = dataset_stats[ACTION]["max"].numpy()
expected = (actions_np + 1.0) / 2.0 * (a_max - a_min) + a_min
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
unnorm_step = UnnormalizerProcessorStep(
features=features,
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
stats=dataset_stats,
)
actions_tensor = torch.from_numpy(actions_np)
transition = policy_action_to_transition(actions_tensor)
result = transition_to_policy_action(unnorm_step(transition)).numpy()
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
@torch.no_grad()
def test_postprocessor_clip_clamps_before_unnorm(patch_vla_jepa_external_models: None) -> None:
"""ClipActionsProcessorStep clamps to [-1, 1] before unnormalization."""
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.vla_jepa.processor_vla_jepa import ClipActionsProcessorStep
from lerobot.processor import UnnormalizerProcessorStep
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import ACTION
dataset_stats = _make_dataset_stats()
a_min = dataset_stats[ACTION]["min"].numpy()
a_max = dataset_stats[ACTION]["max"].numpy()
# Deliberately out-of-range inputs
actions_np = np.array([[[2.0] * ACTION_DIM, [-3.0] * ACTION_DIM]], dtype=np.float32)
clipped = np.clip(actions_np, -1.0, 1.0)
expected = (clipped + 1.0) / 2.0 * (a_max - a_min) + a_min
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
clip_step = ClipActionsProcessorStep()
unnorm_step = UnnormalizerProcessorStep(
features=features,
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
stats=dataset_stats,
)
transition = policy_action_to_transition(torch.from_numpy(actions_np))
transition = clip_step(transition)
result = transition_to_policy_action(unnorm_step(transition)).numpy()
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
@torch.no_grad()
def test_postprocessor_applied_after_predict_action_chunk(
patch_vla_jepa_external_models: None, monkeypatch: pytest.MonkeyPatch
) -> None:
"""predict_action_chunk returns raw actions; the postprocessor applies unnormalization.
Verifies the split: predict_action_chunk returns normalized actions, and calling the
postprocessor on them produces the correctly unnormalized result.
"""
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
raw_actions = np.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=np.float32)
cfg = make_config()
cfg.clip_normalized_actions = False
cfg.binarize_gripper_action = False
policy = VLAJEPAPolicy(cfg)
policy.eval()
monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.copy())
dataset_stats = _make_dataset_stats()
_, postprocessor = make_vla_jepa_pre_post_processors(cfg, dataset_stats)
batch = make_inference_batch()
chunk = policy.predict_action_chunk(batch)
# predict_action_chunk returns raw (normalized) actions
assert torch.allclose(chunk, torch.zeros_like(chunk), atol=1e-6), (
"predict_action_chunk should return raw actions without unnormalization applied."
)
# Postprocessor applies unnormalization: 0 → (0+1)/2 * (max-min) + min = 5 + i
unnormed = postprocessor(chunk)
from lerobot.utils.constants import ACTION
a_min = dataset_stats[ACTION]["min"].numpy()
a_max = dataset_stats[ACTION]["max"].numpy()
expected_first = 0.5 * (0.0 + 1.0) * (a_max[0] - a_min[0]) + a_min[0]
assert unnormed[0, 0, 0].item() == pytest.approx(expected_first, abs=1e-5)
# ---------------------------------------------------------------------------
# World-model view adjustment (padding / trimming) tests
# ---------------------------------------------------------------------------
_MULTIVIEW_NUM_FRAMES = 4 # must be >= 2 * jepa_tubelet_size (=2) for world-model tests
def _make_multiview_config(num_views: int, jepa_tubelet_size: int = 2) -> VLAJEPAConfig:
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
config = VLAJEPAConfig(
input_features={
**{
f"{OBS_IMAGES}.cam{i}": PolicyFeature(
type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)
)
for i in range(num_views)
},
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
device="cpu",
chunk_size=ACTION_HORIZON,
n_action_steps=N_ACTION_STEPS,
action_dim=ACTION_DIM,
state_dim=STATE_DIM,
num_video_frames=_MULTIVIEW_NUM_FRAMES,
num_action_tokens_per_timestep=2,
num_embodied_action_tokens_per_instruction=3,
num_inference_timesteps=2,
action_hidden_size=QWEN_HIDDEN_SIZE,
action_model_type="DiT-test",
action_num_layers=1,
predictor_depth=1,
predictor_num_heads=2,
predictor_mlp_ratio=2.0,
jepa_tubelet_size=jepa_tubelet_size,
)
config.validate_features()
return config
def _make_multiview_train_batch(num_views: int, batch_size: int = BATCH_SIZE) -> dict:
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
batch = {
f"{OBS_IMAGES}.cam{i}": torch.rand(batch_size, _MULTIVIEW_NUM_FRAMES, 3, IMAGE_SIZE, IMAGE_SIZE)
for i in range(num_views)
}
batch[OBS_STATE] = torch.randn(batch_size, 1, STATE_DIM)
batch[ACTION] = torch.randn(batch_size, ACTION_HORIZON, ACTION_DIM)
batch["task"] = ["pick up the cube"] * batch_size
return batch
@pytest.mark.parametrize(
"num_views",
[
1, # fewer views than jepa_tubelet_size → first view duplicated
2, # exact match → unchanged
3, # more views than jepa_tubelet_size → trimmed to first two
],
)
def test_training_forward_world_model_view_adjustment(
patch_vla_jepa_external_models: None,
num_views: int,
) -> None:
"""World-model view padding/trimming must not break the training forward pass."""
set_seed_all(42)
policy = VLAJEPAPolicy(_make_multiview_config(num_views=num_views, jepa_tubelet_size=2))
policy.train()
loss, logs = policy.forward(_make_multiview_train_batch(num_views=num_views))
assert torch.isfinite(loss)
assert logs["wm_loss"] >= 0
def test_single_view_is_duplicated_for_world_model(patch_vla_jepa_external_models: None) -> None:
"""With one dataset view and jepa_tubelet_size=2, the view must be duplicated before encoding."""
set_seed_all(42)
policy = VLAJEPAPolicy(_make_multiview_config(num_views=1, jepa_tubelet_size=2))
policy.train()
captured_videos: list = []
original_processor = policy.model.video_processor
class _CapturingProcessor:
def __call__(self, videos: list, return_tensors: str) -> dict:
captured_videos.extend(videos)
return original_processor(videos=videos, return_tensors=return_tensors)
policy.model.video_processor = _CapturingProcessor()
policy.forward(_make_multiview_train_batch(num_views=1))
# reshape is batch-major: (b0v0, b0v1, b1v0, b1v1, …)
assert len(captured_videos) == BATCH_SIZE * 2
for i in range(BATCH_SIZE):
np.testing.assert_array_equal(captured_videos[2 * i], captured_videos[2 * i + 1])
def test_excess_views_trimmed_for_world_model(patch_vla_jepa_external_models: None) -> None:
"""With three dataset views and jepa_tubelet_size=2, only the first two views reach the encoder."""
set_seed_all(42)
policy = VLAJEPAPolicy(_make_multiview_config(num_views=3, jepa_tubelet_size=2))
policy.train()
captured_videos: list = []
original_processor = policy.model.video_processor
class _CapturingProcessor:
def __call__(self, videos: list, return_tensors: str) -> dict:
captured_videos.extend(videos)
return original_processor(videos=videos, return_tensors=return_tensors)
policy.model.video_processor = _CapturingProcessor()
policy.forward(_make_multiview_train_batch(num_views=3))
# Only B*2 items must reach the encoder, not B*3.
assert len(captured_videos) == BATCH_SIZE * 2

View File

@@ -0,0 +1,60 @@
#!/usr/bin/env python
from __future__ import annotations
import pytest
import torch
from lerobot.policies.vla_jepa.world_model import (
ActionConditionedVideoPredictor,
)
_ACTION_EMBED_DIM = 8
def _make_predictor(
embed_dim: int = 8,
action_embed_dim: int = _ACTION_EMBED_DIM,
predictor_embed_dim: int = 24,
num_action_tokens: int = 2,
tokens_per_frame: int = 1,
) -> ActionConditionedVideoPredictor:
return ActionConditionedVideoPredictor(
num_frames=1,
img_size=(1, tokens_per_frame),
patch_size=1,
tubelet_size=1,
embed_dim=embed_dim,
action_embed_dim=action_embed_dim,
predictor_embed_dim=predictor_embed_dim,
depth=1,
num_heads=2,
mlp_ratio=2.0,
num_action_tokens_per_step=num_action_tokens,
)
@pytest.mark.parametrize(
"batch,num_steps,tokens_per_frame,embed_dim",
[
(1, 2, 1, 8),
(2, 3, 4, 8),
(4, 5, 2, 16),
],
)
def test_predictor_output_shape(batch: int, num_steps: int, tokens_per_frame: int, embed_dim: int) -> None:
predictor = _make_predictor(
embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM, tokens_per_frame=tokens_per_frame
)
frame_tokens = torch.randn(batch, num_steps * tokens_per_frame, embed_dim)
action_tokens = torch.randn(batch, num_steps * 2, _ACTION_EMBED_DIM)
out = predictor(frame_tokens, action_tokens)
assert tuple(out.shape) == (batch, num_steps * tokens_per_frame, embed_dim)
assert torch.isfinite(out).all()
def test_predictor_step_mismatch_raises() -> None:
predictor = _make_predictor(tokens_per_frame=4)
frame_tokens = torch.randn(2, 3 * 4, 8) # 3 steps, 4 tokens each
with pytest.raises(RuntimeError):
predictor(frame_tokens, torch.randn(2, 2 * 2, 8)) # 2 steps → mismatch

11
uv.lock generated
View File

@@ -3047,6 +3047,11 @@ video-benchmark = [
viz = [
{ name = "rerun-sdk" },
]
vla-jepa = [
{ name = "diffusers" },
{ name = "qwen-vl-utils" },
{ name = "transformers" },
]
wallx = [
{ name = "peft" },
{ name = "qwen-vl-utils" },
@@ -3115,6 +3120,7 @@ requires-dist = [
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'diffusion'" },
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'groot'" },
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'multi-task-dit'" },
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'vla-jepa'" },
{ name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" },
@@ -3164,6 +3170,7 @@ requires-dist = [
{ name = "lerobot", extras = ["pyzmq-dep"], marker = "extra == 'unitree-g1'" },
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'eo1'" },
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'sarm'" },
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'vla-jepa'" },
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'wallx'" },
{ name = "lerobot", extras = ["reachy2"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["rebot"], marker = "extra == 'all'" },
@@ -3191,12 +3198,14 @@ requires-dist = [
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'sarm'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'smolvla'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'topreward'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'vla-jepa'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'wallx'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'xvla'" },
{ name = "lerobot", extras = ["video-benchmark"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["viz"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["viz"], marker = "extra == 'core-scripts'" },
{ name = "lerobot", extras = ["viz"], marker = "extra == 'dataset-viz'" },
{ name = "lerobot", extras = ["vla-jepa"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["wallx"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["xvla"], marker = "extra == 'all'" },
{ name = "matplotlib", marker = "extra == 'matplotlib-dep'", specifier = ">=3.10.3,<4.0.0" },
@@ -3258,7 +3267,7 @@ requires-dist = [
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "topreward", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
[[package]]
name = "librt"