From 1bcba9dec67e7b2c497fb8886e828dbda93c1b66 Mon Sep 17 00:00:00 2001 From: Maximellerbach Date: Thu, 21 May 2026 16:35:50 +0200 Subject: [PATCH] fixing training and exal examples --- docs/source/policy_vla_jepa_README.md | 92 ++++++++++++--------------- 1 file changed, 39 insertions(+), 53 deletions(-) diff --git a/docs/source/policy_vla_jepa_README.md b/docs/source/policy_vla_jepa_README.md index 67f0d2015..1739ae072 100644 --- a/docs/source/policy_vla_jepa_README.md +++ b/docs/source/policy_vla_jepa_README.md @@ -28,14 +28,6 @@ Only Qwen + the action head are used. The world model is not needed at inference ### Action head details -The action head is a **Diffusion Transformer (DiT-B)** with flow matching: - -- **Inner dim**: 768 (12 heads × 64 head dim, DiT-B preset) -- **Output dim**: `action_hidden_size` (default 1024), projected down to `action_dim` -- **Cross/self alternation**: even-indexed DiT blocks attend to Qwen context tokens (cross-attention); odd-indexed blocks are self-attention -- **Noise schedule**: Beta distribution with parameters `action_noise_beta_alpha` / `action_noise_beta_beta` -- **Inference**: Euler integration over `num_inference_timesteps` steps - Available presets via `action_model_type`: | Preset | Hidden dim | Heads | Head dim | @@ -68,14 +60,6 @@ Three checkpoints are available, converted from [ginwind/VLA-JEPA](https://huggi All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone. -### Loading a pretrained checkpoint - -```python -from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy - -policy = VLAJEPAPolicy.from_pretrained("lerobot/VLA-JEPA-LIBERO") -``` - --- ## Configuration @@ -96,49 +80,48 @@ Key parameters in `VLAJEPAConfig`: ## Training +Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints. + ### Full training from scratch ```bash lerobot-train \ - dataset.repo_id=your_org/your_dataset \ - policy.chunk_size=16 \ - policy.n_action_steps=16 + policy.type=vla_jepa \ + policy.repo_id=your_org/your_repo \ + dataset.repo_id=your_org/your_dataset ``` ### Fine-tuning from a pretrained checkpoint ```bash lerobot-train \ - policy.path=lerobot/VLA-JEPA-Pretrain \ - dataset.repo_id=your_org/your_dataset + --policy.path=lerobot/VLA-JEPA-Pretrain \ + --policy.repo_id=your_org/your_repo \ + --dataset.repo_id=your_org/your_dataset ``` If you want to go further and freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`: ```bash lerobot-train \ - policy.path=lerobot/VLA-JEPA-Pretrain \ - dataset.repo_id=your_org/your_dataset \ - policy.freeze_qwen=true + --policy.path=lerobot/VLA-JEPA-Pretrain \ + --policy.repo_id=your_org/your_repo \ + --policy.freeze_qwen=true \ + --dataset.repo_id=your_org/your_dataset ``` ### Reproducing the LIBERO results **Training on LIBERO:** - -TODO(Maxime): - -- [ ] double check the training command -- [ ] double check which LIBERO dataset (libero_10 or full libero) was used for training the checkpoint -- [ ] add the evaluation command for the pretrained checkpoint + check that the results match the original paper +starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset. +Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256. ```bash lerobot-train \ - policy.path=lerobot/VLA-JEPA-Pretrain \ - dataset.repo_id=lerobot/libero_10 \ - training.num_steps=50000 \ - env.type=libero \ - env.task=libero_10 + --policy.path=lerobot/VLA-JEPA-Pretrain \ + --policy.repo_id=your_org/your_repo \ + --dataset.repo_id=HuggingFaceVLA/libero \ + --steps=30000 ``` **Evaluating the pretrained LIBERO-10 checkpoint:** @@ -147,14 +130,11 @@ lerobot-train \ lerobot-eval \ --policy.path=lerobot/VLA-JEPA-LIBERO \ --env.type=libero \ - --env.task=libero_10 \ - --env.obs_type=pixels_agent_pos \ - --eval.n_episodes=500 \ - --eval.batch_size=10 \ - --policy.device=cuda -``` + --env.task=libero_spatial,libero_object,libero_goal,libero_10 \ + --eval.n_episodes=10 \ + --eval.batch_size=5 \ -This runs all 10 LIBERO-10 tasks (50 episodes each, 500 total) with the default camera setup (`agentview_image` → `observation.images.image`, `robot0_eye_in_hand_image` → `observation.images.image2`) and the `pixels_agent_pos` obs type that provides both images and robot state. +``` To evaluate a subset of tasks only: @@ -164,9 +144,8 @@ lerobot-eval \ --env.type=libero \ --env.task=libero_10 \ --env.task_ids='[0,1,2]' \ - --eval.n_episodes=50 \ + --eval.n_episodes=10 \ --eval.batch_size=5 \ - --policy.device=cuda ``` --- @@ -181,25 +160,32 @@ Set `enable_world_model=False`. Only the Qwen backbone and action head are loade ```bash lerobot-train \ - policy.path=lerobot/VLA-JEPA-Pretrain \ - policy.enable_world_model=false \ - dataset.repo_id=your_org/single_camera_dataset + --policy.path=lerobot/VLA-JEPA-Pretrain \ + --policy.enable_world_model=false \ + --policy.repo_id=your_org/your_repo \ + --dataset.repo_id=your_org/single_camera_dataset ``` **Option 2 — Reinitialize the predictor input projection** If you want the JEPA self-supervised signal during fine-tuning, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint. +**Option 3 - Duplicate frames to match the expected number of cameras** +A bit more advanced, you would need to change some parts of the code to support that. + --- ## Citation ```bibtex -@misc{vla_jepa_2025, - title = {VLA-JEPA: Vision-Language-Action Model with Joint-Embedding Predictive Architecture}, - author = {Gin, Wind and others}, - year = {2025}, - url = {https://huggingface.co/ginwind/VLA-JEPA}, +@misc{sun2026vlajepaenhancingvisionlanguageactionmodel, + title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model}, + author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen}, + year = {2026}, + eprint = {2602.10098}, + archivePrefix = {arXiv}, + primaryClass = {cs.RO}, + url = {https://arxiv.org/abs/2602.10098}, } ``` @@ -207,4 +193,4 @@ If you want the JEPA self-supervised signal during fine-tuning, load the checkpo ## License -Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository. The LeRobot integration code follows the **Apache 2.0 License**. +Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.