mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 19:31:25 +00:00
chore (batch handling): Enhance processing components with batch conversion utilities
This commit is contained in:
@@ -33,6 +33,9 @@ class DeviceProcessor:
|
||||
|
||||
device: str = "cpu"
|
||||
|
||||
def __post_init__(self):
|
||||
self.non_blocking = "cuda" in self.device
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation: dict[str, torch.Tensor] = transition[TransitionIndex.OBSERVATION]
|
||||
action = transition[TransitionIndex.ACTION]
|
||||
@@ -43,7 +46,9 @@ class DeviceProcessor:
|
||||
complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
|
||||
|
||||
if observation is not None:
|
||||
observation = {k: v.to(self.device) for k, v in observation.items()}
|
||||
observation = {
|
||||
k: v.to(self.device, non_blocking=self.non_blocking) for k, v in observation.items()
|
||||
}
|
||||
if action is not None:
|
||||
action = action.to(self.device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user