cast float64 to float32 for mps

This commit is contained in:
Pepijn
2025-09-16 16:48:08 +02:00
parent 5924d4d9eb
commit 8f624f1c1e

View File

@@ -37,6 +37,8 @@ from lerobot.policies.pretrained import PreTrainedPolicy
# Helper functions
def get_safe_dtype(target_dtype, device_type): # see openpi `get_safe_dtype` (exact copy)
"""Get a safe dtype for the given device type."""
if device_type == "mps" and target_dtype == torch.float64:
return torch.float32
if device_type == "cpu":
# CPU doesn't support bfloat16, use float32 instead
if target_dtype == torch.bfloat16: