mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
fix qwen norm layer output libero eval is now as expected
This commit is contained in:
@@ -182,9 +182,7 @@ def _apply_source_config(kwargs: dict, source_config: dict) -> None:
|
|||||||
_set_if_present(kwargs, "action_dropout", diffusion_cfg.get("dropout"))
|
_set_if_present(kwargs, "action_dropout", diffusion_cfg.get("dropout"))
|
||||||
|
|
||||||
_set_if_present(kwargs, "num_video_frames", video_cfg.get("num_frames"))
|
_set_if_present(kwargs, "num_video_frames", video_cfg.get("num_frames"))
|
||||||
_set_if_present(
|
_set_if_present(kwargs, "predictor_depth", video_cfg.get("predictor_depth", video_cfg.get("depth")))
|
||||||
kwargs, "predictor_depth", video_cfg.get("predictor_depth", video_cfg.get("depth"))
|
|
||||||
)
|
|
||||||
_set_if_present(
|
_set_if_present(
|
||||||
kwargs, "predictor_num_heads", video_cfg.get("predictor_num_heads", video_cfg.get("num_heads"))
|
kwargs, "predictor_num_heads", video_cfg.get("predictor_num_heads", video_cfg.get("num_heads"))
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -117,21 +117,33 @@ class VLAJEPAModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
|
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
"""Return Qwen's last decoder hidden state matching training behaviour.
|
"""Return the last decoder hidden state before the final RMSNorm.
|
||||||
|
|
||||||
The original starVLA uses `output_hidden_states=True` and takes `hidden_states[-1]`.
|
The model was trained with the output of the last transformer block BEFORE
|
||||||
In transformers 5.x the `@capture_outputs` decorator explicitly replaces
|
the final RMSNorm. In transformers 5.x, `hidden_states[-1]` from
|
||||||
`hidden_states[-1]` with `last_hidden_state` (post-RMSNorm), so this call
|
`output_hidden_states=True` is post-norm (tied to `last_hidden_state` via
|
||||||
consistently returns the post-norm output regardless of transformers version,
|
`@capture_outputs`). A forward hook on `language_model.layers[-1]` recovers
|
||||||
matching what the model was trained with.
|
the correct pre-RMSNorm state, matching the training-time representation.
|
||||||
"""
|
"""
|
||||||
outputs = self.qwen.model(
|
captured: list[torch.Tensor] = []
|
||||||
**qwen_inputs,
|
|
||||||
output_hidden_states=True,
|
def _hook(module, input, output):
|
||||||
output_attentions=False,
|
h = output[0] if isinstance(output, tuple) else output
|
||||||
return_dict=True,
|
captured.append(h)
|
||||||
)
|
|
||||||
return outputs.hidden_states[-1]
|
last_layer = self.qwen.model.model.language_model.layers[-1]
|
||||||
|
handle = last_layer.register_forward_hook(_hook)
|
||||||
|
try:
|
||||||
|
self.qwen.model(
|
||||||
|
**qwen_inputs,
|
||||||
|
output_hidden_states=False,
|
||||||
|
output_attentions=False,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
handle.remove()
|
||||||
|
|
||||||
|
return captured[0] # [B, seq_len, H]
|
||||||
|
|
||||||
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
|
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user