fixing misconception about multiview / singleview handling

This commit is contained in:
Maximellerbach
2026-05-27 11:16:37 +02:00
parent 952e5146dc
commit 58eac863aa
4 changed files with 160 additions and 22 deletions

View File

@@ -54,9 +54,7 @@ Three checkpoints are available, converted from [ginwind/VLA-JEPA](https://huggi
| ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- |
| `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 |
| `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 |
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 | Disabled\* | 7 |
\* The SimplerEnv checkpoint was fine-tuned from Pretrain. The world model predictor architecture expects `embed_dim=2048` (2-camera input) but SimplerEnv is single-camera, so the world model cannot be loaded cleanly. Since inference only needs Qwen + the action head, `enable_world_model=False` is set for this variant. See [Fine-tuning on single-camera datasets](#fine-tuning-on-single-camera-datasets) for implications.
| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 (view duplicated ×2) | Enabled | 7 |
All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
@@ -184,13 +182,20 @@ lerobot-eval \
---
## Fine-tuning on single-camera datasets
## Fine-tuning on datasets with a different number of cameras
The pretrained world model predictor was trained with `embed_dim = num_views × 1024`. If your target dataset has fewer cameras than the source checkpoint, the predictor input projection will have a shape mismatch and cannot be loaded.
The pretrained world model predictor was trained with `embed_dim = jepa_tubelet_size × 1024` (default `jepa_tubelet_size=2`).
**Option 1 — Disable the world model (recommended)**
**Default behaviour — view padding / trimming (no action required)**
Set `enable_world_model=False`. Only the Qwen backbone and action head are loaded and trained. This matches the original SimplerEnv fine-tuning strategy and is sufficient for good action performance.
When fine-tuning from `VLA-JEPA-Pretrain` the model automatically adjusts the number of views fed to the world model to match `jepa_tubelet_size`:
- **Single-view datasets (e.g. BridgeV2):** the single-view latent is duplicated to produce a two-view world-model input, preserving the JEPA self-supervised signal without any weight mismatch.
- **>2-view datasets (e.g. DROID with 3 views):** all views are passed to the Qwen backbone (for richer context), but only the first `jepa_tubelet_size` views (one wrist + one third-person, following the configured view order) are used for the world model.
**Option 1 — Disable the world model**
Set `enable_world_model=False` to skip the JEPA loss entirely. Only the Qwen backbone and action head are loaded and trained. This is sufficient for good action performance.
```bash
lerobot-train \
@@ -202,10 +207,7 @@ lerobot-train \
**Option 2 — Reinitialize the predictor input projection**
If you want the JEPA self-supervised signal during fine-tuning, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
**Option 3 - Duplicate frames to match the expected number of cameras**
A bit more advanced, you would need to change some parts of the code to support that.
If you want to change `jepa_tubelet_size` to a value other than 2, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint.
---

View File

@@ -17,9 +17,9 @@ For each variant the script:
Config sources
--------------
Numeric hyper-params : ginwind/VLA-JEPA/<variant>/config.json
Image keys LIBERO : lerobot/libero_10 meta/info.json ✓ confirmed
Image keys Pretrain : lerobot/droid_1.0.1 meta/info.json ✓ confirmed
Image keys SimplerEnv: OXE Bridge/RT1 are single-camera ✓ confirmed
Image keys LIBERO : lerobot/libero_10 meta/info.json
Image keys Pretrain : lerobot/droid_1.0.1 meta/info.json
Image keys SimplerEnv: OXE Bridge/RT1 single-camera (view x2)
"""
from __future__ import annotations
@@ -252,7 +252,7 @@ _DROID_CAMS = [
"observation.images.exterior_2_left",
]
# OXE Bridge + RT1 — single-camera; world model disabled (predictor embed_dim mismatch)
# OXE Bridge + RT1 — single-camera; view is duplicated at runtime to match the 2-view world model
_OXE_CAMS = [
"observation.images.image",
]
@@ -311,9 +311,9 @@ def _build_config(
VARIANTS: dict[str, tuple] = {
"LIBERO": (_LIBERO_CAMS, True, True, "LIBERO"),
"Pretrain": (_DROID_CAMS, False, True, "Pretrain"),
# SimplerEnv uses a single camera; the predictor embed_dim (2048) would mismatch, so
# disable the world model — only qwen + action_model weights are needed for inference.
"SimplerEnv": (_OXE_CAMS, False, False, "SimplerEnv"),
# SimplerEnv uses a single camera; the single view is duplicated at runtime to produce
# the 2-view input the world model expects (embed_dim=2048).
"SimplerEnv": (_OXE_CAMS, False, True, "SimplerEnv"),
}
# ---------------------------------------------------------------------------
@@ -423,13 +423,13 @@ def main() -> None:
log.info(" Saving model.safetensors …")
save_safetensors(mapped_sd, save_dir / "model.safetensors")
config.device = None # don't bake in the conversion machine's device
config._save_pretrained(save_dir) # writes config.json via draccus
preprocessor, postprocessor = make_vla_jepa_pre_post_processors(config, dataset_stats)
preprocessor.save_pretrained(save_dir) # writes policy_preprocessor.json
postprocessor.save_pretrained(save_dir) # writes policy_postprocessor.json
config.device = None # don't bake in the conversion machine's device
config._save_pretrained(save_dir) # writes config.json via draccus
log.info(" Uploading …")
commit_url = api.upload_folder(
folder_path=save_dir,

View File

@@ -72,7 +72,7 @@ class VLAJEPAModel(nn.Module):
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
)
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
num_views = max(1, len(config.image_features))
num_views = config.jepa_tubelet_size
tubelet_size = self.video_encoder.config.tubelet_size
image_size = getattr(self.video_encoder.config, "image_size", None)
if image_size is None:
@@ -180,6 +180,17 @@ class VLAJEPAModel(nn.Module):
batch_videos = np.stack(batch_videos)
batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W]
# Adjust number of views for the world model:
# - fewer views than expected: duplicate the first view to fill up
# - more views than expected: keep only the first num_views_world_model views
num_views_world_model = self.config.jepa_tubelet_size
if batch_videos.shape[1] < num_views_world_model:
num_missing_views = num_views_world_model - batch_videos.shape[1]
first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1)
batch_videos = np.concatenate([batch_videos, first_view], axis=1)
elif batch_videos.shape[1] > num_views_world_model:
batch_videos = batch_videos[:, :num_views_world_model]
# ---- Step 1: QwenVL encode (same as original) ----
qwen_inputs = self.qwen.build_inputs(
images=batch_images,

View File

@@ -23,7 +23,9 @@ from conftest import ( # noqa: E402
BATCH_SIZE,
EXPECTED_ACTION_CHUNK_SHAPE,
EXPECTED_SELECT_ACTION_SHAPE,
IMAGE_SIZE,
N_ACTION_STEPS,
QWEN_HIDDEN_SIZE,
STATE_DIM,
make_config,
make_inference_batch,
@@ -31,6 +33,7 @@ from conftest import ( # noqa: E402
set_seed_all,
)
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig # noqa: E402
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy # noqa: E402
from lerobot.utils.constants import ACTION # noqa: E402
@@ -471,3 +474,125 @@ def test_postprocessor_applied_after_predict_action_chunk(
a_max = dataset_stats[ACTION]["max"].numpy()
expected_first = 0.5 * (0.0 + 1.0) * (a_max[0] - a_min[0]) + a_min[0]
assert unnormed[0, 0, 0].item() == pytest.approx(expected_first, abs=1e-5)
# ---------------------------------------------------------------------------
# World-model view adjustment (padding / trimming) tests
# ---------------------------------------------------------------------------
_MULTIVIEW_NUM_FRAMES = 4 # must be >= 2 * jepa_tubelet_size (=2) for world-model tests
def _make_multiview_config(num_views: int, jepa_tubelet_size: int = 2) -> VLAJEPAConfig:
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
config = VLAJEPAConfig(
input_features={
**{
f"{OBS_IMAGES}.cam{i}": PolicyFeature(
type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE)
)
for i in range(num_views)
},
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))},
device="cpu",
chunk_size=ACTION_HORIZON,
n_action_steps=N_ACTION_STEPS,
action_dim=ACTION_DIM,
state_dim=STATE_DIM,
num_video_frames=_MULTIVIEW_NUM_FRAMES,
num_action_tokens_per_timestep=2,
num_embodied_action_tokens_per_instruction=3,
num_inference_timesteps=2,
action_hidden_size=QWEN_HIDDEN_SIZE,
action_model_type="DiT-test",
action_num_layers=1,
predictor_depth=1,
predictor_num_heads=2,
predictor_mlp_ratio=2.0,
jepa_tubelet_size=jepa_tubelet_size,
)
config.validate_features()
return config
def _make_multiview_train_batch(num_views: int, batch_size: int = BATCH_SIZE) -> dict:
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
batch = {
f"{OBS_IMAGES}.cam{i}": torch.rand(batch_size, _MULTIVIEW_NUM_FRAMES, 3, IMAGE_SIZE, IMAGE_SIZE)
for i in range(num_views)
}
batch[OBS_STATE] = torch.randn(batch_size, 1, STATE_DIM)
batch[ACTION] = torch.randn(batch_size, ACTION_HORIZON, ACTION_DIM)
batch["task"] = ["pick up the cube"] * batch_size
return batch
@pytest.mark.parametrize(
"num_views",
[
1, # fewer views than jepa_tubelet_size → first view duplicated
2, # exact match → unchanged
3, # more views than jepa_tubelet_size → trimmed to first two
],
)
def test_training_forward_world_model_view_adjustment(
patch_vla_jepa_external_models: None,
num_views: int,
) -> None:
"""World-model view padding/trimming must not break the training forward pass."""
set_seed_all(42)
policy = VLAJEPAPolicy(_make_multiview_config(num_views=num_views, jepa_tubelet_size=2))
policy.train()
loss, logs = policy.forward(_make_multiview_train_batch(num_views=num_views))
assert torch.isfinite(loss)
assert logs["wm_loss"] >= 0
def test_single_view_is_duplicated_for_world_model(patch_vla_jepa_external_models: None) -> None:
"""With one dataset view and jepa_tubelet_size=2, the view must be duplicated before encoding."""
set_seed_all(42)
policy = VLAJEPAPolicy(_make_multiview_config(num_views=1, jepa_tubelet_size=2))
policy.train()
captured_videos: list = []
original_processor = policy.model.video_processor
class _CapturingProcessor:
def __call__(self, videos: list, return_tensors: str) -> dict:
captured_videos.extend(videos)
return original_processor(videos=videos, return_tensors=return_tensors)
policy.model.video_processor = _CapturingProcessor()
policy.forward(_make_multiview_train_batch(num_views=1))
# reshape is batch-major: (b0v0, b0v1, b1v0, b1v1, …)
assert len(captured_videos) == BATCH_SIZE * 2
for i in range(BATCH_SIZE):
np.testing.assert_array_equal(captured_videos[2 * i], captured_videos[2 * i + 1])
def test_excess_views_trimmed_for_world_model(patch_vla_jepa_external_models: None) -> None:
"""With three dataset views and jepa_tubelet_size=2, only the first two views reach the encoder."""
set_seed_all(42)
policy = VLAJEPAPolicy(_make_multiview_config(num_views=3, jepa_tubelet_size=2))
policy.train()
captured_videos: list = []
original_processor = policy.model.video_processor
class _CapturingProcessor:
def __call__(self, videos: list, return_tensors: str) -> dict:
captured_videos.extend(videos)
return original_processor(videos=videos, return_tensors=return_tensors)
policy.model.video_processor = _CapturingProcessor()
policy.forward(_make_multiview_train_batch(num_views=3))
# Only B*2 items must reach the encoder, not B*3.
assert len(captured_videos) == BATCH_SIZE * 2