(feat)policies: add VLA-JEPA

This commit is contained in:
ginwind
2026-05-11 06:54:15 +00:00
committed by Maximellerbach
parent addb354296
commit da56489174
2 changed files with 328 additions and 4 deletions

View File

@@ -129,7 +129,9 @@ class VLAJEPAModel(nn.Module):
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
with torch.autocast("cuda", dtype=torch.bfloat16):
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
qwen_outputs = self.qwen.model(
**qwen_inputs,
output_hidden_states=True,
@@ -201,7 +203,7 @@ class VLAJEPAModel(nn.Module):
return {"wm_loss": wm_loss}
# ---- Step 4: Action Head (same as original) ----
with torch.autocast("cuda", dtype=torch.float32):
with torch.autocast(device_type=device_type, dtype=torch.float32):
actions_tensor = torch.tensor(
np.array(actions), device=last_hidden.device, dtype=torch.float32
) # [B, T_full, action_dim]
@@ -249,7 +251,9 @@ class VLAJEPAModel(nn.Module):
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
with torch.autocast("cuda", dtype=torch.bfloat16):
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
qwen_outputs = self.qwen.model(
**qwen_inputs,
output_hidden_states=True,
@@ -266,7 +270,7 @@ class VLAJEPAModel(nn.Module):
device=last_hidden.device, dtype=torch.float32
)
with torch.autocast("cuda", dtype=torch.float32):
with torch.autocast(device_type=device_type, dtype=torch.float32):
# Cast embodied tokens to float32 for action model compatibility
pred_actions = self.action_model.predict_action(
embodied_action_tokens.float(), state_tensor