diff --git a/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py b/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py index 591ce5db9..2566dc18e 100644 --- a/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py +++ b/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py @@ -182,9 +182,7 @@ def _apply_source_config(kwargs: dict, source_config: dict) -> None: _set_if_present(kwargs, "action_dropout", diffusion_cfg.get("dropout")) _set_if_present(kwargs, "num_video_frames", video_cfg.get("num_frames")) - _set_if_present( - kwargs, "predictor_depth", video_cfg.get("predictor_depth", video_cfg.get("depth")) - ) + _set_if_present(kwargs, "predictor_depth", video_cfg.get("predictor_depth", video_cfg.get("depth"))) _set_if_present( kwargs, "predictor_num_heads", video_cfg.get("predictor_num_heads", video_cfg.get("num_heads")) ) diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index adf8b7540..8b7916add 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -117,21 +117,33 @@ class VLAJEPAModel(nn.Module): ) def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor: - """Return Qwen's last decoder hidden state matching training behaviour. + """Return the last decoder hidden state before the final RMSNorm. - The original starVLA uses `output_hidden_states=True` and takes `hidden_states[-1]`. - In transformers 5.x the `@capture_outputs` decorator explicitly replaces - `hidden_states[-1]` with `last_hidden_state` (post-RMSNorm), so this call - consistently returns the post-norm output regardless of transformers version, - matching what the model was trained with. + The model was trained with the output of the last transformer block BEFORE + the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from + `output_hidden_states=True` is post-norm (tied to `last_hidden_state` via + `@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers + the correct pre-RMSNorm state, matching the training-time representation. """ - outputs = self.qwen.model( - **qwen_inputs, - output_hidden_states=True, - output_attentions=False, - return_dict=True, - ) - return outputs.hidden_states[-1] + captured: list[torch.Tensor] = [] + + def _hook(module, input, output): + h = output[0] if isinstance(output, tuple) else output + captured.append(h) + + last_layer = self.qwen.model.model.language_model.layers[-1] + handle = last_layer.register_forward_hook(_hook) + try: + self.qwen.model( + **qwen_inputs, + output_hidden_states=False, + output_attentions=False, + return_dict=True, + ) + finally: + handle.remove() + + return captured[0] # [B, seq_len, H] # ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----