mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
231 lines
11 KiB
Markdown
231 lines
11 KiB
Markdown
# VLA-JEPA
|
||
|
||
This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||
|
||
---
|
||
|
||
## Architecture Overview
|
||
|
||
VLA-JEPA has three main components:
|
||
|
||
| Component | Module | Role |
|
||
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
|
||
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
|
||
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
|
||
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
|
||
|
||
### Data flow
|
||
|
||
**Training:**
|
||
|
||
1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens.
|
||
2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens.
|
||
3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching.
|
||
4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss.
|
||
|
||
**Inference:**
|
||
Only Qwen + the action head are used. The world model is not needed at inference time.
|
||
|
||
### Action head details
|
||
|
||
Available presets via `action_model_type`:
|
||
|
||
| Preset | Hidden dim | Heads | Head dim |
|
||
| ------- | ---------- | ----- | -------- |
|
||
| `DiT-B` | 768 | 12 | 64 |
|
||
| `DiT-L` | 1536 | 32 | 48 |
|
||
|
||
### World model details
|
||
|
||
The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes:
|
||
|
||
- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim`
|
||
- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim`
|
||
|
||
It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints).
|
||
|
||
---
|
||
|
||
## Pretrained Checkpoints
|
||
|
||
Three checkpoints are available, converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA):
|
||
|
||
| Checkpoint | Dataset | Cameras | World model | Action dim |
|
||
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
|
||
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
|
||
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
|
||
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 | Disabled\* | 7 |
|
||
|
||
\* The SimplerEnv checkpoint was fine-tuned from Pretrain. The world model predictor architecture expects `embed_dim=2048` (2-camera input) but SimplerEnv is single-camera, so the world model cannot be loaded cleanly. Since inference only needs Qwen + the action head, `enable_world_model=False` is set for this variant. See [Fine-tuning on single-camera datasets](#fine-tuning-on-single-camera-datasets) for implications.
|
||
|
||
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
|
||
|
||
---
|
||
|
||
## Configuration
|
||
|
||
Key parameters in `VLAJEPAConfig`:
|
||
|
||
| Parameter | Default | Description |
|
||
| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||
| `chunk_size` | 7 | Number of actions predicted per inference call |
|
||
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
|
||
| `num_video_frames` | 8 | Video clip length fed to the world model |
|
||
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
|
||
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
|
||
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
|
||
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
|
||
| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) |
|
||
|
||
---
|
||
|
||
## Training
|
||
|
||
Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints.
|
||
|
||
### Full training from scratch
|
||
|
||
```bash
|
||
lerobot-train \
|
||
policy.type=vla_jepa \
|
||
policy.repo_id=your_org/your_repo \
|
||
dataset.repo_id=your_org/your_dataset
|
||
```
|
||
|
||
### Fine-tuning from a pretrained checkpoint
|
||
|
||
```bash
|
||
lerobot-train \
|
||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||
--policy.repo_id=your_org/your_repo \
|
||
--dataset.repo_id=your_org/your_dataset
|
||
```
|
||
|
||
If you want to go further and freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
|
||
|
||
```bash
|
||
lerobot-train \
|
||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||
--policy.repo_id=your_org/your_repo \
|
||
--policy.freeze_qwen=true \
|
||
--dataset.repo_id=your_org/your_dataset
|
||
```
|
||
|
||
### Fine-tuning on a different embodiment
|
||
|
||
When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error.
|
||
|
||
The layers that depend on `action_dim` and `state_dim` are:
|
||
|
||
| Layer | Key prefix |
|
||
| ----------------------------------------- | ----------------------------------- |
|
||
| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` |
|
||
| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` |
|
||
| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` |
|
||
|
||
```bash
|
||
lerobot-train \
|
||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||
--policy.repo_id=your_org/your_repo \
|
||
--policy.freeze_qwen=true \
|
||
--policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \
|
||
--dataset.repo_id=your_org/your_dataset
|
||
```
|
||
|
||
If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list.
|
||
|
||
### Reproducing the LIBERO results
|
||
|
||
**Training on LIBERO:**
|
||
starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset.
|
||
Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256.
|
||
|
||
```bash
|
||
lerobot-train \
|
||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||
--policy.repo_id=your_org/your_repo \
|
||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||
--steps=30000
|
||
```
|
||
|
||
**Evaluating the pretrained LIBERO-10 checkpoint:**
|
||
|
||
```bash
|
||
lerobot-eval \
|
||
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||
--env.type=libero \
|
||
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
||
--eval.n_episodes=10 \
|
||
--eval.batch_size=5
|
||
|
||
```
|
||
|
||
To evaluate a subset of tasks only:
|
||
|
||
```bash
|
||
lerobot-eval \
|
||
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||
--env.type=libero \
|
||
--env.task=libero_10 \
|
||
--env.task_ids='[0,1,2]' \
|
||
--eval.n_episodes=10 \
|
||
--eval.batch_size=5
|
||
```
|
||
|
||
**Expected results:**
|
||
|
||
| Suite | Episodes | Successes | Success Rate |
|
||
| -------------- | -------- | --------- | ------------ |
|
||
| libero_spatial | 100 | 93 | **95.0%** |
|
||
| libero_object | 100 | 100 | **100.0%** |
|
||
| libero_goal | 100 | 98 | **98.0%** |
|
||
| libero_10 | 100 | 96 | **93.0%** |
|
||
| **Overall** | **400** | **387** | **96.5%** |
|
||
|
||
---
|
||
|
||
## Fine-tuning on single-camera datasets
|
||
|
||
The pretrained world model predictor was trained with `embed_dim = num_views × 1024`. If your target dataset has fewer cameras than the source checkpoint, the predictor input projection will have a shape mismatch and cannot be loaded.
|
||
|
||
**Option 1 — Disable the world model (recommended)**
|
||
|
||
Set `enable_world_model=False`. Only the Qwen backbone and action head are loaded and trained. This matches the original SimplerEnv fine-tuning strategy and is sufficient for good action performance.
|
||
|
||
```bash
|
||
lerobot-train \
|
||
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||
--policy.enable_world_model=false \
|
||
--policy.repo_id=your_org/your_repo \
|
||
--dataset.repo_id=your_org/single_camera_dataset
|
||
```
|
||
|
||
**Option 2 — Reinitialize the predictor input projection**
|
||
|
||
If you want the JEPA self-supervised signal during fine-tuning, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
|
||
|
||
**Option 3 - Duplicate frames to match the expected number of cameras**
|
||
A bit more advanced, you would need to change some parts of the code to support that.
|
||
|
||
---
|
||
|
||
## Citation
|
||
|
||
```bibtex
|
||
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
|
||
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
|
||
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
|
||
year = {2026},
|
||
eprint = {2602.10098},
|
||
archivePrefix = {arXiv},
|
||
primaryClass = {cs.RO},
|
||
url = {https://arxiv.org/abs/2602.10098},
|
||
}
|
||
```
|
||
|
||
---
|
||
|
||
## License
|
||
|
||
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
|