From 8f624f1c1eb5909dd7242bc7eb223b46dcf79bb5 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 16 Sep 2025 16:48:08 +0200 Subject: [PATCH] cast float64 to float32 for mps --- src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py index 73e3a0f44..b97fb2acf 100644 --- a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py +++ b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py @@ -37,6 +37,8 @@ from lerobot.policies.pretrained import PreTrainedPolicy # Helper functions def get_safe_dtype(target_dtype, device_type): # see openpi `get_safe_dtype` (exact copy) """Get a safe dtype for the given device type.""" + if device_type == "mps" and target_dtype == torch.float64: + return torch.float32 if device_type == "cpu": # CPU doesn't support bfloat16, use float32 instead if target_dtype == torch.bfloat16: