mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 12:21:27 +00:00
(feat)policies: add VLA-JEPA
This commit is contained in:
@@ -129,7 +129,9 @@ class VLAJEPAModel(nn.Module):
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
with torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
qwen_outputs = self.qwen.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=True,
|
||||
@@ -201,7 +203,7 @@ class VLAJEPAModel(nn.Module):
|
||||
return {"wm_loss": wm_loss}
|
||||
|
||||
# ---- Step 4: Action Head (same as original) ----
|
||||
with torch.autocast("cuda", dtype=torch.float32):
|
||||
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
||||
actions_tensor = torch.tensor(
|
||||
np.array(actions), device=last_hidden.device, dtype=torch.float32
|
||||
) # [B, T_full, action_dim]
|
||||
@@ -249,7 +251,9 @@ class VLAJEPAModel(nn.Module):
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
with torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
qwen_outputs = self.qwen.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=True,
|
||||
@@ -266,7 +270,7 @@ class VLAJEPAModel(nn.Module):
|
||||
device=last_hidden.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
with torch.autocast("cuda", dtype=torch.float32):
|
||||
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
||||
# Cast embodied tokens to float32 for action model compatibility
|
||||
pred_actions = self.action_model.predict_action(
|
||||
embodied_action_tokens.float(), state_tensor
|
||||
|
||||
Reference in New Issue
Block a user