mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
revert to self.transformer
This commit is contained in:
@@ -68,7 +68,7 @@ class XVLAModel(nn.Module):
|
||||
if projection_dim is None:
|
||||
raise ValueError("Florence2 config must provide `projection_dim` for multimodal fusion.")
|
||||
|
||||
self.soft_prompted_transformer = SoftPromptedTransformer(
|
||||
self.transformer = SoftPromptedTransformer(
|
||||
hidden_size=config.hidden_size,
|
||||
multi_modal_input_size=projection_dim,
|
||||
depth=config.depth,
|
||||
@@ -140,7 +140,7 @@ class XVLAModel(nn.Module):
|
||||
action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
|
||||
proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
|
||||
|
||||
pred_action = self.soft_prompted_transformer(
|
||||
pred_action = self.transformer(
|
||||
domain_id=domain_id,
|
||||
action_with_noise=action_noisy_m,
|
||||
t=t,
|
||||
@@ -173,7 +173,7 @@ class XVLAModel(nn.Module):
|
||||
t = torch.full((batch_size,), i / steps, device=proprio.device, dtype=proprio.dtype)
|
||||
x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
|
||||
proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t)
|
||||
action = self.soft_prompted_transformer(
|
||||
action = self.transformer(
|
||||
domain_id=domain_id,
|
||||
action_with_noise=x_t_m,
|
||||
proprio=proprio_m,
|
||||
@@ -422,6 +422,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
shared_key = "model.vlm.language_model.model.shared.weight"
|
||||
if encoder_key in state_dict:
|
||||
state_dict[shared_key] = state_dict[encoder_key]
|
||||
# or deepcopy
|
||||
# step 5: load into instance
|
||||
missing, unexpected = instance.load_state_dict(state_dict, strict=True)
|
||||
print("Loaded XVLA checkpoint")
|
||||
|
||||
Reference in New Issue
Block a user