diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 888ec3b36..ea08e2d1a 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -28,6 +28,8 @@ import torch from huggingface_hub import ModelHubMixin, hf_hub_download from safetensors.torch import load_file, save_file +from lerobot.utils.utils import get_safe_torch_device + class TransitionKey(str, Enum): """Keys for accessing EnvTransition dictionary components.""" @@ -465,7 +467,7 @@ class RobotProcessor(ModelHubMixin): to the target device, and reload it. Only works for steps that implement both state_dict() and load_state_dict() methods. """ - device = torch.device(device) + device = get_safe_torch_device(device) for step in self.steps: if hasattr(step, "state_dict") and hasattr(step, "load_state_dict"):