From 37fda2a6fc2dbade9dc8d1e07e2ea9167c953681 Mon Sep 17 00:00:00 2001 From: Maximellerbach Date: Tue, 26 May 2026 11:51:54 +0200 Subject: [PATCH] adding instructions for different embodiement + fixing some tests --- docs/source/policy_vla_jepa_README.md | 56 +++++++++++++++---- .../vla_jepa/configuration_vla_jepa.py | 7 ++- .../policies/vla_jepa/modeling_vla_jepa.py | 35 ++++++------ tests/policies/vla_jepa/conftest.py | 9 ++- 4 files changed, 74 insertions(+), 33 deletions(-) diff --git a/docs/source/policy_vla_jepa_README.md b/docs/source/policy_vla_jepa_README.md index 1739ae072..f0541ab56 100644 --- a/docs/source/policy_vla_jepa_README.md +++ b/docs/source/policy_vla_jepa_README.md @@ -66,15 +66,16 @@ All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone. Key parameters in `VLAJEPAConfig`: -| Parameter | Default | Description | -| ------------------------- | ------- | -------------------------------------------------------------- | -| `chunk_size` | 7 | Number of actions predicted per inference call | -| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning | -| `num_video_frames` | 8 | Video clip length fed to the world model | -| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor | -| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss | -| `num_inference_timesteps` | 4 | Euler integration steps for action denoising | -| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head | +| Parameter | Default | Description | +| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `chunk_size` | 7 | Number of actions predicted per inference call | +| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning | +| `num_video_frames` | 8 | Video clip length fed to the world model | +| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor | +| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss | +| `num_inference_timesteps` | 4 | Euler integration steps for action denoising | +| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head | +| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) | --- @@ -110,6 +111,29 @@ lerobot-train \ --dataset.repo_id=your_org/your_dataset ``` +### Fine-tuning on a different embodiment + +When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error. + +The layers that depend on `action_dim` and `state_dim` are: + +| Layer | Key prefix | +| ----------------------------------------- | ----------------------------------- | +| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` | +| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` | +| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` | + +```bash +lerobot-train \ + --policy.path=lerobot/VLA-JEPA-Pretrain \ + --policy.repo_id=your_org/your_repo \ + --policy.freeze_qwen=true \ + --policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \ + --dataset.repo_id=your_org/your_dataset +``` + +If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list. + ### Reproducing the LIBERO results **Training on LIBERO:** @@ -132,7 +156,7 @@ lerobot-eval \ --env.type=libero \ --env.task=libero_spatial,libero_object,libero_goal,libero_10 \ --eval.n_episodes=10 \ - --eval.batch_size=5 \ + --eval.batch_size=5 ``` @@ -145,9 +169,19 @@ lerobot-eval \ --env.task=libero_10 \ --env.task_ids='[0,1,2]' \ --eval.n_episodes=10 \ - --eval.batch_size=5 \ + --eval.batch_size=5 ``` +**Expected results:** + +| Suite | Episodes | Successes | Success Rate | +| -------------- | -------- | --------- | ------------ | +| libero_spatial | 100 | 93 | **95.0%** | +| libero_object | 100 | 100 | **100.0%** | +| libero_goal | 100 | 98 | **98.0%** | +| libero_10 | 100 | 96 | **93.0%** | +| **Overall** | **400** | **387** | **96.5%** | + --- ## Fine-tuning on single-camera datasets diff --git a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py index 1794a5c46..cf53e343b 100644 --- a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py @@ -27,7 +27,12 @@ class VLAJEPAConfig(PreTrainedConfig): jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256" freeze_qwen: bool = False enable_world_model: bool = True - reinit_action_head: bool = False + # Enables cross-embodiment transfer: when fine-tuning a pretrained model on a robot with a + # different action or state dimensionality, the input/output projection layers must be + # re-initialised from scratch while the rest of the network keeps its pretrained weights. + # List the key prefixes that are allowed to have shape mismatches; anything else raises an error. + # e.g. ["model.action_model.action_encoder", "model.action_model.state_encoder"] + reinit_modules: list[str] | None = None tokenizer_padding_side: str = "left" prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}." diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index 8b7916add..ecafaa7c3 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -219,14 +219,9 @@ class VLAJEPAModel(nn.Module): b, v, t_frames, c, h_img, w_img = batch_videos.shape batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img) - video_pixels = [] - for i in range(b * v): - video_pixels.append( - self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[ - "pixel_values_videos" - ].to(self.video_encoder.device) - ) - video_pixels = torch.cat(video_pixels, dim=0) # [B*V, T, C, H, W] + video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[ + "pixel_values_videos" + ].to(self.video_encoder.device) # [B*V, T, C, H, W] with torch.no_grad(): video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels) @@ -572,11 +567,8 @@ class VLAJEPAPolicy(PreTrainedPolicy): @classmethod def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: - """ - Custom loading to enable opt reinit of action head - when loading pretrained weights with mismatched action head shapes. - """ - if not model.config.reinit_action_head: + reinit_prefixes = model.config.reinit_modules + if not reinit_prefixes: return super()._load_as_safetensor(model, model_file, map_location, strict) from safetensors.torch import load_file @@ -584,20 +576,25 @@ class VLAJEPAPolicy(PreTrainedPolicy): state_dict = load_file(model_file, device=map_location) current = model.state_dict() - mismatched: list[str] = [] + reinitialized: list[str] = [] filtered: dict = {} for key, value in state_dict.items(): if key in current and value.shape != current[key].shape: - mismatched.append( - f"{key}: checkpoint {tuple(value.shape)} vs model {tuple(current[key].shape)}" + if not any(key.startswith(p) for p in reinit_prefixes): + raise ValueError( + f"Shape mismatch for '{key}' (checkpoint {tuple(value.shape)} vs model " + f"{tuple(current[key].shape)}) and its prefix is not in `reinit_modules`." + ) + reinitialized.append( + f"{key}: checkpoint {tuple(value.shape)} → model {tuple(current[key].shape)}" ) else: filtered[key] = value - if mismatched: + if reinitialized: logging.warning( - f"reinit_action_head=True: skipping {len(mismatched)} tensor(s) with mismatched shapes " - f"(randomly re-initialised):\n " + "\n ".join(mismatched) + f"reinit_modules: skipping {len(reinitialized)} tensor(s) with mismatched shapes " + f"(randomly re-initialised):\n " + "\n ".join(reinitialized) ) from lerobot.policies.utils import log_model_loading_keys diff --git a/tests/policies/vla_jepa/conftest.py b/tests/policies/vla_jepa/conftest.py index da7c38cca..5301b5bc7 100644 --- a/tests/policies/vla_jepa/conftest.py +++ b/tests/policies/vla_jepa/conftest.py @@ -169,6 +169,7 @@ class _FakeQwenBackbone(nn.Module): dtype=torch.float32, ).view(batch_size, seq_len, hidden_size) hidden = values / values.numel() + self.weight + self.model(input_ids) # call through so the forward hook on layers[-1] fires return SimpleNamespace(hidden_states=[hidden]) @@ -241,9 +242,13 @@ class _FakeVideoEncoder(nn.Module): class _FakeVideoProcessor: - def __call__(self, videos: np.ndarray, return_tensors: str) -> dict[str, Tensor]: + def __call__(self, videos, return_tensors: str) -> dict[str, Tensor]: assert return_tensors == "pt" - return {"pixel_values_videos": torch.as_tensor(videos).unsqueeze(0)} + if isinstance(videos, list): + pixel_values = torch.stack([torch.as_tensor(v) for v in videos]) + else: + pixel_values = torch.as_tensor(videos).unsqueeze(0) + return {"pixel_values_videos": pixel_values} # ---------------------------------------------------------------------------