fix qwen norm layer output libero eval is now as expected

This commit is contained in:
Maximellerbach
2026-05-22 17:36:43 +02:00
parent 8efa5cabe9
commit df7d5132d1
2 changed files with 26 additions and 16 deletions

View File

@@ -182,9 +182,7 @@ def _apply_source_config(kwargs: dict, source_config: dict) -> None:
_set_if_present(kwargs, "action_dropout", diffusion_cfg.get("dropout")) _set_if_present(kwargs, "action_dropout", diffusion_cfg.get("dropout"))
_set_if_present(kwargs, "num_video_frames", video_cfg.get("num_frames")) _set_if_present(kwargs, "num_video_frames", video_cfg.get("num_frames"))
_set_if_present( _set_if_present(kwargs, "predictor_depth", video_cfg.get("predictor_depth", video_cfg.get("depth")))
kwargs, "predictor_depth", video_cfg.get("predictor_depth", video_cfg.get("depth"))
)
_set_if_present( _set_if_present(
kwargs, "predictor_num_heads", video_cfg.get("predictor_num_heads", video_cfg.get("num_heads")) kwargs, "predictor_num_heads", video_cfg.get("predictor_num_heads", video_cfg.get("num_heads"))
) )

View File

@@ -117,21 +117,33 @@ class VLAJEPAModel(nn.Module):
) )
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor: def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return Qwen's last decoder hidden state matching training behaviour. """Return the last decoder hidden state before the final RMSNorm.
The original starVLA uses `output_hidden_states=True` and takes `hidden_states[-1]`. The model was trained with the output of the last transformer block BEFORE
In transformers 5.x the `@capture_outputs` decorator explicitly replaces the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
`hidden_states[-1]` with `last_hidden_state` (post-RMSNorm), so this call `output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
consistently returns the post-norm output regardless of transformers version, `@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
matching what the model was trained with. the correct pre-RMSNorm state, matching the training-time representation.
""" """
outputs = self.qwen.model( captured: list[torch.Tensor] = []
**qwen_inputs,
output_hidden_states=True, def _hook(module, input, output):
output_attentions=False, h = output[0] if isinstance(output, tuple) else output
return_dict=True, captured.append(h)
)
return outputs.hidden_states[-1] last_layer = self.qwen.model.model.language_model.layers[-1]
handle = last_layer.register_forward_hook(_hook)
try:
self.qwen.model(
**qwen_inputs,
output_hidden_states=False,
output_attentions=False,
return_dict=True,
)
finally:
handle.remove()
return captured[0] # [B, seq_len, H]
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ---- # ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----