mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 03:11:29 +00:00
feat(DeviceProcessor): Enhance tensor processing with device detection and float dtype conversion
- Improved the _process_tensor method to preserve GPU placement for tensors already on a GPU, facilitating multi-GPU training scenarios. - Introduced a new _detect_device method in TokenizerProcessor to ensure tokenized tensors match the device of existing tensors in transitions. - Added comprehensive unit tests to validate the functionality of device detection and float dtype conversion across various scenarios.
This commit is contained in:
@@ -66,9 +66,26 @@ class DeviceProcessor:
|
||||
self._target_float_dtype = None
|
||||
|
||||
def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Process a tensor by moving to device and optionally converting float dtype."""
|
||||
# Move to device first
|
||||
tensor = tensor.to(self.device, non_blocking=self.non_blocking)
|
||||
"""Process a tensor by moving to device and optionally converting float dtype.
|
||||
|
||||
If the tensor is already on a GPU and we're configured for a GPU, it preserves
|
||||
that GPU placement (useful for multi-GPU training with Accelerate).
|
||||
Otherwise, it moves to the configured device.
|
||||
"""
|
||||
# Determine target device
|
||||
if tensor.is_cuda and self._device.type == "cuda":
|
||||
# Both tensor and target are on GPU - preserve tensor's GPU placement
|
||||
# This handles multi-GPU scenarios where Accelerate has already placed
|
||||
# tensors on the correct GPU for each process
|
||||
target_device = tensor.device
|
||||
else:
|
||||
# Either tensor is on CPU, or we're configured for CPU
|
||||
# In both cases, use the configured device
|
||||
target_device = self._device
|
||||
|
||||
# Only move if necessary
|
||||
if tensor.device != target_device:
|
||||
tensor = tensor.to(target_device, non_blocking=self.non_blocking)
|
||||
|
||||
# Convert float dtype if specified and tensor is floating point
|
||||
if self._target_float_dtype is not None and tensor.is_floating_point():
|
||||
|
||||
Reference in New Issue
Block a user