mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 03:11:29 +00:00
refactor(pipeline): minor improvements (#1684)
* chore(pipeline): remove unused features + device torch + envtransition keys * refactor(pipeline): ImageProcessor & StateProcessor are both implemented directly in VanillaObservationPRocessor * refactor(pipeline): RenameProcessor now inherits from ObservationProcessor + remove unused code * test(pipeline): fix broken test after refactors * docs(pipeline): update docstrings VanillaObservationProcessor * chore(pipeline): move None check to base pipeline classes
This commit is contained in:
@@ -20,6 +20,7 @@ import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||
from lerobot.utils.utils import get_safe_torch_device
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -30,10 +31,11 @@ class DeviceProcessor:
|
||||
specified device (CPU or GPU) before they are returned.
|
||||
"""
|
||||
|
||||
device: str = "cpu"
|
||||
device: torch.device = "cpu"
|
||||
|
||||
def __post_init__(self):
|
||||
self.non_blocking = "cuda" in self.device
|
||||
self.device = get_safe_torch_device(self.device)
|
||||
self.non_blocking = "cuda" in str(self.device)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Create a copy of the transition
|
||||
|
||||
Reference in New Issue
Block a user