This commit is contained in:
Jade Choghari
2025-11-17 18:45:43 +01:00
parent b16bc5f1ff
commit efacf8f0e0

View File

@@ -402,22 +402,6 @@ class XVLAPolicy(PreTrainedPolicy):
print(f"Loading checkpoint from {model_file}")
state_dict = safetensors.torch.load_file(model_file)
# # --- Step 4: Modify keys ---
# new_state_dict = {f"model.{k}": v for k, v in state_dict.items()}
# # Layers to skip (reinitialize)
# keys_to_skip = [
# "model.transformer.action_encoder.fc.weight",
# "model.transformer.action_encoder.fc.bias",
# "model.transformer.action_decoder.fc.weight",
# "model.transformer.action_decoder.bias.weight"
# ]
# new_state_dict = {
# k: v for k, v in new_state_dict.items()
# if k not in keys_to_skip
# }
# # ---- ADD THIS: Fix shared embeddings ----
encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight"
shared_key = "model.vlm.language_model.model.shared.weight"
if encoder_key in state_dict: