fix qwen norm layer output libero eval is now as expected

This commit is contained in:
Maximellerbach
2026-05-22 17:36:43 +02:00
parent 5495c10cdf
commit b75b3ce02d
2 changed files with 26 additions and 16 deletions

View File

@@ -182,9 +182,7 @@ def _apply_source_config(kwargs: dict, source_config: dict) -> None:
_set_if_present(kwargs, "action_dropout", diffusion_cfg.get("dropout"))
_set_if_present(kwargs, "num_video_frames", video_cfg.get("num_frames"))
_set_if_present(
kwargs, "predictor_depth", video_cfg.get("predictor_depth", video_cfg.get("depth"))
)
_set_if_present(kwargs, "predictor_depth", video_cfg.get("predictor_depth", video_cfg.get("depth")))
_set_if_present(
kwargs, "predictor_num_heads", video_cfg.get("predictor_num_heads", video_cfg.get("num_heads"))
)

View File

@@ -117,21 +117,33 @@ class VLAJEPAModel(nn.Module):
)
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return Qwen's last decoder hidden state matching training behaviour.
"""Return the last decoder hidden state before the final RMSNorm.
The original starVLA uses `output_hidden_states=True` and takes `hidden_states[-1]`.
In transformers 5.x the `@capture_outputs` decorator explicitly replaces
`hidden_states[-1]` with `last_hidden_state` (post-RMSNorm), so this call
consistently returns the post-norm output regardless of transformers version,
matching what the model was trained with.
The model was trained with the output of the last transformer block BEFORE
the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
`output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
`@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
the correct pre-RMSNorm state, matching the training-time representation.
"""
outputs = self.qwen.model(
**qwen_inputs,
output_hidden_states=True,
output_attentions=False,
return_dict=True,
)
return outputs.hidden_states[-1]
captured: list[torch.Tensor] = []
def _hook(module, input, output):
h = output[0] if isinstance(output, tuple) else output
captured.append(h)
last_layer = self.qwen.model.model.language_model.layers[-1]
handle = last_layer.register_forward_hook(_hook)
try:
self.qwen.model(
**qwen_inputs,
output_hidden_states=False,
output_attentions=False,
return_dict=True,
)
finally:
handle.remove()
return captured[0] # [B, seq_len, H]
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----