revert to self.transformer

This commit is contained in:
Jade Choghari
2025-11-17 14:59:45 +01:00
parent 8591fc10b3
commit 9896ba4ee4

View File

@@ -68,7 +68,7 @@ class XVLAModel(nn.Module):
if projection_dim is None:
raise ValueError("Florence2 config must provide `projection_dim` for multimodal fusion.")
self.soft_prompted_transformer = SoftPromptedTransformer(
self.transformer = SoftPromptedTransformer(
hidden_size=config.hidden_size,
multi_modal_input_size=projection_dim,
depth=config.depth,
@@ -140,7 +140,7 @@ class XVLAModel(nn.Module):
action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
pred_action = self.soft_prompted_transformer(
pred_action = self.transformer(
domain_id=domain_id,
action_with_noise=action_noisy_m,
t=t,
@@ -173,7 +173,7 @@ class XVLAModel(nn.Module):
t = torch.full((batch_size,), i / steps, device=proprio.device, dtype=proprio.dtype)
x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t)
action = self.soft_prompted_transformer(
action = self.transformer(
domain_id=domain_id,
action_with_noise=x_t_m,
proprio=proprio_m,
@@ -422,6 +422,7 @@ class XVLAPolicy(PreTrainedPolicy):
shared_key = "model.vlm.language_model.model.shared.weight"
if encoder_key in state_dict:
state_dict[shared_key] = state_dict[encoder_key]
# or deepcopy
# step 5: load into instance
missing, unexpected = instance.load_state_dict(state_dict, strict=True)
print("Loaded XVLA checkpoint")