mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
refactor(pipeline): Utilize get_safe_torch_device for device assignment
- Replaced direct torch.device instantiation with get_safe_torch_device to ensure safe device handling. - This change enhances code readability and maintains consistency in device management across the RobotProcessor class.
This commit is contained in:
@@ -28,6 +28,8 @@ import torch
|
||||
from huggingface_hub import ModelHubMixin, hf_hub_download
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.utils.utils import get_safe_torch_device
|
||||
|
||||
|
||||
class TransitionKey(str, Enum):
|
||||
"""Keys for accessing EnvTransition dictionary components."""
|
||||
@@ -465,7 +467,7 @@ class RobotProcessor(ModelHubMixin):
|
||||
to the target device, and reload it. Only works for steps that implement
|
||||
both state_dict() and load_state_dict() methods.
|
||||
"""
|
||||
device = torch.device(device)
|
||||
device = get_safe_torch_device(device)
|
||||
|
||||
for step in self.steps:
|
||||
if hasattr(step, "state_dict") and hasattr(step, "load_state_dict"):
|
||||
|
||||
Reference in New Issue
Block a user