refactor(pipeline): Utilize get_safe_torch_device for device assignment

- Replaced direct torch.device instantiation with get_safe_torch_device to ensure safe device handling.
- This change enhances code readability and maintains consistency in device management across the RobotProcessor class.
This commit is contained in:
Adil Zouitine
2025-07-22 11:03:28 +02:00
parent fb9139b882
commit ae7a54de57

View File

@@ -28,6 +28,8 @@ import torch
from huggingface_hub import ModelHubMixin, hf_hub_download
from safetensors.torch import load_file, save_file
from lerobot.utils.utils import get_safe_torch_device
class TransitionKey(str, Enum):
"""Keys for accessing EnvTransition dictionary components."""
@@ -465,7 +467,7 @@ class RobotProcessor(ModelHubMixin):
to the target device, and reload it. Only works for steps that implement
both state_dict() and load_state_dict() methods.
"""
device = torch.device(device)
device = get_safe_torch_device(device)
for step in self.steps:
if hasattr(step, "state_dict") and hasattr(step, "load_state_dict"):