feat(DeviceProcessor): Enhance tensor processing with device detection and float dtype conversion

- Improved the _process_tensor method to preserve GPU placement for tensors already on a GPU, facilitating multi-GPU training scenarios.
- Introduced a new _detect_device method in TokenizerProcessor to ensure tokenized tensors match the device of existing tensors in transitions.
- Added comprehensive unit tests to validate the functionality of device detection and float dtype conversion across various scenarios.
This commit is contained in:
AdilZouitine
2025-08-08 19:33:24 +02:00
parent 8bde9d0ab7
commit 5ca3920611
4 changed files with 468 additions and 4 deletions

View File

@@ -134,9 +134,19 @@ class TokenizerProcessor:
if task is None:
return transition
# Tokenize the task
# Tokenize the task (creates CPU tensors)
tokenized_prompt = self._tokenize_text(task)
# Detect device from existing tensors in the transition
target_device = self._detect_device(transition)
# Move tokenized tensors to match the device of other data
if target_device is not None:
tokenized_prompt = {
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
for k, v in tokenized_prompt.items()
}
# Get or create observation dict
observation = transition.get(TransitionKey.OBSERVATION)
if observation is None:
@@ -153,6 +163,45 @@ class TokenizerProcessor:
transition[TransitionKey.OBSERVATION.value] = observation # type: ignore[misc]
return transition
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
"""Detect device from existing tensors in the transition.
This allows the tokenized tensors to match the device of other data,
which is especially important for multi-GPU training with Accelerate.
Args:
transition: The transition to search for existing tensors.
Returns:
The device of the first tensor found, or None if no tensors exist.
"""
# Check observation tensors first (most likely to exist)
observation = transition.get(TransitionKey.OBSERVATION)
if observation:
for value in observation.values():
if isinstance(value, torch.Tensor):
return value.device
# Check action tensor
action = transition.get(TransitionKey.ACTION)
if isinstance(action, torch.Tensor):
return action.device
# Check other tensor fields
for key in [TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED]:
value = transition.get(key)
if isinstance(value, torch.Tensor):
return value.device
# Check complementary data for tensors
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data:
for value in complementary_data.values():
if isinstance(value, torch.Tensor):
return value.device
return None # No tensors found, keep on CPU
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
"""Tokenize text using the configured tokenizer.