diff --git a/src/lerobot/policies/xvla/modeling_xvla.py b/src/lerobot/policies/xvla/modeling_xvla.py index 92ea01007..bad553fbe 100644 --- a/src/lerobot/policies/xvla/modeling_xvla.py +++ b/src/lerobot/policies/xvla/modeling_xvla.py @@ -68,7 +68,7 @@ class XVLAModel(nn.Module): if projection_dim is None: raise ValueError("Florence2 config must provide `projection_dim` for multimodal fusion.") - self.soft_prompted_transformer = SoftPromptedTransformer( + self.transformer = SoftPromptedTransformer( hidden_size=config.hidden_size, multi_modal_input_size=projection_dim, depth=config.depth, @@ -140,7 +140,7 @@ class XVLAModel(nn.Module): action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1) proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy) - pred_action = self.soft_prompted_transformer( + pred_action = self.transformer( domain_id=domain_id, action_with_noise=action_noisy_m, t=t, @@ -173,7 +173,7 @@ class XVLAModel(nn.Module): t = torch.full((batch_size,), i / steps, device=proprio.device, dtype=proprio.dtype) x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1) proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t) - action = self.soft_prompted_transformer( + action = self.transformer( domain_id=domain_id, action_with_noise=x_t_m, proprio=proprio_m, @@ -422,6 +422,7 @@ class XVLAPolicy(PreTrainedPolicy): shared_key = "model.vlm.language_model.model.shared.weight" if encoder_key in state_dict: state_dict[shared_key] = state_dict[encoder_key] + # or deepcopy # step 5: load into instance missing, unexpected = instance.load_state_dict(state_dict, strict=True) print("Loaded XVLA checkpoint")