From b7727b8a6cdc5131fe6c37bb84f72cb36f46ed42 Mon Sep 17 00:00:00 2001 From: Maximellerbach Date: Thu, 21 May 2026 15:40:17 +0200 Subject: [PATCH] adressing dtype zeros issue --- src/lerobot/policies/vla_jepa/modeling_vla_jepa.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index 7d728b774..88b2fbbdd 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -530,9 +530,9 @@ class VLAJEPAPolicy(PreTrainedPolicy): examples = self._prepare_model_inputs(batch) native_output = self.model.forward(examples) - total_loss = native_output.get("action_loss", torch.tensor(0.0)) + native_output.get( - "wm_loss", torch.tensor(0.0) - ) + ref = next(iter(native_output.values())) + zero = torch.zeros((), device=ref.device, dtype=ref.dtype) + total_loss = native_output.get("action_loss", zero) + native_output.get("wm_loss", zero) logs = {k: v.detach().item() for k, v in native_output.items()} logs["loss"] = total_loss.detach().item() return total_loss, logs @@ -577,6 +577,10 @@ class VLAJEPAPolicy(PreTrainedPolicy): @classmethod def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: + """ + Custom loading to enable opt reinit of action head + when loading pretrained weights with mismatched action head shapes. + """ if not model.config.reinit_action_head: return super()._load_as_safetensor(model, model_file, map_location, strict)