feature(pipeline): port tokenizer pipeline for VLA (#1645)

* feat(tokenizer): Introduce TokenizerProcessor for text tokenization

- Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer.
- Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings.
- Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor.
- Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor.

* feat(language): Enhance language processing in TokenizerProcessor

- Added OBS_LANGUAGE constant to define the observation language key.
- Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature.
- Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization.
- Modified tests to validate the integration of language tokens and attention masks in the observation structure.

* feat(tokenizer): Add padding configuration to TokenizerProcessor

- Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction.
- Updated the `make_pi0_processor` function to include the new padding configuration.
- Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios.

* feat(processor): Add state management methods to Pi0NewLineProcessor

* feat(normalization): Track normalization and unnormalization info in complementary data

- Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes.
- Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions.
- Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys.

* feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs

- Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations.
- Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization.

* feat(processors): Integrate RenameProcessor into various processor configurations

- Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor.
- Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency.
- Updated the input steps to ensure compatibility with the new RenameProcessor integration.

* feat(smolvla): Refactor language processing and introduce new line processor (#1658)

- Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant.
- Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility.
- Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling.

* feture(policies): add device processor (#1659)

* feat(processors): Integrate DeviceProcessor into multiple processor configurations

- Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor.
- Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines.
- Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor(pipeline): Remove to() method for device management

- Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices.
- Removed associated unit tests that validated the functionality of the to() method across various scenarios.
- Streamlined the pipeline code by focusing on other device management strategies.

* feat(processor): Enhance DeviceProcessor with float dtype conversion

- Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types.
- Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype.
- Refactored tensor processing logic to streamline device movement and dtype conversion.
- Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios.

* feat(policies): Add new line processors and update module exports

* feat(processor): Enhance batch and device processors to handle index and task_index fields

- Added logic to ToBatchProcessor for unsqueezing 0D tensors for index and task_index fields, ensuring they are processed as 1D tensors.
- Updated DeviceProcessor to process index and task_index fields in complementary data, preserving their tensor types and ensuring non-tensor fields remain unchanged.
- Enhanced unit tests to validate the correct handling of index and task_index fields across various scenarios, including device compatibility and dtype preservation.
This commit is contained in:
Adil Zouitine
2025-08-05 10:53:08 +02:00
committed by Steven Palma
parent a1734cf575
commit 5326ffe77e
26 changed files with 2776 additions and 232 deletions

View File

@@ -33,6 +33,7 @@ from .pipeline import (
TruncatedProcessor,
)
from .rename_processor import RenameProcessor
from .tokenizer_processor import TokenizerProcessor
__all__ = [
"ActionProcessor",
@@ -51,6 +52,7 @@ __all__ = [
"RewardProcessor",
"RobotProcessor",
"ToBatchProcessor",
"TokenizerProcessor",
"TransitionKey",
"TruncatedProcessor",
"VanillaObservationProcessor",

View File

@@ -106,6 +106,18 @@ class ToBatchProcessor:
if isinstance(task_value, str):
complementary_data["task"] = [task_value]
# Process index field - add batch dim if 0D
if "index" in complementary_data:
index_value = complementary_data["index"]
if isinstance(index_value, Tensor) and index_value.dim() == 0:
complementary_data["index"] = index_value.unsqueeze(0)
# Process task_index field - add batch dim if 0D
if "task_index" in complementary_data:
task_index_value = complementary_data["task_index"]
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
complementary_data["task_index"] = task_index_value.unsqueeze(0)
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}

View File

@@ -19,24 +19,61 @@ from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import EnvTransition, TransitionKey
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
from lerobot.utils.utils import get_safe_torch_device
@ProcessorStepRegistry.register("device_processor")
@dataclass
class DeviceProcessor:
"""Processes transitions by moving tensors to the specified device.
"""Processes transitions by moving tensors to the specified device and optionally converting float dtypes.
This processor ensures that all tensors in the transition are moved to the
specified device (CPU or GPU) before they are returned.
specified device (CPU or GPU) before they are returned. It can also convert
floating-point tensors to a specified dtype while preserving non-float types
(int, long, bool, etc.).
"""
device: torch.device = "cpu"
float_dtype: str | None = None
def __post_init__(self):
self.device = get_safe_torch_device(self.device)
self.non_blocking = "cuda" in str(self.device)
# Validate and convert float_dtype string to torch dtype
if self.float_dtype is not None:
dtype_mapping = {
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
"bfloat16": torch.bfloat16,
"half": torch.float16,
"float": torch.float32,
"double": torch.float64,
}
if self.float_dtype not in dtype_mapping:
available_dtypes = list(dtype_mapping.keys())
raise ValueError(
f"Invalid float_dtype '{self.float_dtype}'. Available options: {available_dtypes}"
)
self._target_float_dtype = dtype_mapping[self.float_dtype]
else:
self._target_float_dtype = None
def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
"""Process a tensor by moving to device and optionally converting float dtype."""
# Move to device first
tensor = tensor.to(self.device, non_blocking=self.non_blocking)
# Convert float dtype if specified and tensor is floating point
if self._target_float_dtype is not None and tensor.is_floating_point():
tensor = tensor.to(dtype=self._target_float_dtype)
return tensor
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Create a copy of the transition
new_transition = transition.copy()
@@ -45,7 +82,7 @@ class DeviceProcessor:
observation = transition.get(TransitionKey.OBSERVATION)
if observation is not None:
new_observation = {
k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v
k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v
for k, v in observation.items()
}
new_transition[TransitionKey.OBSERVATION] = new_observation
@@ -53,30 +90,54 @@ class DeviceProcessor:
# Process action tensor
action = transition.get(TransitionKey.ACTION)
if action is not None and isinstance(action, torch.Tensor):
new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking)
new_transition[TransitionKey.ACTION] = self._process_tensor(action)
# Process reward tensor
reward = transition.get(TransitionKey.REWARD)
if reward is not None and isinstance(reward, torch.Tensor):
new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking)
new_transition[TransitionKey.REWARD] = self._process_tensor(reward)
# Process done tensor
done = transition.get(TransitionKey.DONE)
if done is not None and isinstance(done, torch.Tensor):
new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking)
new_transition[TransitionKey.DONE] = self._process_tensor(done)
# Process truncated tensor
truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is not None and isinstance(truncated, torch.Tensor):
new_transition[TransitionKey.TRUNCATED] = truncated.to(
self.device, non_blocking=self.non_blocking
)
new_transition[TransitionKey.TRUNCATED] = self._process_tensor(truncated)
# Process complementary data tensors
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is not None:
new_complementary_data = {}
# Process all items in complementary_data
for key, value in complementary_data.items():
if isinstance(value, torch.Tensor):
new_complementary_data[key] = self._process_tensor(value)
else:
new_complementary_data[key] = value
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {"device": self.device}
return {"device": self.device, "float_dtype": self.float_dtype}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -116,7 +116,7 @@ class NormalizerProcessor:
if self.normalize_keys is not None and not isinstance(self.normalize_keys, set):
self.normalize_keys = set(self.normalize_keys)
def _normalize_obs(self, observation):
def _normalize_obs(self, observation, normalized_info):
if observation is None:
return None
@@ -138,6 +138,7 @@ class NormalizerProcessor:
# Skip normalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
normalized_info[key] = "IDENTITY"
continue
# Skip if no stats available for this key
@@ -156,16 +157,18 @@ class NormalizerProcessor:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = (tensor - mean) / (std + self.eps)
normalized_info[key] = "MEAN_STD"
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
normalized_info[key] = "MIN_MAX"
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
return processed
def _normalize_action(self, action):
def _normalize_action(self, action, normalized_info):
if action is None:
return action
@@ -174,6 +177,7 @@ class NormalizerProcessor:
# Skip normalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
normalized_info["action"] = "IDENTITY"
return action
# Skip if no stats available for actions
@@ -190,10 +194,12 @@ class NormalizerProcessor:
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
normalized_info["action"] = "MEAN_STD"
return (tensor - mean) / (std + self.eps)
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
normalized_info["action"] = "MIN_MAX"
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
@@ -202,13 +208,24 @@ class NormalizerProcessor:
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
action = self._normalize_action(transition.get(TransitionKey.ACTION))
# Track what was normalized
normalized_info = {}
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION), normalized_info)
action = self._normalize_action(transition.get(TransitionKey.ACTION), normalized_info)
# Create a new transition with normalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
# Add normalization info to complementary data
if normalized_info:
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp_data = {} if comp_data is None else dict(comp_data)
comp_data["normalized_keys"] = normalized_info
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self) -> dict[str, Any]:
@@ -289,7 +306,7 @@ class UnnormalizerProcessor:
self.stats = self.stats or {}
self._tensor_stats = _convert_stats_to_tensors(self.stats)
def _unnormalize_obs(self, observation):
def _unnormalize_obs(self, observation, unnormalized_info):
if observation is None:
return None
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
@@ -304,6 +321,7 @@ class UnnormalizerProcessor:
# Skip unnormalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
unnormalized_info[key] = "IDENTITY"
continue
# Skip if no stats available for this key
@@ -322,16 +340,18 @@ class UnnormalizerProcessor:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = tensor * std + mean
unnormalized_info[key] = "MEAN_STD"
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
unnormalized_info[key] = "MIN_MAX"
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
return processed
def _unnormalize_action(self, action):
def _unnormalize_action(self, action, unnormalized_info):
if action is None:
return action
@@ -340,6 +360,7 @@ class UnnormalizerProcessor:
# Skip unnormalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
unnormalized_info["action"] = "IDENTITY"
return action
# Skip if no stats available for actions
@@ -356,10 +377,12 @@ class UnnormalizerProcessor:
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
unnormalized_info["action"] = "MEAN_STD"
return tensor * std + mean
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
unnormalized_info["action"] = "MIN_MAX"
return (tensor + 1) / 2 * (max_val - min_val) + min_val
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
@@ -368,13 +391,24 @@ class UnnormalizerProcessor:
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
action = self._unnormalize_action(transition.get(TransitionKey.ACTION))
# Track what was unnormalized
unnormalized_info = {}
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION), unnormalized_info)
action = self._unnormalize_action(transition.get(TransitionKey.ACTION), unnormalized_info)
# Create a new transition with unnormalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
# Add unnormalization info to complementary data
if unnormalized_info:
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp_data = {} if comp_data is None else dict(comp_data)
comp_data["unnormalized_keys"] = unnormalized_info
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self) -> dict[str, Any]:
@@ -413,3 +447,29 @@ def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, An
step.stats = stats
step._tensor_stats = _convert_stats_to_tensors(stats)
return robot_processor
def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]:
"""Rename keys in the stats dictionary according to the provided mapping.
Args:
stats: The statistics dictionary with structure {feature_key: {stat_name: value}}
rename_map: Dictionary mapping old key names to new key names
Returns:
A new stats dictionary with renamed keys
Example:
>>> stats = {"observation.state": {"mean": 0.0, "std": 1.0}, "action": {"mean": 0.5, "std": 0.5}}
>>> rename_map = {"observation.state": "observation.robot_state"}
>>> new_stats = rename_stats(stats, rename_map)
>>> # new_stats will have "observation.robot_state" instead of "observation.state"
"""
renamed_stats = {}
for old_key, sub_stats in stats.items():
# Use the new key if it exists in the rename map, otherwise keep the old key
new_key = rename_map.get(old_key, old_key)
renamed_stats[new_key] = deepcopy(sub_stats)
return renamed_stats

View File

@@ -201,10 +201,16 @@ def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noq
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
observation = observation_keys if observation_keys else None
# Extract padding and task keys for complementary data
# Extract padding, task, index, and task_index keys for complementary data
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
complementary_data = (
{**pad_keys, **task_key, **index_key, **task_index_key}
if pad_keys or task_key or index_key or task_index_key
else {}
)
transition: EnvTransition = {
TransitionKey.OBSERVATION: observation,
@@ -231,7 +237,7 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
"info": transition.get(TransitionKey.INFO, {}),
}
# Add padding and task data from complementary_data
# Add padding, task, index, and task_index data from complementary_data
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data:
pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k}
@@ -240,6 +246,12 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
if "task" in complementary_data:
batch["task"] = complementary_data["task"]
if "index" in complementary_data:
batch["index"] = complementary_data["index"]
if "task_index" in complementary_data:
batch["task_index"] = complementary_data["task_index"]
# Handle observation - flatten dict to observation.* keys if it's a dict
observation = transition.get(TransitionKey.OBSERVATION)
if isinstance(observation, dict):

View File

@@ -0,0 +1,210 @@
"""
Tokenizer processor for handling text tokenization in robot transitions.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
import torch
from transformers import AutoTokenizer
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import OBS_LANGUAGE
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
@dataclass
@ProcessorStepRegistry.register(name="tokenizer_processor")
class TokenizerProcessor:
"""Tokenizes text tasks in complementary data using a huggingface tokenizer.
This processor handles tokenization of task strings found in the complementary_data
using a specified pretrained tokenizer from Hugging Face. It adds tokenized versions
to the observation data for model processing while preserving the original task string.
The processor supports both single strings and lists of strings as task inputs.
Args:
tokenizer_name: Name of the pretrained tokenizer to load from Hugging Face Hub
(e.g., "bert-base-uncased", "microsoft/DialoGPT-medium"). This will be used
with AutoTokenizer.from_pretrained(). If tokenizer is provided, this is ignored.
tokenizer: A tokenizer object (e.g., from transformers library) that implements
the __call__ method. If provided, tokenizer_name is ignored. This parameter
is not serialized and must be provided via overrides when loading.
max_length: Maximum sequence length for tokenization. Defaults to 512.
task_key: Key in complementary_data containing the task text. Defaults to "task".
padding: Padding strategy for tokenization. Defaults to "max_length".
truncation: Whether to truncate sequences longer than max_length. Defaults to True.
Examples:
Using tokenizer name (auto-loaded):
```python
processor = TokenizerProcessor(tokenizer_name="bert-base-uncased", max_length=128)
```
Using custom tokenizer object:
```python
from transformers import AutoTokenizer
custom_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
processor = TokenizerProcessor(tokenizer=custom_tokenizer, max_length=128)
```
"""
tokenizer_name: str | None = None
tokenizer: AutoTokenizer | None = None
max_length: int = 512
task_key: str = "task"
padding_side: str = "right"
padding: str = "max_length"
truncation: bool = True
# Internal tokenizer instance (not serialized)
_tokenizer: Any = field(default=None, init=False, repr=False)
def __post_init__(self):
"""Initialize the tokenizer from the provided tokenizer or tokenizer name."""
if self.tokenizer is not None:
# Use provided tokenizer object directly
self._tokenizer = self.tokenizer
elif self.tokenizer_name is not None:
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
else:
raise ValueError(
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
"Pass a tokenizer object directly or a tokenizer name to auto-load."
)
def get_task(self, transition: EnvTransition) -> list[str] | None:
"""Extract and normalize task from complementary data.
Args:
transition: Input transition containing complementary_data.
Returns:
List of task strings if task is present, None otherwise.
"""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return None
if self.task_key not in complementary_data:
return None
task = complementary_data[self.task_key]
if task is None:
return None
# Convert to list of strings
if isinstance(task, str):
return [task]
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
return task
return None
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Process the transition by tokenizing the task text.
Args:
transition: Input transition containing complementary_data with task text.
Returns:
Modified transition with tokenized task added to observation.
Raises:
ValueError: If tokenizer initialization failed.
"""
task = self.get_task(transition)
if task is None:
return transition
# Tokenize the task
tokenized_prompt = self._tokenize_text(task)
# Get or create observation dict
if TransitionKey.OBSERVATION not in transition or transition[TransitionKey.OBSERVATION] is None:
transition[TransitionKey.OBSERVATION] = {}
observation = transition[TransitionKey.OBSERVATION]
# Add tokenized data to observation
observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"]
observation[f"{OBS_LANGUAGE}.attention_mask"] = tokenized_prompt["attention_mask"].to(
dtype=torch.bool
)
return transition
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
"""Tokenize text using the configured tokenizer.
Args:
text: Text string or list of strings to tokenize.
Returns:
Dictionary containing tokenized output with keys like 'input_ids', 'attention_mask'.
"""
return self._tokenizer(
text,
max_length=self.max_length,
truncation=self.truncation,
padding=self.padding,
padding_side=self.padding_side,
return_tensors="pt",
)
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization.
Note: Only tokenizer_name is saved, not the tokenizer object itself.
When loading, provide the tokenizer via overrides if needed.
"""
config = {
"max_length": self.max_length,
"task_key": self.task_key,
"padding_side": self.padding_side,
"padding": self.padding,
"truncation": self.truncation,
}
# Only include tokenizer_name if it was used (not when tokenizer object was provided)
if self.tokenizer_name is not None:
config["tokenizer_name"] = self.tokenizer_name
return config
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Add tokenized task features to the feature contract.
Args:
features: Input feature dictionary.
Returns:
Updated feature dictionary with tokenized task features added.
"""
# Add features for tokenized output if they don't exist
# Standard tokenizer output includes tokens and attention_mask
tokens_key = f"{OBS_LANGUAGE}.tokens"
attention_mask_key = f"{OBS_LANGUAGE}.attention_mask"
if tokens_key not in features:
features[tokens_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
if attention_mask_key not in features:
features[attention_mask_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
return features