diff --git a/docs/source/policy_vla_jepa_README.md b/docs/source/policy_vla_jepa_README.md index f0541ab56..3961c018f 100644 --- a/docs/source/policy_vla_jepa_README.md +++ b/docs/source/policy_vla_jepa_README.md @@ -54,9 +54,7 @@ Three checkpoints are available, converted from [ginwind/VLA-JEPA](https://huggi | ----------------------------- | ----------------- | ----------------------- | ----------- | ---------- | | `lerobot/VLA-JEPA-LIBERO` | LIBERO-10 | 2 (agentview + wrist) | Enabled | 7 | | `lerobot/VLA-JEPA-Pretrain` | DROID 1.0.1 | 2 (exterior left views) | Enabled | 7 | -| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 | Disabled\* | 7 | - -\* The SimplerEnv checkpoint was fine-tuned from Pretrain. The world model predictor architecture expects `embed_dim=2048` (2-camera input) but SimplerEnv is single-camera, so the world model cannot be loaded cleanly. Since inference only needs Qwen + the action head, `enable_world_model=False` is set for this variant. See [Fine-tuning on single-camera datasets](#fine-tuning-on-single-camera-datasets) for implications. +| `lerobot/VLA-JEPA-SimplerEnv` | OXE Bridge / RT-1 | 1 (view duplicated ×2) | Enabled | 7 | All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone. @@ -184,13 +182,20 @@ lerobot-eval \ --- -## Fine-tuning on single-camera datasets +## Fine-tuning on datasets with a different number of cameras -The pretrained world model predictor was trained with `embed_dim = num_views × 1024`. If your target dataset has fewer cameras than the source checkpoint, the predictor input projection will have a shape mismatch and cannot be loaded. +The pretrained world model predictor was trained with `embed_dim = jepa_tubelet_size × 1024` (default `jepa_tubelet_size=2`). -**Option 1 — Disable the world model (recommended)** +**Default behaviour — view padding / trimming (no action required)** -Set `enable_world_model=False`. Only the Qwen backbone and action head are loaded and trained. This matches the original SimplerEnv fine-tuning strategy and is sufficient for good action performance. +When fine-tuning from `VLA-JEPA-Pretrain` the model automatically adjusts the number of views fed to the world model to match `jepa_tubelet_size`: + +- **Single-view datasets (e.g. BridgeV2):** the single-view latent is duplicated to produce a two-view world-model input, preserving the JEPA self-supervised signal without any weight mismatch. +- **>2-view datasets (e.g. DROID with 3 views):** all views are passed to the Qwen backbone (for richer context), but only the first `jepa_tubelet_size` views (one wrist + one third-person, following the configured view order) are used for the world model. + +**Option 1 — Disable the world model** + +Set `enable_world_model=False` to skip the JEPA loss entirely. Only the Qwen backbone and action head are loaded and trained. This is sufficient for good action performance. ```bash lerobot-train \ @@ -202,10 +207,7 @@ lerobot-train \ **Option 2 — Reinitialize the predictor input projection** -If you want the JEPA self-supervised signal during fine-tuning, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint. - -**Option 3 - Duplicate frames to match the expected number of cameras** -A bit more advanced, you would need to change some parts of the code to support that. +If you want to change `jepa_tubelet_size` to a value other than 2, load the checkpoint with `strict=False` and reinitialize `model.video_predictor.predictor_embed` for the new `embed_dim`. All other predictor block weights (attention, MLP, norm, output projection) are camera-count-agnostic and can be reused from the pretrained checkpoint. --- diff --git a/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py b/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py index edcb06b80..7f04bdfa3 100644 --- a/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py +++ b/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py @@ -17,9 +17,9 @@ For each variant the script: Config sources -------------- Numeric hyper-params : ginwind/VLA-JEPA//config.json -Image keys LIBERO : lerobot/libero_10 meta/info.json ✓ confirmed -Image keys Pretrain : lerobot/droid_1.0.1 meta/info.json ✓ confirmed -Image keys SimplerEnv: OXE Bridge/RT1 are single-camera ✓ confirmed +Image keys LIBERO : lerobot/libero_10 meta/info.json +Image keys Pretrain : lerobot/droid_1.0.1 meta/info.json +Image keys SimplerEnv: OXE Bridge/RT1 single-camera (view x2) """ from __future__ import annotations @@ -252,7 +252,7 @@ _DROID_CAMS = [ "observation.images.exterior_2_left", ] -# OXE Bridge + RT1 — single-camera; world model disabled (predictor embed_dim mismatch) +# OXE Bridge + RT1 — single-camera; view is duplicated at runtime to match the 2-view world model _OXE_CAMS = [ "observation.images.image", ] @@ -311,9 +311,9 @@ def _build_config( VARIANTS: dict[str, tuple] = { "LIBERO": (_LIBERO_CAMS, True, True, "LIBERO"), "Pretrain": (_DROID_CAMS, False, True, "Pretrain"), - # SimplerEnv uses a single camera; the predictor embed_dim (2048) would mismatch, so - # disable the world model — only qwen + action_model weights are needed for inference. - "SimplerEnv": (_OXE_CAMS, False, False, "SimplerEnv"), + # SimplerEnv uses a single camera; the single view is duplicated at runtime to produce + # the 2-view input the world model expects (embed_dim=2048). + "SimplerEnv": (_OXE_CAMS, False, True, "SimplerEnv"), } # --------------------------------------------------------------------------- @@ -423,13 +423,13 @@ def main() -> None: log.info(" Saving model.safetensors …") save_safetensors(mapped_sd, save_dir / "model.safetensors") - config.device = None # don't bake in the conversion machine's device - config._save_pretrained(save_dir) # writes config.json via draccus - preprocessor, postprocessor = make_vla_jepa_pre_post_processors(config, dataset_stats) preprocessor.save_pretrained(save_dir) # writes policy_preprocessor.json postprocessor.save_pretrained(save_dir) # writes policy_postprocessor.json + config.device = None # don't bake in the conversion machine's device + config._save_pretrained(save_dir) # writes config.json via draccus + log.info(" Uploading …") commit_url = api.upload_folder( folder_path=save_dir, diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index ecafaa7c3..16b7e51ce 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -72,7 +72,7 @@ class VLAJEPAModel(nn.Module): torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype), ) self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name) - num_views = max(1, len(config.image_features)) + num_views = config.jepa_tubelet_size tubelet_size = self.video_encoder.config.tubelet_size image_size = getattr(self.video_encoder.config, "image_size", None) if image_size is None: @@ -180,6 +180,17 @@ class VLAJEPAModel(nn.Module): batch_videos = np.stack(batch_videos) batch_videos = batch_videos.transpose(0, 1, 2, 5, 3, 4) # [B, V, T, 3, H, W] + # Adjust number of views for the world model: + # - fewer views than expected: duplicate the first view to fill up + # - more views than expected: keep only the first num_views_world_model views + num_views_world_model = self.config.jepa_tubelet_size + if batch_videos.shape[1] < num_views_world_model: + num_missing_views = num_views_world_model - batch_videos.shape[1] + first_view = np.repeat(batch_videos[:, :1], num_missing_views, axis=1) + batch_videos = np.concatenate([batch_videos, first_view], axis=1) + elif batch_videos.shape[1] > num_views_world_model: + batch_videos = batch_videos[:, :num_views_world_model] + # ---- Step 1: QwenVL encode (same as original) ---- qwen_inputs = self.qwen.build_inputs( images=batch_images, diff --git a/tests/policies/vla_jepa/test_vla_jepa.py b/tests/policies/vla_jepa/test_vla_jepa.py index 52b00697c..6c716f31a 100644 --- a/tests/policies/vla_jepa/test_vla_jepa.py +++ b/tests/policies/vla_jepa/test_vla_jepa.py @@ -23,7 +23,9 @@ from conftest import ( # noqa: E402 BATCH_SIZE, EXPECTED_ACTION_CHUNK_SHAPE, EXPECTED_SELECT_ACTION_SHAPE, + IMAGE_SIZE, N_ACTION_STEPS, + QWEN_HIDDEN_SIZE, STATE_DIM, make_config, make_inference_batch, @@ -31,6 +33,7 @@ from conftest import ( # noqa: E402 set_seed_all, ) +from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig # noqa: E402 from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy # noqa: E402 from lerobot.utils.constants import ACTION # noqa: E402 @@ -471,3 +474,125 @@ def test_postprocessor_applied_after_predict_action_chunk( a_max = dataset_stats[ACTION]["max"].numpy() expected_first = 0.5 * (0.0 + 1.0) * (a_max[0] - a_min[0]) + a_min[0] assert unnormed[0, 0, 0].item() == pytest.approx(expected_first, abs=1e-5) + + +# --------------------------------------------------------------------------- +# World-model view adjustment (padding / trimming) tests +# --------------------------------------------------------------------------- + + +_MULTIVIEW_NUM_FRAMES = 4 # must be >= 2 * jepa_tubelet_size (=2) for world-model tests + + +def _make_multiview_config(num_views: int, jepa_tubelet_size: int = 2) -> VLAJEPAConfig: + from lerobot.configs.types import FeatureType, PolicyFeature + from lerobot.utils.constants import OBS_IMAGES, OBS_STATE + + config = VLAJEPAConfig( + input_features={ + **{ + f"{OBS_IMAGES}.cam{i}": PolicyFeature( + type=FeatureType.VISUAL, shape=(3, IMAGE_SIZE, IMAGE_SIZE) + ) + for i in range(num_views) + }, + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)), + }, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}, + device="cpu", + chunk_size=ACTION_HORIZON, + n_action_steps=N_ACTION_STEPS, + action_dim=ACTION_DIM, + state_dim=STATE_DIM, + num_video_frames=_MULTIVIEW_NUM_FRAMES, + num_action_tokens_per_timestep=2, + num_embodied_action_tokens_per_instruction=3, + num_inference_timesteps=2, + action_hidden_size=QWEN_HIDDEN_SIZE, + action_model_type="DiT-test", + action_num_layers=1, + predictor_depth=1, + predictor_num_heads=2, + predictor_mlp_ratio=2.0, + jepa_tubelet_size=jepa_tubelet_size, + ) + config.validate_features() + return config + + +def _make_multiview_train_batch(num_views: int, batch_size: int = BATCH_SIZE) -> dict: + from lerobot.utils.constants import OBS_IMAGES, OBS_STATE + + batch = { + f"{OBS_IMAGES}.cam{i}": torch.rand(batch_size, _MULTIVIEW_NUM_FRAMES, 3, IMAGE_SIZE, IMAGE_SIZE) + for i in range(num_views) + } + batch[OBS_STATE] = torch.randn(batch_size, 1, STATE_DIM) + batch[ACTION] = torch.randn(batch_size, ACTION_HORIZON, ACTION_DIM) + batch["task"] = ["pick up the cube"] * batch_size + return batch + + +@pytest.mark.parametrize( + "num_views", + [ + 1, # fewer views than jepa_tubelet_size → first view duplicated + 2, # exact match → unchanged + 3, # more views than jepa_tubelet_size → trimmed to first two + ], +) +def test_training_forward_world_model_view_adjustment( + patch_vla_jepa_external_models: None, + num_views: int, +) -> None: + """World-model view padding/trimming must not break the training forward pass.""" + set_seed_all(42) + policy = VLAJEPAPolicy(_make_multiview_config(num_views=num_views, jepa_tubelet_size=2)) + policy.train() + loss, logs = policy.forward(_make_multiview_train_batch(num_views=num_views)) + assert torch.isfinite(loss) + assert logs["wm_loss"] >= 0 + + +def test_single_view_is_duplicated_for_world_model(patch_vla_jepa_external_models: None) -> None: + """With one dataset view and jepa_tubelet_size=2, the view must be duplicated before encoding.""" + set_seed_all(42) + policy = VLAJEPAPolicy(_make_multiview_config(num_views=1, jepa_tubelet_size=2)) + policy.train() + + captured_videos: list = [] + original_processor = policy.model.video_processor + + class _CapturingProcessor: + def __call__(self, videos: list, return_tensors: str) -> dict: + captured_videos.extend(videos) + return original_processor(videos=videos, return_tensors=return_tensors) + + policy.model.video_processor = _CapturingProcessor() + policy.forward(_make_multiview_train_batch(num_views=1)) + + # reshape is batch-major: (b0v0, b0v1, b1v0, b1v1, …) + assert len(captured_videos) == BATCH_SIZE * 2 + for i in range(BATCH_SIZE): + np.testing.assert_array_equal(captured_videos[2 * i], captured_videos[2 * i + 1]) + + +def test_excess_views_trimmed_for_world_model(patch_vla_jepa_external_models: None) -> None: + """With three dataset views and jepa_tubelet_size=2, only the first two views reach the encoder.""" + set_seed_all(42) + policy = VLAJEPAPolicy(_make_multiview_config(num_views=3, jepa_tubelet_size=2)) + policy.train() + + captured_videos: list = [] + original_processor = policy.model.video_processor + + class _CapturingProcessor: + def __call__(self, videos: list, return_tensors: str) -> dict: + captured_videos.extend(videos) + return original_processor(videos=videos, return_tensors=return_tensors) + + policy.model.video_processor = _CapturingProcessor() + policy.forward(_make_multiview_train_batch(num_views=3)) + + # Only B*2 items must reach the encoder, not B*3. + assert len(captured_videos) == BATCH_SIZE * 2