mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
Compare commits
33 Commits
docs/compl
...
b75b3ce02d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b75b3ce02d | ||
|
|
5495c10cdf | ||
|
|
1bcba9dec6 | ||
|
|
dd13eda002 | ||
|
|
c01a00a972 | ||
|
|
f75a2ee2f5 | ||
|
|
0edbb68ec3 | ||
|
|
47f8a50fa0 | ||
|
|
51e57789ba | ||
|
|
27669724d2 | ||
|
|
e32a552edb | ||
|
|
596fda60d7 | ||
|
|
3fcea935b2 | ||
|
|
76b63ebb26 | ||
|
|
bbe4ba7a53 | ||
|
|
593253e155 | ||
|
|
acf65faaff | ||
|
|
82a05f9cb4 | ||
|
|
d4abb9d562 | ||
|
|
090d392b19 | ||
|
|
e36d742d7d | ||
|
|
f8a1acb6c9 | ||
|
|
60347bc742 | ||
|
|
c6ec8d00e3 | ||
|
|
0edb693ee4 | ||
|
|
cdae1b9ad8 | ||
|
|
80ecf7bf53 | ||
|
|
5597d539e7 | ||
|
|
dfbedb71d7 | ||
|
|
ebe6c66263 | ||
|
|
0e18bdaf7a | ||
|
|
d5944c410c | ||
|
|
0d37efdb4b |
196
docs/source/policy_vla_jepa_README.md
Normal file
196
docs/source/policy_vla_jepa_README.md
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
# VLA-JEPA
|
||||||
|
|
||||||
|
This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture Overview
|
||||||
|
|
||||||
|
VLA-JEPA has three main components:
|
||||||
|
|
||||||
|
| Component | Module | Role |
|
||||||
|
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
|
||||||
|
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
|
||||||
|
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
|
||||||
|
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
|
||||||
|
|
||||||
|
### Data flow
|
||||||
|
|
||||||
|
**Training:**
|
||||||
|
|
||||||
|
1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens.
|
||||||
|
2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens.
|
||||||
|
3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching.
|
||||||
|
4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss.
|
||||||
|
|
||||||
|
**Inference:**
|
||||||
|
Only Qwen + the action head are used. The world model is not needed at inference time.
|
||||||
|
|
||||||
|
### Action head details
|
||||||
|
|
||||||
|
Available presets via `action_model_type`:
|
||||||
|
|
||||||
|
| Preset | Hidden dim | Heads | Head dim |
|
||||||
|
| ------- | ---------- | ----- | -------- |
|
||||||
|
| `DiT-B` | 768 | 12 | 64 |
|
||||||
|
| `DiT-L` | 1536 | 32 | 48 |
|
||||||
|
|
||||||
|
### World model details
|
||||||
|
|
||||||
|
The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes:
|
||||||
|
|
||||||
|
- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim`
|
||||||
|
- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim`
|
||||||
|
|
||||||
|
It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Pretrained Checkpoints
|
||||||
|
|
||||||
|
Three checkpoints are available, converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA):
|
||||||
|
|
||||||
|
| Checkpoint | Dataset | Cameras | World model | Action dim |
|
||||||
|
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
|
||||||
|
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
|
||||||
|
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
|
||||||
|
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 | Disabled\* | 7 |
|
||||||
|
|
||||||
|
\* The SimplerEnv checkpoint was fine-tuned from Pretrain. The world model predictor architecture expects `embed_dim=2048` (2-camera input) but SimplerEnv is single-camera, so the world model cannot be loaded cleanly. Since inference only needs Qwen + the action head, `enable_world_model=False` is set for this variant. See [Fine-tuning on single-camera datasets](#fine-tuning-on-single-camera-datasets) for implications.
|
||||||
|
|
||||||
|
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Key parameters in `VLAJEPAConfig`:
|
||||||
|
|
||||||
|
| Parameter | Default | Description |
|
||||||
|
| ------------------------- | ------- | -------------------------------------------------------------- |
|
||||||
|
| `chunk_size` | 7 | Number of actions predicted per inference call |
|
||||||
|
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
|
||||||
|
| `num_video_frames` | 8 | Video clip length fed to the world model |
|
||||||
|
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
|
||||||
|
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
|
||||||
|
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
|
||||||
|
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|
||||||
|
Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints.
|
||||||
|
|
||||||
|
### Full training from scratch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
policy.type=vla_jepa \
|
||||||
|
policy.repo_id=your_org/your_repo \
|
||||||
|
dataset.repo_id=your_org/your_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
### Fine-tuning from a pretrained checkpoint
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||||
|
--policy.repo_id=your_org/your_repo \
|
||||||
|
--dataset.repo_id=your_org/your_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to go further and freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||||
|
--policy.repo_id=your_org/your_repo \
|
||||||
|
--policy.freeze_qwen=true \
|
||||||
|
--dataset.repo_id=your_org/your_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
### Reproducing the LIBERO results
|
||||||
|
|
||||||
|
**Training on LIBERO:**
|
||||||
|
starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset.
|
||||||
|
Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||||
|
--policy.repo_id=your_org/your_repo \
|
||||||
|
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||||
|
--steps=30000
|
||||||
|
```
|
||||||
|
|
||||||
|
**Evaluating the pretrained LIBERO-10 checkpoint:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-eval \
|
||||||
|
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||||
|
--env.type=libero \
|
||||||
|
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
||||||
|
--eval.n_episodes=10 \
|
||||||
|
--eval.batch_size=5 \
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
To evaluate a subset of tasks only:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-eval \
|
||||||
|
--policy.path=lerobot/VLA-JEPA-LIBERO \
|
||||||
|
--env.type=libero \
|
||||||
|
--env.task=libero_10 \
|
||||||
|
--env.task_ids='[0,1,2]' \
|
||||||
|
--eval.n_episodes=10 \
|
||||||
|
--eval.batch_size=5 \
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Fine-tuning on single-camera datasets
|
||||||
|
|
||||||
|
The pretrained world model predictor was trained with `embed_dim = num_views × 1024`. If your target dataset has fewer cameras than the source checkpoint, the predictor input projection will have a shape mismatch and cannot be loaded.
|
||||||
|
|
||||||
|
**Option 1 — Disable the world model (recommended)**
|
||||||
|
|
||||||
|
Set `enable_world_model=False`. Only the Qwen backbone and action head are loaded and trained. This matches the original SimplerEnv fine-tuning strategy and is sufficient for good action performance.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||||
|
--policy.enable_world_model=false \
|
||||||
|
--policy.repo_id=your_org/your_repo \
|
||||||
|
--dataset.repo_id=your_org/single_camera_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
**Option 2 — Reinitialize the predictor input projection**
|
||||||
|
|
||||||
|
If you want the JEPA self-supervised signal during fine-tuning, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
|
||||||
|
|
||||||
|
**Option 3 - Duplicate frames to match the expected number of cameras**
|
||||||
|
A bit more advanced, you would need to change some parts of the code to support that.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
|
||||||
|
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
|
||||||
|
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
|
||||||
|
year = {2026},
|
||||||
|
eprint = {2602.10098},
|
||||||
|
archivePrefix = {arXiv},
|
||||||
|
primaryClass = {cs.RO},
|
||||||
|
url = {https://arxiv.org/abs/2602.10098},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.
|
||||||
@@ -212,6 +212,7 @@ sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<3
|
|||||||
xvla = ["lerobot[transformers-dep]"]
|
xvla = ["lerobot[transformers-dep]"]
|
||||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||||
|
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||||
|
|
||||||
# Features
|
# Features
|
||||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||||
@@ -276,6 +277,7 @@ all = [
|
|||||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||||
"lerobot[xvla]",
|
"lerobot[xvla]",
|
||||||
"lerobot[hilserl]",
|
"lerobot[hilserl]",
|
||||||
|
"lerobot[vla_jepa]",
|
||||||
"lerobot[async]",
|
"lerobot[async]",
|
||||||
"lerobot[dev]",
|
"lerobot[dev]",
|
||||||
"lerobot[test]",
|
"lerobot[test]",
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ from .pretrained import PreTrainedPolicy
|
|||||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from .utils import validate_visual_features_consistency
|
from .utils import validate_visual_features_consistency
|
||||||
|
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||||
from .vqbet.configuration_vqbet import VQBeTConfig
|
from .vqbet.configuration_vqbet import VQBeTConfig
|
||||||
from .wall_x.configuration_wall_x import WallXConfig
|
from .wall_x.configuration_wall_x import WallXConfig
|
||||||
from .xvla.configuration_xvla import XVLAConfig
|
from .xvla.configuration_xvla import XVLAConfig
|
||||||
@@ -151,6 +152,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
from .eo1.modeling_eo1 import EO1Policy
|
from .eo1.modeling_eo1 import EO1Policy
|
||||||
|
|
||||||
return EO1Policy
|
return EO1Policy
|
||||||
|
elif name == "vla_jepa":
|
||||||
|
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||||
|
|
||||||
|
return VLAJEPAPolicy
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return _get_policy_cls_from_policy_name(name=name)
|
return _get_policy_cls_from_policy_name(name=name)
|
||||||
@@ -203,6 +208,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
return WallXConfig(**kwargs)
|
return WallXConfig(**kwargs)
|
||||||
elif policy_type == "eo1":
|
elif policy_type == "eo1":
|
||||||
return EO1Config(**kwargs)
|
return EO1Config(**kwargs)
|
||||||
|
elif policy_type == "vla_jepa":
|
||||||
|
return VLAJEPAConfig(**kwargs)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||||
@@ -406,6 +413,7 @@ def make_pre_post_processors(
|
|||||||
config=policy_cfg,
|
config=policy_cfg,
|
||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(policy_cfg, EO1Config):
|
elif isinstance(policy_cfg, EO1Config):
|
||||||
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
||||||
|
|
||||||
@@ -414,6 +422,14 @@ def make_pre_post_processors(
|
|||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif isinstance(policy_cfg, VLAJEPAConfig):
|
||||||
|
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||||
|
|
||||||
|
processors = make_vla_jepa_pre_post_processors(
|
||||||
|
config=policy_cfg,
|
||||||
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
processors = _make_processors_from_policy_config(
|
processors = _make_processors_from_policy_config(
|
||||||
|
|||||||
1
src/lerobot/policies/vla_jepa/README.md
Symbolic link
1
src/lerobot/policies/vla_jepa/README.md
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
/home/maxime/github/robots/lerobot/docs/source/policy_vla_jepa_README.md
|
||||||
10
src/lerobot/policies/vla_jepa/__init__.py
Normal file
10
src/lerobot/policies/vla_jepa/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from .configuration_vla_jepa import VLAJEPAConfig
|
||||||
|
from .modeling_vla_jepa import VLAJEPAPolicy
|
||||||
|
from .processor_vla_jepa import VLAJEPANewLineProcessor, make_vla_jepa_pre_post_processors
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"VLAJEPAConfig",
|
||||||
|
"VLAJEPAPolicy",
|
||||||
|
"VLAJEPANewLineProcessor",
|
||||||
|
"make_vla_jepa_pre_post_processors",
|
||||||
|
]
|
||||||
327
src/lerobot/policies/vla_jepa/action_head.py
Normal file
327
src/lerobot/policies/vla_jepa/action_head.py
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
from torch import nn
|
||||||
|
from torch.distributions import Beta
|
||||||
|
|
||||||
|
from lerobot.utils.import_utils import _diffusers_available, require_package
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _diffusers_available:
|
||||||
|
from diffusers import ConfigMixin, ModelMixin
|
||||||
|
from diffusers.configuration_utils import register_to_config
|
||||||
|
from diffusers.models.attention import Attention, FeedForward
|
||||||
|
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
||||||
|
else:
|
||||||
|
|
||||||
|
class ModelMixin: # type: ignore[no-redef]
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ConfigMixin: # type: ignore[no-redef]
|
||||||
|
pass
|
||||||
|
|
||||||
|
register_to_config = lambda f: f # noqa: E731
|
||||||
|
Attention = FeedForward = TimestepEmbedding = Timesteps = None
|
||||||
|
|
||||||
|
from .configuration_vla_jepa import VLAJEPAConfig
|
||||||
|
|
||||||
|
|
||||||
|
def swish(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SinusoidalPositionalEncoding(nn.Module):
|
||||||
|
def __init__(self, embedding_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
|
||||||
|
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||||
|
timesteps = timesteps.float()
|
||||||
|
batch_size, seq_len = timesteps.shape
|
||||||
|
half_dim = self.embedding_dim // 2
|
||||||
|
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device)
|
||||||
|
exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1))
|
||||||
|
freqs = timesteps.unsqueeze(-1) * exponent.exp()
|
||||||
|
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1)
|
||||||
|
|
||||||
|
|
||||||
|
class ActionEncoder(nn.Module):
|
||||||
|
def __init__(self, action_dim: int, hidden_size: int):
|
||||||
|
super().__init__()
|
||||||
|
self.layer1 = nn.Linear(action_dim, hidden_size)
|
||||||
|
self.layer2 = nn.Linear(hidden_size * 2, hidden_size)
|
||||||
|
self.layer3 = nn.Linear(hidden_size, hidden_size)
|
||||||
|
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||||
|
|
||||||
|
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size, seq_len, _ = actions.shape
|
||||||
|
if timesteps.ndim != 1 or timesteps.shape[0] != batch_size:
|
||||||
|
raise ValueError("timesteps must have shape [batch_size].")
|
||||||
|
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
|
||||||
|
action_emb = self.layer1(actions)
|
||||||
|
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
|
||||||
|
return self.layer3(swish(self.layer2(torch.cat([action_emb, time_emb], dim=-1))))
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEncoder(nn.Module):
|
||||||
|
def __init__(self, embedding_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
require_package("diffusers", extra="vla_jepa")
|
||||||
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||||
|
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||||
|
|
||||||
|
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||||
|
projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype)
|
||||||
|
return self.timestep_embedder(projected)
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNorm(nn.Module):
|
||||||
|
def __init__(self, embedding_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
||||||
|
self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False)
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
|
||||||
|
scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1)
|
||||||
|
return self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||||
|
|
||||||
|
|
||||||
|
class BasicTransformerBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
attention_head_dim: int,
|
||||||
|
dropout: float,
|
||||||
|
cross_attention_dim: int,
|
||||||
|
is_cross_attention: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.is_cross_attention = is_cross_attention
|
||||||
|
self.norm1 = AdaLayerNorm(dim)
|
||||||
|
self.attn1 = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=num_attention_heads,
|
||||||
|
dim_head=attention_head_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
bias=True,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
out_bias=True,
|
||||||
|
)
|
||||||
|
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
|
||||||
|
self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor | None,
|
||||||
|
temb: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
attn_input = self.norm1(hidden_states, temb)
|
||||||
|
attention_context = encoder_hidden_states if self.is_cross_attention else None
|
||||||
|
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context)
|
||||||
|
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class DiT(ModelMixin, ConfigMixin):
|
||||||
|
_supports_gradient_checkpointing = False
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_attention_heads: int,
|
||||||
|
attention_head_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
num_layers: int,
|
||||||
|
dropout: float,
|
||||||
|
cross_attention_dim: int,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.timestep_encoder = TimestepEncoder(self.inner_dim)
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
BasicTransformerBlock(
|
||||||
|
dim=self.inner_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim,
|
||||||
|
is_cross_attention=layer_idx % 2 == 0,
|
||||||
|
)
|
||||||
|
for layer_idx in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
|
||||||
|
self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2)
|
||||||
|
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
temb = self.timestep_encoder(timestep)
|
||||||
|
x = hidden_states
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
|
||||||
|
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
|
||||||
|
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
|
||||||
|
return self.proj_out_2(x)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ActionModelPreset:
|
||||||
|
hidden_size: int
|
||||||
|
attention_head_dim: int
|
||||||
|
num_attention_heads: int
|
||||||
|
|
||||||
|
|
||||||
|
DIT_PRESETS = {
|
||||||
|
"DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12),
|
||||||
|
"DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32),
|
||||||
|
"DiT-test": ActionModelPreset(hidden_size=16, attention_head_dim=8, num_attention_heads=2),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class VLAJEPAActionHead(nn.Module):
|
||||||
|
def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
preset = DIT_PRESETS[config.action_model_type]
|
||||||
|
self.config = config
|
||||||
|
num_heads = config.action_num_heads or preset.num_attention_heads
|
||||||
|
head_dim = config.action_attention_head_dim or preset.attention_head_dim
|
||||||
|
inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768
|
||||||
|
|
||||||
|
self.input_embedding_dim = inner_dim
|
||||||
|
self.action_horizon = config.chunk_size
|
||||||
|
self.num_inference_timesteps = config.num_inference_timesteps
|
||||||
|
|
||||||
|
hidden_size = config.action_hidden_size
|
||||||
|
self.model = DiT(
|
||||||
|
num_attention_heads=num_heads,
|
||||||
|
attention_head_dim=head_dim,
|
||||||
|
output_dim=hidden_size,
|
||||||
|
num_layers=config.action_num_layers,
|
||||||
|
dropout=config.action_dropout,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
)
|
||||||
|
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
|
||||||
|
self.action_decoder = nn.Sequential(
|
||||||
|
OrderedDict(
|
||||||
|
[
|
||||||
|
("layer1", nn.Linear(hidden_size, hidden_size)),
|
||||||
|
("relu", nn.ReLU()),
|
||||||
|
("layer2", nn.Linear(hidden_size, config.action_dim)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.state_encoder = (
|
||||||
|
nn.Sequential(
|
||||||
|
OrderedDict(
|
||||||
|
[
|
||||||
|
("layer1", nn.Linear(config.state_dim, hidden_size)),
|
||||||
|
("relu", nn.ReLU()),
|
||||||
|
("layer2", nn.Linear(hidden_size, inner_dim)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if config.state_dim > 0
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim)
|
||||||
|
self.position_embedding = nn.Embedding(
|
||||||
|
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
|
||||||
|
inner_dim,
|
||||||
|
)
|
||||||
|
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
|
||||||
|
|
||||||
|
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||||
|
sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype)
|
||||||
|
return (self.config.action_noise_s - sample) / self.config.action_noise_s
|
||||||
|
|
||||||
|
def _build_inputs(
|
||||||
|
self,
|
||||||
|
conditioning_tokens: torch.Tensor,
|
||||||
|
actions: torch.Tensor,
|
||||||
|
state: torch.Tensor | None,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
action_features = self.action_encoder(actions, timesteps)
|
||||||
|
pos_ids = torch.arange(action_features.shape[1], device=actions.device)
|
||||||
|
action_features = action_features + self.position_embedding(pos_ids)[None]
|
||||||
|
|
||||||
|
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1)
|
||||||
|
seq = [future_tokens, action_features]
|
||||||
|
if state is not None and self.state_encoder is not None:
|
||||||
|
if state.ndim == 2:
|
||||||
|
state = state.unsqueeze(1)
|
||||||
|
seq.insert(0, self.state_encoder(state))
|
||||||
|
return torch.cat(seq, dim=1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
conditioning_tokens: torch.Tensor,
|
||||||
|
actions: torch.Tensor,
|
||||||
|
state: torch.Tensor | None = None,
|
||||||
|
action_is_pad: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
noise = torch.randn_like(actions)
|
||||||
|
t = self.sample_time(actions.shape[0], actions.device, actions.dtype)
|
||||||
|
noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions
|
||||||
|
velocity = actions - noise
|
||||||
|
t_discretized = (t * self.config.action_num_timestep_buckets).long()
|
||||||
|
|
||||||
|
hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized)
|
||||||
|
pred = self.model(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
encoder_hidden_states=conditioning_tokens,
|
||||||
|
timestep=t_discretized,
|
||||||
|
)
|
||||||
|
pred_actions = self.action_decoder(pred[:, -actions.shape[1] :])
|
||||||
|
|
||||||
|
if action_is_pad is None:
|
||||||
|
action_is_pad = torch.zeros(actions.shape[:2], dtype=torch.bool, device=actions.device)
|
||||||
|
|
||||||
|
loss = F.mse_loss(pred_actions, velocity, reduction="none") # [B, T, action_dim]
|
||||||
|
valid_mask = ~action_is_pad.unsqueeze(-1) # [B, T, 1]
|
||||||
|
num_valid = valid_mask.sum() * loss.shape[-1]
|
||||||
|
return (loss * valid_mask).sum() / num_valid.clamp_min(1)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def predict_action(
|
||||||
|
self,
|
||||||
|
conditioning_tokens: torch.Tensor,
|
||||||
|
state: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size = conditioning_tokens.shape[0]
|
||||||
|
actions = torch.randn(
|
||||||
|
batch_size,
|
||||||
|
self.action_horizon,
|
||||||
|
self.config.action_dim,
|
||||||
|
dtype=conditioning_tokens.dtype,
|
||||||
|
device=conditioning_tokens.device,
|
||||||
|
)
|
||||||
|
dt = 1.0 / max(self.num_inference_timesteps, 1)
|
||||||
|
for step in range(self.num_inference_timesteps):
|
||||||
|
t_cont = step / float(max(self.num_inference_timesteps, 1))
|
||||||
|
t_value = int(t_cont * self.config.action_num_timestep_buckets)
|
||||||
|
timesteps = torch.full(
|
||||||
|
(batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long
|
||||||
|
)
|
||||||
|
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
|
||||||
|
pred = self.model(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
encoder_hidden_states=conditioning_tokens,
|
||||||
|
timestep=timesteps,
|
||||||
|
)
|
||||||
|
pred_velocity = self.action_decoder(pred[:, -self.action_horizon :])
|
||||||
|
actions = actions + dt * pred_velocity
|
||||||
|
return actions
|
||||||
133
src/lerobot/policies/vla_jepa/configuration_vla_jepa.py
Normal file
133
src/lerobot/policies/vla_jepa/configuration_vla_jepa.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.configs.types import NormalizationMode
|
||||||
|
from lerobot.optim.optimizers import AdamWConfig
|
||||||
|
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||||
|
|
||||||
|
|
||||||
|
@PreTrainedConfig.register_subclass("vla_jepa")
|
||||||
|
@dataclass
|
||||||
|
class VLAJEPAConfig(PreTrainedConfig):
|
||||||
|
n_obs_steps: int = 1
|
||||||
|
chunk_size: int = 7
|
||||||
|
n_action_steps: int = 7
|
||||||
|
|
||||||
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
"STATE": NormalizationMode.MEAN_STD,
|
||||||
|
"ACTION": NormalizationMode.MIN_MAX,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
|
||||||
|
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
|
||||||
|
freeze_qwen: bool = False
|
||||||
|
enable_world_model: bool = True
|
||||||
|
reinit_action_head: bool = False
|
||||||
|
|
||||||
|
tokenizer_padding_side: str = "left"
|
||||||
|
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
|
||||||
|
special_action_token: str = "<|action_{}|>"
|
||||||
|
embodied_action_token: str = "<|embodied_action|>"
|
||||||
|
|
||||||
|
action_dim: int = 7
|
||||||
|
state_dim: int = 8
|
||||||
|
|
||||||
|
num_action_tokens_per_timestep: int = 8
|
||||||
|
num_embodied_action_tokens_per_instruction: int = 32
|
||||||
|
num_inference_timesteps: int = 4
|
||||||
|
|
||||||
|
action_hidden_size: int = 1024
|
||||||
|
action_model_type: str = "DiT-B"
|
||||||
|
action_num_layers: int = 16
|
||||||
|
action_num_heads: int | None = None
|
||||||
|
action_attention_head_dim: int | None = None
|
||||||
|
action_dropout: float = 0.2
|
||||||
|
action_num_timestep_buckets: int = 1000
|
||||||
|
action_noise_beta_alpha: float = 1.5
|
||||||
|
action_noise_beta_beta: float = 1.0
|
||||||
|
action_noise_s: float = 0.999
|
||||||
|
num_target_vision_tokens: int = 32
|
||||||
|
action_max_seq_len: int = 1024
|
||||||
|
|
||||||
|
# total video frames loaded per sample
|
||||||
|
num_video_frames: int = 8
|
||||||
|
predictor_depth: int = 12
|
||||||
|
predictor_num_heads: int = 8
|
||||||
|
predictor_mlp_ratio: float = 4.0
|
||||||
|
predictor_dropout: float = 0.0
|
||||||
|
world_model_loss_weight: float = 0.1
|
||||||
|
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
|
||||||
|
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
|
||||||
|
|
||||||
|
resize_images_to: tuple[int, int] | None = None
|
||||||
|
binarize_gripper_action: bool = True
|
||||||
|
pre_snap_gripper_action: bool = True
|
||||||
|
clip_normalized_actions: bool = True
|
||||||
|
torch_dtype: str = "bfloat16"
|
||||||
|
|
||||||
|
optimizer_lr: float = 1e-4
|
||||||
|
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||||
|
optimizer_eps: float = 1e-8
|
||||||
|
optimizer_weight_decay: float = 1e-10
|
||||||
|
optimizer_grad_clip_norm: float = 10.0
|
||||||
|
scheduler_warmup_steps: int = 1_000
|
||||||
|
scheduler_decay_steps: int = 30_000
|
||||||
|
scheduler_decay_lr: float = 2.5e-6
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
super().__post_init__()
|
||||||
|
if self.freeze_qwen and self.enable_world_model:
|
||||||
|
# freezing qwen backbone makes world model training irrelevant since no grad flows
|
||||||
|
self.enable_world_model = False
|
||||||
|
if self.n_action_steps > self.chunk_size:
|
||||||
|
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
|
||||||
|
if self.num_video_frames < 2 * self.jepa_tubelet_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` "
|
||||||
|
f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position."
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_features(self) -> None:
|
||||||
|
if not self.image_features:
|
||||||
|
raise ValueError("VLAJEPA requires at least one visual input feature.")
|
||||||
|
if self.action_feature is None:
|
||||||
|
raise ValueError("VLAJEPA requires an action output feature.")
|
||||||
|
self.action_dim = self.action_feature.shape[0]
|
||||||
|
if self.robot_state_feature is not None:
|
||||||
|
self.state_dim = self.robot_state_feature.shape[0]
|
||||||
|
|
||||||
|
def get_optimizer_preset(self) -> AdamWConfig:
|
||||||
|
return AdamWConfig(
|
||||||
|
lr=self.optimizer_lr,
|
||||||
|
betas=self.optimizer_betas,
|
||||||
|
eps=self.optimizer_eps,
|
||||||
|
weight_decay=self.optimizer_weight_decay,
|
||||||
|
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||||
|
return CosineDecayWithWarmupSchedulerConfig(
|
||||||
|
peak_lr=self.optimizer_lr,
|
||||||
|
decay_lr=self.scheduler_decay_lr,
|
||||||
|
num_warmup_steps=self.scheduler_warmup_steps,
|
||||||
|
num_decay_steps=self.scheduler_decay_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_delta_indices(self) -> list[int]:
|
||||||
|
# load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1]
|
||||||
|
# matches original repo's observation_indices=list(range(video_horizon))
|
||||||
|
return list(range(self.num_video_frames))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_delta_indices(self) -> list[int]:
|
||||||
|
return list(range(self.chunk_size))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward_delta_indices(self) -> None:
|
||||||
|
return None
|
||||||
454
src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py
Normal file
454
src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py
Normal file
@@ -0,0 +1,454 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Convert all VLA-JEPA .pt checkpoints (ginwind/VLA-JEPA) to LeRobot safetensors
|
||||||
|
format and upload them to maximellerbach org inside a HF collection.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run python convert_vla_jepa_checkpoints.py
|
||||||
|
|
||||||
|
For each variant the script:
|
||||||
|
1. Downloads the .pt checkpoint.
|
||||||
|
2. Extracts the state dict.
|
||||||
|
3. Instantiates VLAJEPAPolicy with the variant's confirmed config.
|
||||||
|
4. Loads the state dict (strict=False — mismatches printed to stdout).
|
||||||
|
5. push_to_hub → writes model.safetensors + config.json in LeRobot format.
|
||||||
|
6. Adds the new repo to a shared HF collection.
|
||||||
|
|
||||||
|
Config sources
|
||||||
|
--------------
|
||||||
|
Numeric hyper-params : ginwind/VLA-JEPA/<variant>/config.json
|
||||||
|
Image keys LIBERO : lerobot/libero_10 meta/info.json ✓ confirmed
|
||||||
|
Image keys Pretrain : lerobot/droid_1.0.1 meta/info.json ✓ confirmed
|
||||||
|
Image keys SimplerEnv: OXE Bridge/RT1 are single-camera ✓ confirmed
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
from safetensors.torch import save_file as save_safetensors
|
||||||
|
|
||||||
|
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Top-level settings
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
SOURCE_REPO_ID = "ginwind/VLA-JEPA"
|
||||||
|
TARGET_ORG = "maximellerbach"
|
||||||
|
COLLECTION_TITLE = "VLA-JEPA"
|
||||||
|
COLLECTION_DESCRIPTION = (
|
||||||
|
"VLA-JEPA model checkpoints (LIBERO, Pretrain, SimplerEnv) converted from .pt to safetensors via LeRobot."
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Key mapping — mirrors todo_converter.py map_key() so both converters
|
||||||
|
# produce identical safetensors layouts that match the LeRobot action_head code.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_source_key(key: str) -> str:
|
||||||
|
return key[len("module.") :] if key.startswith("module.") else key
|
||||||
|
|
||||||
|
|
||||||
|
def _map_checkpoint_key(raw_key: str) -> str | None:
|
||||||
|
"""Map original VLA-JEPA state-dict keys to LeRobot vla_jepa layout."""
|
||||||
|
key = _normalize_source_key(raw_key)
|
||||||
|
|
||||||
|
if key.startswith("qwen_vl_interface."):
|
||||||
|
return "model.qwen." + key[len("qwen_vl_interface.") :]
|
||||||
|
if key.startswith("vj_encoder."):
|
||||||
|
return "model.video_encoder." + key[len("vj_encoder.") :]
|
||||||
|
if key.startswith("vj_predictor."):
|
||||||
|
return "model.video_predictor." + key[len("vj_predictor.") :]
|
||||||
|
if key.startswith("action_model."):
|
||||||
|
# LeRobot code uses the same sub-key names as the source checkpoint,
|
||||||
|
# so only the top-level "model." prefix needs to be added.
|
||||||
|
return "model." + key
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_dataset_stats(api: HfApi, source_repo_id: str, subfolder: str) -> dict | None:
|
||||||
|
"""Download dataset_statistics.json and return {action: {...}, state: {...}} stats dict."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
stats_file = f"{subfolder}/dataset_statistics.json"
|
||||||
|
try:
|
||||||
|
local = api.hf_hub_download(source_repo_id, stats_file)
|
||||||
|
data = json.loads(Path(local).read_text())
|
||||||
|
# Original repo nests stats under a robot key, e.g. {"franka": {"action": {...}, "state": {...}}}
|
||||||
|
for robot_key in data:
|
||||||
|
robot_data = data[robot_key]
|
||||||
|
if isinstance(robot_data, dict) and "action" in robot_data:
|
||||||
|
log.info(" Loaded dataset stats from %s (robot key: %s)", stats_file, robot_key)
|
||||||
|
result = {"action": robot_data["action"]}
|
||||||
|
if "state" in robot_data:
|
||||||
|
result["observation.state"] = robot_data["state"]
|
||||||
|
log.info(" Also loaded state stats.")
|
||||||
|
return result
|
||||||
|
log.warning(" %s found but no 'action' key under any robot key — skipping stats.", stats_file)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
log.warning(" Could not fetch %s: %s — postprocessor will have no unnorm stats.", stats_file, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _set_if_present(d: dict, key: str, value) -> None:
|
||||||
|
if value is not None:
|
||||||
|
d[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
def _deep_get(mapping: dict, path: tuple, default=None):
|
||||||
|
current = mapping
|
||||||
|
for key in path:
|
||||||
|
if not isinstance(current, dict) or key not in current:
|
||||||
|
return default
|
||||||
|
current = current[key]
|
||||||
|
return current
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_source_config(api: HfApi, source_repo_id: str, subfolder: str) -> dict:
|
||||||
|
"""Download config.yaml from the source HF repo for a given variant subfolder."""
|
||||||
|
try:
|
||||||
|
import yaml
|
||||||
|
except ImportError:
|
||||||
|
log.warning("PyYAML not installed — cannot apply source config.yaml overrides.")
|
||||||
|
return {}
|
||||||
|
config_file = f"{subfolder}/config.yaml"
|
||||||
|
try:
|
||||||
|
local = api.hf_hub_download(source_repo_id, config_file)
|
||||||
|
data = yaml.safe_load(Path(local).read_text()) or {}
|
||||||
|
if isinstance(data, dict):
|
||||||
|
log.info(" Loaded source config from %s", config_file)
|
||||||
|
return data
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
log.warning(" Could not fetch %s: %s — using hardcoded defaults.", config_file, exc)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_source_config(kwargs: dict, source_config: dict) -> None:
|
||||||
|
"""Apply ginwind/VLA-JEPA config.yaml values to kwargs, mirroring todo_converter.py logic."""
|
||||||
|
if not source_config:
|
||||||
|
return
|
||||||
|
|
||||||
|
data_cfg = _deep_get(source_config, ("datasets", "vla_data"), {})
|
||||||
|
action_cfg = _deep_get(source_config, ("framework", "action_model"), {})
|
||||||
|
diffusion_cfg = _deep_get(source_config, ("framework", "action_model", "diffusion_model_cfg"), {})
|
||||||
|
video_cfg = _deep_get(source_config, ("framework", "vj2_model"), {})
|
||||||
|
trainer_cfg = source_config.get("trainer", {})
|
||||||
|
|
||||||
|
prompt_template = data_cfg.get("CoT_prompt")
|
||||||
|
if prompt_template:
|
||||||
|
kwargs["prompt_template"] = str(prompt_template)
|
||||||
|
|
||||||
|
action_horizon = action_cfg.get("action_horizon")
|
||||||
|
if action_horizon is not None:
|
||||||
|
kwargs["chunk_size"] = int(action_horizon)
|
||||||
|
kwargs["n_action_steps"] = int(action_horizon)
|
||||||
|
|
||||||
|
_set_if_present(
|
||||||
|
kwargs,
|
||||||
|
"num_action_tokens_per_timestep",
|
||||||
|
video_cfg.get("num_action_tokens_per_timestep", action_cfg.get("num_action_tokens_per_timestep")),
|
||||||
|
)
|
||||||
|
_set_if_present(
|
||||||
|
kwargs,
|
||||||
|
"num_embodied_action_tokens_per_instruction",
|
||||||
|
video_cfg.get(
|
||||||
|
"num_embodied_action_tokens_per_instruction",
|
||||||
|
action_cfg.get("num_embodied_action_tokens_per_instruction"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
_set_if_present(kwargs, "num_inference_timesteps", action_cfg.get("num_inference_timesteps"))
|
||||||
|
_set_if_present(kwargs, "special_action_token", video_cfg.get("special_action_token"))
|
||||||
|
_set_if_present(kwargs, "embodied_action_token", video_cfg.get("embodied_action_token"))
|
||||||
|
_set_if_present(
|
||||||
|
kwargs, "action_hidden_size", action_cfg.get("action_hidden_dim", action_cfg.get("hidden_size"))
|
||||||
|
)
|
||||||
|
_set_if_present(kwargs, "action_model_type", action_cfg.get("action_model_type"))
|
||||||
|
_set_if_present(kwargs, "action_noise_beta_alpha", action_cfg.get("noise_beta_alpha"))
|
||||||
|
_set_if_present(kwargs, "action_noise_beta_beta", action_cfg.get("noise_beta_beta"))
|
||||||
|
_set_if_present(kwargs, "action_noise_s", action_cfg.get("noise_s"))
|
||||||
|
_set_if_present(kwargs, "action_num_timestep_buckets", action_cfg.get("num_timestep_buckets"))
|
||||||
|
_set_if_present(kwargs, "repeated_diffusion_steps", action_cfg.get("repeated_diffusion_steps"))
|
||||||
|
_set_if_present(kwargs, "action_num_layers", diffusion_cfg.get("num_layers"))
|
||||||
|
_set_if_present(kwargs, "action_dropout", diffusion_cfg.get("dropout"))
|
||||||
|
|
||||||
|
_set_if_present(kwargs, "num_video_frames", video_cfg.get("num_frames"))
|
||||||
|
_set_if_present(kwargs, "predictor_depth", video_cfg.get("predictor_depth", video_cfg.get("depth")))
|
||||||
|
_set_if_present(
|
||||||
|
kwargs, "predictor_num_heads", video_cfg.get("predictor_num_heads", video_cfg.get("num_heads"))
|
||||||
|
)
|
||||||
|
_set_if_present(kwargs, "predictor_mlp_ratio", video_cfg.get("predictor_mlp_ratio"))
|
||||||
|
|
||||||
|
_set_if_present(kwargs, "optimizer_grad_clip_norm", trainer_cfg.get("max_grad_norm"))
|
||||||
|
learning_rate = trainer_cfg.get("learning_rate", {})
|
||||||
|
if isinstance(learning_rate, dict):
|
||||||
|
_set_if_present(kwargs, "optimizer_lr", learning_rate.get("action_model"))
|
||||||
|
optimizer_cfg = trainer_cfg.get("optimizer", {})
|
||||||
|
if isinstance(optimizer_cfg, dict):
|
||||||
|
_set_if_present(kwargs, "optimizer_eps", optimizer_cfg.get("eps"))
|
||||||
|
_set_if_present(kwargs, "optimizer_weight_decay", optimizer_cfg.get("weight_decay"))
|
||||||
|
betas = optimizer_cfg.get("betas")
|
||||||
|
if betas is not None:
|
||||||
|
kwargs["optimizer_betas"] = tuple(betas)
|
||||||
|
scheduler = trainer_cfg.get("scheduler", {})
|
||||||
|
if isinstance(scheduler, dict):
|
||||||
|
_set_if_present(kwargs, "scheduler_warmup_steps", scheduler.get("warmup_steps"))
|
||||||
|
_set_if_present(kwargs, "scheduler_decay_lr", scheduler.get("min_lr"))
|
||||||
|
_set_if_present(kwargs, "scheduler_warmup_steps", trainer_cfg.get("num_warmup_steps"))
|
||||||
|
scheduler_kwargs = trainer_cfg.get("scheduler_specific_kwargs", {})
|
||||||
|
if isinstance(scheduler_kwargs, dict):
|
||||||
|
_set_if_present(kwargs, "scheduler_decay_lr", scheduler_kwargs.get("min_lr"))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Architecture — identical across all 4 variants (from config.json)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
_ARCH = {
|
||||||
|
"qwen_model_name": "Qwen/Qwen3-VL-2B-Instruct", # 2B, NOT the default 4B
|
||||||
|
"chunk_size": 7,
|
||||||
|
"n_action_steps": 7,
|
||||||
|
"num_video_frames": 8,
|
||||||
|
"jepa_tubelet_size": 2,
|
||||||
|
"num_action_tokens_per_timestep": 8,
|
||||||
|
"num_embodied_action_tokens_per_instruction": 32,
|
||||||
|
"num_inference_timesteps": 4,
|
||||||
|
"action_hidden_size": 1024,
|
||||||
|
"action_model_type": "DiT-B",
|
||||||
|
# Explicit dims matching DiT-B preset and ginwind checkpoint shape
|
||||||
|
"action_num_heads": 12,
|
||||||
|
"action_attention_head_dim": 64,
|
||||||
|
"action_num_layers": 16,
|
||||||
|
"action_dropout": 0.2,
|
||||||
|
"repeated_diffusion_steps": 8,
|
||||||
|
"action_noise_beta_alpha": 1.5,
|
||||||
|
"action_noise_beta_beta": 1.0,
|
||||||
|
"action_noise_s": 0.999,
|
||||||
|
"action_num_timestep_buckets": 1000,
|
||||||
|
# World model predictor (12 blocks, confirmed from checkpoint)
|
||||||
|
"predictor_depth": 12,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Image-key sets (confirmed sources in module docstring)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LIBERO — confirmed from lerobot/libero_10 meta/info.json
|
||||||
|
_LIBERO_CAMS = [
|
||||||
|
"observation.images.image", # agentview camera
|
||||||
|
"observation.images.image2", # eye-in-hand camera
|
||||||
|
]
|
||||||
|
|
||||||
|
# DROID pretrain — 2 views match the predictor embed_dim=2 × 1024=2048 in checkpoint
|
||||||
|
_DROID_CAMS = [
|
||||||
|
"observation.images.exterior_1_left",
|
||||||
|
"observation.images.exterior_2_left",
|
||||||
|
]
|
||||||
|
|
||||||
|
# OXE Bridge + RT1 — single-camera; world model disabled (predictor embed_dim mismatch)
|
||||||
|
_OXE_CAMS = [
|
||||||
|
"observation.images.image",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Config factories
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _build_config(
|
||||||
|
camera_keys: list[str],
|
||||||
|
with_state: bool,
|
||||||
|
enable_world_model: bool = True,
|
||||||
|
source_config: dict | None = None,
|
||||||
|
):
|
||||||
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||||
|
|
||||||
|
kwargs = dict(_ARCH)
|
||||||
|
_apply_source_config(kwargs, source_config or {})
|
||||||
|
|
||||||
|
# Image resolution: prefer source config, fall back to 224
|
||||||
|
data_cfg = _deep_get(source_config or {}, ("datasets", "vla_data"), {})
|
||||||
|
raw_res = data_cfg.get("resolution_size")
|
||||||
|
resolution_size = int(raw_res) if raw_res is not None else 224
|
||||||
|
image_shape = (3, resolution_size, resolution_size)
|
||||||
|
# Always set resize_images_to so the policy resizes env images to the training resolution,
|
||||||
|
# regardless of what resolution the eval env renders at.
|
||||||
|
kwargs["resize_images_to"] = (resolution_size, resolution_size)
|
||||||
|
|
||||||
|
# State / action dims: prefer source config
|
||||||
|
action_cfg = _deep_get(source_config or {}, ("framework", "action_model"), {})
|
||||||
|
state_dim = int(action_cfg["state_dim"]) if "state_dim" in action_cfg else 8
|
||||||
|
action_dim = int(action_cfg["action_dim"]) if "action_dim" in action_cfg else 7
|
||||||
|
|
||||||
|
input_features = {k: PolicyFeature(type=FeatureType.VISUAL, shape=image_shape) for k in camera_keys}
|
||||||
|
if with_state:
|
||||||
|
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))
|
||||||
|
|
||||||
|
cfg = VLAJEPAConfig(
|
||||||
|
input_features=input_features,
|
||||||
|
output_features={
|
||||||
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
|
||||||
|
},
|
||||||
|
enable_world_model=enable_world_model,
|
||||||
|
binarize_gripper_action=True,
|
||||||
|
clip_normalized_actions=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
cfg.validate_features()
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
# Maps each subfolder in SOURCE_REPO_ID to (camera_keys, with_state, enable_world_model, repo_suffix)
|
||||||
|
VARIANTS: dict[str, tuple] = {
|
||||||
|
"LIBERO": (_LIBERO_CAMS, True, True, "LIBERO"),
|
||||||
|
"Pretrain": (_DROID_CAMS, False, True, "Pretrain"),
|
||||||
|
# SimplerEnv uses a single camera; the predictor embed_dim (2048) would mismatch, so
|
||||||
|
# disable the world model — only qwen + action_model weights are needed for inference.
|
||||||
|
"SimplerEnv": (_OXE_CAMS, False, False, "SimplerEnv"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def extract_state_dict(ckpt: object) -> dict[str, torch.Tensor]:
|
||||||
|
if isinstance(ckpt, dict):
|
||||||
|
sd = ckpt.get("state_dict") or ckpt.get("model_state_dict") or ckpt.get("model")
|
||||||
|
if sd is None:
|
||||||
|
sd = ckpt
|
||||||
|
else:
|
||||||
|
sd = ckpt
|
||||||
|
return {k: v for k, v in sd.items() if isinstance(v, torch.Tensor)}
|
||||||
|
|
||||||
|
|
||||||
|
def subfolder_of(pt_path: str) -> str | None:
|
||||||
|
for part in Path(pt_path).parts:
|
||||||
|
if part in VARIANTS:
|
||||||
|
return part
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
api = HfApi()
|
||||||
|
|
||||||
|
log.info("Listing .pt files in %s …", SOURCE_REPO_ID)
|
||||||
|
pt_files = [f for f in api.list_repo_files(SOURCE_REPO_ID) if f.endswith(".pt")]
|
||||||
|
if not pt_files:
|
||||||
|
log.error("No .pt files found.")
|
||||||
|
return
|
||||||
|
for f in pt_files:
|
||||||
|
log.info(" %s", f)
|
||||||
|
|
||||||
|
# Create / reuse the collection once
|
||||||
|
collection = api.create_collection(
|
||||||
|
title=COLLECTION_TITLE,
|
||||||
|
description=COLLECTION_DESCRIPTION,
|
||||||
|
namespace=TARGET_ORG,
|
||||||
|
exists_ok=True,
|
||||||
|
)
|
||||||
|
log.info("Collection: %s", collection.url)
|
||||||
|
|
||||||
|
for pt_filename in pt_files:
|
||||||
|
log.info("\n=== %s ===", pt_filename)
|
||||||
|
|
||||||
|
subfolder = subfolder_of(pt_filename)
|
||||||
|
if subfolder is None:
|
||||||
|
log.warning(" No variant entry for '%s' — skipping.", pt_filename)
|
||||||
|
continue
|
||||||
|
|
||||||
|
camera_keys, with_state, enable_world_model, repo_suffix = VARIANTS[subfolder]
|
||||||
|
target_repo_id = f"{TARGET_ORG}/VLA-JEPA-{repo_suffix}"
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
" cameras=%d with_state=%s wm=%s → %s",
|
||||||
|
len(camera_keys),
|
||||||
|
with_state,
|
||||||
|
enable_world_model,
|
||||||
|
target_repo_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Download
|
||||||
|
local_pt = api.hf_hub_download(SOURCE_REPO_ID, pt_filename)
|
||||||
|
log.info(" Downloaded → %s", local_pt)
|
||||||
|
|
||||||
|
# 2. Load checkpoint
|
||||||
|
try:
|
||||||
|
ckpt = torch.load(local_pt, map_location="cpu", mmap=True, weights_only=False) # nosec B614
|
||||||
|
except TypeError:
|
||||||
|
ckpt = torch.load(local_pt, map_location="cpu") # nosec B614
|
||||||
|
|
||||||
|
sd = extract_state_dict(ckpt)
|
||||||
|
|
||||||
|
# Map source key names → LeRobot layout (handles layer1→w1, transformer_blocks→blocks, etc.)
|
||||||
|
mapped_sd: dict[str, torch.Tensor] = {}
|
||||||
|
skipped_keys: list[str] = []
|
||||||
|
for raw_key, value in sd.items():
|
||||||
|
target_key = _map_checkpoint_key(raw_key)
|
||||||
|
if target_key is None:
|
||||||
|
skipped_keys.append(raw_key)
|
||||||
|
else:
|
||||||
|
mapped_sd[target_key] = value
|
||||||
|
log.info(" %d tensors mapped, %d skipped", len(mapped_sd), len(skipped_keys))
|
||||||
|
if skipped_keys:
|
||||||
|
log.info(" Skipped sample: %s", skipped_keys[:5])
|
||||||
|
log.info(" First 5 mapped keys: %s", list(mapped_sd)[:5])
|
||||||
|
|
||||||
|
# 3. Fetch action + state stats needed by the pre/postprocessor unnormalizers
|
||||||
|
dataset_stats = _fetch_dataset_stats(api, SOURCE_REPO_ID, subfolder)
|
||||||
|
|
||||||
|
# 4. Build config (no policy instantiation — avoids loading backbone from Hub)
|
||||||
|
source_config = _fetch_source_config(api, SOURCE_REPO_ID, subfolder)
|
||||||
|
config = _build_config(camera_keys, with_state, enable_world_model, source_config)
|
||||||
|
|
||||||
|
# 5. Save everything to a temp dir and upload in one shot
|
||||||
|
api.create_repo(target_repo_id, repo_type="model", exist_ok=True)
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
save_dir = Path(tmp)
|
||||||
|
|
||||||
|
log.info(" Saving model.safetensors …")
|
||||||
|
save_safetensors(mapped_sd, save_dir / "model.safetensors")
|
||||||
|
|
||||||
|
config._save_pretrained(save_dir) # writes config.json via draccus
|
||||||
|
|
||||||
|
preprocessor, postprocessor = make_vla_jepa_pre_post_processors(config, dataset_stats)
|
||||||
|
preprocessor.save_pretrained(save_dir) # writes policy_preprocessor.json
|
||||||
|
postprocessor.save_pretrained(save_dir) # writes policy_postprocessor.json
|
||||||
|
|
||||||
|
log.info(" Uploading …")
|
||||||
|
commit_url = api.upload_folder(
|
||||||
|
folder_path=save_dir,
|
||||||
|
repo_id=target_repo_id,
|
||||||
|
repo_type="model",
|
||||||
|
commit_message=f"Convert {Path(pt_filename).name} to safetensors",
|
||||||
|
)
|
||||||
|
log.info(" Uploaded → %s", commit_url)
|
||||||
|
|
||||||
|
# 6. Add to collection
|
||||||
|
api.add_collection_item(
|
||||||
|
collection_slug=collection.slug,
|
||||||
|
item_id=target_repo_id,
|
||||||
|
item_type="model",
|
||||||
|
exists_ok=True,
|
||||||
|
)
|
||||||
|
log.info(" Added to collection.")
|
||||||
|
|
||||||
|
log.info("\nAll done. Collection: %s", collection.url)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
607
src/lerobot/policies/vla_jepa/modeling_vla_jepa.py
Normal file
607
src/lerobot/policies/vla_jepa/modeling_vla_jepa.py
Normal file
@@ -0,0 +1,607 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections import deque
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
from PIL import Image
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||||
|
from lerobot.policies.utils import populate_queues
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||||
|
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _transformers_available:
|
||||||
|
from transformers import AutoModel, AutoVideoProcessor
|
||||||
|
else:
|
||||||
|
AutoModel = None
|
||||||
|
AutoVideoProcessor = None
|
||||||
|
|
||||||
|
from .action_head import VLAJEPAActionHead
|
||||||
|
from .configuration_vla_jepa import VLAJEPAConfig
|
||||||
|
from .qwen_interface import Qwen3VLInterface
|
||||||
|
from .world_model import ActionConditionedVideoPredictor
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class VLAJEPAModel(nn.Module):
|
||||||
|
"""
|
||||||
|
Native VLA-JEPA model following the original starVLA VLA_JEPA.py.
|
||||||
|
|
||||||
|
Components:
|
||||||
|
- Qwen3-VL: vision-language backbone for fused embeddings
|
||||||
|
- DiT-B: flow-matching action head for future action prediction
|
||||||
|
- V-JEPA: world model for video frame prediction
|
||||||
|
|
||||||
|
Input: List[dict] native format (same as original starVLA)
|
||||||
|
- "image": List[PIL.Image] (multi-view images)
|
||||||
|
- "video": np.ndarray [V, T, H, W, 3]
|
||||||
|
- "lang": str (task instruction)
|
||||||
|
- "action": np.ndarray [T, action_dim] (optional, training only)
|
||||||
|
- "state": np.ndarray [1, state_dim] (optional)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||||
|
super().__init__()
|
||||||
|
require_package("transformers", extra="vla_jepa")
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Vision-language backbone
|
||||||
|
self.qwen = Qwen3VLInterface(config)
|
||||||
|
|
||||||
|
# Tokenizer expansion for special action tokens
|
||||||
|
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = (
|
||||||
|
self.qwen.expand_tokenizer()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Action head (flow-matching DiT)
|
||||||
|
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
|
||||||
|
|
||||||
|
# JEPA world model components
|
||||||
|
if config.enable_world_model:
|
||||||
|
self.video_encoder = AutoModel.from_pretrained(
|
||||||
|
config.jepa_encoder_name,
|
||||||
|
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
|
||||||
|
)
|
||||||
|
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
|
||||||
|
num_views = max(1, len(config.image_features))
|
||||||
|
tubelet_size = self.video_encoder.config.tubelet_size
|
||||||
|
image_size = getattr(self.video_encoder.config, "image_size", None)
|
||||||
|
if image_size is None:
|
||||||
|
first_image_shape = next(iter(config.image_features.values())).shape
|
||||||
|
image_size = first_image_shape[-1]
|
||||||
|
self.video_predictor = ActionConditionedVideoPredictor(
|
||||||
|
num_frames=config.num_video_frames // tubelet_size,
|
||||||
|
img_size=(image_size, image_size),
|
||||||
|
patch_size=16,
|
||||||
|
tubelet_size=1,
|
||||||
|
embed_dim=self.video_encoder.config.hidden_size * num_views,
|
||||||
|
action_embed_dim=self.qwen.model.config.hidden_size,
|
||||||
|
predictor_embed_dim=self.video_encoder.config.hidden_size,
|
||||||
|
depth=config.predictor_depth,
|
||||||
|
num_heads=config.predictor_num_heads,
|
||||||
|
mlp_ratio=config.predictor_mlp_ratio,
|
||||||
|
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.video_encoder = None
|
||||||
|
self.video_processor = None
|
||||||
|
self.video_predictor = None
|
||||||
|
|
||||||
|
if config.freeze_qwen:
|
||||||
|
self.qwen.requires_grad_(False)
|
||||||
|
|
||||||
|
# Build prompt placeholders.
|
||||||
|
# Use the encoder's actual tubelet_size when available (world model enabled),
|
||||||
|
# otherwise fall back to config.
|
||||||
|
_tubelet_size = (
|
||||||
|
self.video_encoder.config.tubelet_size
|
||||||
|
if config.enable_world_model
|
||||||
|
else self.config.jepa_tubelet_size
|
||||||
|
)
|
||||||
|
num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1
|
||||||
|
self.replace_prompt = "".join(
|
||||||
|
token * self.config.num_action_tokens_per_timestep
|
||||||
|
for token in self.action_tokens[:num_action_prompt_steps]
|
||||||
|
)
|
||||||
|
self.embodied_replace_prompt = (
|
||||||
|
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
|
||||||
|
)
|
||||||
|
|
||||||
|
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""Return the last decoder hidden state before the final RMSNorm.
|
||||||
|
|
||||||
|
The model was trained with the output of the last transformer block BEFORE
|
||||||
|
the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
|
||||||
|
`output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
|
||||||
|
`@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
|
||||||
|
the correct pre-RMSNorm state, matching the training-time representation.
|
||||||
|
"""
|
||||||
|
captured: list[torch.Tensor] = []
|
||||||
|
|
||||||
|
def _hook(module, input, output):
|
||||||
|
h = output[0] if isinstance(output, tuple) else output
|
||||||
|
captured.append(h)
|
||||||
|
|
||||||
|
last_layer = self.qwen.model.model.language_model.layers[-1]
|
||||||
|
handle = last_layer.register_forward_hook(_hook)
|
||||||
|
try:
|
||||||
|
self.qwen.model(
|
||||||
|
**qwen_inputs,
|
||||||
|
output_hidden_states=False,
|
||||||
|
output_attentions=False,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
handle.remove()
|
||||||
|
|
||||||
|
return captured[0] # [B, seq_len, H]
|
||||||
|
|
||||||
|
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
|
||||||
|
|
||||||
|
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
|
||||||
|
"""
|
||||||
|
Native forward pass following original starVLA VLA_JEPA.forward.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
examples: List of per-sample dicts with keys:
|
||||||
|
"image" : List[PIL.Image] — multi-view images
|
||||||
|
"video" : np.ndarray [V, T, H, W, 3]
|
||||||
|
"lang" : str — task instruction
|
||||||
|
"action" : np.ndarray [T, action_dim] (optional)
|
||||||
|
"state" : np.ndarray [1, state_dim] (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
|
||||||
|
"""
|
||||||
|
# Unpack native format (same pattern as original VLA_JEPA.py)
|
||||||
|
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
|
||||||
|
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
|
||||||
|
instructions = [ex["lang"] for ex in examples] # List[str]
|
||||||
|
has_action = "action" in examples[0] and examples[0]["action"] is not None
|
||||||
|
actions = [ex["action"] for ex in examples] if has_action else None
|
||||||
|
has_state = "state" in examples[0] and examples[0]["state"] is not None
|
||||||
|
state = [ex["state"] for ex in examples] if has_state else None
|
||||||
|
action_is_pad = (
|
||||||
|
[ex["action_is_pad"] for ex in examples]
|
||||||
|
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
|
||||||
|
batch_videos = np.stack(batch_videos)
|
||||||
|
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
|
||||||
|
|
||||||
|
# ---- Step 1: QwenVL encode (same as original) ----
|
||||||
|
qwen_inputs = self.qwen.build_inputs(
|
||||||
|
images=batch_images,
|
||||||
|
instructions=instructions,
|
||||||
|
action_prompt=self.replace_prompt,
|
||||||
|
embodied_prompt=self.embodied_replace_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Locate embodied-action tokens (always needed for action head)
|
||||||
|
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||||
|
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||||
|
|
||||||
|
# Locate action tokens (only needed for world model predictor)
|
||||||
|
if self.config.enable_world_model:
|
||||||
|
action_mask = torch.isin(
|
||||||
|
qwen_inputs["input_ids"],
|
||||||
|
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
|
||||||
|
)
|
||||||
|
action_indices = action_mask.nonzero(as_tuple=True)
|
||||||
|
|
||||||
|
device_type = next(self.parameters()).device.type
|
||||||
|
|
||||||
|
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||||
|
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||||
|
b, _, h = last_hidden.shape
|
||||||
|
|
||||||
|
if self.config.enable_world_model:
|
||||||
|
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
|
||||||
|
|
||||||
|
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||||
|
|
||||||
|
# ---- Step 2+3: JEPA Encoder + Predictor ----
|
||||||
|
device_wm = last_hidden.device
|
||||||
|
if not self.config.enable_world_model:
|
||||||
|
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||||
|
else:
|
||||||
|
b, v, t_frames, c, h_img, w_img = batch_videos.shape
|
||||||
|
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
|
||||||
|
|
||||||
|
video_pixels = []
|
||||||
|
for i in range(b * v):
|
||||||
|
video_pixels.append(
|
||||||
|
self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[
|
||||||
|
"pixel_values_videos"
|
||||||
|
].to(self.video_encoder.device)
|
||||||
|
)
|
||||||
|
video_pixels = torch.cat(video_pixels, dim=0) # [B*V, T, C, H, W]
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
||||||
|
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
|
||||||
|
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
|
||||||
|
|
||||||
|
tubelet_size = self.video_encoder.config.tubelet_size
|
||||||
|
device_wm = video_embeddings.device
|
||||||
|
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
|
||||||
|
t_enc_total = self.config.num_video_frames // tubelet_size
|
||||||
|
|
||||||
|
if t_enc_total < 2:
|
||||||
|
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||||
|
else:
|
||||||
|
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
|
||||||
|
# input_states: positions 0..T-2, gt_states: positions 1..T-1
|
||||||
|
t_enc_ctx = t_enc_total - 1
|
||||||
|
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
|
||||||
|
|
||||||
|
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
|
||||||
|
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
||||||
|
|
||||||
|
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
|
||||||
|
if action_tokens.shape[1] < expected_actions:
|
||||||
|
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
||||||
|
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||||
|
|
||||||
|
predicted_states = self.video_predictor(
|
||||||
|
input_states.float(),
|
||||||
|
action_tokens[:, :expected_actions].float(),
|
||||||
|
)
|
||||||
|
|
||||||
|
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
||||||
|
|
||||||
|
if not has_action:
|
||||||
|
return {"wm_loss": wm_loss}
|
||||||
|
|
||||||
|
# ---- Step 4: Action Head ----
|
||||||
|
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
||||||
|
actions_tensor = torch.tensor(
|
||||||
|
np.array(actions), device=last_hidden.device, dtype=torch.float32
|
||||||
|
) # [B, T_full, action_dim]
|
||||||
|
action_horizon = self.config.chunk_size
|
||||||
|
actions_target = actions_tensor[:, -action_horizon:, :]
|
||||||
|
|
||||||
|
state_tensor = None
|
||||||
|
if state is not None:
|
||||||
|
state_tensor = torch.tensor(
|
||||||
|
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
|
||||||
|
) # [B, 1, state_dim]
|
||||||
|
|
||||||
|
repeated_diffusion_steps = self.config.repeated_diffusion_steps
|
||||||
|
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
|
||||||
|
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
|
||||||
|
if state_tensor is not None:
|
||||||
|
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
|
||||||
|
|
||||||
|
action_is_pad_rep = None
|
||||||
|
if action_is_pad is not None:
|
||||||
|
pad_tensor = torch.stack(
|
||||||
|
[
|
||||||
|
p.to(actions_target.device)
|
||||||
|
if isinstance(p, Tensor)
|
||||||
|
else torch.tensor(p, device=actions_target.device)
|
||||||
|
for p in action_is_pad
|
||||||
|
]
|
||||||
|
) # [B, T_full]
|
||||||
|
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
|
||||||
|
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
|
||||||
|
|
||||||
|
action_loss = self.action_model(
|
||||||
|
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
|
||||||
|
|
||||||
|
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def predict_action(
|
||||||
|
self,
|
||||||
|
batch_images: list[list[Image.Image]],
|
||||||
|
instructions: list[str],
|
||||||
|
state: np.ndarray | None = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Native action prediction following original VLA_JEPA.predict_action.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_images: List of samples; each is List[PIL.Image] (multi-view).
|
||||||
|
instructions: Task instructions, one per sample.
|
||||||
|
state: Optional [B, state_dim] numpy array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray [B, action_horizon, action_dim] — predicted actions.
|
||||||
|
"""
|
||||||
|
if self.config.resize_images_to is not None:
|
||||||
|
height, width = self.config.resize_images_to
|
||||||
|
resampling = getattr(Image, "Resampling", Image).BOX
|
||||||
|
batch_images = [
|
||||||
|
[image.resize((width, height), resample=resampling) for image in sample_images]
|
||||||
|
for sample_images in batch_images
|
||||||
|
]
|
||||||
|
|
||||||
|
qwen_inputs = self.qwen.build_inputs(
|
||||||
|
images=batch_images,
|
||||||
|
instructions=instructions,
|
||||||
|
action_prompt=self.replace_prompt,
|
||||||
|
embodied_prompt=self.embodied_replace_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||||
|
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||||
|
|
||||||
|
device_type = next(self.parameters()).device.type
|
||||||
|
|
||||||
|
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||||
|
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||||
|
b, _, h = last_hidden.shape
|
||||||
|
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||||
|
|
||||||
|
state_tensor = None
|
||||||
|
if state is not None:
|
||||||
|
state_tensor = torch.from_numpy(np.array(state)).to(
|
||||||
|
device=last_hidden.device, dtype=last_hidden.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_actions = self.action_model.predict_action(
|
||||||
|
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
|
||||||
|
) # [B, action_horizon, action_dim]
|
||||||
|
|
||||||
|
return pred_actions.detach().cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class VLAJEPAPolicy(PreTrainedPolicy):
|
||||||
|
"""
|
||||||
|
LeRobot adapter for VLA-JEPA.
|
||||||
|
|
||||||
|
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
|
||||||
|
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
|
||||||
|
back to LeRobot format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = VLAJEPAConfig
|
||||||
|
name = "vla_jepa"
|
||||||
|
|
||||||
|
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
|
||||||
|
super().__init__(config)
|
||||||
|
config.validate_features()
|
||||||
|
if dataset_meta := kwargs.get("dataset_meta"):
|
||||||
|
# cfg.input_features keeps the pretrained model's feature keys (needed for rename_map
|
||||||
|
# compatibility), so validate_features() may have read stale dims from a pretrained
|
||||||
|
# config. Override state_dim/action_dim from the actual dataset being used.
|
||||||
|
ds_features = dataset_meta.features
|
||||||
|
if OBS_STATE in ds_features:
|
||||||
|
config.state_dim = ds_features[OBS_STATE]["shape"][0]
|
||||||
|
if ACTION in ds_features:
|
||||||
|
config.action_dim = ds_features[ACTION]["shape"][0]
|
||||||
|
|
||||||
|
self.model = VLAJEPAModel(config)
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
|
||||||
|
|
||||||
|
# ---- Format Conversion: LeRobot → Native ----
|
||||||
|
|
||||||
|
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Convert LeRobot batch format to native VLA-JEPA examples format.
|
||||||
|
|
||||||
|
LeRobot format:
|
||||||
|
batch = {
|
||||||
|
"observation.images.<key>": Tensor [B, C, H, W] or [B, T, C, H, W],
|
||||||
|
"observation.state": Tensor [B, state_dim] or [B, T, state_dim],
|
||||||
|
"action": Tensor [B, chunk_size, action_dim], (training only)
|
||||||
|
"task": str | List[str], (optional instruction)
|
||||||
|
}
|
||||||
|
|
||||||
|
Native format (List[dict]):
|
||||||
|
{
|
||||||
|
"image": List[PIL.Image], # multi-view images per sample
|
||||||
|
"video": np.ndarray [V, T, H, W, 3],
|
||||||
|
"lang": str, # task instruction
|
||||||
|
"action": np.ndarray [T, action_dim], # optional
|
||||||
|
"state": np.ndarray [1, state_dim], # optional
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# Determine batch size from the first image feature
|
||||||
|
image_keys = list(self.config.image_features.keys())
|
||||||
|
if not image_keys:
|
||||||
|
raise ValueError("VLAJEPA requires at least one image feature.")
|
||||||
|
first_key = image_keys[0]
|
||||||
|
first_tensor = batch[first_key]
|
||||||
|
batch_size = first_tensor.shape[0]
|
||||||
|
|
||||||
|
# ---- Collect images per sample ----
|
||||||
|
# images_per_sample[b][v] = PIL.Image for view v
|
||||||
|
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
|
||||||
|
for key in image_keys:
|
||||||
|
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
|
||||||
|
if tensor.ndim == 5:
|
||||||
|
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
|
||||||
|
# index 0 is the current observation (delta=0)
|
||||||
|
tensor = tensor[:, 0]
|
||||||
|
for b in range(batch_size):
|
||||||
|
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
|
||||||
|
|
||||||
|
# ---- Collect videos per sample ----
|
||||||
|
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
|
||||||
|
# Check whether any image feature has a time dimension
|
||||||
|
video_source = None
|
||||||
|
for k in image_keys:
|
||||||
|
if k in batch:
|
||||||
|
video_source = batch[k] # Use first available for shape inspection
|
||||||
|
break
|
||||||
|
|
||||||
|
if video_source is None:
|
||||||
|
raise ValueError("No image data found in batch for video construction.")
|
||||||
|
|
||||||
|
videos_per_sample = []
|
||||||
|
for b in range(batch_size):
|
||||||
|
sample_views = []
|
||||||
|
for k in image_keys:
|
||||||
|
t = batch[k][b] # [C, H, W] or [T, C, H, W]
|
||||||
|
if t.ndim == 3:
|
||||||
|
t = t.unsqueeze(0) # [1, C, H, W]
|
||||||
|
# Convert to [T, H, W, 3] numpy
|
||||||
|
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
|
||||||
|
# Clamp to [0, 255]
|
||||||
|
if t_np.max() <= 1.0:
|
||||||
|
t_np = t_np * 255.0
|
||||||
|
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
|
||||||
|
sample_views.append(t_np)
|
||||||
|
# Stack views: [V, T, H, W, 3]
|
||||||
|
videos_per_sample.append(np.stack(sample_views, axis=0))
|
||||||
|
|
||||||
|
# ---- Collect instructions ----
|
||||||
|
tasks = batch.get("task")
|
||||||
|
if tasks is None:
|
||||||
|
instructions = ["Execute the robot action."] * batch_size
|
||||||
|
elif isinstance(tasks, str):
|
||||||
|
instructions = [tasks] * batch_size
|
||||||
|
else:
|
||||||
|
instructions = list(tasks)
|
||||||
|
|
||||||
|
# ---- Collect actions (training only) ----
|
||||||
|
actions_list = None
|
||||||
|
action_is_pad_list = None
|
||||||
|
actions_tensor = batch.get(ACTION)
|
||||||
|
if actions_tensor is not None:
|
||||||
|
if actions_tensor.ndim == 2:
|
||||||
|
actions_tensor = actions_tensor.unsqueeze(1)
|
||||||
|
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||||
|
action_is_pad_tensor = batch.get("action_is_pad")
|
||||||
|
if action_is_pad_tensor is not None:
|
||||||
|
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
|
||||||
|
|
||||||
|
# ---- Collect state ----
|
||||||
|
state_list = None
|
||||||
|
state_tensor = batch.get(OBS_STATE)
|
||||||
|
if state_tensor is not None:
|
||||||
|
if state_tensor.ndim > 2:
|
||||||
|
state_tensor = state_tensor[:, -1, :]
|
||||||
|
if state_tensor.ndim == 2:
|
||||||
|
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
|
||||||
|
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||||
|
|
||||||
|
# ---- Assemble native examples ----
|
||||||
|
examples = []
|
||||||
|
for b in range(batch_size):
|
||||||
|
example = {
|
||||||
|
"image": images_per_sample[b],
|
||||||
|
"video": videos_per_sample[b],
|
||||||
|
"lang": instructions[b],
|
||||||
|
}
|
||||||
|
if actions_list is not None:
|
||||||
|
example["action"] = actions_list[b]
|
||||||
|
if action_is_pad_list is not None:
|
||||||
|
example["action_is_pad"] = action_is_pad_list[b]
|
||||||
|
if state_list is not None:
|
||||||
|
example["state"] = state_list[b]
|
||||||
|
examples.append(example)
|
||||||
|
|
||||||
|
return examples
|
||||||
|
|
||||||
|
# ---- LeRobot Policy Interface ----
|
||||||
|
|
||||||
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||||
|
"""LeRobot train forward: convert → native forward → aggregate losses."""
|
||||||
|
examples = self._prepare_model_inputs(batch)
|
||||||
|
native_output = self.model.forward(examples)
|
||||||
|
|
||||||
|
ref = next(iter(native_output.values()))
|
||||||
|
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
|
||||||
|
total_loss = native_output.get("action_loss", zero) + native_output.get("wm_loss", zero)
|
||||||
|
logs = {k: v.detach().item() for k, v in native_output.items()}
|
||||||
|
logs["loss"] = total_loss.detach().item()
|
||||||
|
return total_loss, logs
|
||||||
|
|
||||||
|
def get_optim_params(self) -> dict:
|
||||||
|
return self.model.parameters()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
|
"""LeRobot inference: convert → native predict → return as Tensor."""
|
||||||
|
self.eval()
|
||||||
|
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||||
|
|
||||||
|
examples = self._prepare_model_inputs(batch)
|
||||||
|
batch_images = [ex["image"] for ex in examples]
|
||||||
|
instructions = [ex["lang"] for ex in examples]
|
||||||
|
|
||||||
|
state_np = None
|
||||||
|
if "state" in examples[0] and examples[0]["state"] is not None:
|
||||||
|
state_np = np.stack([ex["state"] for ex in examples])
|
||||||
|
|
||||||
|
actions_np = self.model.predict_action(batch_images, instructions, state_np)
|
||||||
|
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
|
"""LeRobot select_action with action queue caching."""
|
||||||
|
self.eval()
|
||||||
|
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||||
|
if len(self._queues[ACTION]) == 0:
|
||||||
|
actions = self.predict_action_chunk(batch)
|
||||||
|
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||||
|
return self._queues[ACTION].popleft()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls: type[T],
|
||||||
|
pretrained_name_or_path: str | Path,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
return super().from_pretrained(pretrained_name_or_path, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||||
|
"""
|
||||||
|
Custom loading to enable opt reinit of action head
|
||||||
|
when loading pretrained weights with mismatched action head shapes.
|
||||||
|
"""
|
||||||
|
if not model.config.reinit_action_head:
|
||||||
|
return super()._load_as_safetensor(model, model_file, map_location, strict)
|
||||||
|
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
state_dict = load_file(model_file, device=map_location)
|
||||||
|
current = model.state_dict()
|
||||||
|
|
||||||
|
mismatched: list[str] = []
|
||||||
|
filtered: dict = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if key in current and value.shape != current[key].shape:
|
||||||
|
mismatched.append(
|
||||||
|
f"{key}: checkpoint {tuple(value.shape)} vs model {tuple(current[key].shape)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
filtered[key] = value
|
||||||
|
|
||||||
|
if mismatched:
|
||||||
|
logging.warning(
|
||||||
|
f"reinit_action_head=True: skipping {len(mismatched)} tensor(s) with mismatched shapes "
|
||||||
|
f"(randomly re-initialised):\n " + "\n ".join(mismatched)
|
||||||
|
)
|
||||||
|
|
||||||
|
from lerobot.policies.utils import log_model_loading_keys
|
||||||
|
|
||||||
|
missing_keys, unexpected_keys = model.load_state_dict(filtered, strict=False)
|
||||||
|
log_model_loading_keys(missing_keys, unexpected_keys)
|
||||||
|
return model
|
||||||
139
src/lerobot/policies/vla_jepa/processor_vla_jepa.py
Normal file
139
src/lerobot/policies/vla_jepa/processor_vla_jepa.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||||
|
from lerobot.processor import (
|
||||||
|
AddBatchDimensionProcessorStep,
|
||||||
|
ComplementaryDataProcessorStep,
|
||||||
|
DeviceProcessorStep,
|
||||||
|
EnvTransition,
|
||||||
|
NormalizerProcessorStep,
|
||||||
|
PolicyAction,
|
||||||
|
PolicyProcessorPipeline,
|
||||||
|
ProcessorStep,
|
||||||
|
ProcessorStepRegistry,
|
||||||
|
RenameObservationsProcessorStep,
|
||||||
|
TransitionKey,
|
||||||
|
UnnormalizerProcessorStep,
|
||||||
|
)
|
||||||
|
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||||
|
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
|
|
||||||
|
|
||||||
|
@ProcessorStepRegistry.register(name="vla_jepa_clip_actions")
|
||||||
|
class ClipActionsProcessorStep(ProcessorStep):
|
||||||
|
"""Clips action tensor to [-1, 1] before unnormalization."""
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
action = transition.get(TransitionKey.ACTION)
|
||||||
|
if action is not None:
|
||||||
|
transition = dict(transition)
|
||||||
|
transition[TransitionKey.ACTION] = action.clamp(-1.0, 1.0)
|
||||||
|
return transition
|
||||||
|
|
||||||
|
def transform_features(self, features):
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper")
|
||||||
|
class PreSnapGripperProcessorStep(ProcessorStep):
|
||||||
|
"""Snaps gripper dim (index 6) to {0, 1} BEFORE unnormalization.
|
||||||
|
|
||||||
|
Mirrors the original starVLA LIBERO eval:
|
||||||
|
normalized[:, 6] = np.where(normalized[:, 6] < 0.5, 0, 1)
|
||||||
|
This ensures the unnormalizer receives an exact binary value, which is
|
||||||
|
required when the model was trained with gripper in identity (mask=False)
|
||||||
|
space where 0=open and 1=close.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
action = transition.get(TransitionKey.ACTION)
|
||||||
|
if action is not None and action.shape[-1] >= 7:
|
||||||
|
transition = dict(transition)
|
||||||
|
a = action.clone()
|
||||||
|
a[..., 6] = (a[..., 6] >= 0.5).float()
|
||||||
|
transition[TransitionKey.ACTION] = a
|
||||||
|
return transition
|
||||||
|
|
||||||
|
def transform_features(self, features):
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
|
||||||
|
class BinarizeGripperProcessorStep(ProcessorStep):
|
||||||
|
"""Binarizes gripper dim (index 6) after unnormalization.
|
||||||
|
|
||||||
|
Maps continuous value to {-1, 1}: > 0.5 → -1, <= 0.5 → 1 (matches starVLA convention).
|
||||||
|
Only applied when action has >= 7 dimensions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
action = transition.get(TransitionKey.ACTION)
|
||||||
|
if action is not None and action.shape[-1] >= 7:
|
||||||
|
transition = dict(transition)
|
||||||
|
a = action.clone()
|
||||||
|
a[..., 6] = 1.0 - 2.0 * (a[..., 6] > 0.5).float()
|
||||||
|
transition[TransitionKey.ACTION] = a
|
||||||
|
return transition
|
||||||
|
|
||||||
|
def transform_features(self, features):
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def make_vla_jepa_pre_post_processors(
|
||||||
|
config: VLAJEPAConfig,
|
||||||
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
|
) -> tuple[
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
|
]:
|
||||||
|
features = {**config.input_features, **config.output_features}
|
||||||
|
input_steps = [
|
||||||
|
RenameObservationsProcessorStep(rename_map={}),
|
||||||
|
AddBatchDimensionProcessorStep(),
|
||||||
|
DeviceProcessorStep(device=config.device),
|
||||||
|
NormalizerProcessorStep(
|
||||||
|
features=features,
|
||||||
|
norm_map=config.normalization_mapping,
|
||||||
|
stats=dataset_stats,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
output_steps: list[ProcessorStep] = []
|
||||||
|
if config.clip_normalized_actions:
|
||||||
|
output_steps.append(ClipActionsProcessorStep())
|
||||||
|
if config.pre_snap_gripper_action:
|
||||||
|
output_steps.append(PreSnapGripperProcessorStep())
|
||||||
|
output_steps.append(
|
||||||
|
UnnormalizerProcessorStep(
|
||||||
|
features=features,
|
||||||
|
norm_map=config.normalization_mapping,
|
||||||
|
stats=dataset_stats,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if config.binarize_gripper_action:
|
||||||
|
output_steps.append(BinarizeGripperProcessorStep())
|
||||||
|
output_steps.append(DeviceProcessorStep(device="cpu"))
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||||
|
steps=input_steps,
|
||||||
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
|
),
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||||
|
steps=output_steps,
|
||||||
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
|
to_transition=policy_action_to_transition,
|
||||||
|
to_output=transition_to_policy_action,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ProcessorStepRegistry.register(name="vla_jepa_new_line_processor")
|
||||||
|
class VLAJEPANewLineProcessor(ComplementaryDataProcessorStep):
|
||||||
|
def complementary_data(self, complementary_data):
|
||||||
|
return complementary_data
|
||||||
|
|
||||||
|
def transform_features(self, features):
|
||||||
|
return features
|
||||||
103
src/lerobot/policies/vla_jepa/qwen_interface.py
Normal file
103
src/lerobot/policies/vla_jepa/qwen_interface.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _transformers_available:
|
||||||
|
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
||||||
|
else:
|
||||||
|
AutoProcessor = None
|
||||||
|
Qwen3VLForConditionalGeneration = None
|
||||||
|
|
||||||
|
from .configuration_vla_jepa import VLAJEPAConfig
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3VLInterface(torch.nn.Module):
|
||||||
|
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
|
||||||
|
config.qwen_model_name,
|
||||||
|
torch_dtype=self._get_torch_dtype(config.torch_dtype),
|
||||||
|
)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(config.qwen_model_name)
|
||||||
|
self.processor.tokenizer.padding_side = config.tokenizer_padding_side
|
||||||
|
self.model.config.hidden_size = self.model.config.text_config.hidden_size
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
|
||||||
|
if dtype_name == "float32":
|
||||||
|
return torch.float32
|
||||||
|
if dtype_name == "float16":
|
||||||
|
return torch.float16
|
||||||
|
return torch.bfloat16
|
||||||
|
|
||||||
|
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
||||||
|
# starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4,
|
||||||
|
# independent of vj2 num_action_tokens_per_timestep. Keeping this count
|
||||||
|
# is required for Qwen embedding/lm_head checkpoint shapes to match.
|
||||||
|
max_action_tokens = self.config.chunk_size * 4
|
||||||
|
tokenizer = self.processor.tokenizer
|
||||||
|
action_tokens = []
|
||||||
|
action_token_ids = []
|
||||||
|
for idx in range(max_action_tokens):
|
||||||
|
token = self.config.special_action_token.format(idx)
|
||||||
|
action_tokens.append(token)
|
||||||
|
if token not in tokenizer.get_vocab():
|
||||||
|
tokenizer.add_tokens([token], special_tokens=True)
|
||||||
|
action_token_ids.append(tokenizer.convert_tokens_to_ids(token))
|
||||||
|
|
||||||
|
embodied_action_token = self.config.embodied_action_token
|
||||||
|
if embodied_action_token not in tokenizer.get_vocab():
|
||||||
|
tokenizer.add_tokens([embodied_action_token], special_tokens=True)
|
||||||
|
embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token)
|
||||||
|
|
||||||
|
if self.model.get_input_embeddings().weight.size(0) < len(tokenizer):
|
||||||
|
self.model.resize_token_embeddings(len(tokenizer))
|
||||||
|
return action_tokens, action_token_ids, embodied_action_token_id
|
||||||
|
|
||||||
|
def build_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence[Sequence[Image.Image]],
|
||||||
|
instructions: Sequence[str],
|
||||||
|
action_prompt: str,
|
||||||
|
embodied_prompt: str,
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
messages = []
|
||||||
|
for sample_images, instruction in zip(images, instructions, strict=True):
|
||||||
|
prompt = self.config.prompt_template.format(
|
||||||
|
instruction=instruction,
|
||||||
|
actions=action_prompt,
|
||||||
|
e_actions=embodied_prompt,
|
||||||
|
)
|
||||||
|
content = [{"type": "image", "image": img} for img in sample_images]
|
||||||
|
content.append({"type": "text", "text": prompt})
|
||||||
|
messages.append([{"role": "user", "content": content}])
|
||||||
|
|
||||||
|
batch_inputs = self.processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=True,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
return_dict=True,
|
||||||
|
processor_kwargs={"padding": True, "return_tensors": "pt"},
|
||||||
|
)
|
||||||
|
return batch_inputs.to(self.model.device)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
|
||||||
|
image = image_tensor.detach().cpu()
|
||||||
|
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||||
|
image = image.permute(1, 2, 0)
|
||||||
|
image = image.float()
|
||||||
|
if image.max() <= 1.0:
|
||||||
|
image = image * 255.0
|
||||||
|
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
|
||||||
|
if image.shape[-1] == 1:
|
||||||
|
image = np.repeat(image, 3, axis=-1)
|
||||||
|
return Image.fromarray(image)
|
||||||
404
src/lerobot/policies/vla_jepa/world_model.py
Normal file
404
src/lerobot/policies/vla_jepa/world_model.py
Normal file
@@ -0,0 +1,404 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
def build_action_block_causal_attention_mask(
|
||||||
|
num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1
|
||||||
|
) -> torch.Tensor:
|
||||||
|
tokens_per_frame = add_tokens + grid_height * grid_width
|
||||||
|
num_tokens = num_frames * tokens_per_frame
|
||||||
|
mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool)
|
||||||
|
mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool)
|
||||||
|
local_window_time = num_frames
|
||||||
|
|
||||||
|
for current_frame in range(num_frames):
|
||||||
|
first_context_frame = max(0, current_frame - local_window_time + 1)
|
||||||
|
for context_frame in range(first_context_frame, current_frame + 1):
|
||||||
|
row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame)
|
||||||
|
col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame)
|
||||||
|
mask[row, col] = mask_block
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
|
||||||
|
_, _, _, dim = x.size()
|
||||||
|
if dim % 2 != 0:
|
||||||
|
raise ValueError("Embedding dimension must be even for rotary position encoding.")
|
||||||
|
|
||||||
|
omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device)
|
||||||
|
omega /= dim / 2.0
|
||||||
|
omega = 1.0 / 10000**omega
|
||||||
|
freqs = torch.einsum("..., f -> ... f", pos, omega)
|
||||||
|
emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2)
|
||||||
|
emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2)
|
||||||
|
|
||||||
|
y = x.unflatten(-1, (-1, 2))
|
||||||
|
y1, y2 = y.unbind(dim=-1)
|
||||||
|
y = torch.stack((-y2, y1), dim=-1).flatten(-2)
|
||||||
|
return x * emb_cos + y * emb_sin
|
||||||
|
|
||||||
|
|
||||||
|
class DropPath(nn.Module):
|
||||||
|
def __init__(self, drop_prob: float = 0.0) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.drop_prob == 0.0 or not self.training:
|
||||||
|
return x
|
||||||
|
keep_prob = 1 - self.drop_prob
|
||||||
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
||||||
|
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||||
|
random_tensor.floor_()
|
||||||
|
return x.div(keep_prob) * random_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
hidden_features: int | None = None,
|
||||||
|
out_features: int | None = None,
|
||||||
|
act_layer: type[nn.Module] = nn.GELU,
|
||||||
|
drop: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||||
|
self.drop = nn.Dropout(drop)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ACRoPEAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int = 8,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
qk_scale: float | None = None,
|
||||||
|
attn_drop: float = 0.0,
|
||||||
|
proj_drop: float = 0.0,
|
||||||
|
use_sdpa: bool = True,
|
||||||
|
is_causal: bool = False,
|
||||||
|
grid_size: int = 16,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = qk_scale or self.head_dim**-0.5
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop_prob = proj_drop
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
self.use_sdpa = use_sdpa
|
||||||
|
self.d_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||||
|
self.h_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||||
|
self.w_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||||
|
self.grid_size = grid_size
|
||||||
|
self.is_causal = is_causal
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
|
return ids // int(height * width)
|
||||||
|
|
||||||
|
def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
|
frame_ids = self._get_frame_pos(ids, height, width)
|
||||||
|
ids = ids - int(height * width) * frame_ids
|
||||||
|
return ids // width
|
||||||
|
|
||||||
|
def separate_positions(
|
||||||
|
self, ids: torch.Tensor, height: int, width: int
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
frame_ids = self._get_frame_pos(ids, height, width)
|
||||||
|
height_ids = self._get_height_pos(ids, height, width)
|
||||||
|
width_ids = ids - int(height * width) * frame_ids - width * height_ids
|
||||||
|
return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
mask: torch.Tensor | None = None,
|
||||||
|
attn_mask: torch.Tensor | None = None,
|
||||||
|
num_frames: int | None = None,
|
||||||
|
grid_height: int | None = None,
|
||||||
|
grid_width: int | None = None,
|
||||||
|
action_tokens: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, num_tokens, channels = x.size()
|
||||||
|
if num_frames is None or grid_height is None or grid_width is None:
|
||||||
|
raise ValueError("num_frames, grid_height and grid_width are required.")
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
|
||||||
|
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
|
||||||
|
else:
|
||||||
|
mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device)
|
||||||
|
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
|
||||||
|
|
||||||
|
h_mask *= self.grid_size / grid_height
|
||||||
|
w_mask *= self.grid_size / grid_width
|
||||||
|
|
||||||
|
if action_tokens > 0:
|
||||||
|
x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels)
|
||||||
|
action_q, action_k, action_v = [], [], []
|
||||||
|
for idx in range(action_tokens):
|
||||||
|
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
|
||||||
|
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||||
|
qd = rotate_queries_or_keys(
|
||||||
|
q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
|
||||||
|
)
|
||||||
|
kd = rotate_queries_or_keys(
|
||||||
|
k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
|
||||||
|
)
|
||||||
|
qr = q[..., self.d_dim :]
|
||||||
|
kr = k[..., self.d_dim :]
|
||||||
|
action_q.append(
|
||||||
|
torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
|
||||||
|
)
|
||||||
|
action_k.append(
|
||||||
|
torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
|
||||||
|
)
|
||||||
|
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
|
||||||
|
|
||||||
|
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
|
||||||
|
action_k = torch.cat(action_k, dim=3).flatten(2, 3)
|
||||||
|
action_v = torch.cat(action_v, dim=3).flatten(2, 3)
|
||||||
|
x = x[:, :, action_tokens:, :].flatten(1, 2)
|
||||||
|
|
||||||
|
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask)
|
||||||
|
kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask)
|
||||||
|
offset += self.d_dim
|
||||||
|
qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask)
|
||||||
|
kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask)
|
||||||
|
offset += self.h_dim
|
||||||
|
qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask)
|
||||||
|
kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask)
|
||||||
|
offset += self.w_dim
|
||||||
|
|
||||||
|
if offset < self.head_dim:
|
||||||
|
q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1)
|
||||||
|
k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1)
|
||||||
|
else:
|
||||||
|
q = torch.cat([qd, qh, qw], dim=-1)
|
||||||
|
k = torch.cat([kd, kh, kw], dim=-1)
|
||||||
|
|
||||||
|
if action_tokens > 0:
|
||||||
|
|
||||||
|
def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor:
|
||||||
|
frame_tokens = frame_tokens.view(
|
||||||
|
batch_size, self.num_heads, num_frames, grid_height * grid_width, -1
|
||||||
|
)
|
||||||
|
action_token_values = action_token_values.view(
|
||||||
|
batch_size, self.num_heads, num_frames, action_tokens, -1
|
||||||
|
)
|
||||||
|
return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3)
|
||||||
|
|
||||||
|
q = merge(q, action_q)
|
||||||
|
k = merge(k, action_k)
|
||||||
|
v = merge(v, action_v)
|
||||||
|
|
||||||
|
if attn_mask is not None or self.use_sdpa:
|
||||||
|
x = F.scaled_dot_product_attention(
|
||||||
|
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
x = attn @ v
|
||||||
|
|
||||||
|
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
|
||||||
|
x = self.proj(x)
|
||||||
|
return self.proj_drop(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ACBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
qkv_bias: bool = True,
|
||||||
|
qk_scale: float | None = None,
|
||||||
|
drop: float = 0.0,
|
||||||
|
attn_drop: float = 0.0,
|
||||||
|
drop_path: float = 0.0,
|
||||||
|
norm_layer: type[nn.Module] = nn.LayerNorm,
|
||||||
|
use_sdpa: bool = True,
|
||||||
|
is_causal: bool = False,
|
||||||
|
grid_size: int = 16,
|
||||||
|
use_rope: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
if not use_rope:
|
||||||
|
raise ValueError("JEVLA1 world predictor uses AC RoPE attention.")
|
||||||
|
self.attn = ACRoPEAttention(
|
||||||
|
dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_scale=qk_scale,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
use_sdpa=use_sdpa,
|
||||||
|
is_causal=is_causal,
|
||||||
|
grid_size=grid_size,
|
||||||
|
proj_drop=drop,
|
||||||
|
)
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
self.mlp = MLP(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=int(dim * mlp_ratio),
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
drop=drop,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attn_mask: torch.Tensor | None = None,
|
||||||
|
num_frames: int | None = None,
|
||||||
|
grid_height: int | None = None,
|
||||||
|
grid_width: int | None = None,
|
||||||
|
action_tokens: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
y = self.norm1(x)
|
||||||
|
y = self.attn(
|
||||||
|
y,
|
||||||
|
mask=None,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
num_frames=num_frames,
|
||||||
|
grid_height=grid_height,
|
||||||
|
grid_width=grid_width,
|
||||||
|
action_tokens=action_tokens,
|
||||||
|
)
|
||||||
|
x = x + self.drop_path(y)
|
||||||
|
y = self.norm2(x)
|
||||||
|
return x + self.drop_path(self.mlp(y))
|
||||||
|
|
||||||
|
|
||||||
|
class ActionConditionedVideoPredictor(nn.Module):
|
||||||
|
"""JEVLA1-compatible action-conditioned V-JEPA predictor."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_frames: int,
|
||||||
|
img_size: tuple[int, int],
|
||||||
|
patch_size: int,
|
||||||
|
tubelet_size: int,
|
||||||
|
embed_dim: int,
|
||||||
|
action_embed_dim: int,
|
||||||
|
predictor_embed_dim: int,
|
||||||
|
depth: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float,
|
||||||
|
num_action_tokens_per_step: int,
|
||||||
|
use_extrinsics: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.is_frame_causal = True
|
||||||
|
self.use_extrinsics = use_extrinsics
|
||||||
|
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
|
||||||
|
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
||||||
|
self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
||||||
|
self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
|
||||||
|
|
||||||
|
self.img_height, self.img_width = img_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_frames = num_frames
|
||||||
|
self.tubelet_size = tubelet_size
|
||||||
|
self.grid_height = self.img_height // self.patch_size
|
||||||
|
self.grid_width = self.img_width // self.patch_size
|
||||||
|
|
||||||
|
self.predictor_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ACBlock(
|
||||||
|
dim=predictor_embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=True,
|
||||||
|
drop=0.0,
|
||||||
|
attn_drop=0.0,
|
||||||
|
drop_path=0.0,
|
||||||
|
norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6),
|
||||||
|
grid_size=self.grid_height,
|
||||||
|
use_rope=True,
|
||||||
|
)
|
||||||
|
for _ in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6)
|
||||||
|
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
|
||||||
|
self.num_action_tokens_per_step = num_action_tokens_per_step
|
||||||
|
|
||||||
|
@property
|
||||||
|
def norm(self) -> nn.LayerNorm:
|
||||||
|
return self.predictor_norm
|
||||||
|
|
||||||
|
@property
|
||||||
|
def proj(self) -> nn.Linear:
|
||||||
|
return self.predictor_proj
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
frame_tokens: torch.Tensor,
|
||||||
|
action_tokens: torch.Tensor,
|
||||||
|
extrinsics: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D].
|
||||||
|
x = self.predictor_embed(frame_tokens)
|
||||||
|
batch_size, num_context_tokens, hidden_dim = x.size()
|
||||||
|
num_frames = num_context_tokens // (self.grid_height * self.grid_width)
|
||||||
|
|
||||||
|
actions = self.action_encoder(action_tokens)
|
||||||
|
actions = actions.view(batch_size, num_frames, -1, hidden_dim)
|
||||||
|
cond_tokens = actions.shape[2]
|
||||||
|
|
||||||
|
x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim)
|
||||||
|
if self.use_extrinsics:
|
||||||
|
if extrinsics is None:
|
||||||
|
raise ValueError("extrinsics are required when use_extrinsics=True.")
|
||||||
|
cond_tokens += 1
|
||||||
|
extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2)
|
||||||
|
x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2)
|
||||||
|
else:
|
||||||
|
x = torch.cat([actions, x], dim=2).flatten(1, 2)
|
||||||
|
|
||||||
|
attn_mask = build_action_block_causal_attention_mask(
|
||||||
|
num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens
|
||||||
|
)
|
||||||
|
attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
|
||||||
|
|
||||||
|
for block in self.predictor_blocks:
|
||||||
|
x = block(
|
||||||
|
x,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
num_frames=num_frames,
|
||||||
|
grid_height=self.grid_height,
|
||||||
|
grid_width=self.grid_width,
|
||||||
|
action_tokens=cond_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim)
|
||||||
|
x = x[:, :, cond_tokens:, :].flatten(1, 2)
|
||||||
|
x = self.predictor_norm(x)
|
||||||
|
return self.predictor_proj(x)
|
||||||
268
tests/policies/vla_jepa/conftest.py
Normal file
268
tests/policies/vla_jepa/conftest.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""Shared fixtures and helpers for VLA-JEPA tests."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shared constants
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
BATCH_SIZE = 2
|
||||||
|
ACTION_DIM = 3
|
||||||
|
STATE_DIM = 4
|
||||||
|
IMAGE_SIZE = 8
|
||||||
|
ACTION_HORIZON = 4
|
||||||
|
N_ACTION_STEPS = 2
|
||||||
|
NUM_VIDEO_FRAMES = 3
|
||||||
|
QWEN_HIDDEN_SIZE = 16 # hidden size produced by _FakeQwenBackbone
|
||||||
|
|
||||||
|
EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||||
|
EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed_all(seed: int) -> None:
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def make_config(
|
||||||
|
action_dim: int = ACTION_DIM,
|
||||||
|
state_dim: int = STATE_DIM,
|
||||||
|
action_horizon: int = ACTION_HORIZON,
|
||||||
|
num_video_frames: int = NUM_VIDEO_FRAMES,
|
||||||
|
) -> VLAJEPAConfig:
|
||||||
|
config = VLAJEPAConfig(
|
||||||
|
input_features={
|
||||||
|
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
|
||||||
|
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
|
||||||
|
},
|
||||||
|
output_features={
|
||||||
|
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
|
||||||
|
},
|
||||||
|
device="cpu",
|
||||||
|
chunk_size=action_horizon,
|
||||||
|
n_action_steps=min(N_ACTION_STEPS, action_horizon),
|
||||||
|
action_dim=action_dim,
|
||||||
|
state_dim=state_dim,
|
||||||
|
num_video_frames=num_video_frames,
|
||||||
|
num_action_tokens_per_timestep=2,
|
||||||
|
num_embodied_action_tokens_per_instruction=3,
|
||||||
|
num_inference_timesteps=2,
|
||||||
|
action_hidden_size=QWEN_HIDDEN_SIZE,
|
||||||
|
action_model_type="DiT-test",
|
||||||
|
action_num_layers=1,
|
||||||
|
predictor_depth=1,
|
||||||
|
predictor_num_heads=2,
|
||||||
|
predictor_mlp_ratio=2.0,
|
||||||
|
jepa_tubelet_size=1,
|
||||||
|
)
|
||||||
|
config.validate_features()
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def make_train_batch(
|
||||||
|
batch_size: int = BATCH_SIZE,
|
||||||
|
action_dim: int = ACTION_DIM,
|
||||||
|
state_dim: int = STATE_DIM,
|
||||||
|
action_horizon: int = ACTION_HORIZON,
|
||||||
|
num_video_frames: int = NUM_VIDEO_FRAMES,
|
||||||
|
) -> dict[str, Tensor | list[str]]:
|
||||||
|
return {
|
||||||
|
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, num_video_frames, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||||
|
OBS_STATE: torch.randn(batch_size, 1, state_dim),
|
||||||
|
ACTION: torch.randn(batch_size, action_horizon, action_dim),
|
||||||
|
"task": ["pick up the cube"] * batch_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def make_inference_batch(
|
||||||
|
batch_size: int = BATCH_SIZE,
|
||||||
|
state_dim: int = STATE_DIM,
|
||||||
|
) -> dict[str, Tensor | list[str]]:
|
||||||
|
return {
|
||||||
|
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||||
|
OBS_STATE: torch.randn(batch_size, state_dim),
|
||||||
|
"task": ["pick up the cube"] * batch_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fake external models (replace Qwen3-VL and V-JEPA at test time)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeLanguageLayer(nn.Module):
|
||||||
|
"""Leaf module whose forward hook is captured by _qwen_last_decoder_hidden."""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._hidden_size = hidden_size
|
||||||
|
|
||||||
|
def forward(self, hidden: Tensor, **_: object) -> tuple[Tensor, ...]:
|
||||||
|
return (hidden,)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeLanguageModel(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._hidden_size = hidden_size
|
||||||
|
self.layers = nn.ModuleList([_FakeLanguageLayer(hidden_size)])
|
||||||
|
|
||||||
|
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
|
||||||
|
batch_size, seq_len = input_ids.shape
|
||||||
|
hidden = torch.zeros(batch_size, seq_len, self._hidden_size, device=input_ids.device)
|
||||||
|
self.layers[-1](hidden)
|
||||||
|
return SimpleNamespace()
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeQwenInnerModel(nn.Module):
|
||||||
|
"""Mimics the `.model.model` level that _qwen_last_decoder_hidden walks into."""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.language_model = _FakeLanguageModel(hidden_size)
|
||||||
|
|
||||||
|
def forward(self, input_ids: Tensor, **kwargs: object) -> SimpleNamespace:
|
||||||
|
return self.language_model(input_ids)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeQwenBackbone(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(1))
|
||||||
|
self.config = SimpleNamespace(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
text_config=SimpleNamespace(hidden_size=hidden_size),
|
||||||
|
)
|
||||||
|
self.model = _FakeQwenInnerModel(hidden_size)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return self.weight.device
|
||||||
|
|
||||||
|
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
|
||||||
|
batch_size, seq_len = input_ids.shape
|
||||||
|
hidden_size = self.config.hidden_size
|
||||||
|
values = torch.arange(
|
||||||
|
batch_size * seq_len * hidden_size,
|
||||||
|
device=input_ids.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
).view(batch_size, seq_len, hidden_size)
|
||||||
|
hidden = values / values.numel() + self.weight
|
||||||
|
return SimpleNamespace(hidden_states=[hidden])
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeQwenInterface(nn.Module):
|
||||||
|
def __init__(self, config: VLAJEPAConfig) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.model = _FakeQwenBackbone(hidden_size=QWEN_HIDDEN_SIZE)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
|
||||||
|
return torch.float32 if dtype_name == "float32" else torch.bfloat16
|
||||||
|
|
||||||
|
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
||||||
|
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
|
||||||
|
action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)]
|
||||||
|
action_token_ids = list(range(1000, 1000 + max_action_tokens))
|
||||||
|
return action_tokens, action_token_ids, 2000
|
||||||
|
|
||||||
|
def build_inputs(
|
||||||
|
self,
|
||||||
|
images: list[list[Image.Image]],
|
||||||
|
instructions: list[str],
|
||||||
|
action_prompt: str,
|
||||||
|
embodied_prompt: str,
|
||||||
|
) -> dict[str, Tensor]:
|
||||||
|
batch_size = len(images)
|
||||||
|
del images, instructions, action_prompt, embodied_prompt
|
||||||
|
action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep
|
||||||
|
token_ids = (
|
||||||
|
[10]
|
||||||
|
+ list(range(1000, 1000 + action_count))
|
||||||
|
+ [2000] * self.config.num_embodied_action_tokens_per_instruction
|
||||||
|
+ [11]
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"input_ids": torch.tensor(
|
||||||
|
[token_ids] * batch_size,
|
||||||
|
device=self.model.device,
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
|
||||||
|
image = image_tensor.detach().cpu()
|
||||||
|
if image.ndim == 3 and image.shape[0] in (1, 3):
|
||||||
|
image = image.permute(1, 2, 0)
|
||||||
|
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
|
||||||
|
return Image.fromarray(image)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeVideoEncoder(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(1))
|
||||||
|
# image_size must be >= patch_size (16) so the predictor grid is non-zero.
|
||||||
|
# Setting image_size=16 gives a 1x1 grid (1 patch per frame).
|
||||||
|
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size, image_size=16)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return self.weight.device
|
||||||
|
|
||||||
|
def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor:
|
||||||
|
batch_size, num_frames = pixel_values_videos.shape[:2]
|
||||||
|
hidden_size = self.config.hidden_size
|
||||||
|
frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False)
|
||||||
|
return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeVideoProcessor:
|
||||||
|
def __call__(self, videos: np.ndarray, return_tensors: str) -> dict[str, Tensor]:
|
||||||
|
assert return_tensors == "pt"
|
||||||
|
return {"pixel_values_videos": torch.as_tensor(videos).unsqueeze(0)}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
from lerobot.policies.vla_jepa import modeling_vla_jepa
|
||||||
|
|
||||||
|
monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
modeling_vla_jepa.AutoModel,
|
||||||
|
"from_pretrained",
|
||||||
|
lambda *args, **kwargs: _FakeVideoEncoder(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
modeling_vla_jepa.AutoVideoProcessor,
|
||||||
|
"from_pretrained",
|
||||||
|
lambda *args, **kwargs: _FakeVideoProcessor(),
|
||||||
|
)
|
||||||
157
tests/policies/vla_jepa/test_action_head.py
Normal file
157
tests/policies/vla_jepa/test_action_head.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pytest.importorskip("diffusers")
|
||||||
|
|
||||||
|
from conftest import (
|
||||||
|
ACTION_DIM,
|
||||||
|
ACTION_HORIZON,
|
||||||
|
BATCH_SIZE,
|
||||||
|
QWEN_HIDDEN_SIZE,
|
||||||
|
STATE_DIM,
|
||||||
|
make_config,
|
||||||
|
set_seed_all,
|
||||||
|
) # noqa: E402
|
||||||
|
|
||||||
|
from lerobot.policies.vla_jepa.action_head import ( # noqa: E402
|
||||||
|
VLAJEPAActionHead,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# VLAJEPAActionHead
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"action_dim,state_dim,action_horizon",
|
||||||
|
[
|
||||||
|
(3, 4, 4), # default test dims
|
||||||
|
(7, 0, 16), # no proprioceptive state, production-like action space
|
||||||
|
(6, 8, 8), # medium dims
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_action_head_sample_time_range(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||||
|
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||||
|
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||||
|
t = head.sample_time(batch_size=200, device=torch.device("cpu"), dtype=torch.float32)
|
||||||
|
assert t.shape == (200,)
|
||||||
|
assert torch.isfinite(t).all()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"action_dim,state_dim,action_horizon",
|
||||||
|
[
|
||||||
|
(3, 4, 4),
|
||||||
|
(7, 0, 16),
|
||||||
|
(6, 8, 8),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_action_head_build_inputs_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||||
|
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||||
|
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||||
|
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
|
||||||
|
actions = torch.randn(2, action_horizon, action_dim)
|
||||||
|
timesteps = torch.randint(0, 100, (2,))
|
||||||
|
|
||||||
|
state = torch.randn(2, state_dim) if state_dim > 0 else None
|
||||||
|
out_with = head._build_inputs(conditioning, actions, state, timesteps)
|
||||||
|
out_none = head._build_inputs(conditioning, actions, None, timesteps)
|
||||||
|
|
||||||
|
assert out_with.ndim == 3 and out_none.ndim == 3
|
||||||
|
if state_dim > 0:
|
||||||
|
assert out_with.shape[1] > out_none.shape[1]
|
||||||
|
assert torch.isfinite(out_with).all() and torch.isfinite(out_none).all()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"action_dim,state_dim,action_horizon",
|
||||||
|
[
|
||||||
|
(3, 4, 4),
|
||||||
|
(7, 0, 16),
|
||||||
|
(6, 8, 8),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_action_head_forward_loss_valid(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||||
|
set_seed_all(42)
|
||||||
|
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||||
|
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||||
|
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
|
||||||
|
actions = torch.randn(2, action_horizon, action_dim)
|
||||||
|
state = torch.randn(2, state_dim) if state_dim > 0 else None
|
||||||
|
loss = head.forward(conditioning, actions, state)
|
||||||
|
assert loss.shape == ()
|
||||||
|
assert torch.isfinite(loss) and loss > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_action_head_forward_gradient_flows() -> None:
|
||||||
|
set_seed_all(42)
|
||||||
|
config = make_config()
|
||||||
|
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||||
|
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
|
||||||
|
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||||
|
state = torch.randn(BATCH_SIZE, STATE_DIM)
|
||||||
|
loss = head.forward(conditioning, actions, state)
|
||||||
|
loss.backward()
|
||||||
|
assert any(p.grad is not None for p in head.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"action_dim,state_dim,action_horizon",
|
||||||
|
[
|
||||||
|
(3, 4, 4),
|
||||||
|
(7, 0, 16),
|
||||||
|
(6, 8, 8),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_action_head_predict_action_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
|
||||||
|
set_seed_all(42)
|
||||||
|
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||||
|
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||||
|
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
|
||||||
|
state = torch.randn(2, state_dim) if state_dim > 0 else None
|
||||||
|
pred = head.predict_action(conditioning, state)
|
||||||
|
assert tuple(pred.shape) == (2, action_horizon, action_dim)
|
||||||
|
assert torch.isfinite(pred).all()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# action_is_pad masking
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_action_head_loss_fully_padded_is_zero() -> None:
|
||||||
|
"""Loss is 0 when every timestep is padded (exercises the clamp_min guard)."""
|
||||||
|
set_seed_all(42)
|
||||||
|
config = make_config()
|
||||||
|
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||||
|
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
|
||||||
|
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||||
|
state = torch.randn(BATCH_SIZE, STATE_DIM)
|
||||||
|
|
||||||
|
action_is_pad = torch.ones(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
|
||||||
|
loss = head.forward(conditioning, actions, state, action_is_pad)
|
||||||
|
assert loss.item() == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_action_head_loss_none_matches_no_padding() -> None:
|
||||||
|
"""action_is_pad=None is equivalent to an all-False (no padding) mask."""
|
||||||
|
set_seed_all(42)
|
||||||
|
config = make_config()
|
||||||
|
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
|
||||||
|
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
|
||||||
|
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
|
||||||
|
state = torch.randn(BATCH_SIZE, STATE_DIM)
|
||||||
|
|
||||||
|
set_seed_all(0)
|
||||||
|
loss_none = head.forward(conditioning, actions, state, action_is_pad=None)
|
||||||
|
|
||||||
|
set_seed_all(0)
|
||||||
|
no_pad = torch.zeros(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
|
||||||
|
loss_zeros = head.forward(conditioning, actions, state, action_is_pad=no_pad)
|
||||||
|
|
||||||
|
assert torch.isclose(loss_none, loss_zeros)
|
||||||
57
tests/policies/vla_jepa/test_configuration.py
Normal file
57
tests/policies/vla_jepa/test_configuration.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from conftest import ACTION_DIM, ACTION_HORIZON, IMAGE_SIZE, NUM_VIDEO_FRAMES, STATE_DIM, make_config
|
||||||
|
|
||||||
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||||
|
|
||||||
|
|
||||||
|
def test_delta_indices() -> None:
|
||||||
|
config = make_config()
|
||||||
|
assert config.observation_delta_indices == list(range(NUM_VIDEO_FRAMES))
|
||||||
|
assert config.action_delta_indices == list(range(ACTION_HORIZON))
|
||||||
|
|
||||||
|
|
||||||
|
def test_n_action_steps_exceeds_chunk_size_raises() -> None:
|
||||||
|
with pytest.raises(ValueError, match="n_action_steps"):
|
||||||
|
VLAJEPAConfig(chunk_size=4, n_action_steps=8)
|
||||||
|
|
||||||
|
|
||||||
|
def test_too_few_video_frames_raises() -> None:
|
||||||
|
with pytest.raises(ValueError, match="video_horizon"):
|
||||||
|
VLAJEPAConfig(
|
||||||
|
chunk_size=16,
|
||||||
|
n_action_steps=16,
|
||||||
|
num_video_frames=2,
|
||||||
|
jepa_tubelet_size=2, # needs >= 4 frames (2 for current, 2 for future) to have a window of size > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_features_no_image_raises() -> None:
|
||||||
|
config = VLAJEPAConfig(
|
||||||
|
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,))},
|
||||||
|
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="at least one visual input feature"):
|
||||||
|
config.validate_features()
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_features_no_action_raises() -> None:
|
||||||
|
config = VLAJEPAConfig(
|
||||||
|
input_features={
|
||||||
|
f"{OBS_IMAGES}.cam": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
|
||||||
|
},
|
||||||
|
output_features={},
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="action output feature"):
|
||||||
|
config.validate_features()
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_features_sets_action_dim_from_feature() -> None:
|
||||||
|
config = make_config(action_dim=6, state_dim=10)
|
||||||
|
assert config.action_dim == 6
|
||||||
|
assert config.state_dim == 10
|
||||||
473
tests/policies/vla_jepa/test_vla_jepa.py
Normal file
473
tests/policies/vla_jepa/test_vla_jepa.py
Normal file
@@ -0,0 +1,473 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
pytest.importorskip("transformers")
|
||||||
|
pytest.importorskip("diffusers")
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.filterwarnings(
|
||||||
|
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
|
||||||
|
)
|
||||||
|
|
||||||
|
from conftest import ( # noqa: E402
|
||||||
|
ACTION_DIM,
|
||||||
|
ACTION_HORIZON,
|
||||||
|
BATCH_SIZE,
|
||||||
|
EXPECTED_ACTION_CHUNK_SHAPE,
|
||||||
|
EXPECTED_SELECT_ACTION_SHAPE,
|
||||||
|
N_ACTION_STEPS,
|
||||||
|
STATE_DIM,
|
||||||
|
make_config,
|
||||||
|
make_inference_batch,
|
||||||
|
make_train_batch,
|
||||||
|
set_seed_all,
|
||||||
|
)
|
||||||
|
|
||||||
|
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy # noqa: E402
|
||||||
|
from lerobot.utils.constants import ACTION # noqa: E402
|
||||||
|
|
||||||
|
PRETRAINED_REPO_ID = "ginwind/VLA-JEPA"
|
||||||
|
PRETRAINED_SUBFOLDER = "LIBERO"
|
||||||
|
|
||||||
|
# extended hub tests load the full converted safetensors checkpoints (~5 GB) and are
|
||||||
|
# skipped by default. Set VLA_JEPA_EXTENDED=1 to opt in.
|
||||||
|
_VLA_JEPA_EXTENDED = os.environ.get("VLA_JEPA_EXTENDED", "0") != "0"
|
||||||
|
extended_test = pytest.mark.skipif(not _VLA_JEPA_EXTENDED, reason="Set VLA_JEPA_EXTENDED=1 to run hub tests")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Core training / inference tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_training_forward_pass(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
set_seed_all(42)
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
policy.train()
|
||||||
|
|
||||||
|
batch = make_train_batch()
|
||||||
|
batch_before = deepcopy(batch)
|
||||||
|
|
||||||
|
loss, logs = policy.forward(batch)
|
||||||
|
|
||||||
|
assert loss.shape == ()
|
||||||
|
assert torch.isfinite(loss)
|
||||||
|
assert set(logs) == {"action_loss", "wm_loss", "loss"}
|
||||||
|
assert logs["action_loss"] > 0
|
||||||
|
assert logs["wm_loss"] >= 0
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
assert any(p.grad is not None for p in policy.model.action_model.parameters() if p.requires_grad)
|
||||||
|
# Batch must not be mutated.
|
||||||
|
assert set(batch) == set(batch_before)
|
||||||
|
for key, value in batch.items():
|
||||||
|
if isinstance(value, Tensor):
|
||||||
|
assert torch.equal(value, batch_before[key])
|
||||||
|
else:
|
||||||
|
assert value == batch_before[key]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 2, 4])
|
||||||
|
def test_training_forward_various_batch_sizes(patch_vla_jepa_external_models: None, batch_size: int) -> None:
|
||||||
|
set_seed_all(42)
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
policy.train()
|
||||||
|
loss, logs = policy.forward(make_train_batch(batch_size=batch_size))
|
||||||
|
assert torch.isfinite(loss) and loss > 0
|
||||||
|
assert set(logs) == {"action_loss", "wm_loss", "loss"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"action_dim,state_dim,action_horizon",
|
||||||
|
[
|
||||||
|
(3, 4, 4),
|
||||||
|
(7, 0, 16),
|
||||||
|
(6, 8, 8),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_training_forward_various_dims(
|
||||||
|
patch_vla_jepa_external_models: None,
|
||||||
|
action_dim: int,
|
||||||
|
state_dim: int,
|
||||||
|
action_horizon: int,
|
||||||
|
) -> None:
|
||||||
|
set_seed_all(42)
|
||||||
|
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||||
|
policy = VLAJEPAPolicy(config)
|
||||||
|
policy.train()
|
||||||
|
batch = make_train_batch(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
|
||||||
|
loss, _ = policy.forward(batch)
|
||||||
|
assert torch.isfinite(loss) and loss > 0
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_action_generation_shape(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
set_seed_all(42)
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
policy.eval()
|
||||||
|
batch = make_inference_batch()
|
||||||
|
|
||||||
|
chunk = policy.predict_action_chunk(batch)
|
||||||
|
assert tuple(chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE
|
||||||
|
assert chunk.device.type == "cpu"
|
||||||
|
assert torch.isfinite(chunk).all()
|
||||||
|
|
||||||
|
a1 = policy.select_action(batch)
|
||||||
|
a2 = policy.select_action(batch)
|
||||||
|
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||||
|
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||||
|
assert torch.isfinite(a1).all() and torch.isfinite(a2).all()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@pytest.mark.parametrize("action_dim,state_dim", [(3, 4), (7, 0), (6, 8)])
|
||||||
|
def test_action_generation_various_dims(
|
||||||
|
patch_vla_jepa_external_models: None, action_dim: int, state_dim: int
|
||||||
|
) -> None:
|
||||||
|
set_seed_all(42)
|
||||||
|
config = make_config(action_dim=action_dim, state_dim=state_dim)
|
||||||
|
policy = VLAJEPAPolicy(config)
|
||||||
|
policy.eval()
|
||||||
|
batch = make_inference_batch(state_dim=state_dim)
|
||||||
|
chunk = policy.predict_action_chunk(batch)
|
||||||
|
assert chunk.shape[-1] == action_dim
|
||||||
|
assert torch.isfinite(chunk).all()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_inference_reproducibility(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
set_seed_all(42)
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
policy.eval()
|
||||||
|
batch = make_inference_batch()
|
||||||
|
|
||||||
|
set_seed_all(123)
|
||||||
|
actions_1 = policy.predict_action_chunk(batch)
|
||||||
|
set_seed_all(123)
|
||||||
|
actions_2 = policy.predict_action_chunk(batch)
|
||||||
|
|
||||||
|
assert tuple(actions_1.shape) == EXPECTED_ACTION_CHUNK_SHAPE
|
||||||
|
assert torch.allclose(actions_1, actions_2, atol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_predict_action_chunk_always_finite(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
policy.eval()
|
||||||
|
for seed in [0, 42, 123]:
|
||||||
|
set_seed_all(seed)
|
||||||
|
chunk = policy.predict_action_chunk(make_inference_batch())
|
||||||
|
assert torch.isfinite(chunk).all(), f"non-finite actions with seed={seed}"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Action queue behaviour
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_select_action_queue_drains_before_refill(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
set_seed_all(42)
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
policy.eval()
|
||||||
|
batch = make_inference_batch()
|
||||||
|
|
||||||
|
# First call fills the queue (n_action_steps items) and pops one.
|
||||||
|
a1 = policy.select_action(batch)
|
||||||
|
assert len(policy._queues[ACTION]) == N_ACTION_STEPS - 1
|
||||||
|
|
||||||
|
# Second call pops from the existing queue without calling predict_action_chunk.
|
||||||
|
a2 = policy.select_action(batch)
|
||||||
|
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||||
|
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
set_seed_all(42)
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
policy.eval()
|
||||||
|
policy.select_action(make_inference_batch())
|
||||||
|
assert len(policy._queues[ACTION]) > 0
|
||||||
|
|
||||||
|
policy.reset()
|
||||||
|
assert len(policy._queues[ACTION]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Format conversion
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_model_inputs_training_format(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
examples = policy._prepare_model_inputs(make_train_batch())
|
||||||
|
|
||||||
|
assert len(examples) == BATCH_SIZE
|
||||||
|
for ex in examples:
|
||||||
|
assert set(ex) >= {"image", "video", "lang", "action", "state"}
|
||||||
|
assert len(ex["image"]) == 1 and isinstance(ex["image"][0], Image.Image)
|
||||||
|
assert ex["video"].ndim == 5 and ex["video"].dtype == np.uint8 # [V,T,H,W,C]
|
||||||
|
assert ex["action"].shape == (ACTION_HORIZON, ACTION_DIM)
|
||||||
|
assert ex["state"].shape == (1, STATE_DIM)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_model_inputs_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
for ex in policy._prepare_model_inputs(make_inference_batch()):
|
||||||
|
assert "action" not in ex
|
||||||
|
assert "image" in ex and "video" in ex and "lang" in ex
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_model_inputs_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
batch = make_inference_batch()
|
||||||
|
del batch["task"]
|
||||||
|
examples = policy._prepare_model_inputs(batch)
|
||||||
|
assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_model_inputs_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
batch = make_inference_batch()
|
||||||
|
batch["task"] = "open the drawer"
|
||||||
|
assert all(ex["lang"] == "open the drawer" for ex in policy._prepare_model_inputs(batch))
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
from lerobot.utils.constants import OBS_STATE
|
||||||
|
|
||||||
|
policy = VLAJEPAPolicy(make_config())
|
||||||
|
batch = make_inference_batch()
|
||||||
|
del batch[OBS_STATE]
|
||||||
|
assert all("state" not in ex for ex in policy._prepare_model_inputs(batch))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pretrained checkpoint
|
||||||
|
# Hub tests (opt-in: VLA_JEPA_EXTENDED=1)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_hub_train_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
|
||||||
|
"""Build a training batch whose keys/shapes match a hub-loaded policy config."""
|
||||||
|
cfg = policy.config
|
||||||
|
batch: dict = {"task": ["pick up the cube"] * batch_size}
|
||||||
|
for key, feat in cfg.image_features.items():
|
||||||
|
h, w = feat.shape[-2], feat.shape[-1]
|
||||||
|
batch[key] = torch.rand(batch_size, cfg.num_video_frames, 3, h, w)
|
||||||
|
if cfg.robot_state_feature is not None:
|
||||||
|
batch["observation.state"] = torch.randn(batch_size, 1, cfg.robot_state_feature.shape[0])
|
||||||
|
batch[ACTION] = torch.randn(batch_size, cfg.chunk_size, cfg.action_dim)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def _make_hub_inference_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
|
||||||
|
"""Build an inference batch whose keys/shapes match a hub-loaded policy config."""
|
||||||
|
cfg = policy.config
|
||||||
|
batch: dict = {"task": ["pick up the cube"] * batch_size}
|
||||||
|
for key, feat in cfg.image_features.items():
|
||||||
|
h, w = feat.shape[-2], feat.shape[-1]
|
||||||
|
batch[key] = torch.rand(batch_size, 3, h, w)
|
||||||
|
if cfg.robot_state_feature is not None:
|
||||||
|
batch["observation.state"] = torch.randn(batch_size, cfg.robot_state_feature.shape[0])
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
_CP_ROOT = "lerobot" # TODO: upload converted checkpoints
|
||||||
|
|
||||||
|
# Each tuple: (repo_id, enable_world_model)
|
||||||
|
_HUB_VARIANTS = [
|
||||||
|
(f"{_CP_ROOT}/VLA-JEPA-LIBERO", True),
|
||||||
|
(f"{_CP_ROOT}/VLA-JEPA-Pretrain", True),
|
||||||
|
(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv", False),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@extended_test
|
||||||
|
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
|
||||||
|
def test_hub_checkpoint_loads(repo_id: str, enable_world_model: bool) -> None:
|
||||||
|
"""Policy loads from the converted safetensors checkpoint on the Hub."""
|
||||||
|
policy = VLAJEPAPolicy.from_pretrained(repo_id)
|
||||||
|
assert policy.config.enable_world_model == enable_world_model
|
||||||
|
assert sum(p.numel() for p in policy.parameters()) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@extended_test
|
||||||
|
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
|
||||||
|
def test_hub_checkpoint_forward_pass(repo_id: str, enable_world_model: bool) -> None:
|
||||||
|
"""Policy loaded from hub produces finite losses with a correctly-shaped batch."""
|
||||||
|
policy = VLAJEPAPolicy.from_pretrained(repo_id)
|
||||||
|
policy.train()
|
||||||
|
|
||||||
|
batch = _make_hub_train_batch(policy)
|
||||||
|
loss, logs = policy.forward(batch)
|
||||||
|
assert torch.isfinite(loss)
|
||||||
|
assert "action_loss" in logs
|
||||||
|
if enable_world_model:
|
||||||
|
assert "wm_loss" in logs
|
||||||
|
|
||||||
|
|
||||||
|
@extended_test
|
||||||
|
def test_hub_freeze_qwen_disables_world_model() -> None:
|
||||||
|
"""freeze_qwen=True (via cli_overrides) freezes qwen and disables the world model."""
|
||||||
|
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO", cli_overrides=["freeze_qwen=true"])
|
||||||
|
assert not policy.config.enable_world_model
|
||||||
|
assert policy.model.video_predictor is None
|
||||||
|
qwen_params = list(policy.model.qwen.parameters())
|
||||||
|
assert all(not p.requires_grad for p in qwen_params)
|
||||||
|
assert any(p.requires_grad for p in policy.model.action_model.parameters())
|
||||||
|
|
||||||
|
|
||||||
|
@extended_test
|
||||||
|
def test_hub_disable_world_model_loads_simpler_env() -> None:
|
||||||
|
"""SimplerEnv checkpoint (world model disabled) loads cleanly and runs inference."""
|
||||||
|
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv")
|
||||||
|
assert not policy.config.enable_world_model
|
||||||
|
assert policy.model.video_predictor is None
|
||||||
|
assert policy.model.video_encoder is None
|
||||||
|
|
||||||
|
|
||||||
|
@extended_test
|
||||||
|
def test_hub_libero_inference_shape() -> None:
|
||||||
|
"""select_action returns the expected shape using the LIBERO hub checkpoint."""
|
||||||
|
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO")
|
||||||
|
policy.eval()
|
||||||
|
batch = _make_hub_inference_batch(policy)
|
||||||
|
action = policy.select_action(batch)
|
||||||
|
assert action.shape[-1] == policy.config.action_dim
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Postprocessor unnormalization tests
|
||||||
|
#
|
||||||
|
# These tests verify that the postprocessor pipeline (clip → unnorm → binarize)
|
||||||
|
# correctly applies MIN_MAX unnormalization after predict_action_chunk.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_dataset_stats(action_dim: int = ACTION_DIM) -> dict:
|
||||||
|
"""Returns sample dataset_stats with a simple [i, i+10] range per action dim."""
|
||||||
|
from lerobot.utils.constants import ACTION
|
||||||
|
|
||||||
|
return {
|
||||||
|
ACTION: {
|
||||||
|
"min": torch.tensor([float(i) for i in range(action_dim)], dtype=torch.float32),
|
||||||
|
"max": torch.tensor([float(i) + 10.0 for i in range(action_dim)], dtype=torch.float32),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_postprocessor_unnormalizes_actions(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
"""UnnormalizerProcessorStep with MIN_MAX produces the correct inverse of MIN_MAX normalization."""
|
||||||
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
from lerobot.processor import UnnormalizerProcessorStep
|
||||||
|
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||||
|
from lerobot.utils.constants import ACTION
|
||||||
|
|
||||||
|
dataset_stats = _make_dataset_stats()
|
||||||
|
|
||||||
|
rng = np.random.default_rng(7)
|
||||||
|
actions_np = rng.uniform(-1.0, 1.0, (2, ACTION_HORIZON, ACTION_DIM)).astype(np.float32)
|
||||||
|
|
||||||
|
a_min = dataset_stats[ACTION]["min"].numpy()
|
||||||
|
a_max = dataset_stats[ACTION]["max"].numpy()
|
||||||
|
expected = (actions_np + 1.0) / 2.0 * (a_max - a_min) + a_min
|
||||||
|
|
||||||
|
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
|
||||||
|
unnorm_step = UnnormalizerProcessorStep(
|
||||||
|
features=features,
|
||||||
|
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
|
||||||
|
stats=dataset_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
actions_tensor = torch.from_numpy(actions_np)
|
||||||
|
transition = policy_action_to_transition(actions_tensor)
|
||||||
|
result = transition_to_policy_action(unnorm_step(transition)).numpy()
|
||||||
|
|
||||||
|
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_postprocessor_clip_clamps_before_unnorm(patch_vla_jepa_external_models: None) -> None:
|
||||||
|
"""ClipActionsProcessorStep clamps to [-1, 1] before unnormalization."""
|
||||||
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
from lerobot.policies.vla_jepa.processor_vla_jepa import ClipActionsProcessorStep
|
||||||
|
from lerobot.processor import UnnormalizerProcessorStep
|
||||||
|
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||||
|
from lerobot.utils.constants import ACTION
|
||||||
|
|
||||||
|
dataset_stats = _make_dataset_stats()
|
||||||
|
a_min = dataset_stats[ACTION]["min"].numpy()
|
||||||
|
a_max = dataset_stats[ACTION]["max"].numpy()
|
||||||
|
|
||||||
|
# Deliberately out-of-range inputs
|
||||||
|
actions_np = np.array([[[2.0] * ACTION_DIM, [-3.0] * ACTION_DIM]], dtype=np.float32)
|
||||||
|
clipped = np.clip(actions_np, -1.0, 1.0)
|
||||||
|
expected = (clipped + 1.0) / 2.0 * (a_max - a_min) + a_min
|
||||||
|
|
||||||
|
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
|
||||||
|
clip_step = ClipActionsProcessorStep()
|
||||||
|
unnorm_step = UnnormalizerProcessorStep(
|
||||||
|
features=features,
|
||||||
|
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
|
||||||
|
stats=dataset_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
transition = policy_action_to_transition(torch.from_numpy(actions_np))
|
||||||
|
transition = clip_step(transition)
|
||||||
|
result = transition_to_policy_action(unnorm_step(transition)).numpy()
|
||||||
|
|
||||||
|
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_postprocessor_applied_after_predict_action_chunk(
|
||||||
|
patch_vla_jepa_external_models: None, monkeypatch: pytest.MonkeyPatch
|
||||||
|
) -> None:
|
||||||
|
"""predict_action_chunk returns raw actions; the postprocessor applies unnormalization.
|
||||||
|
|
||||||
|
Verifies the split: predict_action_chunk returns normalized actions, and calling the
|
||||||
|
postprocessor on them produces the correctly unnormalized result.
|
||||||
|
"""
|
||||||
|
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||||
|
|
||||||
|
raw_actions = np.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=np.float32)
|
||||||
|
|
||||||
|
cfg = make_config()
|
||||||
|
cfg.clip_normalized_actions = False
|
||||||
|
cfg.binarize_gripper_action = False
|
||||||
|
policy = VLAJEPAPolicy(cfg)
|
||||||
|
policy.eval()
|
||||||
|
monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.copy())
|
||||||
|
|
||||||
|
dataset_stats = _make_dataset_stats()
|
||||||
|
_, postprocessor = make_vla_jepa_pre_post_processors(cfg, dataset_stats)
|
||||||
|
|
||||||
|
batch = make_inference_batch()
|
||||||
|
chunk = policy.predict_action_chunk(batch)
|
||||||
|
|
||||||
|
# predict_action_chunk returns raw (normalized) actions
|
||||||
|
assert torch.allclose(chunk, torch.zeros_like(chunk), atol=1e-6), (
|
||||||
|
"predict_action_chunk should return raw actions without unnormalization applied."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Postprocessor applies unnormalization: 0 → (0+1)/2 * (max-min) + min = 5 + i
|
||||||
|
unnormed = postprocessor(chunk)
|
||||||
|
from lerobot.utils.constants import ACTION
|
||||||
|
|
||||||
|
a_min = dataset_stats[ACTION]["min"].numpy()
|
||||||
|
a_max = dataset_stats[ACTION]["max"].numpy()
|
||||||
|
expected_first = 0.5 * (0.0 + 1.0) * (a_max[0] - a_min[0]) + a_min[0]
|
||||||
|
assert unnormed[0, 0, 0].item() == pytest.approx(expected_first, abs=1e-5)
|
||||||
60
tests/policies/vla_jepa/test_world_model.py
Normal file
60
tests/policies/vla_jepa/test_world_model.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.policies.vla_jepa.world_model import (
|
||||||
|
ActionConditionedVideoPredictor,
|
||||||
|
)
|
||||||
|
|
||||||
|
_ACTION_EMBED_DIM = 8
|
||||||
|
|
||||||
|
|
||||||
|
def _make_predictor(
|
||||||
|
embed_dim: int = 8,
|
||||||
|
action_embed_dim: int = _ACTION_EMBED_DIM,
|
||||||
|
predictor_embed_dim: int = 24,
|
||||||
|
num_action_tokens: int = 2,
|
||||||
|
tokens_per_frame: int = 1,
|
||||||
|
) -> ActionConditionedVideoPredictor:
|
||||||
|
return ActionConditionedVideoPredictor(
|
||||||
|
num_frames=1,
|
||||||
|
img_size=(1, tokens_per_frame),
|
||||||
|
patch_size=1,
|
||||||
|
tubelet_size=1,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
action_embed_dim=action_embed_dim,
|
||||||
|
predictor_embed_dim=predictor_embed_dim,
|
||||||
|
depth=1,
|
||||||
|
num_heads=2,
|
||||||
|
mlp_ratio=2.0,
|
||||||
|
num_action_tokens_per_step=num_action_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"batch,num_steps,tokens_per_frame,embed_dim",
|
||||||
|
[
|
||||||
|
(1, 2, 1, 8),
|
||||||
|
(2, 3, 4, 8),
|
||||||
|
(4, 5, 2, 16),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_predictor_output_shape(batch: int, num_steps: int, tokens_per_frame: int, embed_dim: int) -> None:
|
||||||
|
predictor = _make_predictor(
|
||||||
|
embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM, tokens_per_frame=tokens_per_frame
|
||||||
|
)
|
||||||
|
frame_tokens = torch.randn(batch, num_steps * tokens_per_frame, embed_dim)
|
||||||
|
action_tokens = torch.randn(batch, num_steps * 2, _ACTION_EMBED_DIM)
|
||||||
|
out = predictor(frame_tokens, action_tokens)
|
||||||
|
assert tuple(out.shape) == (batch, num_steps * tokens_per_frame, embed_dim)
|
||||||
|
assert torch.isfinite(out).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_predictor_step_mismatch_raises() -> None:
|
||||||
|
predictor = _make_predictor(tokens_per_frame=4)
|
||||||
|
frame_tokens = torch.randn(2, 3 * 4, 8) # 3 steps, 4 tokens each
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
predictor(frame_tokens, torch.randn(2, 2 * 2, 8)) # 2 steps → mismatch
|
||||||
11
uv.lock
generated
11
uv.lock
generated
@@ -3039,6 +3039,11 @@ video-benchmark = [
|
|||||||
viz = [
|
viz = [
|
||||||
{ name = "rerun-sdk" },
|
{ name = "rerun-sdk" },
|
||||||
]
|
]
|
||||||
|
vla-jepa = [
|
||||||
|
{ name = "diffusers" },
|
||||||
|
{ name = "qwen-vl-utils" },
|
||||||
|
{ name = "transformers" },
|
||||||
|
]
|
||||||
wallx = [
|
wallx = [
|
||||||
{ name = "peft" },
|
{ name = "peft" },
|
||||||
{ name = "qwen-vl-utils" },
|
{ name = "qwen-vl-utils" },
|
||||||
@@ -3107,6 +3112,7 @@ requires-dist = [
|
|||||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'diffusion'" },
|
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'diffusion'" },
|
||||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'groot'" },
|
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'groot'" },
|
||||||
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'multi-task-dit'" },
|
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'multi-task-dit'" },
|
||||||
|
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'vla-jepa'" },
|
||||||
{ name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" },
|
{ name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" },
|
||||||
{ name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" },
|
{ name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" },
|
||||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" },
|
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" },
|
||||||
@@ -3154,6 +3160,7 @@ requires-dist = [
|
|||||||
{ name = "lerobot", extras = ["pyzmq-dep"], marker = "extra == 'unitree-g1'" },
|
{ name = "lerobot", extras = ["pyzmq-dep"], marker = "extra == 'unitree-g1'" },
|
||||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'eo1'" },
|
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'eo1'" },
|
||||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'sarm'" },
|
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'sarm'" },
|
||||||
|
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'vla-jepa'" },
|
||||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'wallx'" },
|
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'wallx'" },
|
||||||
{ name = "lerobot", extras = ["reachy2"], marker = "extra == 'all'" },
|
{ name = "lerobot", extras = ["reachy2"], marker = "extra == 'all'" },
|
||||||
{ name = "lerobot", extras = ["rebot"], marker = "extra == 'all'" },
|
{ name = "lerobot", extras = ["rebot"], marker = "extra == 'all'" },
|
||||||
@@ -3177,12 +3184,14 @@ requires-dist = [
|
|||||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'pi'" },
|
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'pi'" },
|
||||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'sarm'" },
|
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'sarm'" },
|
||||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'smolvla'" },
|
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'smolvla'" },
|
||||||
|
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'vla-jepa'" },
|
||||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'wallx'" },
|
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'wallx'" },
|
||||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'xvla'" },
|
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'xvla'" },
|
||||||
{ name = "lerobot", extras = ["video-benchmark"], marker = "extra == 'all'" },
|
{ name = "lerobot", extras = ["video-benchmark"], marker = "extra == 'all'" },
|
||||||
{ name = "lerobot", extras = ["viz"], marker = "extra == 'all'" },
|
{ name = "lerobot", extras = ["viz"], marker = "extra == 'all'" },
|
||||||
{ name = "lerobot", extras = ["viz"], marker = "extra == 'core-scripts'" },
|
{ name = "lerobot", extras = ["viz"], marker = "extra == 'core-scripts'" },
|
||||||
{ name = "lerobot", extras = ["viz"], marker = "extra == 'dataset-viz'" },
|
{ name = "lerobot", extras = ["viz"], marker = "extra == 'dataset-viz'" },
|
||||||
|
{ name = "lerobot", extras = ["vla-jepa"], marker = "extra == 'all'" },
|
||||||
{ name = "lerobot", extras = ["wallx"], marker = "extra == 'all'" },
|
{ name = "lerobot", extras = ["wallx"], marker = "extra == 'all'" },
|
||||||
{ name = "lerobot", extras = ["xvla"], marker = "extra == 'all'" },
|
{ name = "lerobot", extras = ["xvla"], marker = "extra == 'all'" },
|
||||||
{ name = "matplotlib", marker = "extra == 'matplotlib-dep'", specifier = ">=3.10.3,<4.0.0" },
|
{ name = "matplotlib", marker = "extra == 'matplotlib-dep'", specifier = ">=3.10.3,<4.0.0" },
|
||||||
@@ -3244,7 +3253,7 @@ requires-dist = [
|
|||||||
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
|
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
|
||||||
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
|
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
|
||||||
]
|
]
|
||||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "librt"
|
name = "librt"
|
||||||
|
|||||||
Reference in New Issue
Block a user