chore (batch handling): Enhance processing components with batch conversion utilities

This commit is contained in:
Adil Zouitine
2025-07-06 21:29:51 +02:00
parent c227107f60
commit b08149a113
6 changed files with 606 additions and 53 deletions

View File

@@ -33,6 +33,9 @@ class DeviceProcessor:
device: str = "cpu"
def __post_init__(self):
self.non_blocking = "cuda" in self.device
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation: dict[str, torch.Tensor] = transition[TransitionIndex.OBSERVATION]
action = transition[TransitionIndex.ACTION]
@@ -43,7 +46,9 @@ class DeviceProcessor:
complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
if observation is not None:
observation = {k: v.to(self.device) for k, v in observation.items()}
observation = {
k: v.to(self.device, non_blocking=self.non_blocking) for k, v in observation.items()
}
if action is not None:
action = action.to(self.device)