mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 18:31:25 +00:00
cast float64 to float32 for mps
This commit is contained in:
@@ -37,6 +37,8 @@ from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
# Helper functions
|
||||
def get_safe_dtype(target_dtype, device_type): # see openpi `get_safe_dtype` (exact copy)
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "mps" and target_dtype == torch.float64:
|
||||
return torch.float32
|
||||
if device_type == "cpu":
|
||||
# CPU doesn't support bfloat16, use float32 instead
|
||||
if target_dtype == torch.bfloat16:
|
||||
|
||||
Reference in New Issue
Block a user