Compare commits

...

34 Commits

Author SHA1 Message Date
Maximellerbach
b75b3ce02d fix qwen norm layer output libero eval is now as expected 2026-05-22 17:36:43 +02:00
Maxime Ellerbach
5495c10cdf trying to close success rate gap 2026-05-22 07:59:47 +00:00
Maximellerbach
1bcba9dec6 fixing training and exal examples 2026-05-21 16:35:50 +02:00
Maximellerbach
dd13eda002 adding guard for diffusers 2026-05-21 15:41:26 +02:00
Maximellerbach
c01a00a972 adressing dtype zeros issue 2026-05-21 15:40:17 +02:00
Maxime Ellerbach
f75a2ee2f5 fixing doc defaults args
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
2026-05-21 15:37:37 +02:00
Maximellerbach
0edbb68ec3 pre-commit cleanup 2026-05-21 13:29:14 +02:00
Maxime Ellerbach
47f8a50fa0 refactoring into using pre and post processor 2026-05-21 13:29:14 +02:00
Maxime Ellerbach
51e57789ba lots of changes to make existing weights work, need to massively refactor the pre and post processing 2026-05-21 13:29:14 +02:00
Maximellerbach
27669724d2 removing missleading future_action_window_size to just use chunk_size 2026-05-21 13:29:14 +02:00
Maximellerbach
e32a552edb allow different state dim and action dim 2026-05-21 13:29:14 +02:00
Maximellerbach
596fda60d7 trying out to re-init the action head to avoid pretraining dimension mismatch 2026-05-21 13:29:14 +02:00
Maximellerbach
3fcea935b2 make default params more aligned with paper and pretrained models
- adding possibility of freezing qwen backbone and world model
- added tests for weight loading
2026-05-21 13:29:14 +02:00
Maximellerbach
76b63ebb26 add one-shot script to convert ginwind/VLA-JEPA checkpoints to safetensors (will remove once migrated) 2026-05-21 13:29:14 +02:00
Maximellerbach
bbe4ba7a53 add VLA-JEPA documentation
Covers architecture overview, pretrained checkpoints, config reference,
training/eval commands for LIBERO-10, and guidance on fine-tuning for
single-camera datasets.
2026-05-21 13:29:14 +02:00
Maximellerbach
593253e155 update VLA-JEPA tests for arch changes and action_is_pad
- Switch conftest to use `action_model_type="DiT-test"` now that
  `action_num_heads` / `action_attention_head_dim` have been removed.
- Add action_head tests covering fully-padded loss (zero) and equivalence
  of action_is_pad=None vs all-zeros mask.
- Remove obsolete `test_native_to_lerobot_wm_only` test.
2026-05-21 13:29:14 +02:00
Maximellerbach
acf65faaff propagate action_is_pad masking through VLA-JEPA policy pipeline
Pass the `action_is_pad` tensor from the batch through to the action head
so padded timesteps are excluded from the flow-matching loss.
2026-05-21 13:29:14 +02:00
Maximellerbach
82a05f9cb4 align VLA-JEPA architecture with original checkpoint
- Remove stale `action_num_heads` / `action_attention_head_dim` config fields;
  DiT head dimensions are now always derived from the preset (DiT-B/L/test).
- Add `num_target_vision_tokens` and `action_max_seq_len` config fields required
  by the action head's future-token embedding and positional embedding tables.
- Fix default `qwen_model_name` to 2B (matches all released checkpoints).
- Rename `ActionEncoder` attrs w1/w2/w3 → layer1/layer2/layer3 to match
  checkpoint key names; replace `nn.Sequential` decoder/state-encoder with
  `_MLP2` (layer1/layer2 naming).
- Fix `VLAJEPAActionHead` to size ActionEncoder and StateEncoder at `inner_dim`
  (DiT input width) rather than `action_hidden_size` (DiT output width).
- Rename `DiT.blocks` → `transformer_blocks` and `attn` → `attn1` to match
  checkpoint; add alternating cross/self attention (even blocks cross-attend to
  Qwen context, odd blocks self-attend).
- Add `DiT-test` preset for unit tests.
- Rewrite `ActionConditionedVideoPredictor` with explicit ViT-style blocks
  (`_PredictorBlock` with fused qkv) to match checkpoint structure; rename
  `encoder`/`norm`/`proj` → `predictor_blocks`/`predictor_norm`/`predictor_proj`.
2026-05-21 13:29:14 +02:00
Maximellerbach
d4abb9d562 adding more tests to ensure good coverage 2026-05-21 13:29:13 +02:00
Maximellerbach
090d392b19 some more fixes to be closer to the original implem 2026-05-21 13:29:13 +02:00
Maxime Ellerbach
e36d742d7d adjusting obs steps, tublets size to match original implementation 2026-05-21 13:29:13 +02:00
Maxime Ellerbach
f8a1acb6c9 fixing wm_loss not propagating 2026-05-21 13:29:13 +02:00
Maxime Ellerbach
60347bc742 fix warnings with qwen processor kwargs 2026-05-21 13:29:13 +02:00
Maxime Ellerbach
c6ec8d00e3 fixing action and state dim 2026-05-21 13:29:13 +02:00
Maximellerbach
0edb693ee4 adding guards to avoid needing transformers and diffusers for type checking and basic tests 2026-05-21 13:29:13 +02:00
Maximellerbach
cdae1b9ad8 updating uv lock 2026-05-21 13:29:13 +02:00
Maximellerbach
80ecf7bf53 adding deps to pyproject.toml 2026-05-21 13:28:04 +02:00
Maximellerbach
5597d539e7 linting 2026-05-21 13:27:07 +02:00
ginwind
dfbedb71d7 (feat)policies: add VLA-JEPA 2026-05-21 13:27:07 +02:00
ginwind
ebe6c66263 support vla_jepa 2026-05-21 13:27:07 +02:00
ginwind
0e18bdaf7a feat(policies): add VLA-JEPA 2026-05-21 13:27:07 +02:00
ginwind
d5944c410c feat(policies): add VLA-JEPA 2026-05-21 13:27:07 +02:00
ginwind
0d37efdb4b first commit 2026-05-21 13:27:07 +02:00
Virgileboat
f4b834844e Feat/clean can bus (#3526)
* change timeout  for handshake

* enforce last state read when querry

* change import order

* fix(motors): flush stale robstride RX and harden feedback drain

* robstride: remove redundant timeout and max_messages casts

* bugfix + %-style

* update exception catch
2026-05-21 11:44:04 +02:00
20 changed files with 3519 additions and 20 deletions

View File

@@ -0,0 +1,196 @@
# VLA-JEPA
This is the LeRobot port of **VLA-JEPA**, a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head.
---
## Architecture Overview
VLA-JEPA has three main components:
| Component | Module | Role |
| ----------------------- | --------------------------------- | ------------------------------------------------------- |
| **Qwen3-VL backbone** | `Qwen3VLInterface` | Fuses images + language instruction into context tokens |
| **DiT-B action head** | `VLAJEPAActionHead` | Flow-matching diffusion over the action chunk |
| **V-JEPA2 world model** | `ActionConditionedVideoPredictor` | Self-supervised video prediction loss (training only) |
### Data flow
**Training:**
1. A video clip of `num_video_frames` frames is encoded by V-JEPA2 into per-frame patch tokens.
2. The Qwen3-VL backbone processes multi-view images + the task instruction and produces a sequence of context tokens that includes special action tokens (for world model conditioning) and embodied tokens.
3. The action head receives those context tokens as cross-attention keys/values and predicts a denoised action chunk via flow matching.
4. The world model predictor uses the action tokens extracted from Qwen to predict future V-JEPA2 frame embeddings; a regression loss on those predictions is added to the action loss.
**Inference:**
Only Qwen + the action head are used. The world model is not needed at inference time.
### Action head details
Available presets via `action_model_type`:
| Preset | Hidden dim | Heads | Head dim |
| ------- | ---------- | ----- | -------- |
| `DiT-B` | 768 | 12 | 64 |
| `DiT-L` | 1536 | 32 | 48 |
### World model details
The video predictor is a ViT-style transformer (`ActionConditionedVideoPredictor`) that takes:
- **Frame tokens**: V-JEPA2 patch embeddings projected to `predictor_embed_dim`
- **Action tokens**: Qwen action token embeddings projected to `predictor_embed_dim`
It uses block-causal attention so each temporal step can attend to all previous steps. The predictor's input `embed_dim` equals `num_views × video_encoder_hidden_size` (e.g. 2 views × 1024 = 2048 for the pretrained checkpoints).
---
## Pretrained Checkpoints
Three checkpoints are available, converted from [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA):
| Checkpoint | Dataset | Cameras | World model | Action dim |
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 | Disabled\* | 7 |
\* The SimplerEnv checkpoint was fine-tuned from Pretrain. The world model predictor architecture expects `embed_dim=2048` (2-camera input) but SimplerEnv is single-camera, so the world model cannot be loaded cleanly. Since inference only needs Qwen + the action head, `enable_world_model=False` is set for this variant. See [Fine-tuning on single-camera datasets](#fine-tuning-on-single-camera-datasets) for implications.
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
---
## Configuration
Key parameters in `VLAJEPAConfig`:
| Parameter | Default | Description |
| ------------------------- | ------- | -------------------------------------------------------------- |
| `chunk_size` | 7 | Number of actions predicted per inference call |
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
| `num_video_frames` | 8 | Video clip length fed to the world model |
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
---
## Training
Number of training steps may vary based on dataset size and compute budget. The original paper pretrained for 50k on ssv2 + droid jointly, then additional 30k steps for LIBERO, but fewer steps may still yield good performance when fine-tuning from the provided pretrained checkpoints.
### Full training from scratch
```bash
lerobot-train \
policy.type=vla_jepa \
policy.repo_id=your_org/your_repo \
dataset.repo_id=your_org/your_dataset
```
### Fine-tuning from a pretrained checkpoint
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=your_org/your_dataset
```
If you want to go further and freeze the Qwen backbone and only train the action head, set `policy.freeze_qwen=True`:
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--policy.freeze_qwen=true \
--dataset.repo_id=your_org/your_dataset
```
### Reproducing the LIBERO results
**Training on LIBERO:**
starts the training from the Pretrain checkpoint, trains for 30k steps on the LIBERO dataset.
Original paper mentions training across 8 GPUs with a batch size of 32, meaning global batch size of 256.
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=HuggingFaceVLA/libero \
--steps=30000
```
**Evaluating the pretrained LIBERO-10 checkpoint:**
```bash
lerobot-eval \
--policy.path=lerobot/VLA-JEPA-LIBERO \
--env.type=libero \
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
--eval.n_episodes=10 \
--eval.batch_size=5 \
```
To evaluate a subset of tasks only:
```bash
lerobot-eval \
--policy.path=lerobot/VLA-JEPA-LIBERO \
--env.type=libero \
--env.task=libero_10 \
--env.task_ids='[0,1,2]' \
--eval.n_episodes=10 \
--eval.batch_size=5 \
```
---
## Fine-tuning on single-camera datasets
The pretrained world model predictor was trained with `embed_dim = num_views × 1024`. If your target dataset has fewer cameras than the source checkpoint, the predictor input projection will have a shape mismatch and cannot be loaded.
**Option 1 — Disable the world model (recommended)**
Set `enable_world_model=False`. Only the Qwen backbone and action head are loaded and trained. This matches the original SimplerEnv fine-tuning strategy and is sufficient for good action performance.
```bash
lerobot-train \
--policy.path=lerobot/VLA-JEPA-Pretrain \
--policy.enable_world_model=false \
--policy.repo_id=your_org/your_repo \
--dataset.repo_id=your_org/single_camera_dataset
```
**Option 2 — Reinitialize the predictor input projection**
If you want the JEPA self-supervised signal during fine-tuning, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
**Option 3 - Duplicate frames to match the expected number of cameras**
A bit more advanced, you would need to change some parts of the code to support that.
---
## Citation
```bibtex
@misc{sun2026vlajepaenhancingvisionlanguageactionmodel,
title = {VLA-JEPA: Enhancing Vision-Language-Action Model with Latent World Model},
author = {Jingwen Sun and Wenyao Zhang and Zekun Qi and Shaojie Ren and Zezhi Liu and Hanxin Zhu and Guangzhong Sun and Xin Jin and Zhibo Chen},
year = {2026},
eprint = {2602.10098},
archivePrefix = {arXiv},
primaryClass = {cs.RO},
url = {https://arxiv.org/abs/2602.10098},
}
```
---
## License
Weights are distributed under the license terms of the original [ginwind/VLA-JEPA](https://huggingface.co/ginwind/VLA-JEPA) repository (**Apache 2.0 License**). The LeRobot integration code follows the **Apache 2.0 License**.

View File

@@ -212,6 +212,7 @@ sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<3
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
@@ -276,6 +277,7 @@ all = [
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
"lerobot[hilserl]",
"lerobot[vla_jepa]",
"lerobot[async]",
"lerobot[dev]",
"lerobot[test]",

View File

@@ -43,6 +43,7 @@ from .tables import (
CAN_CMD_SET_ZERO,
DEFAULT_BAUDRATE,
DEFAULT_TIMEOUT_MS,
HANDSHAKE_TIMEOUT_S,
MODEL_RESOLUTION,
MOTOR_LIMIT_PARAMS,
NORMALIZED_DATA,
@@ -215,14 +216,16 @@ class RobstrideMotorsBus(MotorsBusBase):
self._is_connected = False
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
def _query_status_via_clear_fault(self, motor: NameOrID) -> tuple[bool, can.Message | None]:
def _query_status_via_clear_fault(
self, motor: NameOrID, timeout: float = RUNNING_TIMEOUT
) -> tuple[bool, can.Message | None]:
motor_name = self._get_motor_name(motor)
motor_id = self._get_motor_id(motor_name)
recv_id = self._get_motor_recv_id(motor_name)
data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self._bus().send(msg)
return self._recv_status_via_clear_fault(expected_recv_id=recv_id)
return self._recv_status_via_clear_fault(expected_recv_id=recv_id, timeout=timeout)
def _recv_status_via_clear_fault(
self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT
@@ -280,7 +283,7 @@ class RobstrideMotorsBus(MotorsBusBase):
faulted_motors = []
for motor_name in self.motors:
has_fault, msg = self._query_status_via_clear_fault(motor_name)
has_fault, msg = self._query_status_via_clear_fault(motor_name, timeout=HANDSHAKE_TIMEOUT_S)
if msg is None:
missing_motors.append(motor_name)
elif has_fault:
@@ -505,6 +508,87 @@ class RobstrideMotorsBus(MotorsBusBase):
return responses
def _recv_all_messages_until_quiet(
self,
*,
timeout: float = RUNNING_TIMEOUT,
max_messages: int = 4096,
) -> list[can.Message]:
"""
Receive frames until the bus goes quiet.
Args:
timeout: Poll timeout used for each recv() call. Collection stops
when one recv() times out (quiet gap).
max_messages: Safety cap to prevent unbounded loops.
"""
out: list[can.Message] = []
max_messages = max(1, max_messages)
timeout = max(0.0, timeout)
try:
while len(out) < max_messages:
msg = self._bus().recv(timeout=timeout)
if msg is None:
break
out.append(msg)
except (can.CanError, OSError) as e:
logger.debug(f"Error draining CAN RX queue on {self.port}: {e}")
return out
def _process_feedback_messages(self, messages: list[can.Message]) -> set[int]:
"""
Decode all received feedback frames and update cached motor states.
Returns:
Set of payload recv_ids that were successfully mapped to motors.
"""
processed_recv_ids: set[int] = set()
for msg in messages:
if len(msg.data) < 1:
logger.debug(
f"Dropping short CAN frame on {self.port} "
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()})"
)
continue
recv_id = int(msg.data[0])
motor_name = self._recv_id_to_motor.get(recv_id)
if motor_name is None:
logger.debug(
f"Unmapped CAN frame on {self.port} "
f"(arb=0x{int(msg.arbitration_id):02X}, recv_id=0x{recv_id:02X}, data={bytes(msg.data).hex()})"
)
continue
self._process_response(motor_name, msg)
processed_recv_ids.add(recv_id)
return processed_recv_ids
def flush_rx_queue(self, poll_timeout_s: float = 0.0005, max_messages: int = 4096) -> int:
"""
Drain pending RX frames from the CAN interface.
This is used by higher-level controllers to drop stale feedback before issuing
a fresh read cycle, so subsequent state reads are based on most recent replies.
It should also be called once when a controller instance is created/connected,
to clear residual frames left on the interface from previous sessions.
"""
drained = 0
poll_timeout_s = max(0.0, poll_timeout_s)
max_messages = max(1, max_messages)
try:
while drained < max_messages:
msg = self._bus().recv(timeout=poll_timeout_s)
if msg is None:
break
drained += 1
except (can.CanError, OSError) as e:
logger.debug(f"Failed to flush CAN RX queue on {self.port}: {e}")
return drained
def _speed_control(
self,
motor: NameOrID,
@@ -644,11 +728,14 @@ class RobstrideMotorsBus(MotorsBusBase):
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self._bus().send(msg)
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
# Read every feedback frame until RX goes quiet, then decode all of them.
# This avoids dropping useful frames when responses from different motors interleave.
messages = self._recv_all_messages_until_quiet()
processed_recv_ids = self._process_feedback_messages(messages)
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=RUNNING_TIMEOUT)
for recv_id, motor_name in recv_id_to_motor.items():
if msg := responses.get(recv_id):
self._process_response(motor_name, msg)
if recv_id not in processed_recv_ids:
logger.warning(f"Packet drop: {motor_name} (ID: 0x{recv_id:02X}). Using last known state.")
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
"""Convert float to unsigned integer for CAN transmission."""
@@ -711,7 +798,10 @@ class RobstrideMotorsBus(MotorsBusBase):
try:
self._decode_motor_state(msg.data)
except Exception as e:
logger.warning(f"Failed to decode response from {motor}: {e}")
logger.warning(
f"Failed to decode response from {motor} "
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()}): {e}"
)
def _get_cached_value(self, motor: str, data_name: str) -> Value:
"""Retrieve a specific value from the state cache."""
@@ -848,20 +938,12 @@ class RobstrideMotorsBus(MotorsBusBase):
self._bus().send(msg)
updated_motors.append(motor)
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in updated_motors]
responses = self._recv_all_responses(expected_recv_ids, timeout=RUNNING_TIMEOUT)
for response in responses.values():
payload_motor_name = self._recv_id_to_motor.get(response.data[0])
if payload_motor_name is not None:
self._process_response(payload_motor_name, response)
else:
# Fallback: still attempt to decode based on payload byte0 mapping.
self._decode_motor_state(response.data)
messages = self._recv_all_messages_until_quiet()
processed_recv_ids = self._process_feedback_messages(messages)
for motor in updated_motors:
recv_id = self._get_motor_recv_id(motor)
if recv_id not in responses:
if recv_id not in processed_recv_ids:
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
def read_calibration(self) -> dict[str, MotorCalibration]:

View File

@@ -114,7 +114,8 @@ CAN_CMD_SAVE_PARAM = 0xAA
CAN_PARAM_ID = 0x7FF
RUNNING_TIMEOUT = 0.001
RUNNING_TIMEOUT = 0.003
HANDSHAKE_TIMEOUT_S = 0.05
PARAM_TIMEOUT = 0.01
STATE_CACHE_TTL_S = 0.02

View File

@@ -56,6 +56,7 @@ from .pretrained import PreTrainedPolicy
from .smolvla.configuration_smolvla import SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig
from .utils import validate_visual_features_consistency
from .vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from .vqbet.configuration_vqbet import VQBeTConfig
from .wall_x.configuration_wall_x import WallXConfig
from .xvla.configuration_xvla import XVLAConfig
@@ -151,6 +152,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .eo1.modeling_eo1 import EO1Policy
return EO1Policy
elif name == "vla_jepa":
from .vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
return VLAJEPAPolicy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -203,6 +208,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return WallXConfig(**kwargs)
elif policy_type == "eo1":
return EO1Config(**kwargs)
elif policy_type == "vla_jepa":
return VLAJEPAConfig(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -406,6 +413,7 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, EO1Config):
from .eo1.processor_eo1 import make_eo1_pre_post_processors
@@ -414,6 +422,14 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, VLAJEPAConfig):
from .vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
processors = make_vla_jepa_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_policy_config(

View File

@@ -0,0 +1 @@
/home/maxime/github/robots/lerobot/docs/source/policy_vla_jepa_README.md

View File

@@ -0,0 +1,10 @@
from .configuration_vla_jepa import VLAJEPAConfig
from .modeling_vla_jepa import VLAJEPAPolicy
from .processor_vla_jepa import VLAJEPANewLineProcessor, make_vla_jepa_pre_post_processors
__all__ = [
"VLAJEPAConfig",
"VLAJEPAPolicy",
"VLAJEPANewLineProcessor",
"make_vla_jepa_pre_post_processors",
]

View File

@@ -0,0 +1,327 @@
from __future__ import annotations
from collections import OrderedDict
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
from torch.distributions import Beta
from lerobot.utils.import_utils import _diffusers_available, require_package
if TYPE_CHECKING or _diffusers_available:
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
else:
class ModelMixin: # type: ignore[no-redef]
pass
class ConfigMixin: # type: ignore[no-redef]
pass
register_to_config = lambda f: f # noqa: E731
Attention = FeedForward = TimestepEmbedding = Timesteps = None
from .configuration_vla_jepa import VLAJEPAConfig
def swish(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
timesteps = timesteps.float()
batch_size, seq_len = timesteps.shape
half_dim = self.embedding_dim // 2
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device)
exponent = exponent * (torch.log(torch.tensor(10000.0, device=timesteps.device)) / max(half_dim, 1))
freqs = timesteps.unsqueeze(-1) * exponent.exp()
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1).view(batch_size, seq_len, -1)
class ActionEncoder(nn.Module):
def __init__(self, action_dim: int, hidden_size: int):
super().__init__()
self.layer1 = nn.Linear(action_dim, hidden_size)
self.layer2 = nn.Linear(hidden_size * 2, hidden_size)
self.layer3 = nn.Linear(hidden_size, hidden_size)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, _ = actions.shape
if timesteps.ndim != 1 or timesteps.shape[0] != batch_size:
raise ValueError("timesteps must have shape [batch_size].")
timesteps = timesteps.unsqueeze(1).expand(-1, seq_len)
action_emb = self.layer1(actions)
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
return self.layer3(swish(self.layer2(torch.cat([action_emb, time_emb], dim=-1))))
class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
require_package("diffusers", extra="vla_jepa")
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
projected = self.time_proj(timesteps).to(dtype=next(self.parameters()).dtype)
return self.timestep_embedder(projected)
class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, eps=1e-5, elementwise_affine=False)
self.silu = nn.SiLU()
def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
scale, shift = self.linear(self.silu(temb)).chunk(2, dim=-1)
return self.norm(x) * (1 + scale[:, None]) + shift[:, None]
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout: float,
cross_attention_dim: int,
is_cross_attention: bool = True,
) -> None:
super().__init__()
self.is_cross_attention = is_cross_attention
self.norm1 = AdaLayerNorm(dim)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=True,
cross_attention_dim=cross_attention_dim,
out_bias=True,
)
self.norm2 = nn.LayerNorm(dim, eps=1e-5, elementwise_affine=False)
self.ff = FeedForward(dim, dropout=dropout, activation_fn="gelu-approximate", final_dropout=True)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None,
temb: torch.Tensor,
) -> torch.Tensor:
attn_input = self.norm1(hidden_states, temb)
attention_context = encoder_hidden_states if self.is_cross_attention else None
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context)
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
return hidden_states
class DiT(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
output_dim: int,
num_layers: int,
dropout: float,
cross_attention_dim: int,
) -> None:
super().__init__()
self.inner_dim = num_attention_heads * attention_head_dim
self.timestep_encoder = TimestepEncoder(self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim,
is_cross_attention=layer_idx % 2 == 0,
)
for layer_idx in range(num_layers)
]
)
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
self.proj_out_1 = nn.Linear(self.inner_dim, self.inner_dim * 2)
self.proj_out_2 = nn.Linear(self.inner_dim, output_dim)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.Tensor,
) -> torch.Tensor:
temb = self.timestep_encoder(timestep)
x = hidden_states
for block in self.transformer_blocks:
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
return self.proj_out_2(x)
@dataclass
class ActionModelPreset:
hidden_size: int
attention_head_dim: int
num_attention_heads: int
DIT_PRESETS = {
"DiT-B": ActionModelPreset(hidden_size=768, attention_head_dim=64, num_attention_heads=12),
"DiT-L": ActionModelPreset(hidden_size=1536, attention_head_dim=48, num_attention_heads=32),
"DiT-test": ActionModelPreset(hidden_size=16, attention_head_dim=8, num_attention_heads=2),
}
class VLAJEPAActionHead(nn.Module):
def __init__(self, config: VLAJEPAConfig, cross_attention_dim: int) -> None:
super().__init__()
preset = DIT_PRESETS[config.action_model_type]
self.config = config
num_heads = config.action_num_heads or preset.num_attention_heads
head_dim = config.action_attention_head_dim or preset.attention_head_dim
inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768
self.input_embedding_dim = inner_dim
self.action_horizon = config.chunk_size
self.num_inference_timesteps = config.num_inference_timesteps
hidden_size = config.action_hidden_size
self.model = DiT(
num_attention_heads=num_heads,
attention_head_dim=head_dim,
output_dim=hidden_size,
num_layers=config.action_num_layers,
dropout=config.action_dropout,
cross_attention_dim=cross_attention_dim,
)
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
self.action_decoder = nn.Sequential(
OrderedDict(
[
("layer1", nn.Linear(hidden_size, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, config.action_dim)),
]
)
)
self.state_encoder = (
nn.Sequential(
OrderedDict(
[
("layer1", nn.Linear(config.state_dim, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, inner_dim)),
]
)
)
if config.state_dim > 0
else None
)
self.future_tokens = nn.Embedding(config.num_embodied_action_tokens_per_instruction, inner_dim)
self.position_embedding = nn.Embedding(
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
inner_dim,
)
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
sample = self.beta_dist.sample([batch_size]).to(device=device, dtype=dtype)
return (self.config.action_noise_s - sample) / self.config.action_noise_s
def _build_inputs(
self,
conditioning_tokens: torch.Tensor,
actions: torch.Tensor,
state: torch.Tensor | None,
timesteps: torch.Tensor,
) -> torch.Tensor:
action_features = self.action_encoder(actions, timesteps)
pos_ids = torch.arange(action_features.shape[1], device=actions.device)
action_features = action_features + self.position_embedding(pos_ids)[None]
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(actions.shape[0], -1, -1)
seq = [future_tokens, action_features]
if state is not None and self.state_encoder is not None:
if state.ndim == 2:
state = state.unsqueeze(1)
seq.insert(0, self.state_encoder(state))
return torch.cat(seq, dim=1)
def forward(
self,
conditioning_tokens: torch.Tensor,
actions: torch.Tensor,
state: torch.Tensor | None = None,
action_is_pad: torch.Tensor | None = None,
) -> torch.Tensor:
noise = torch.randn_like(actions)
t = self.sample_time(actions.shape[0], actions.device, actions.dtype)
noisy_actions = (1 - t[:, None, None]) * noise + t[:, None, None] * actions
velocity = actions - noise
t_discretized = (t * self.config.action_num_timestep_buckets).long()
hidden_states = self._build_inputs(conditioning_tokens, noisy_actions, state, t_discretized)
pred = self.model(
hidden_states=hidden_states,
encoder_hidden_states=conditioning_tokens,
timestep=t_discretized,
)
pred_actions = self.action_decoder(pred[:, -actions.shape[1] :])
if action_is_pad is None:
action_is_pad = torch.zeros(actions.shape[:2], dtype=torch.bool, device=actions.device)
loss = F.mse_loss(pred_actions, velocity, reduction="none") # [B, T, action_dim]
valid_mask = ~action_is_pad.unsqueeze(-1) # [B, T, 1]
num_valid = valid_mask.sum() * loss.shape[-1]
return (loss * valid_mask).sum() / num_valid.clamp_min(1)
@torch.no_grad()
def predict_action(
self,
conditioning_tokens: torch.Tensor,
state: torch.Tensor | None = None,
) -> torch.Tensor:
batch_size = conditioning_tokens.shape[0]
actions = torch.randn(
batch_size,
self.action_horizon,
self.config.action_dim,
dtype=conditioning_tokens.dtype,
device=conditioning_tokens.device,
)
dt = 1.0 / max(self.num_inference_timesteps, 1)
for step in range(self.num_inference_timesteps):
t_cont = step / float(max(self.num_inference_timesteps, 1))
t_value = int(t_cont * self.config.action_num_timestep_buckets)
timesteps = torch.full(
(batch_size,), t_value, device=conditioning_tokens.device, dtype=torch.long
)
hidden_states = self._build_inputs(conditioning_tokens, actions, state, timesteps)
pred = self.model(
hidden_states=hidden_states,
encoder_hidden_states=conditioning_tokens,
timestep=timesteps,
)
pred_velocity = self.action_decoder(pred[:, -self.action_horizon :])
actions = actions + dt * pred_velocity
return actions

View File

@@ -0,0 +1,133 @@
from __future__ import annotations
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("vla_jepa")
@dataclass
class VLAJEPAConfig(PreTrainedConfig):
n_obs_steps: int = 1
chunk_size: int = 7
n_action_steps: int = 7
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MIN_MAX,
}
)
qwen_model_name: str = "Qwen/Qwen3-VL-2B-Instruct"
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
freeze_qwen: bool = False
enable_world_model: bool = True
reinit_action_head: bool = False
tokenizer_padding_side: str = "left"
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
special_action_token: str = "<|action_{}|>"
embodied_action_token: str = "<|embodied_action|>"
action_dim: int = 7
state_dim: int = 8
num_action_tokens_per_timestep: int = 8
num_embodied_action_tokens_per_instruction: int = 32
num_inference_timesteps: int = 4
action_hidden_size: int = 1024
action_model_type: str = "DiT-B"
action_num_layers: int = 16
action_num_heads: int | None = None
action_attention_head_dim: int | None = None
action_dropout: float = 0.2
action_num_timestep_buckets: int = 1000
action_noise_beta_alpha: float = 1.5
action_noise_beta_beta: float = 1.0
action_noise_s: float = 0.999
num_target_vision_tokens: int = 32
action_max_seq_len: int = 1024
# total video frames loaded per sample
num_video_frames: int = 8
predictor_depth: int = 12
predictor_num_heads: int = 8
predictor_mlp_ratio: float = 4.0
predictor_dropout: float = 0.0
world_model_loss_weight: float = 0.1
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
resize_images_to: tuple[int, int] | None = None
binarize_gripper_action: bool = True
pre_snap_gripper_action: bool = True
clip_normalized_actions: bool = True
torch_dtype: str = "bfloat16"
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_grad_clip_norm: float = 10.0
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
def __post_init__(self) -> None:
super().__post_init__()
if self.freeze_qwen and self.enable_world_model:
# freezing qwen backbone makes world model training irrelevant since no grad flows
self.enable_world_model = False
if self.n_action_steps > self.chunk_size:
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
if self.num_video_frames < 2 * self.jepa_tubelet_size:
raise ValueError(
f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` "
f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position."
)
def validate_features(self) -> None:
if not self.image_features:
raise ValueError("VLAJEPA requires at least one visual input feature.")
if self.action_feature is None:
raise ValueError("VLAJEPA requires an action output feature.")
self.action_dim = self.action_feature.shape[0]
if self.robot_state_feature is not None:
self.state_dim = self.robot_state_feature.shape[0]
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list[int]:
# load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1]
# matches original repo's observation_indices=list(range(video_horizon))
return list(range(self.num_video_frames))
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -0,0 +1,454 @@
#!/usr/bin/env python
"""
Convert all VLA-JEPA .pt checkpoints (ginwind/VLA-JEPA) to LeRobot safetensors
format and upload them to maximellerbach org inside a HF collection.
Usage:
uv run python convert_vla_jepa_checkpoints.py
For each variant the script:
1. Downloads the .pt checkpoint.
2. Extracts the state dict.
3. Instantiates VLAJEPAPolicy with the variant's confirmed config.
4. Loads the state dict (strict=False — mismatches printed to stdout).
5. push_to_hub → writes model.safetensors + config.json in LeRobot format.
6. Adds the new repo to a shared HF collection.
Config sources
--------------
Numeric hyper-params : ginwind/VLA-JEPA/<variant>/config.json
Image keys LIBERO : lerobot/libero_10 meta/info.json ✓ confirmed
Image keys Pretrain : lerobot/droid_1.0.1 meta/info.json ✓ confirmed
Image keys SimplerEnv: OXE Bridge/RT1 are single-camera ✓ confirmed
"""
from __future__ import annotations
import logging
import tempfile
from pathlib import Path
import torch
from huggingface_hub import HfApi
from safetensors.torch import save_file as save_safetensors
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
# ---------------------------------------------------------------------------
# Top-level settings
# ---------------------------------------------------------------------------
SOURCE_REPO_ID = "ginwind/VLA-JEPA"
TARGET_ORG = "maximellerbach"
COLLECTION_TITLE = "VLA-JEPA"
COLLECTION_DESCRIPTION = (
"VLA-JEPA model checkpoints (LIBERO, Pretrain, SimplerEnv) converted from .pt to safetensors via LeRobot."
)
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Key mapping — mirrors todo_converter.py map_key() so both converters
# produce identical safetensors layouts that match the LeRobot action_head code.
# ---------------------------------------------------------------------------
def _normalize_source_key(key: str) -> str:
return key[len("module.") :] if key.startswith("module.") else key
def _map_checkpoint_key(raw_key: str) -> str | None:
"""Map original VLA-JEPA state-dict keys to LeRobot vla_jepa layout."""
key = _normalize_source_key(raw_key)
if key.startswith("qwen_vl_interface."):
return "model.qwen." + key[len("qwen_vl_interface.") :]
if key.startswith("vj_encoder."):
return "model.video_encoder." + key[len("vj_encoder.") :]
if key.startswith("vj_predictor."):
return "model.video_predictor." + key[len("vj_predictor.") :]
if key.startswith("action_model."):
# LeRobot code uses the same sub-key names as the source checkpoint,
# so only the top-level "model." prefix needs to be added.
return "model." + key
return None
def _fetch_dataset_stats(api: HfApi, source_repo_id: str, subfolder: str) -> dict | None:
"""Download dataset_statistics.json and return {action: {...}, state: {...}} stats dict."""
import json
stats_file = f"{subfolder}/dataset_statistics.json"
try:
local = api.hf_hub_download(source_repo_id, stats_file)
data = json.loads(Path(local).read_text())
# Original repo nests stats under a robot key, e.g. {"franka": {"action": {...}, "state": {...}}}
for robot_key in data:
robot_data = data[robot_key]
if isinstance(robot_data, dict) and "action" in robot_data:
log.info(" Loaded dataset stats from %s (robot key: %s)", stats_file, robot_key)
result = {"action": robot_data["action"]}
if "state" in robot_data:
result["observation.state"] = robot_data["state"]
log.info(" Also loaded state stats.")
return result
log.warning(" %s found but no 'action' key under any robot key — skipping stats.", stats_file)
except Exception as exc: # noqa: BLE001
log.warning(" Could not fetch %s: %s — postprocessor will have no unnorm stats.", stats_file, exc)
return None
def _set_if_present(d: dict, key: str, value) -> None:
if value is not None:
d[key] = value
def _deep_get(mapping: dict, path: tuple, default=None):
current = mapping
for key in path:
if not isinstance(current, dict) or key not in current:
return default
current = current[key]
return current
def _fetch_source_config(api: HfApi, source_repo_id: str, subfolder: str) -> dict:
"""Download config.yaml from the source HF repo for a given variant subfolder."""
try:
import yaml
except ImportError:
log.warning("PyYAML not installed — cannot apply source config.yaml overrides.")
return {}
config_file = f"{subfolder}/config.yaml"
try:
local = api.hf_hub_download(source_repo_id, config_file)
data = yaml.safe_load(Path(local).read_text()) or {}
if isinstance(data, dict):
log.info(" Loaded source config from %s", config_file)
return data
except Exception as exc: # noqa: BLE001
log.warning(" Could not fetch %s: %s — using hardcoded defaults.", config_file, exc)
return {}
def _apply_source_config(kwargs: dict, source_config: dict) -> None:
"""Apply ginwind/VLA-JEPA config.yaml values to kwargs, mirroring todo_converter.py logic."""
if not source_config:
return
data_cfg = _deep_get(source_config, ("datasets", "vla_data"), {})
action_cfg = _deep_get(source_config, ("framework", "action_model"), {})
diffusion_cfg = _deep_get(source_config, ("framework", "action_model", "diffusion_model_cfg"), {})
video_cfg = _deep_get(source_config, ("framework", "vj2_model"), {})
trainer_cfg = source_config.get("trainer", {})
prompt_template = data_cfg.get("CoT_prompt")
if prompt_template:
kwargs["prompt_template"] = str(prompt_template)
action_horizon = action_cfg.get("action_horizon")
if action_horizon is not None:
kwargs["chunk_size"] = int(action_horizon)
kwargs["n_action_steps"] = int(action_horizon)
_set_if_present(
kwargs,
"num_action_tokens_per_timestep",
video_cfg.get("num_action_tokens_per_timestep", action_cfg.get("num_action_tokens_per_timestep")),
)
_set_if_present(
kwargs,
"num_embodied_action_tokens_per_instruction",
video_cfg.get(
"num_embodied_action_tokens_per_instruction",
action_cfg.get("num_embodied_action_tokens_per_instruction"),
),
)
_set_if_present(kwargs, "num_inference_timesteps", action_cfg.get("num_inference_timesteps"))
_set_if_present(kwargs, "special_action_token", video_cfg.get("special_action_token"))
_set_if_present(kwargs, "embodied_action_token", video_cfg.get("embodied_action_token"))
_set_if_present(
kwargs, "action_hidden_size", action_cfg.get("action_hidden_dim", action_cfg.get("hidden_size"))
)
_set_if_present(kwargs, "action_model_type", action_cfg.get("action_model_type"))
_set_if_present(kwargs, "action_noise_beta_alpha", action_cfg.get("noise_beta_alpha"))
_set_if_present(kwargs, "action_noise_beta_beta", action_cfg.get("noise_beta_beta"))
_set_if_present(kwargs, "action_noise_s", action_cfg.get("noise_s"))
_set_if_present(kwargs, "action_num_timestep_buckets", action_cfg.get("num_timestep_buckets"))
_set_if_present(kwargs, "repeated_diffusion_steps", action_cfg.get("repeated_diffusion_steps"))
_set_if_present(kwargs, "action_num_layers", diffusion_cfg.get("num_layers"))
_set_if_present(kwargs, "action_dropout", diffusion_cfg.get("dropout"))
_set_if_present(kwargs, "num_video_frames", video_cfg.get("num_frames"))
_set_if_present(kwargs, "predictor_depth", video_cfg.get("predictor_depth", video_cfg.get("depth")))
_set_if_present(
kwargs, "predictor_num_heads", video_cfg.get("predictor_num_heads", video_cfg.get("num_heads"))
)
_set_if_present(kwargs, "predictor_mlp_ratio", video_cfg.get("predictor_mlp_ratio"))
_set_if_present(kwargs, "optimizer_grad_clip_norm", trainer_cfg.get("max_grad_norm"))
learning_rate = trainer_cfg.get("learning_rate", {})
if isinstance(learning_rate, dict):
_set_if_present(kwargs, "optimizer_lr", learning_rate.get("action_model"))
optimizer_cfg = trainer_cfg.get("optimizer", {})
if isinstance(optimizer_cfg, dict):
_set_if_present(kwargs, "optimizer_eps", optimizer_cfg.get("eps"))
_set_if_present(kwargs, "optimizer_weight_decay", optimizer_cfg.get("weight_decay"))
betas = optimizer_cfg.get("betas")
if betas is not None:
kwargs["optimizer_betas"] = tuple(betas)
scheduler = trainer_cfg.get("scheduler", {})
if isinstance(scheduler, dict):
_set_if_present(kwargs, "scheduler_warmup_steps", scheduler.get("warmup_steps"))
_set_if_present(kwargs, "scheduler_decay_lr", scheduler.get("min_lr"))
_set_if_present(kwargs, "scheduler_warmup_steps", trainer_cfg.get("num_warmup_steps"))
scheduler_kwargs = trainer_cfg.get("scheduler_specific_kwargs", {})
if isinstance(scheduler_kwargs, dict):
_set_if_present(kwargs, "scheduler_decay_lr", scheduler_kwargs.get("min_lr"))
# ---------------------------------------------------------------------------
# Architecture — identical across all 4 variants (from config.json)
# ---------------------------------------------------------------------------
_ARCH = {
"qwen_model_name": "Qwen/Qwen3-VL-2B-Instruct", # 2B, NOT the default 4B
"chunk_size": 7,
"n_action_steps": 7,
"num_video_frames": 8,
"jepa_tubelet_size": 2,
"num_action_tokens_per_timestep": 8,
"num_embodied_action_tokens_per_instruction": 32,
"num_inference_timesteps": 4,
"action_hidden_size": 1024,
"action_model_type": "DiT-B",
# Explicit dims matching DiT-B preset and ginwind checkpoint shape
"action_num_heads": 12,
"action_attention_head_dim": 64,
"action_num_layers": 16,
"action_dropout": 0.2,
"repeated_diffusion_steps": 8,
"action_noise_beta_alpha": 1.5,
"action_noise_beta_beta": 1.0,
"action_noise_s": 0.999,
"action_num_timestep_buckets": 1000,
# World model predictor (12 blocks, confirmed from checkpoint)
"predictor_depth": 12,
}
# ---------------------------------------------------------------------------
# Image-key sets (confirmed sources in module docstring)
# ---------------------------------------------------------------------------
# LIBERO — confirmed from lerobot/libero_10 meta/info.json
_LIBERO_CAMS = [
"observation.images.image", # agentview camera
"observation.images.image2", # eye-in-hand camera
]
# DROID pretrain — 2 views match the predictor embed_dim=2 × 1024=2048 in checkpoint
_DROID_CAMS = [
"observation.images.exterior_1_left",
"observation.images.exterior_2_left",
]
# OXE Bridge + RT1 — single-camera; world model disabled (predictor embed_dim mismatch)
_OXE_CAMS = [
"observation.images.image",
]
# ---------------------------------------------------------------------------
# Config factories
# ---------------------------------------------------------------------------
def _build_config(
camera_keys: list[str],
with_state: bool,
enable_world_model: bool = True,
source_config: dict | None = None,
):
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
kwargs = dict(_ARCH)
_apply_source_config(kwargs, source_config or {})
# Image resolution: prefer source config, fall back to 224
data_cfg = _deep_get(source_config or {}, ("datasets", "vla_data"), {})
raw_res = data_cfg.get("resolution_size")
resolution_size = int(raw_res) if raw_res is not None else 224
image_shape = (3, resolution_size, resolution_size)
# Always set resize_images_to so the policy resizes env images to the training resolution,
# regardless of what resolution the eval env renders at.
kwargs["resize_images_to"] = (resolution_size, resolution_size)
# State / action dims: prefer source config
action_cfg = _deep_get(source_config or {}, ("framework", "action_model"), {})
state_dim = int(action_cfg["state_dim"]) if "state_dim" in action_cfg else 8
action_dim = int(action_cfg["action_dim"]) if "action_dim" in action_cfg else 7
input_features = {k: PolicyFeature(type=FeatureType.VISUAL, shape=image_shape) for k in camera_keys}
if with_state:
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))
cfg = VLAJEPAConfig(
input_features=input_features,
output_features={
"action": PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
},
enable_world_model=enable_world_model,
binarize_gripper_action=True,
clip_normalized_actions=True,
**kwargs,
)
cfg.validate_features()
return cfg
# Maps each subfolder in SOURCE_REPO_ID to (camera_keys, with_state, enable_world_model, repo_suffix)
VARIANTS: dict[str, tuple] = {
"LIBERO": (_LIBERO_CAMS, True, True, "LIBERO"),
"Pretrain": (_DROID_CAMS, False, True, "Pretrain"),
# SimplerEnv uses a single camera; the predictor embed_dim (2048) would mismatch, so
# disable the world model — only qwen + action_model weights are needed for inference.
"SimplerEnv": (_OXE_CAMS, False, False, "SimplerEnv"),
}
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def extract_state_dict(ckpt: object) -> dict[str, torch.Tensor]:
if isinstance(ckpt, dict):
sd = ckpt.get("state_dict") or ckpt.get("model_state_dict") or ckpt.get("model")
if sd is None:
sd = ckpt
else:
sd = ckpt
return {k: v for k, v in sd.items() if isinstance(v, torch.Tensor)}
def subfolder_of(pt_path: str) -> str | None:
for part in Path(pt_path).parts:
if part in VARIANTS:
return part
return None
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
api = HfApi()
log.info("Listing .pt files in %s", SOURCE_REPO_ID)
pt_files = [f for f in api.list_repo_files(SOURCE_REPO_ID) if f.endswith(".pt")]
if not pt_files:
log.error("No .pt files found.")
return
for f in pt_files:
log.info(" %s", f)
# Create / reuse the collection once
collection = api.create_collection(
title=COLLECTION_TITLE,
description=COLLECTION_DESCRIPTION,
namespace=TARGET_ORG,
exists_ok=True,
)
log.info("Collection: %s", collection.url)
for pt_filename in pt_files:
log.info("\n=== %s ===", pt_filename)
subfolder = subfolder_of(pt_filename)
if subfolder is None:
log.warning(" No variant entry for '%s' — skipping.", pt_filename)
continue
camera_keys, with_state, enable_world_model, repo_suffix = VARIANTS[subfolder]
target_repo_id = f"{TARGET_ORG}/VLA-JEPA-{repo_suffix}"
log.info(
" cameras=%d with_state=%s wm=%s%s",
len(camera_keys),
with_state,
enable_world_model,
target_repo_id,
)
# 1. Download
local_pt = api.hf_hub_download(SOURCE_REPO_ID, pt_filename)
log.info(" Downloaded → %s", local_pt)
# 2. Load checkpoint
try:
ckpt = torch.load(local_pt, map_location="cpu", mmap=True, weights_only=False) # nosec B614
except TypeError:
ckpt = torch.load(local_pt, map_location="cpu") # nosec B614
sd = extract_state_dict(ckpt)
# Map source key names → LeRobot layout (handles layer1→w1, transformer_blocks→blocks, etc.)
mapped_sd: dict[str, torch.Tensor] = {}
skipped_keys: list[str] = []
for raw_key, value in sd.items():
target_key = _map_checkpoint_key(raw_key)
if target_key is None:
skipped_keys.append(raw_key)
else:
mapped_sd[target_key] = value
log.info(" %d tensors mapped, %d skipped", len(mapped_sd), len(skipped_keys))
if skipped_keys:
log.info(" Skipped sample: %s", skipped_keys[:5])
log.info(" First 5 mapped keys: %s", list(mapped_sd)[:5])
# 3. Fetch action + state stats needed by the pre/postprocessor unnormalizers
dataset_stats = _fetch_dataset_stats(api, SOURCE_REPO_ID, subfolder)
# 4. Build config (no policy instantiation — avoids loading backbone from Hub)
source_config = _fetch_source_config(api, SOURCE_REPO_ID, subfolder)
config = _build_config(camera_keys, with_state, enable_world_model, source_config)
# 5. Save everything to a temp dir and upload in one shot
api.create_repo(target_repo_id, repo_type="model", exist_ok=True)
with tempfile.TemporaryDirectory() as tmp:
save_dir = Path(tmp)
log.info(" Saving model.safetensors …")
save_safetensors(mapped_sd, save_dir / "model.safetensors")
config._save_pretrained(save_dir) # writes config.json via draccus
preprocessor, postprocessor = make_vla_jepa_pre_post_processors(config, dataset_stats)
preprocessor.save_pretrained(save_dir) # writes policy_preprocessor.json
postprocessor.save_pretrained(save_dir) # writes policy_postprocessor.json
log.info(" Uploading …")
commit_url = api.upload_folder(
folder_path=save_dir,
repo_id=target_repo_id,
repo_type="model",
commit_message=f"Convert {Path(pt_filename).name} to safetensors",
)
log.info(" Uploaded → %s", commit_url)
# 6. Add to collection
api.add_collection_item(
collection_slug=collection.slug,
item_id=target_repo_id,
item_type="model",
exists_ok=True,
)
log.info(" Added to collection.")
log.info("\nAll done. Collection: %s", collection.url)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,607 @@
from __future__ import annotations
import logging
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from PIL import Image
from torch import Tensor, nn
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoModel, AutoVideoProcessor
else:
AutoModel = None
AutoVideoProcessor = None
from .action_head import VLAJEPAActionHead
from .configuration_vla_jepa import VLAJEPAConfig
from .qwen_interface import Qwen3VLInterface
from .world_model import ActionConditionedVideoPredictor
# ============================================================================
# Native VLA-JEPA Model - follows original starVLA VLA_JEPA.py implementation
# ============================================================================
class VLAJEPAModel(nn.Module):
"""
Native VLA-JEPA model following the original starVLA VLA_JEPA.py.
Components:
- Qwen3-VL: vision-language backbone for fused embeddings
- DiT-B: flow-matching action head for future action prediction
- V-JEPA: world model for video frame prediction
Input: List[dict] native format (same as original starVLA)
- "image": List[PIL.Image] (multi-view images)
- "video": np.ndarray [V, T, H, W, 3]
- "lang": str (task instruction)
- "action": np.ndarray [T, action_dim] (optional, training only)
- "state": np.ndarray [1, state_dim] (optional)
"""
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
require_package("transformers", extra="vla_jepa")
self.config = config
# Vision-language backbone
self.qwen = Qwen3VLInterface(config)
# Tokenizer expansion for special action tokens
self.action_tokens, self.action_token_ids, self.embodied_action_token_id = (
self.qwen.expand_tokenizer()
)
# Action head (flow-matching DiT)
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
# JEPA world model components
if config.enable_world_model:
self.video_encoder = AutoModel.from_pretrained(
config.jepa_encoder_name,
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
)
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
num_views = max(1, len(config.image_features))
tubelet_size = self.video_encoder.config.tubelet_size
image_size = getattr(self.video_encoder.config, "image_size", None)
if image_size is None:
first_image_shape = next(iter(config.image_features.values())).shape
image_size = first_image_shape[-1]
self.video_predictor = ActionConditionedVideoPredictor(
num_frames=config.num_video_frames // tubelet_size,
img_size=(image_size, image_size),
patch_size=16,
tubelet_size=1,
embed_dim=self.video_encoder.config.hidden_size * num_views,
action_embed_dim=self.qwen.model.config.hidden_size,
predictor_embed_dim=self.video_encoder.config.hidden_size,
depth=config.predictor_depth,
num_heads=config.predictor_num_heads,
mlp_ratio=config.predictor_mlp_ratio,
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
)
else:
self.video_encoder = None
self.video_processor = None
self.video_predictor = None
if config.freeze_qwen:
self.qwen.requires_grad_(False)
# Build prompt placeholders.
# Use the encoder's actual tubelet_size when available (world model enabled),
# otherwise fall back to config.
_tubelet_size = (
self.video_encoder.config.tubelet_size
if config.enable_world_model
else self.config.jepa_tubelet_size
)
num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1
self.replace_prompt = "".join(
token * self.config.num_action_tokens_per_timestep
for token in self.action_tokens[:num_action_prompt_steps]
)
self.embodied_replace_prompt = (
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
)
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return the last decoder hidden state before the final RMSNorm.
The model was trained with the output of the last transformer block BEFORE
the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
`output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
`@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
the correct pre-RMSNorm state, matching the training-time representation.
"""
captured: list[torch.Tensor] = []
def _hook(module, input, output):
h = output[0] if isinstance(output, tuple) else output
captured.append(h)
last_layer = self.qwen.model.model.language_model.layers[-1]
handle = last_layer.register_forward_hook(_hook)
try:
self.qwen.model(
**qwen_inputs,
output_hidden_states=False,
output_attentions=False,
return_dict=True,
)
finally:
handle.remove()
return captured[0] # [B, seq_len, H]
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
"""
Native forward pass following original starVLA VLA_JEPA.forward.
Args:
examples: List of per-sample dicts with keys:
"image" : List[PIL.Image] — multi-view images
"video" : np.ndarray [V, T, H, W, 3]
"lang" : str — task instruction
"action" : np.ndarray [T, action_dim] (optional)
"state" : np.ndarray [1, state_dim] (optional)
Returns:
dict with "action_loss" and "wm_loss" keys (scalar Tensors).
"""
# Unpack native format (same pattern as original VLA_JEPA.py)
batch_images = [ex["image"] for ex in examples] # List[List[PIL.Image]]
batch_videos = [ex["video"] for ex in examples] # List[np.ndarray]
instructions = [ex["lang"] for ex in examples] # List[str]
has_action = "action" in examples[0] and examples[0]["action"] is not None
actions = [ex["action"] for ex in examples] if has_action else None
has_state = "state" in examples[0] and examples[0]["state"] is not None
state = [ex["state"] for ex in examples] if has_state else None
action_is_pad = (
[ex["action_is_pad"] for ex in examples]
if has_action and "action_is_pad" in examples[0] and examples[0]["action_is_pad"] is not None
else None
)
# Stack videos: [B, V, T, H, W, 3] -> [B, V, T, 3, H, W]
batch_videos = np.stack(batch_videos)
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
# ---- Step 1: QwenVL encode (same as original) ----
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
# Locate embodied-action tokens (always needed for action head)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
# Locate action tokens (only needed for world model predictor)
if self.config.enable_world_model:
action_mask = torch.isin(
qwen_inputs["input_ids"],
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
)
action_indices = action_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
if self.config.enable_world_model:
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
# ---- Step 2+3: JEPA Encoder + Predictor ----
device_wm = last_hidden.device
if not self.config.enable_world_model:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
b, v, t_frames, c, h_img, w_img = batch_videos.shape
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
video_pixels = []
for i in range(b * v):
video_pixels.append(
self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[
"pixel_values_videos"
].to(self.video_encoder.device)
)
video_pixels = torch.cat(video_pixels, dim=0) # [B*V, T, C, H, W]
with torch.no_grad():
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
tubelet_size = self.video_encoder.config.tubelet_size
device_wm = video_embeddings.device
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
t_enc_total = self.config.num_video_frames // tubelet_size
if t_enc_total < 2:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
# input_states: positions 0..T-2, gt_states: positions 1..T-1
t_enc_ctx = t_enc_total - 1
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
gt_states = video_embeddings[:, tokens_per_frame:, :]
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_actions:
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
action_tokens = torch.cat([action_tokens, pad], dim=1)
predicted_states = self.video_predictor(
input_states.float(),
action_tokens[:, :expected_actions].float(),
)
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
if not has_action:
return {"wm_loss": wm_loss}
# ---- Step 4: Action Head ----
with torch.autocast(device_type=device_type, dtype=torch.float32):
actions_tensor = torch.tensor(
np.array(actions), device=last_hidden.device, dtype=torch.float32
) # [B, T_full, action_dim]
action_horizon = self.config.chunk_size
actions_target = actions_tensor[:, -action_horizon:, :]
state_tensor = None
if state is not None:
state_tensor = torch.tensor(
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
) # [B, 1, state_dim]
repeated_diffusion_steps = self.config.repeated_diffusion_steps
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
if state_tensor is not None:
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
action_is_pad_rep = None
if action_is_pad is not None:
pad_tensor = torch.stack(
[
p.to(actions_target.device)
if isinstance(p, Tensor)
else torch.tensor(p, device=actions_target.device)
for p in action_is_pad
]
) # [B, T_full]
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
action_loss = self.action_model(
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
)
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
# ---- Native predict_action (follows original VLA_JEPA.predict_action) ----
@torch.no_grad()
def predict_action(
self,
batch_images: list[list[Image.Image]],
instructions: list[str],
state: np.ndarray | None = None,
) -> np.ndarray:
"""
Native action prediction following original VLA_JEPA.predict_action.
Args:
batch_images: List of samples; each is List[PIL.Image] (multi-view).
instructions: Task instructions, one per sample.
state: Optional [B, state_dim] numpy array.
Returns:
np.ndarray [B, action_horizon, action_dim] — predicted actions.
"""
if self.config.resize_images_to is not None:
height, width = self.config.resize_images_to
resampling = getattr(Image, "Resampling", Image).BOX
batch_images = [
[image.resize((width, height), resample=resampling) for image in sample_images]
for sample_images in batch_images
]
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
action_prompt=self.replace_prompt,
embodied_prompt=self.embodied_replace_prompt,
)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
state_tensor = None
if state is not None:
state_tensor = torch.from_numpy(np.array(state)).to(
device=last_hidden.device, dtype=last_hidden.dtype
)
pred_actions = self.action_model.predict_action(
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
) # [B, action_horizon, action_dim]
return pred_actions.detach().cpu().numpy()
# ============================================================================
# LeRobot Adapter Layer - converts between LeRobot batch format and native VLA-JEPA format
# ============================================================================
class VLAJEPAPolicy(PreTrainedPolicy):
"""
LeRobot adapter for VLA-JEPA.
Converts LeRobot's standard batch format (dict[str, Tensor]) to the native
VLA-JEPA format (List[dict]), calls the native model, and converts outputs
back to LeRobot format.
"""
config_class = VLAJEPAConfig
name = "vla_jepa"
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
super().__init__(config)
config.validate_features()
if dataset_meta := kwargs.get("dataset_meta"):
# cfg.input_features keeps the pretrained model's feature keys (needed for rename_map
# compatibility), so validate_features() may have read stale dims from a pretrained
# config. Override state_dim/action_dim from the actual dataset being used.
ds_features = dataset_meta.features
if OBS_STATE in ds_features:
config.state_dim = ds_features[OBS_STATE]["shape"][0]
if ACTION in ds_features:
config.action_dim = ds_features[ACTION]["shape"][0]
self.model = VLAJEPAModel(config)
self.reset()
def reset(self) -> None:
self._queues = {ACTION: deque(maxlen=self.config.n_action_steps)}
# ---- Format Conversion: LeRobot → Native ----
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
"""
Convert LeRobot batch format to native VLA-JEPA examples format.
LeRobot format:
batch = {
"observation.images.<key>": Tensor [B, C, H, W] or [B, T, C, H, W],
"observation.state": Tensor [B, state_dim] or [B, T, state_dim],
"action": Tensor [B, chunk_size, action_dim], (training only)
"task": str | List[str], (optional instruction)
}
Native format (List[dict]):
{
"image": List[PIL.Image], # multi-view images per sample
"video": np.ndarray [V, T, H, W, 3],
"lang": str, # task instruction
"action": np.ndarray [T, action_dim], # optional
"state": np.ndarray [1, state_dim], # optional
}
"""
# Determine batch size from the first image feature
image_keys = list(self.config.image_features.keys())
if not image_keys:
raise ValueError("VLAJEPA requires at least one image feature.")
first_key = image_keys[0]
first_tensor = batch[first_key]
batch_size = first_tensor.shape[0]
# ---- Collect images per sample ----
# images_per_sample[b][v] = PIL.Image for view v
images_per_sample: list[list[Image.Image]] = [[] for _ in range(batch_size)]
for key in image_keys:
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
if tensor.ndim == 5:
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
# index 0 is the current observation (delta=0)
tensor = tensor[:, 0]
for b in range(batch_size):
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
# ---- Collect videos per sample ----
# Build video arrays: for each sample, stack views as [V, T, H, W, 3]
# Check whether any image feature has a time dimension
video_source = None
for k in image_keys:
if k in batch:
video_source = batch[k] # Use first available for shape inspection
break
if video_source is None:
raise ValueError("No image data found in batch for video construction.")
videos_per_sample = []
for b in range(batch_size):
sample_views = []
for k in image_keys:
t = batch[k][b] # [C, H, W] or [T, C, H, W]
if t.ndim == 3:
t = t.unsqueeze(0) # [1, C, H, W]
# Convert to [T, H, W, 3] numpy
t_np = t.permute(0, 2, 3, 1).detach().cpu().float().numpy()
# Clamp to [0, 255]
if t_np.max() <= 1.0:
t_np = t_np * 255.0
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
sample_views.append(t_np)
# Stack views: [V, T, H, W, 3]
videos_per_sample.append(np.stack(sample_views, axis=0))
# ---- Collect instructions ----
tasks = batch.get("task")
if tasks is None:
instructions = ["Execute the robot action."] * batch_size
elif isinstance(tasks, str):
instructions = [tasks] * batch_size
else:
instructions = list(tasks)
# ---- Collect actions (training only) ----
actions_list = None
action_is_pad_list = None
actions_tensor = batch.get(ACTION)
if actions_tensor is not None:
if actions_tensor.ndim == 2:
actions_tensor = actions_tensor.unsqueeze(1)
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
action_is_pad_tensor = batch.get("action_is_pad")
if action_is_pad_tensor is not None:
action_is_pad_list = [action_is_pad_tensor[b].detach().cpu() for b in range(batch_size)]
# ---- Collect state ----
state_list = None
state_tensor = batch.get(OBS_STATE)
if state_tensor is not None:
if state_tensor.ndim > 2:
state_tensor = state_tensor[:, -1, :]
if state_tensor.ndim == 2:
state_tensor = state_tensor.unsqueeze(1) # [B, 1, state_dim]
state_list = [state_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
# ---- Assemble native examples ----
examples = []
for b in range(batch_size):
example = {
"image": images_per_sample[b],
"video": videos_per_sample[b],
"lang": instructions[b],
}
if actions_list is not None:
example["action"] = actions_list[b]
if action_is_pad_list is not None:
example["action_is_pad"] = action_is_pad_list[b]
if state_list is not None:
example["state"] = state_list[b]
examples.append(example)
return examples
# ---- LeRobot Policy Interface ----
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""LeRobot train forward: convert → native forward → aggregate losses."""
examples = self._prepare_model_inputs(batch)
native_output = self.model.forward(examples)
ref = next(iter(native_output.values()))
zero = torch.zeros((), device=ref.device, dtype=ref.dtype)
total_loss = native_output.get("action_loss", zero) + native_output.get("wm_loss", zero)
logs = {k: v.detach().item() for k, v in native_output.items()}
logs["loss"] = total_loss.detach().item()
return total_loss, logs
def get_optim_params(self) -> dict:
return self.model.parameters()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot inference: convert → native predict → return as Tensor."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
examples = self._prepare_model_inputs(batch)
batch_images = [ex["image"] for ex in examples]
instructions = [ex["lang"] for ex in examples]
state_np = None
if "state" in examples[0] and examples[0]["state"] is not None:
state_np = np.stack([ex["state"] for ex in examples])
actions_np = self.model.predict_action(batch_images, instructions, state_np)
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot select_action with action queue caching."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()
@classmethod
def from_pretrained(
cls: type[T],
pretrained_name_or_path: str | Path,
**kwargs,
):
return super().from_pretrained(pretrained_name_or_path, **kwargs)
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
"""
Custom loading to enable opt reinit of action head
when loading pretrained weights with mismatched action head shapes.
"""
if not model.config.reinit_action_head:
return super()._load_as_safetensor(model, model_file, map_location, strict)
from safetensors.torch import load_file
state_dict = load_file(model_file, device=map_location)
current = model.state_dict()
mismatched: list[str] = []
filtered: dict = {}
for key, value in state_dict.items():
if key in current and value.shape != current[key].shape:
mismatched.append(
f"{key}: checkpoint {tuple(value.shape)} vs model {tuple(current[key].shape)}"
)
else:
filtered[key] = value
if mismatched:
logging.warning(
f"reinit_action_head=True: skipping {len(mismatched)} tensor(s) with mismatched shapes "
f"(randomly re-initialised):\n " + "\n ".join(mismatched)
)
from lerobot.policies.utils import log_model_loading_keys
missing_keys, unexpected_keys = model.load_state_dict(filtered, strict=False)
log_model_loading_keys(missing_keys, unexpected_keys)
return model

View File

@@ -0,0 +1,139 @@
from __future__ import annotations
from typing import Any
import torch
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeviceProcessorStep,
EnvTransition,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
@ProcessorStepRegistry.register(name="vla_jepa_clip_actions")
class ClipActionsProcessorStep(ProcessorStep):
"""Clips action tensor to [-1, 1] before unnormalization."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None:
transition = dict(transition)
transition[TransitionKey.ACTION] = action.clamp(-1.0, 1.0)
return transition
def transform_features(self, features):
return features
@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper")
class PreSnapGripperProcessorStep(ProcessorStep):
"""Snaps gripper dim (index 6) to {0, 1} BEFORE unnormalization.
Mirrors the original starVLA LIBERO eval:
normalized[:, 6] = np.where(normalized[:, 6] < 0.5, 0, 1)
This ensures the unnormalizer receives an exact binary value, which is
required when the model was trained with gripper in identity (mask=False)
space where 0=open and 1=close.
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] >= 7:
transition = dict(transition)
a = action.clone()
a[..., 6] = (a[..., 6] >= 0.5).float()
transition[TransitionKey.ACTION] = a
return transition
def transform_features(self, features):
return features
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
class BinarizeGripperProcessorStep(ProcessorStep):
"""Binarizes gripper dim (index 6) after unnormalization.
Maps continuous value to {-1, 1}: > 0.5 → -1, <= 0.5 → 1 (matches starVLA convention).
Only applied when action has >= 7 dimensions.
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] >= 7:
transition = dict(transition)
a = action.clone()
a[..., 6] = 1.0 - 2.0 * (a[..., 6] > 0.5).float()
transition[TransitionKey.ACTION] = a
return transition
def transform_features(self, features):
return features
def make_vla_jepa_pre_post_processors(
config: VLAJEPAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
features = {**config.input_features, **config.output_features}
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features=features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps: list[ProcessorStep] = []
if config.clip_normalized_actions:
output_steps.append(ClipActionsProcessorStep())
if config.pre_snap_gripper_action:
output_steps.append(PreSnapGripperProcessorStep())
output_steps.append(
UnnormalizerProcessorStep(
features=features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
)
)
if config.binarize_gripper_action:
output_steps.append(BinarizeGripperProcessorStep())
output_steps.append(DeviceProcessorStep(device="cpu"))
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
@ProcessorStepRegistry.register(name="vla_jepa_new_line_processor")
class VLAJEPANewLineProcessor(ComplementaryDataProcessorStep):
def complementary_data(self, complementary_data):
return complementary_data
def transform_features(self, features):
return features

View File

@@ -0,0 +1,103 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
import numpy as np
import torch
from PIL import Image
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
else:
AutoProcessor = None
Qwen3VLForConditionalGeneration = None
from .configuration_vla_jepa import VLAJEPAConfig
class Qwen3VLInterface(torch.nn.Module):
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
self.config = config
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
config.qwen_model_name,
torch_dtype=self._get_torch_dtype(config.torch_dtype),
)
self.processor = AutoProcessor.from_pretrained(config.qwen_model_name)
self.processor.tokenizer.padding_side = config.tokenizer_padding_side
self.model.config.hidden_size = self.model.config.text_config.hidden_size
@staticmethod
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
if dtype_name == "float32":
return torch.float32
if dtype_name == "float16":
return torch.float16
return torch.bfloat16
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
# starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4,
# independent of vj2 num_action_tokens_per_timestep. Keeping this count
# is required for Qwen embedding/lm_head checkpoint shapes to match.
max_action_tokens = self.config.chunk_size * 4
tokenizer = self.processor.tokenizer
action_tokens = []
action_token_ids = []
for idx in range(max_action_tokens):
token = self.config.special_action_token.format(idx)
action_tokens.append(token)
if token not in tokenizer.get_vocab():
tokenizer.add_tokens([token], special_tokens=True)
action_token_ids.append(tokenizer.convert_tokens_to_ids(token))
embodied_action_token = self.config.embodied_action_token
if embodied_action_token not in tokenizer.get_vocab():
tokenizer.add_tokens([embodied_action_token], special_tokens=True)
embodied_action_token_id = tokenizer.convert_tokens_to_ids(embodied_action_token)
if self.model.get_input_embeddings().weight.size(0) < len(tokenizer):
self.model.resize_token_embeddings(len(tokenizer))
return action_tokens, action_token_ids, embodied_action_token_id
def build_inputs(
self,
images: Sequence[Sequence[Image.Image]],
instructions: Sequence[str],
action_prompt: str,
embodied_prompt: str,
) -> dict[str, torch.Tensor]:
messages = []
for sample_images, instruction in zip(images, instructions, strict=True):
prompt = self.config.prompt_template.format(
instruction=instruction,
actions=action_prompt,
e_actions=embodied_prompt,
)
content = [{"type": "image", "image": img} for img in sample_images]
content.append({"type": "text", "text": prompt})
messages.append([{"role": "user", "content": content}])
batch_inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
processor_kwargs={"padding": True, "return_tensors": "pt"},
)
return batch_inputs.to(self.model.device)
@staticmethod
def tensor_to_pil(image_tensor: torch.Tensor) -> Image.Image:
image = image_tensor.detach().cpu()
if image.ndim == 3 and image.shape[0] in (1, 3):
image = image.permute(1, 2, 0)
image = image.float()
if image.max() <= 1.0:
image = image * 255.0
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
if image.shape[-1] == 1:
image = np.repeat(image, 3, axis=-1)
return Image.fromarray(image)

View File

@@ -0,0 +1,404 @@
from __future__ import annotations
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
def build_action_block_causal_attention_mask(
num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1
) -> torch.Tensor:
tokens_per_frame = add_tokens + grid_height * grid_width
num_tokens = num_frames * tokens_per_frame
mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool)
mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool)
local_window_time = num_frames
for current_frame in range(num_frames):
first_context_frame = max(0, current_frame - local_window_time + 1)
for context_frame in range(first_context_frame, current_frame + 1):
row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame)
col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame)
mask[row, col] = mask_block
return mask
def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
_, _, _, dim = x.size()
if dim % 2 != 0:
raise ValueError("Embedding dimension must be even for rotary position encoding.")
omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device)
omega /= dim / 2.0
omega = 1.0 / 10000**omega
freqs = torch.einsum("..., f -> ... f", pos, omega)
emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2)
emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2)
y = x.unflatten(-1, (-1, 2))
y1, y2 = y.unbind(dim=-1)
y = torch.stack((-y2, y1), dim=-1).flatten(-2)
return x * emb_cos + y * emb_sin
class DropPath(nn.Module):
def __init__(self, drop_prob: float = 0.0) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
return x.div(keep_prob) * random_tensor
class MLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int | None = None,
out_features: int | None = None,
act_layer: type[nn.Module] = nn.GELU,
drop: float = 0.0,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ACRoPEAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_scale: float | None = None,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
use_sdpa: bool = True,
is_causal: bool = False,
grid_size: int = 16,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = qk_scale or self.head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop_prob = proj_drop
self.proj_drop = nn.Dropout(proj_drop)
self.use_sdpa = use_sdpa
self.d_dim = int(2 * ((self.head_dim // 3) // 2))
self.h_dim = int(2 * ((self.head_dim // 3) // 2))
self.w_dim = int(2 * ((self.head_dim // 3) // 2))
self.grid_size = grid_size
self.is_causal = is_causal
@staticmethod
def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
return ids // int(height * width)
def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
frame_ids = self._get_frame_pos(ids, height, width)
ids = ids - int(height * width) * frame_ids
return ids // width
def separate_positions(
self, ids: torch.Tensor, height: int, width: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
frame_ids = self._get_frame_pos(ids, height, width)
height_ids = self._get_height_pos(ids, height, width)
width_ids = ids - int(height * width) * frame_ids - width * height_ids
return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor | None = None,
attn_mask: torch.Tensor | None = None,
num_frames: int | None = None,
grid_height: int | None = None,
grid_width: int | None = None,
action_tokens: int = 0,
) -> torch.Tensor:
batch_size, num_tokens, channels = x.size()
if num_frames is None or grid_height is None or grid_width is None:
raise ValueError("num_frames, grid_height and grid_width are required.")
if mask is not None:
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
else:
mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device)
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
h_mask *= self.grid_size / grid_height
w_mask *= self.grid_size / grid_width
if action_tokens > 0:
x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels)
action_q, action_k, action_v = [], [], []
for idx in range(action_tokens):
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
qd = rotate_queries_or_keys(
q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
)
kd = rotate_queries_or_keys(
k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)
)
qr = q[..., self.d_dim :]
kr = k[..., self.d_dim :]
action_q.append(
torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
)
action_k.append(
torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)
)
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
action_k = torch.cat(action_k, dim=3).flatten(2, 3)
action_v = torch.cat(action_v, dim=3).flatten(2, 3)
x = x[:, :, action_tokens:, :].flatten(1, 2)
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
offset = 0
qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask)
kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask)
offset += self.d_dim
qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask)
kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask)
offset += self.h_dim
qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask)
kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask)
offset += self.w_dim
if offset < self.head_dim:
q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1)
k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1)
else:
q = torch.cat([qd, qh, qw], dim=-1)
k = torch.cat([kd, kh, kw], dim=-1)
if action_tokens > 0:
def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor:
frame_tokens = frame_tokens.view(
batch_size, self.num_heads, num_frames, grid_height * grid_width, -1
)
action_token_values = action_token_values.view(
batch_size, self.num_heads, num_frames, action_tokens, -1
)
return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3)
q = merge(q, action_q)
k = merge(k, action_k)
v = merge(v, action_v)
if attn_mask is not None or self.use_sdpa:
x = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
)
else:
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
x = self.proj(x)
return self.proj_drop(x)
class ACBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_scale: float | None = None,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
norm_layer: type[nn.Module] = nn.LayerNorm,
use_sdpa: bool = True,
is_causal: bool = False,
grid_size: int = 16,
use_rope: bool = True,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
if not use_rope:
raise ValueError("JEVLA1 world predictor uses AC RoPE attention.")
self.attn = ACRoPEAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
use_sdpa=use_sdpa,
is_causal=is_causal,
grid_size=grid_size,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = MLP(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=nn.GELU,
drop=drop,
)
def forward(
self,
x: torch.Tensor,
attn_mask: torch.Tensor | None = None,
num_frames: int | None = None,
grid_height: int | None = None,
grid_width: int | None = None,
action_tokens: int = 0,
) -> torch.Tensor:
y = self.norm1(x)
y = self.attn(
y,
mask=None,
attn_mask=attn_mask,
num_frames=num_frames,
grid_height=grid_height,
grid_width=grid_width,
action_tokens=action_tokens,
)
x = x + self.drop_path(y)
y = self.norm2(x)
return x + self.drop_path(self.mlp(y))
class ActionConditionedVideoPredictor(nn.Module):
"""JEVLA1-compatible action-conditioned V-JEPA predictor."""
def __init__(
self,
num_frames: int,
img_size: tuple[int, int],
patch_size: int,
tubelet_size: int,
embed_dim: int,
action_embed_dim: int,
predictor_embed_dim: int,
depth: int,
num_heads: int,
mlp_ratio: float,
num_action_tokens_per_step: int,
use_extrinsics: bool = False,
) -> None:
super().__init__()
self.is_frame_causal = True
self.use_extrinsics = use_extrinsics
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
self.img_height, self.img_width = img_size
self.patch_size = patch_size
self.num_frames = num_frames
self.tubelet_size = tubelet_size
self.grid_height = self.img_height // self.patch_size
self.grid_width = self.img_width // self.patch_size
self.predictor_blocks = nn.ModuleList(
[
ACBlock(
dim=predictor_embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6),
grid_size=self.grid_height,
use_rope=True,
)
for _ in range(depth)
]
)
self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6)
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
self.num_action_tokens_per_step = num_action_tokens_per_step
@property
def norm(self) -> nn.LayerNorm:
return self.predictor_norm
@property
def proj(self) -> nn.Linear:
return self.predictor_proj
def forward(
self,
frame_tokens: torch.Tensor,
action_tokens: torch.Tensor,
extrinsics: torch.Tensor | None = None,
) -> torch.Tensor:
# starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D].
x = self.predictor_embed(frame_tokens)
batch_size, num_context_tokens, hidden_dim = x.size()
num_frames = num_context_tokens // (self.grid_height * self.grid_width)
actions = self.action_encoder(action_tokens)
actions = actions.view(batch_size, num_frames, -1, hidden_dim)
cond_tokens = actions.shape[2]
x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim)
if self.use_extrinsics:
if extrinsics is None:
raise ValueError("extrinsics are required when use_extrinsics=True.")
cond_tokens += 1
extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2)
x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2)
else:
x = torch.cat([actions, x], dim=2).flatten(1, 2)
attn_mask = build_action_block_causal_attention_mask(
num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens
)
attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
for block in self.predictor_blocks:
x = block(
x,
attn_mask=attn_mask,
num_frames=num_frames,
grid_height=self.grid_height,
grid_width=self.grid_width,
action_tokens=cond_tokens,
)
x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim)
x = x[:, :, cond_tokens:, :].flatten(1, 2)
x = self.predictor_norm(x)
return self.predictor_proj(x)

View File

@@ -0,0 +1,268 @@
#!/usr/bin/env python
"""Shared fixtures and helpers for VLA-JEPA tests."""
from __future__ import annotations
from types import SimpleNamespace
import numpy as np
import pytest
import torch
from PIL import Image
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
# ---------------------------------------------------------------------------
# Shared constants
# ---------------------------------------------------------------------------
BATCH_SIZE = 2
ACTION_DIM = 3
STATE_DIM = 4
IMAGE_SIZE = 8
ACTION_HORIZON = 4
N_ACTION_STEPS = 2
NUM_VIDEO_FRAMES = 3
QWEN_HIDDEN_SIZE = 16 # hidden size produced by _FakeQwenBackbone
EXPECTED_ACTION_CHUNK_SHAPE = (BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
EXPECTED_SELECT_ACTION_SHAPE = (BATCH_SIZE, ACTION_DIM)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def set_seed_all(seed: int) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def make_config(
action_dim: int = ACTION_DIM,
state_dim: int = STATE_DIM,
action_horizon: int = ACTION_HORIZON,
num_video_frames: int = NUM_VIDEO_FRAMES,
) -> VLAJEPAConfig:
config = VLAJEPAConfig(
input_features={
f"{OBS_IMAGES}.laptop": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)),
},
output_features={
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
},
device="cpu",
chunk_size=action_horizon,
n_action_steps=min(N_ACTION_STEPS, action_horizon),
action_dim=action_dim,
state_dim=state_dim,
num_video_frames=num_video_frames,
num_action_tokens_per_timestep=2,
num_embodied_action_tokens_per_instruction=3,
num_inference_timesteps=2,
action_hidden_size=QWEN_HIDDEN_SIZE,
action_model_type="DiT-test",
action_num_layers=1,
predictor_depth=1,
predictor_num_heads=2,
predictor_mlp_ratio=2.0,
jepa_tubelet_size=1,
)
config.validate_features()
return config
def make_train_batch(
batch_size: int = BATCH_SIZE,
action_dim: int = ACTION_DIM,
state_dim: int = STATE_DIM,
action_horizon: int = ACTION_HORIZON,
num_video_frames: int = NUM_VIDEO_FRAMES,
) -> dict[str, Tensor | list[str]]:
return {
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, num_video_frames, 3, IMAGE_SIZE, IMAGE_SIZE),
OBS_STATE: torch.randn(batch_size, 1, state_dim),
ACTION: torch.randn(batch_size, action_horizon, action_dim),
"task": ["pick up the cube"] * batch_size,
}
def make_inference_batch(
batch_size: int = BATCH_SIZE,
state_dim: int = STATE_DIM,
) -> dict[str, Tensor | list[str]]:
return {
f"{OBS_IMAGES}.laptop": torch.rand(batch_size, 3, IMAGE_SIZE, IMAGE_SIZE),
OBS_STATE: torch.randn(batch_size, state_dim),
"task": ["pick up the cube"] * batch_size,
}
# ---------------------------------------------------------------------------
# Fake external models (replace Qwen3-VL and V-JEPA at test time)
# ---------------------------------------------------------------------------
class _FakeLanguageLayer(nn.Module):
"""Leaf module whose forward hook is captured by _qwen_last_decoder_hidden."""
def __init__(self, hidden_size: int) -> None:
super().__init__()
self._hidden_size = hidden_size
def forward(self, hidden: Tensor, **_: object) -> tuple[Tensor, ...]:
return (hidden,)
class _FakeLanguageModel(nn.Module):
def __init__(self, hidden_size: int) -> None:
super().__init__()
self._hidden_size = hidden_size
self.layers = nn.ModuleList([_FakeLanguageLayer(hidden_size)])
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
batch_size, seq_len = input_ids.shape
hidden = torch.zeros(batch_size, seq_len, self._hidden_size, device=input_ids.device)
self.layers[-1](hidden)
return SimpleNamespace()
class _FakeQwenInnerModel(nn.Module):
"""Mimics the `.model.model` level that _qwen_last_decoder_hidden walks into."""
def __init__(self, hidden_size: int) -> None:
super().__init__()
self.language_model = _FakeLanguageModel(hidden_size)
def forward(self, input_ids: Tensor, **kwargs: object) -> SimpleNamespace:
return self.language_model(input_ids)
class _FakeQwenBackbone(nn.Module):
def __init__(self, hidden_size: int) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(1))
self.config = SimpleNamespace(
hidden_size=hidden_size,
text_config=SimpleNamespace(hidden_size=hidden_size),
)
self.model = _FakeQwenInnerModel(hidden_size)
@property
def device(self) -> torch.device:
return self.weight.device
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
batch_size, seq_len = input_ids.shape
hidden_size = self.config.hidden_size
values = torch.arange(
batch_size * seq_len * hidden_size,
device=input_ids.device,
dtype=torch.float32,
).view(batch_size, seq_len, hidden_size)
hidden = values / values.numel() + self.weight
return SimpleNamespace(hidden_states=[hidden])
class _FakeQwenInterface(nn.Module):
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
self.config = config
self.model = _FakeQwenBackbone(hidden_size=QWEN_HIDDEN_SIZE)
@staticmethod
def _get_torch_dtype(dtype_name: str) -> torch.dtype:
return torch.float32 if dtype_name == "float32" else torch.bfloat16
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
action_tokens = [self.config.special_action_token.format(idx) for idx in range(max_action_tokens)]
action_token_ids = list(range(1000, 1000 + max_action_tokens))
return action_tokens, action_token_ids, 2000
def build_inputs(
self,
images: list[list[Image.Image]],
instructions: list[str],
action_prompt: str,
embodied_prompt: str,
) -> dict[str, Tensor]:
batch_size = len(images)
del images, instructions, action_prompt, embodied_prompt
action_count = (self.config.num_video_frames - 1) * self.config.num_action_tokens_per_timestep
token_ids = (
[10]
+ list(range(1000, 1000 + action_count))
+ [2000] * self.config.num_embodied_action_tokens_per_instruction
+ [11]
)
return {
"input_ids": torch.tensor(
[token_ids] * batch_size,
device=self.model.device,
dtype=torch.long,
)
}
@staticmethod
def tensor_to_pil(image_tensor: Tensor) -> Image.Image:
image = image_tensor.detach().cpu()
if image.ndim == 3 and image.shape[0] in (1, 3):
image = image.permute(1, 2, 0)
image = (image.float().clamp(0, 1) * 255).to(torch.uint8).numpy()
return Image.fromarray(image)
class _FakeVideoEncoder(nn.Module):
def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(1))
# image_size must be >= patch_size (16) so the predictor grid is non-zero.
# Setting image_size=16 gives a 1x1 grid (1 patch per frame).
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size, image_size=16)
@property
def device(self) -> torch.device:
return self.weight.device
def get_vision_features(self, pixel_values_videos: Tensor) -> Tensor:
batch_size, num_frames = pixel_values_videos.shape[:2]
hidden_size = self.config.hidden_size
frame_values = pixel_values_videos.float().mean(dim=(2, 3, 4), keepdim=False)
return frame_values[:, :, None].expand(batch_size, num_frames, hidden_size)
class _FakeVideoProcessor:
def __call__(self, videos: np.ndarray, return_tensors: str) -> dict[str, Tensor]:
assert return_tensors == "pt"
return {"pixel_values_videos": torch.as_tensor(videos).unsqueeze(0)}
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None:
from lerobot.policies.vla_jepa import modeling_vla_jepa
monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface)
monkeypatch.setattr(
modeling_vla_jepa.AutoModel,
"from_pretrained",
lambda *args, **kwargs: _FakeVideoEncoder(),
)
monkeypatch.setattr(
modeling_vla_jepa.AutoVideoProcessor,
"from_pretrained",
lambda *args, **kwargs: _FakeVideoProcessor(),
)

View File

@@ -0,0 +1,157 @@
#!/usr/bin/env python
from __future__ import annotations
import pytest
import torch
pytest.importorskip("diffusers")
from conftest import (
ACTION_DIM,
ACTION_HORIZON,
BATCH_SIZE,
QWEN_HIDDEN_SIZE,
STATE_DIM,
make_config,
set_seed_all,
) # noqa: E402
from lerobot.policies.vla_jepa.action_head import ( # noqa: E402
VLAJEPAActionHead,
)
# ---------------------------------------------------------------------------
# VLAJEPAActionHead
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4), # default test dims
(7, 0, 16), # no proprioceptive state, production-like action space
(6, 8, 8), # medium dims
],
)
def test_action_head_sample_time_range(action_dim: int, state_dim: int, action_horizon: int) -> None:
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
t = head.sample_time(batch_size=200, device=torch.device("cpu"), dtype=torch.float32)
assert t.shape == (200,)
assert torch.isfinite(t).all()
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_action_head_build_inputs_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(2, action_horizon, action_dim)
timesteps = torch.randint(0, 100, (2,))
state = torch.randn(2, state_dim) if state_dim > 0 else None
out_with = head._build_inputs(conditioning, actions, state, timesteps)
out_none = head._build_inputs(conditioning, actions, None, timesteps)
assert out_with.ndim == 3 and out_none.ndim == 3
if state_dim > 0:
assert out_with.shape[1] > out_none.shape[1]
assert torch.isfinite(out_with).all() and torch.isfinite(out_none).all()
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_action_head_forward_loss_valid(action_dim: int, state_dim: int, action_horizon: int) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(2, action_horizon, action_dim)
state = torch.randn(2, state_dim) if state_dim > 0 else None
loss = head.forward(conditioning, actions, state)
assert loss.shape == ()
assert torch.isfinite(loss) and loss > 0
def test_action_head_forward_gradient_flows() -> None:
set_seed_all(42)
config = make_config()
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
state = torch.randn(BATCH_SIZE, STATE_DIM)
loss = head.forward(conditioning, actions, state)
loss.backward()
assert any(p.grad is not None for p in head.parameters() if p.requires_grad)
@torch.no_grad()
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_action_head_predict_action_shape(action_dim: int, state_dim: int, action_horizon: int) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(2, 4, QWEN_HIDDEN_SIZE)
state = torch.randn(2, state_dim) if state_dim > 0 else None
pred = head.predict_action(conditioning, state)
assert tuple(pred.shape) == (2, action_horizon, action_dim)
assert torch.isfinite(pred).all()
# ---------------------------------------------------------------------------
# action_is_pad masking
# ---------------------------------------------------------------------------
def test_action_head_loss_fully_padded_is_zero() -> None:
"""Loss is 0 when every timestep is padded (exercises the clamp_min guard)."""
set_seed_all(42)
config = make_config()
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
state = torch.randn(BATCH_SIZE, STATE_DIM)
action_is_pad = torch.ones(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
loss = head.forward(conditioning, actions, state, action_is_pad)
assert loss.item() == 0.0
def test_action_head_loss_none_matches_no_padding() -> None:
"""action_is_pad=None is equivalent to an all-False (no padding) mask."""
set_seed_all(42)
config = make_config()
head = VLAJEPAActionHead(config, cross_attention_dim=QWEN_HIDDEN_SIZE)
conditioning = torch.randn(BATCH_SIZE, 4, QWEN_HIDDEN_SIZE)
actions = torch.randn(BATCH_SIZE, ACTION_HORIZON, ACTION_DIM)
state = torch.randn(BATCH_SIZE, STATE_DIM)
set_seed_all(0)
loss_none = head.forward(conditioning, actions, state, action_is_pad=None)
set_seed_all(0)
no_pad = torch.zeros(BATCH_SIZE, ACTION_HORIZON, dtype=torch.bool)
loss_zeros = head.forward(conditioning, actions, state, action_is_pad=no_pad)
assert torch.isclose(loss_none, loss_zeros)

View File

@@ -0,0 +1,57 @@
#!/usr/bin/env python
from __future__ import annotations
import pytest
from conftest import ACTION_DIM, ACTION_HORIZON, IMAGE_SIZE, NUM_VIDEO_FRAMES, STATE_DIM, make_config
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
def test_delta_indices() -> None:
config = make_config()
assert config.observation_delta_indices == list(range(NUM_VIDEO_FRAMES))
assert config.action_delta_indices == list(range(ACTION_HORIZON))
def test_n_action_steps_exceeds_chunk_size_raises() -> None:
with pytest.raises(ValueError, match="n_action_steps"):
VLAJEPAConfig(chunk_size=4, n_action_steps=8)
def test_too_few_video_frames_raises() -> None:
with pytest.raises(ValueError, match="video_horizon"):
VLAJEPAConfig(
chunk_size=16,
n_action_steps=16,
num_video_frames=2,
jepa_tubelet_size=2, # needs >= 4 frames (2 for current, 2 for future) to have a window of size > 0
)
def test_validate_features_no_image_raises() -> None:
config = VLAJEPAConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
)
with pytest.raises(ValueError, match="at least one visual input feature"):
config.validate_features()
def test_validate_features_no_action_raises() -> None:
config = VLAJEPAConfig(
input_features={
f"{OBS_IMAGES}.cam": PolicyFeature(type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)),
},
output_features={},
)
with pytest.raises(ValueError, match="action output feature"):
config.validate_features()
def test_validate_features_sets_action_dim_from_feature() -> None:
config = make_config(action_dim=6, state_dim=10)
assert config.action_dim == 6
assert config.state_dim == 10

View File

@@ -0,0 +1,473 @@
#!/usr/bin/env python
from __future__ import annotations
import os
from copy import deepcopy
import numpy as np
import pytest
import torch
from torch import Tensor
pytest.importorskip("transformers")
pytest.importorskip("diffusers")
pytestmark = pytest.mark.filterwarnings(
"ignore:In CPU autocast, but the target dtype is not supported:UserWarning"
)
from conftest import ( # noqa: E402
ACTION_DIM,
ACTION_HORIZON,
BATCH_SIZE,
EXPECTED_ACTION_CHUNK_SHAPE,
EXPECTED_SELECT_ACTION_SHAPE,
N_ACTION_STEPS,
STATE_DIM,
make_config,
make_inference_batch,
make_train_batch,
set_seed_all,
)
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy # noqa: E402
from lerobot.utils.constants import ACTION # noqa: E402
PRETRAINED_REPO_ID = "ginwind/VLA-JEPA"
PRETRAINED_SUBFOLDER = "LIBERO"
# extended hub tests load the full converted safetensors checkpoints (~5 GB) and are
# skipped by default. Set VLA_JEPA_EXTENDED=1 to opt in.
_VLA_JEPA_EXTENDED = os.environ.get("VLA_JEPA_EXTENDED", "0") != "0"
extended_test = pytest.mark.skipif(not _VLA_JEPA_EXTENDED, reason="Set VLA_JEPA_EXTENDED=1 to run hub tests")
# ---------------------------------------------------------------------------
# Core training / inference tests
# ---------------------------------------------------------------------------
def test_training_forward_pass(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.train()
batch = make_train_batch()
batch_before = deepcopy(batch)
loss, logs = policy.forward(batch)
assert loss.shape == ()
assert torch.isfinite(loss)
assert set(logs) == {"action_loss", "wm_loss", "loss"}
assert logs["action_loss"] > 0
assert logs["wm_loss"] >= 0
loss.backward()
assert any(p.grad is not None for p in policy.model.action_model.parameters() if p.requires_grad)
# Batch must not be mutated.
assert set(batch) == set(batch_before)
for key, value in batch.items():
if isinstance(value, Tensor):
assert torch.equal(value, batch_before[key])
else:
assert value == batch_before[key]
@pytest.mark.parametrize("batch_size", [1, 2, 4])
def test_training_forward_various_batch_sizes(patch_vla_jepa_external_models: None, batch_size: int) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.train()
loss, logs = policy.forward(make_train_batch(batch_size=batch_size))
assert torch.isfinite(loss) and loss > 0
assert set(logs) == {"action_loss", "wm_loss", "loss"}
@pytest.mark.parametrize(
"action_dim,state_dim,action_horizon",
[
(3, 4, 4),
(7, 0, 16),
(6, 8, 8),
],
)
def test_training_forward_various_dims(
patch_vla_jepa_external_models: None,
action_dim: int,
state_dim: int,
action_horizon: int,
) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
policy = VLAJEPAPolicy(config)
policy.train()
batch = make_train_batch(action_dim=action_dim, state_dim=state_dim, action_horizon=action_horizon)
loss, _ = policy.forward(batch)
assert torch.isfinite(loss) and loss > 0
@torch.no_grad()
def test_action_generation_shape(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.eval()
batch = make_inference_batch()
chunk = policy.predict_action_chunk(batch)
assert tuple(chunk.shape) == EXPECTED_ACTION_CHUNK_SHAPE
assert chunk.device.type == "cpu"
assert torch.isfinite(chunk).all()
a1 = policy.select_action(batch)
a2 = policy.select_action(batch)
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
assert torch.isfinite(a1).all() and torch.isfinite(a2).all()
@torch.no_grad()
@pytest.mark.parametrize("action_dim,state_dim", [(3, 4), (7, 0), (6, 8)])
def test_action_generation_various_dims(
patch_vla_jepa_external_models: None, action_dim: int, state_dim: int
) -> None:
set_seed_all(42)
config = make_config(action_dim=action_dim, state_dim=state_dim)
policy = VLAJEPAPolicy(config)
policy.eval()
batch = make_inference_batch(state_dim=state_dim)
chunk = policy.predict_action_chunk(batch)
assert chunk.shape[-1] == action_dim
assert torch.isfinite(chunk).all()
@torch.no_grad()
def test_inference_reproducibility(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.eval()
batch = make_inference_batch()
set_seed_all(123)
actions_1 = policy.predict_action_chunk(batch)
set_seed_all(123)
actions_2 = policy.predict_action_chunk(batch)
assert tuple(actions_1.shape) == EXPECTED_ACTION_CHUNK_SHAPE
assert torch.allclose(actions_1, actions_2, atol=1e-6)
@torch.no_grad()
def test_predict_action_chunk_always_finite(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
policy.eval()
for seed in [0, 42, 123]:
set_seed_all(seed)
chunk = policy.predict_action_chunk(make_inference_batch())
assert torch.isfinite(chunk).all(), f"non-finite actions with seed={seed}"
# ---------------------------------------------------------------------------
# Action queue behaviour
# ---------------------------------------------------------------------------
@torch.no_grad()
def test_select_action_queue_drains_before_refill(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.eval()
batch = make_inference_batch()
# First call fills the queue (n_action_steps items) and pops one.
a1 = policy.select_action(batch)
assert len(policy._queues[ACTION]) == N_ACTION_STEPS - 1
# Second call pops from the existing queue without calling predict_action_chunk.
a2 = policy.select_action(batch)
assert tuple(a1.shape) == EXPECTED_SELECT_ACTION_SHAPE
assert tuple(a2.shape) == EXPECTED_SELECT_ACTION_SHAPE
@torch.no_grad()
def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None:
set_seed_all(42)
policy = VLAJEPAPolicy(make_config())
policy.eval()
policy.select_action(make_inference_batch())
assert len(policy._queues[ACTION]) > 0
policy.reset()
assert len(policy._queues[ACTION]) == 0
# ---------------------------------------------------------------------------
# Format conversion
# ---------------------------------------------------------------------------
def test_prepare_model_inputs_training_format(patch_vla_jepa_external_models: None) -> None:
from PIL import Image
policy = VLAJEPAPolicy(make_config())
examples = policy._prepare_model_inputs(make_train_batch())
assert len(examples) == BATCH_SIZE
for ex in examples:
assert set(ex) >= {"image", "video", "lang", "action", "state"}
assert len(ex["image"]) == 1 and isinstance(ex["image"][0], Image.Image)
assert ex["video"].ndim == 5 and ex["video"].dtype == np.uint8 # [V,T,H,W,C]
assert ex["action"].shape == (ACTION_HORIZON, ACTION_DIM)
assert ex["state"].shape == (1, STATE_DIM)
def test_prepare_model_inputs_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
for ex in policy._prepare_model_inputs(make_inference_batch()):
assert "action" not in ex
assert "image" in ex and "video" in ex and "lang" in ex
def test_prepare_model_inputs_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
del batch["task"]
examples = policy._prepare_model_inputs(batch)
assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples)
def test_prepare_model_inputs_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
batch["task"] = "open the drawer"
assert all(ex["lang"] == "open the drawer" for ex in policy._prepare_model_inputs(batch))
def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
from lerobot.utils.constants import OBS_STATE
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
del batch[OBS_STATE]
assert all("state" not in ex for ex in policy._prepare_model_inputs(batch))
# ---------------------------------------------------------------------------
# Pretrained checkpoint
# Hub tests (opt-in: VLA_JEPA_EXTENDED=1)
# ---------------------------------------------------------------------------
def _make_hub_train_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
"""Build a training batch whose keys/shapes match a hub-loaded policy config."""
cfg = policy.config
batch: dict = {"task": ["pick up the cube"] * batch_size}
for key, feat in cfg.image_features.items():
h, w = feat.shape[-2], feat.shape[-1]
batch[key] = torch.rand(batch_size, cfg.num_video_frames, 3, h, w)
if cfg.robot_state_feature is not None:
batch["observation.state"] = torch.randn(batch_size, 1, cfg.robot_state_feature.shape[0])
batch[ACTION] = torch.randn(batch_size, cfg.chunk_size, cfg.action_dim)
return batch
def _make_hub_inference_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dict:
"""Build an inference batch whose keys/shapes match a hub-loaded policy config."""
cfg = policy.config
batch: dict = {"task": ["pick up the cube"] * batch_size}
for key, feat in cfg.image_features.items():
h, w = feat.shape[-2], feat.shape[-1]
batch[key] = torch.rand(batch_size, 3, h, w)
if cfg.robot_state_feature is not None:
batch["observation.state"] = torch.randn(batch_size, cfg.robot_state_feature.shape[0])
return batch
_CP_ROOT = "lerobot" # TODO: upload converted checkpoints
# Each tuple: (repo_id, enable_world_model)
_HUB_VARIANTS = [
(f"{_CP_ROOT}/VLA-JEPA-LIBERO", True),
(f"{_CP_ROOT}/VLA-JEPA-Pretrain", True),
(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv", False),
]
@extended_test
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
def test_hub_checkpoint_loads(repo_id: str, enable_world_model: bool) -> None:
"""Policy loads from the converted safetensors checkpoint on the Hub."""
policy = VLAJEPAPolicy.from_pretrained(repo_id)
assert policy.config.enable_world_model == enable_world_model
assert sum(p.numel() for p in policy.parameters()) > 0
@extended_test
@pytest.mark.parametrize("repo_id,enable_world_model", _HUB_VARIANTS)
def test_hub_checkpoint_forward_pass(repo_id: str, enable_world_model: bool) -> None:
"""Policy loaded from hub produces finite losses with a correctly-shaped batch."""
policy = VLAJEPAPolicy.from_pretrained(repo_id)
policy.train()
batch = _make_hub_train_batch(policy)
loss, logs = policy.forward(batch)
assert torch.isfinite(loss)
assert "action_loss" in logs
if enable_world_model:
assert "wm_loss" in logs
@extended_test
def test_hub_freeze_qwen_disables_world_model() -> None:
"""freeze_qwen=True (via cli_overrides) freezes qwen and disables the world model."""
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO", cli_overrides=["freeze_qwen=true"])
assert not policy.config.enable_world_model
assert policy.model.video_predictor is None
qwen_params = list(policy.model.qwen.parameters())
assert all(not p.requires_grad for p in qwen_params)
assert any(p.requires_grad for p in policy.model.action_model.parameters())
@extended_test
def test_hub_disable_world_model_loads_simpler_env() -> None:
"""SimplerEnv checkpoint (world model disabled) loads cleanly and runs inference."""
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-SimplerEnv")
assert not policy.config.enable_world_model
assert policy.model.video_predictor is None
assert policy.model.video_encoder is None
@extended_test
def test_hub_libero_inference_shape() -> None:
"""select_action returns the expected shape using the LIBERO hub checkpoint."""
policy = VLAJEPAPolicy.from_pretrained(f"{_CP_ROOT}/VLA-JEPA-LIBERO")
policy.eval()
batch = _make_hub_inference_batch(policy)
action = policy.select_action(batch)
assert action.shape[-1] == policy.config.action_dim
# ---------------------------------------------------------------------------
# Postprocessor unnormalization tests
#
# These tests verify that the postprocessor pipeline (clip → unnorm → binarize)
# correctly applies MIN_MAX unnormalization after predict_action_chunk.
# ---------------------------------------------------------------------------
def _make_dataset_stats(action_dim: int = ACTION_DIM) -> dict:
"""Returns sample dataset_stats with a simple [i, i+10] range per action dim."""
from lerobot.utils.constants import ACTION
return {
ACTION: {
"min": torch.tensor([float(i) for i in range(action_dim)], dtype=torch.float32),
"max": torch.tensor([float(i) + 10.0 for i in range(action_dim)], dtype=torch.float32),
}
}
@torch.no_grad()
def test_postprocessor_unnormalizes_actions(patch_vla_jepa_external_models: None) -> None:
"""UnnormalizerProcessorStep with MIN_MAX produces the correct inverse of MIN_MAX normalization."""
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor import UnnormalizerProcessorStep
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import ACTION
dataset_stats = _make_dataset_stats()
rng = np.random.default_rng(7)
actions_np = rng.uniform(-1.0, 1.0, (2, ACTION_HORIZON, ACTION_DIM)).astype(np.float32)
a_min = dataset_stats[ACTION]["min"].numpy()
a_max = dataset_stats[ACTION]["max"].numpy()
expected = (actions_np + 1.0) / 2.0 * (a_max - a_min) + a_min
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
unnorm_step = UnnormalizerProcessorStep(
features=features,
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
stats=dataset_stats,
)
actions_tensor = torch.from_numpy(actions_np)
transition = policy_action_to_transition(actions_tensor)
result = transition_to_policy_action(unnorm_step(transition)).numpy()
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
@torch.no_grad()
def test_postprocessor_clip_clamps_before_unnorm(patch_vla_jepa_external_models: None) -> None:
"""ClipActionsProcessorStep clamps to [-1, 1] before unnormalization."""
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.vla_jepa.processor_vla_jepa import ClipActionsProcessorStep
from lerobot.processor import UnnormalizerProcessorStep
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import ACTION
dataset_stats = _make_dataset_stats()
a_min = dataset_stats[ACTION]["min"].numpy()
a_max = dataset_stats[ACTION]["max"].numpy()
# Deliberately out-of-range inputs
actions_np = np.array([[[2.0] * ACTION_DIM, [-3.0] * ACTION_DIM]], dtype=np.float32)
clipped = np.clip(actions_np, -1.0, 1.0)
expected = (clipped + 1.0) / 2.0 * (a_max - a_min) + a_min
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
clip_step = ClipActionsProcessorStep()
unnorm_step = UnnormalizerProcessorStep(
features=features,
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
stats=dataset_stats,
)
transition = policy_action_to_transition(torch.from_numpy(actions_np))
transition = clip_step(transition)
result = transition_to_policy_action(unnorm_step(transition)).numpy()
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
@torch.no_grad()
def test_postprocessor_applied_after_predict_action_chunk(
patch_vla_jepa_external_models: None, monkeypatch: pytest.MonkeyPatch
) -> None:
"""predict_action_chunk returns raw actions; the postprocessor applies unnormalization.
Verifies the split: predict_action_chunk returns normalized actions, and calling the
postprocessor on them produces the correctly unnormalized result.
"""
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
raw_actions = np.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=np.float32)
cfg = make_config()
cfg.clip_normalized_actions = False
cfg.binarize_gripper_action = False
policy = VLAJEPAPolicy(cfg)
policy.eval()
monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.copy())
dataset_stats = _make_dataset_stats()
_, postprocessor = make_vla_jepa_pre_post_processors(cfg, dataset_stats)
batch = make_inference_batch()
chunk = policy.predict_action_chunk(batch)
# predict_action_chunk returns raw (normalized) actions
assert torch.allclose(chunk, torch.zeros_like(chunk), atol=1e-6), (
"predict_action_chunk should return raw actions without unnormalization applied."
)
# Postprocessor applies unnormalization: 0 → (0+1)/2 * (max-min) + min = 5 + i
unnormed = postprocessor(chunk)
from lerobot.utils.constants import ACTION
a_min = dataset_stats[ACTION]["min"].numpy()
a_max = dataset_stats[ACTION]["max"].numpy()
expected_first = 0.5 * (0.0 + 1.0) * (a_max[0] - a_min[0]) + a_min[0]
assert unnormed[0, 0, 0].item() == pytest.approx(expected_first, abs=1e-5)

View File

@@ -0,0 +1,60 @@
#!/usr/bin/env python
from __future__ import annotations
import pytest
import torch
from lerobot.policies.vla_jepa.world_model import (
ActionConditionedVideoPredictor,
)
_ACTION_EMBED_DIM = 8
def _make_predictor(
embed_dim: int = 8,
action_embed_dim: int = _ACTION_EMBED_DIM,
predictor_embed_dim: int = 24,
num_action_tokens: int = 2,
tokens_per_frame: int = 1,
) -> ActionConditionedVideoPredictor:
return ActionConditionedVideoPredictor(
num_frames=1,
img_size=(1, tokens_per_frame),
patch_size=1,
tubelet_size=1,
embed_dim=embed_dim,
action_embed_dim=action_embed_dim,
predictor_embed_dim=predictor_embed_dim,
depth=1,
num_heads=2,
mlp_ratio=2.0,
num_action_tokens_per_step=num_action_tokens,
)
@pytest.mark.parametrize(
"batch,num_steps,tokens_per_frame,embed_dim",
[
(1, 2, 1, 8),
(2, 3, 4, 8),
(4, 5, 2, 16),
],
)
def test_predictor_output_shape(batch: int, num_steps: int, tokens_per_frame: int, embed_dim: int) -> None:
predictor = _make_predictor(
embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM, tokens_per_frame=tokens_per_frame
)
frame_tokens = torch.randn(batch, num_steps * tokens_per_frame, embed_dim)
action_tokens = torch.randn(batch, num_steps * 2, _ACTION_EMBED_DIM)
out = predictor(frame_tokens, action_tokens)
assert tuple(out.shape) == (batch, num_steps * tokens_per_frame, embed_dim)
assert torch.isfinite(out).all()
def test_predictor_step_mismatch_raises() -> None:
predictor = _make_predictor(tokens_per_frame=4)
frame_tokens = torch.randn(2, 3 * 4, 8) # 3 steps, 4 tokens each
with pytest.raises(RuntimeError):
predictor(frame_tokens, torch.randn(2, 2 * 2, 8)) # 2 steps → mismatch

11
uv.lock generated
View File

@@ -3039,6 +3039,11 @@ video-benchmark = [
viz = [
{ name = "rerun-sdk" },
]
vla-jepa = [
{ name = "diffusers" },
{ name = "qwen-vl-utils" },
{ name = "transformers" },
]
wallx = [
{ name = "peft" },
{ name = "qwen-vl-utils" },
@@ -3107,6 +3112,7 @@ requires-dist = [
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'diffusion'" },
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'groot'" },
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'multi-task-dit'" },
{ name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'vla-jepa'" },
{ name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" },
@@ -3154,6 +3160,7 @@ requires-dist = [
{ name = "lerobot", extras = ["pyzmq-dep"], marker = "extra == 'unitree-g1'" },
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'eo1'" },
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'sarm'" },
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'vla-jepa'" },
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'wallx'" },
{ name = "lerobot", extras = ["reachy2"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["rebot"], marker = "extra == 'all'" },
@@ -3177,12 +3184,14 @@ requires-dist = [
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'pi'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'sarm'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'smolvla'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'vla-jepa'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'wallx'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'xvla'" },
{ name = "lerobot", extras = ["video-benchmark"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["viz"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["viz"], marker = "extra == 'core-scripts'" },
{ name = "lerobot", extras = ["viz"], marker = "extra == 'dataset-viz'" },
{ name = "lerobot", extras = ["vla-jepa"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["wallx"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["xvla"], marker = "extra == 'all'" },
{ name = "matplotlib", marker = "extra == 'matplotlib-dep'", specifier = ">=3.10.3,<4.0.0" },
@@ -3244,7 +3253,7 @@ requires-dist = [
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
[[package]]
name = "librt"