mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 02:11:25 +00:00
fix qwen norm layer output libero eval is now as expected
This commit is contained in:
@@ -182,9 +182,7 @@ def _apply_source_config(kwargs: dict, source_config: dict) -> None:
|
||||
_set_if_present(kwargs, "action_dropout", diffusion_cfg.get("dropout"))
|
||||
|
||||
_set_if_present(kwargs, "num_video_frames", video_cfg.get("num_frames"))
|
||||
_set_if_present(
|
||||
kwargs, "predictor_depth", video_cfg.get("predictor_depth", video_cfg.get("depth"))
|
||||
)
|
||||
_set_if_present(kwargs, "predictor_depth", video_cfg.get("predictor_depth", video_cfg.get("depth")))
|
||||
_set_if_present(
|
||||
kwargs, "predictor_num_heads", video_cfg.get("predictor_num_heads", video_cfg.get("num_heads"))
|
||||
)
|
||||
|
||||
@@ -117,21 +117,33 @@ class VLAJEPAModel(nn.Module):
|
||||
)
|
||||
|
||||
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return Qwen's last decoder hidden state matching training behaviour.
|
||||
"""Return the last decoder hidden state before the final RMSNorm.
|
||||
|
||||
The original starVLA uses `output_hidden_states=True` and takes `hidden_states[-1]`.
|
||||
In transformers 5.x the `@capture_outputs` decorator explicitly replaces
|
||||
`hidden_states[-1]` with `last_hidden_state` (post-RMSNorm), so this call
|
||||
consistently returns the post-norm output regardless of transformers version,
|
||||
matching what the model was trained with.
|
||||
The model was trained with the output of the last transformer block BEFORE
|
||||
the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
|
||||
`output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
|
||||
`@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
|
||||
the correct pre-RMSNorm state, matching the training-time representation.
|
||||
"""
|
||||
outputs = self.qwen.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=True,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
)
|
||||
return outputs.hidden_states[-1]
|
||||
captured: list[torch.Tensor] = []
|
||||
|
||||
def _hook(module, input, output):
|
||||
h = output[0] if isinstance(output, tuple) else output
|
||||
captured.append(h)
|
||||
|
||||
last_layer = self.qwen.model.model.language_model.layers[-1]
|
||||
handle = last_layer.register_forward_hook(_hook)
|
||||
try:
|
||||
self.qwen.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
)
|
||||
finally:
|
||||
handle.remove()
|
||||
|
||||
return captured[0] # [B, seq_len, H]
|
||||
|
||||
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
|
||||
|
||||
|
||||
Reference in New Issue
Block a user