feature(pipeline): port tokenizer pipeline for VLA (#1645)

* feat(tokenizer): Introduce TokenizerProcessor for text tokenization

- Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer.
- Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings.
- Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor.
- Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor.

* feat(language): Enhance language processing in TokenizerProcessor

- Added OBS_LANGUAGE constant to define the observation language key.
- Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature.
- Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization.
- Modified tests to validate the integration of language tokens and attention masks in the observation structure.

* feat(tokenizer): Add padding configuration to TokenizerProcessor

- Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction.
- Updated the `make_pi0_processor` function to include the new padding configuration.
- Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios.

* feat(processor): Add state management methods to Pi0NewLineProcessor

* feat(normalization): Track normalization and unnormalization info in complementary data

- Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes.
- Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions.
- Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys.

* feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs

- Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations.
- Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization.

* feat(processors): Integrate RenameProcessor into various processor configurations

- Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor.
- Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency.
- Updated the input steps to ensure compatibility with the new RenameProcessor integration.

* feat(smolvla): Refactor language processing and introduce new line processor (#1658)

- Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant.
- Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility.
- Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling.

* feture(policies): add device processor (#1659)

* feat(processors): Integrate DeviceProcessor into multiple processor configurations

- Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor.
- Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines.
- Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor(pipeline): Remove to() method for device management

- Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices.
- Removed associated unit tests that validated the functionality of the to() method across various scenarios.
- Streamlined the pipeline code by focusing on other device management strategies.

* feat(processor): Enhance DeviceProcessor with float dtype conversion

- Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types.
- Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype.
- Refactored tensor processing logic to streamline device movement and dtype conversion.
- Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios.

* feat(policies): Add new line processors and update module exports

* feat(processor): Enhance batch and device processors to handle index and task_index fields

- Added logic to ToBatchProcessor for unsqueezing 0D tensors for index and task_index fields, ensuring they are processed as 1D tensors.
- Updated DeviceProcessor to process index and task_index fields in complementary data, preserving their tensor types and ensuring non-tensor fields remain unchanged.
- Enhanced unit tests to validate the correct handling of index and task_index fields across various scenarios, including device compatibility and dtype preservation.
This commit is contained in:
Adil Zouitine
2025-08-05 10:53:08 +02:00
committed by Steven Palma
parent a1734cf575
commit 5326ffe77e
26 changed files with 2776 additions and 232 deletions

View File

@@ -24,6 +24,7 @@ class FeatureType(str, Enum):
ENV = "ENV"
ACTION = "ACTION"
REWARD = "REWARD"
LANGUAGE = "LANGUAGE"
class NormalizationMode(str, Enum):

View File

@@ -21,6 +21,7 @@ OBS_ENV_STATE = "observation.environment_state"
OBS_STATE = "observation.state"
OBS_IMAGE = "observation.image"
OBS_IMAGES = "observation.images"
OBS_LANGUAGE = "observation.language"
ACTION = "action"
REWARD = "next.reward"

View File

@@ -15,6 +15,17 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0.processor_pi0 import Pi0NewLineProcessor
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
__all__ = [
"ACTConfig",
"DiffusionConfig",
"PI0Config",
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",
]

View File

@@ -17,7 +17,9 @@ import torch
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
@@ -28,15 +30,17 @@ def make_act_processor(
config: ACTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
NormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),

View File

@@ -18,7 +18,9 @@ import torch
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
@@ -29,15 +31,17 @@ def make_diffusion_processor(
config: DiffusionConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
NormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),

View File

@@ -17,7 +17,7 @@
from __future__ import annotations
import logging
from typing import TypedDict
from typing import Any, TypedDict
from torch import nn
from typing_extensions import Unpack
@@ -111,6 +111,8 @@ class ProcessorConfigKwargs(TypedDict, total=False):
preprocessor_config_filename: str | None
postprocessor_config_filename: str | None
preprocessor_overrides: dict[str, Any] | None
postprocessor_overrides: dict[str, Any] | None
def make_processor(
@@ -142,10 +144,12 @@ def make_processor(
RobotProcessor.from_pretrained(
source=pretrained_path,
config_filename=kwargs.get("preprocessor_config_filename", "preprocessor.json"),
overrides=kwargs.get("preprocessor_overrides", {}),
),
RobotProcessor.from_pretrained(
source=pretrained_path,
config_filename=kwargs.get("postprocessor_config_filename", "postprocessor.json"),
overrides=kwargs.get("postprocessor_overrides", {}),
),
)

View File

@@ -56,9 +56,8 @@ from collections import deque
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers import AutoTokenizer
from lerobot.constants import ACTION, OBS_STATE
from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0.paligemma_with_expert import (
PaliGemmaWithExpertConfig,
@@ -226,16 +225,12 @@ class PI0Policy(PreTrainedPolicy):
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
# TODO(azouitine): Add tokenizer to pipeline
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
self.model = PI0FlowMatching(config)
self.reset()
@@ -280,7 +275,8 @@ class PI0Policy(PreTrainedPolicy):
if len(self._action_queue) == 0:
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
actions = self.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise
@@ -306,7 +302,8 @@ class PI0Policy(PreTrainedPolicy):
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
actions = self.prepare_action(batch)
actions_is_pad = batch.get("action_is_pad")
@@ -373,26 +370,6 @@ class PI0Policy(PreTrainedPolicy):
return images, img_masks
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenize the text input"""
device = batch[OBS_STATE].device
tasks = batch["task"]
# PaliGemma prompt has to end with a new line
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
tokenized_prompt = self.language_tokenizer.__call__(
tasks,
padding="max_length",
padding_side="right",
max_length=self.config.tokenizer_max_length,
return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
return lang_tokens, lang_masks
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
@@ -458,7 +435,7 @@ class PI0FlowMatching(nn.Module):
└──────────────────────────────┘
"""
def __init__(self, config):
def __init__(self, config: PI0Config):
super().__init__()
self.config = config

View File

@@ -14,34 +14,107 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RobotProcessor,
ToBatchProcessor,
TokenizerProcessor,
UnnormalizerProcessor,
)
from lerobot.processor.pipeline import (
EnvTransition,
ProcessorStep,
ProcessorStepRegistry,
TransitionKey,
)
from lerobot.processor.rename_processor import RenameProcessor
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
class Pi0NewLineProcessor(ProcessorStep):
"""Add a new line to the end of the task if it doesn't have one.
This is required for the PaliGemma tokenizer.
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Check if complementary_data exists
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None or "task" not in complementary_data:
return transition
task = complementary_data["task"]
if task is None:
return transition
# Handle both string and list of strings
if isinstance(task, str):
# Single string: add newline if not present
if not task.endswith("\n"):
complementary_data["task"] = f"{task}\n"
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: add newline to each if not present
complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
# If task is neither string nor list of strings, leave unchanged
return transition
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Add tokenized task features to the feature contract."""
return features
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}
def make_pi0_processor(
config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
# Add remaining processors
input_steps: list[ProcessorStep] = [
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessor(
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
NormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma
TokenizerProcessor(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
DeviceProcessor(device=config.device),
]
output_steps = [
output_steps: list[ProcessorStep] = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
return RobotProcessor(steps=input_steps, name="pi0_preprocessor"), RobotProcessor(
steps=output_steps, name="pi0_postprocessor"
)

View File

@@ -18,7 +18,9 @@ import torch
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
@@ -29,15 +31,17 @@ def make_pi0_processor(
config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessor(
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
NormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),

View File

@@ -19,7 +19,9 @@ import torch
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
@@ -30,15 +32,17 @@ def make_sac_processor(
config: SACConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
NormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),

View File

@@ -17,6 +17,7 @@ import torch
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.processor import (
DeviceProcessor,
IdentityProcessor,
NormalizerProcessor,
RobotProcessor,
@@ -33,8 +34,9 @@ def make_classifier_processor(
NormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessor(device=config.device),
]
output_steps = [IdentityProcessor()]
output_steps = [DeviceProcessor(device="cpu"), IdentityProcessor()]
return RobotProcessor(steps=input_steps, name="classifier_preprocessor"), RobotProcessor(
steps=output_steps, name="classifier_postprocessor"
)

View File

@@ -53,17 +53,13 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
"""
import math
import os
import re
from collections import deque
import safetensors
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from transformers import AutoProcessor
from lerobot.constants import ACTION, OBS_STATE
from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
@@ -72,102 +68,6 @@ from lerobot.policies.utils import (
)
from lerobot.utils.utils import get_safe_dtype
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
def canonicalise(k: str) -> str:
"""
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
normalisation-buffer key.
"""
return _VARIANT_RE.sub(".buffer_", k)
def standardise_state_dict(
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
) -> tuple[dict[str, torch.Tensor], list[str]]:
"""
• Re-keys `checkpoint ` so that every entry matches the *reference* key set.
• If several variant keys collapse to the same canonical name we keep the
first one and log the collision.
• Returns the new dict + a list of entries that could not be matched.
"""
out, collisions, unmatched = {}, {}, []
for k, v in checkpoint.items():
canon = canonicalise(k)
if canon in ref_keys:
if canon in out: # duplicate after collapsing
collisions.setdefault(canon, []).append(k)
else:
out[canon] = v
else:
unmatched.append(k)
if verbose:
for canon, variants in collisions.items():
print(f"[standardise_state_dict] '{canon}'{variants}")
if unmatched:
print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys")
out.update({k: checkpoint[k] for k in unmatched})
return out, unmatched
def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
"""
Renames keys in a checkpoint dictionary based on the given rename string.
Args:
checkpoint (dict): The checkpoint dictionary.
rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2".
Returns:
dict: The modified checkpoint with renamed keys.
"""
rename_dict = dict(pair.split("//") for pair in rename_str.split(","))
new_checkpoint = {}
for k, v in checkpoint.items():
for old_key, new_key in rename_dict.items():
if old_key in k:
k = k.replace(old_key, new_key)
new_checkpoint[k] = v
return new_checkpoint
def load_smolvla(
model: torch.nn.Module,
filename: str | os.PathLike,
*,
device: str = "cpu",
checkpoint_keys_mapping: str = "",
) -> torch.nn.Module:
state_dict = safetensors.torch.load_file(filename, device=device)
# Optional user-supplied renames (e.g. "model._orig_mod.//model.")
if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
# HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if not all(key.startswith(norm_keys) for key in missing) or unexpected:
raise RuntimeError(
"SmolVLA %d missing / %d unexpected keys",
len(missing),
len(unexpected),
)
return model
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
@@ -333,7 +233,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
self.model = VLAFlowMatching(config)
self.reset()
@@ -343,23 +242,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps),
}
# HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues
@classmethod
def _load_as_safetensor(
cls,
model: "SmolVLAPolicy",
model_file: str,
map_location: str,
strict: bool,
):
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
return load_smolvla(
model,
model_file,
device=map_location,
checkpoint_keys_mapping="model._orig_mod.//model.",
)
def get_optim_params(self) -> dict:
return self.parameters()
@@ -375,7 +257,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
@@ -435,7 +318,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
actions = self.prepare_action(batch)
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
@@ -499,30 +383,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
img_masks.append(mask)
return images, img_masks
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
"""Tokenize the text input"""
device = batch[OBS_STATE].device
tasks = batch["task"]
if isinstance(tasks, str):
tasks = [tasks]
if len(tasks) == 1:
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
tokenized_prompt = self.language_tokenizer.__call__(
tasks,
padding=self.config.pad_language_to,
padding_side="right",
max_length=self.config.tokenizer_max_length,
return_tensors="pt",
)
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
return lang_tokens, lang_masks
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:

View File

@@ -13,30 +13,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
TokenizerProcessor,
UnnormalizerProcessor,
)
from lerobot.processor.pipeline import EnvTransition, ProcessorStep, ProcessorStepRegistry, TransitionKey
def make_smolvla_processor(
config: SmolVLAConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessor(
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
NormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
SmolVLANewLineProcessor(),
TokenizerProcessor(
tokenizer_name=config.vlm_model_name,
padding=config.pad_language_to,
padding_side="right",
max_length=config.tokenizer_max_length,
),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
@@ -44,3 +60,50 @@ def make_smolvla_processor(
return RobotProcessor(steps=input_steps, name="smolvla_preprocessor"), RobotProcessor(
steps=output_steps, name="smolvla_postprocessor"
)
@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
class SmolVLANewLineProcessor(ProcessorStep):
"""Add a new line to the end of the task if it doesn't have one."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Check if complementary_data exists
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None or "task" not in complementary_data:
return transition
task = complementary_data["task"]
if task is None:
return transition
# Handle both string and list of strings
if isinstance(task, str):
# Single string: add newline if not present
if not task.endswith("\n"):
complementary_data["task"] = f"{task}\n"
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: add newline to each if not present
complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
# If task is neither string nor list of strings, leave unchanged
return transition
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Add tokenized task features to the feature contract."""
return features
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}

View File

@@ -18,7 +18,9 @@ import torch
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
@@ -29,15 +31,17 @@ def make_tdmpc_processor(
config: TDMPCConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}),
NormalizerProcessor(
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
NormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),

View File

@@ -19,7 +19,9 @@ import torch
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor import (
DeviceProcessor,
NormalizerProcessor,
RenameProcessor,
RobotProcessor,
ToBatchProcessor,
UnnormalizerProcessor,
@@ -30,15 +32,17 @@ def make_vqbet_processor(
config: VQBeTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[RobotProcessor, RobotProcessor]:
input_steps = [
RenameProcessor(rename_map={}), # Let the possibility to the user to rename the keys
NormalizerProcessor(
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
NormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
ToBatchProcessor(),
DeviceProcessor(device=config.device),
]
output_steps = [
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),

View File

@@ -33,6 +33,7 @@ from .pipeline import (
TruncatedProcessor,
)
from .rename_processor import RenameProcessor
from .tokenizer_processor import TokenizerProcessor
__all__ = [
"ActionProcessor",
@@ -51,6 +52,7 @@ __all__ = [
"RewardProcessor",
"RobotProcessor",
"ToBatchProcessor",
"TokenizerProcessor",
"TransitionKey",
"TruncatedProcessor",
"VanillaObservationProcessor",

View File

@@ -106,6 +106,18 @@ class ToBatchProcessor:
if isinstance(task_value, str):
complementary_data["task"] = [task_value]
# Process index field - add batch dim if 0D
if "index" in complementary_data:
index_value = complementary_data["index"]
if isinstance(index_value, Tensor) and index_value.dim() == 0:
complementary_data["index"] = index_value.unsqueeze(0)
# Process task_index field - add batch dim if 0D
if "task_index" in complementary_data:
task_index_value = complementary_data["task_index"]
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
complementary_data["task_index"] = task_index_value.unsqueeze(0)
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}

View File

@@ -19,24 +19,61 @@ from typing import Any
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.processor.pipeline import EnvTransition, TransitionKey
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
from lerobot.utils.utils import get_safe_torch_device
@ProcessorStepRegistry.register("device_processor")
@dataclass
class DeviceProcessor:
"""Processes transitions by moving tensors to the specified device.
"""Processes transitions by moving tensors to the specified device and optionally converting float dtypes.
This processor ensures that all tensors in the transition are moved to the
specified device (CPU or GPU) before they are returned.
specified device (CPU or GPU) before they are returned. It can also convert
floating-point tensors to a specified dtype while preserving non-float types
(int, long, bool, etc.).
"""
device: torch.device = "cpu"
float_dtype: str | None = None
def __post_init__(self):
self.device = get_safe_torch_device(self.device)
self.non_blocking = "cuda" in str(self.device)
# Validate and convert float_dtype string to torch dtype
if self.float_dtype is not None:
dtype_mapping = {
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
"bfloat16": torch.bfloat16,
"half": torch.float16,
"float": torch.float32,
"double": torch.float64,
}
if self.float_dtype not in dtype_mapping:
available_dtypes = list(dtype_mapping.keys())
raise ValueError(
f"Invalid float_dtype '{self.float_dtype}'. Available options: {available_dtypes}"
)
self._target_float_dtype = dtype_mapping[self.float_dtype]
else:
self._target_float_dtype = None
def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
"""Process a tensor by moving to device and optionally converting float dtype."""
# Move to device first
tensor = tensor.to(self.device, non_blocking=self.non_blocking)
# Convert float dtype if specified and tensor is floating point
if self._target_float_dtype is not None and tensor.is_floating_point():
tensor = tensor.to(dtype=self._target_float_dtype)
return tensor
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Create a copy of the transition
new_transition = transition.copy()
@@ -45,7 +82,7 @@ class DeviceProcessor:
observation = transition.get(TransitionKey.OBSERVATION)
if observation is not None:
new_observation = {
k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v
k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v
for k, v in observation.items()
}
new_transition[TransitionKey.OBSERVATION] = new_observation
@@ -53,30 +90,54 @@ class DeviceProcessor:
# Process action tensor
action = transition.get(TransitionKey.ACTION)
if action is not None and isinstance(action, torch.Tensor):
new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking)
new_transition[TransitionKey.ACTION] = self._process_tensor(action)
# Process reward tensor
reward = transition.get(TransitionKey.REWARD)
if reward is not None and isinstance(reward, torch.Tensor):
new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking)
new_transition[TransitionKey.REWARD] = self._process_tensor(reward)
# Process done tensor
done = transition.get(TransitionKey.DONE)
if done is not None and isinstance(done, torch.Tensor):
new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking)
new_transition[TransitionKey.DONE] = self._process_tensor(done)
# Process truncated tensor
truncated = transition.get(TransitionKey.TRUNCATED)
if truncated is not None and isinstance(truncated, torch.Tensor):
new_transition[TransitionKey.TRUNCATED] = truncated.to(
self.device, non_blocking=self.non_blocking
)
new_transition[TransitionKey.TRUNCATED] = self._process_tensor(truncated)
# Process complementary data tensors
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is not None:
new_complementary_data = {}
# Process all items in complementary_data
for key, value in complementary_data.items():
if isinstance(value, torch.Tensor):
new_complementary_data[key] = self._process_tensor(value)
else:
new_complementary_data[key] = value
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {"device": self.device}
return {"device": self.device, "float_dtype": self.float_dtype}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
return features

View File

@@ -116,7 +116,7 @@ class NormalizerProcessor:
if self.normalize_keys is not None and not isinstance(self.normalize_keys, set):
self.normalize_keys = set(self.normalize_keys)
def _normalize_obs(self, observation):
def _normalize_obs(self, observation, normalized_info):
if observation is None:
return None
@@ -138,6 +138,7 @@ class NormalizerProcessor:
# Skip normalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
normalized_info[key] = "IDENTITY"
continue
# Skip if no stats available for this key
@@ -156,16 +157,18 @@ class NormalizerProcessor:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = (tensor - mean) / (std + self.eps)
normalized_info[key] = "MEAN_STD"
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
normalized_info[key] = "MIN_MAX"
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
return processed
def _normalize_action(self, action):
def _normalize_action(self, action, normalized_info):
if action is None:
return action
@@ -174,6 +177,7 @@ class NormalizerProcessor:
# Skip normalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
normalized_info["action"] = "IDENTITY"
return action
# Skip if no stats available for actions
@@ -190,10 +194,12 @@ class NormalizerProcessor:
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
normalized_info["action"] = "MEAN_STD"
return (tensor - mean) / (std + self.eps)
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
normalized_info["action"] = "MIN_MAX"
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
@@ -202,13 +208,24 @@ class NormalizerProcessor:
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
action = self._normalize_action(transition.get(TransitionKey.ACTION))
# Track what was normalized
normalized_info = {}
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION), normalized_info)
action = self._normalize_action(transition.get(TransitionKey.ACTION), normalized_info)
# Create a new transition with normalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
# Add normalization info to complementary data
if normalized_info:
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp_data = {} if comp_data is None else dict(comp_data)
comp_data["normalized_keys"] = normalized_info
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self) -> dict[str, Any]:
@@ -289,7 +306,7 @@ class UnnormalizerProcessor:
self.stats = self.stats or {}
self._tensor_stats = _convert_stats_to_tensors(self.stats)
def _unnormalize_obs(self, observation):
def _unnormalize_obs(self, observation, unnormalized_info):
if observation is None:
return None
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
@@ -304,6 +321,7 @@ class UnnormalizerProcessor:
# Skip unnormalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
unnormalized_info[key] = "IDENTITY"
continue
# Skip if no stats available for this key
@@ -322,16 +340,18 @@ class UnnormalizerProcessor:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
processed[key] = tensor * std + mean
unnormalized_info[key] = "MEAN_STD"
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
unnormalized_info[key] = "MIN_MAX"
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
return processed
def _unnormalize_action(self, action):
def _unnormalize_action(self, action, unnormalized_info):
if action is None:
return action
@@ -340,6 +360,7 @@ class UnnormalizerProcessor:
# Skip unnormalization if mode is IDENTITY
if norm_mode is NormalizationMode.IDENTITY:
unnormalized_info["action"] = "IDENTITY"
return action
# Skip if no stats available for actions
@@ -356,10 +377,12 @@ class UnnormalizerProcessor:
if norm_mode is NormalizationMode.MEAN_STD:
if "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
unnormalized_info["action"] = "MEAN_STD"
return tensor * std + mean
elif norm_mode is NormalizationMode.MIN_MAX:
if "min" in stats and "max" in stats:
min_val, max_val = stats["min"], stats["max"]
unnormalized_info["action"] = "MIN_MAX"
return (tensor + 1) / 2 * (max_val - min_val) + min_val
else:
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
@@ -368,13 +391,24 @@ class UnnormalizerProcessor:
raise ValueError(f"Action stats must contain appropriate values for {norm_mode} normalization")
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
action = self._unnormalize_action(transition.get(TransitionKey.ACTION))
# Track what was unnormalized
unnormalized_info = {}
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION), unnormalized_info)
action = self._unnormalize_action(transition.get(TransitionKey.ACTION), unnormalized_info)
# Create a new transition with unnormalized values
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = observation
new_transition[TransitionKey.ACTION] = action
# Add unnormalization info to complementary data
if unnormalized_info:
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp_data = {} if comp_data is None else dict(comp_data)
comp_data["unnormalized_keys"] = unnormalized_info
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data
return new_transition
def get_config(self) -> dict[str, Any]:
@@ -413,3 +447,29 @@ def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, An
step.stats = stats
step._tensor_stats = _convert_stats_to_tensors(stats)
return robot_processor
def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]:
"""Rename keys in the stats dictionary according to the provided mapping.
Args:
stats: The statistics dictionary with structure {feature_key: {stat_name: value}}
rename_map: Dictionary mapping old key names to new key names
Returns:
A new stats dictionary with renamed keys
Example:
>>> stats = {"observation.state": {"mean": 0.0, "std": 1.0}, "action": {"mean": 0.5, "std": 0.5}}
>>> rename_map = {"observation.state": "observation.robot_state"}
>>> new_stats = rename_stats(stats, rename_map)
>>> # new_stats will have "observation.robot_state" instead of "observation.state"
"""
renamed_stats = {}
for old_key, sub_stats in stats.items():
# Use the new key if it exists in the rename map, otherwise keep the old key
new_key = rename_map.get(old_key, old_key)
renamed_stats[new_key] = deepcopy(sub_stats)
return renamed_stats

View File

@@ -201,10 +201,16 @@ def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noq
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
observation = observation_keys if observation_keys else None
# Extract padding and task keys for complementary data
# Extract padding, task, index, and task_index keys for complementary data
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
complementary_data = (
{**pad_keys, **task_key, **index_key, **task_index_key}
if pad_keys or task_key or index_key or task_index_key
else {}
)
transition: EnvTransition = {
TransitionKey.OBSERVATION: observation,
@@ -231,7 +237,7 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
"info": transition.get(TransitionKey.INFO, {}),
}
# Add padding and task data from complementary_data
# Add padding, task, index, and task_index data from complementary_data
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data:
pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k}
@@ -240,6 +246,12 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
if "task" in complementary_data:
batch["task"] = complementary_data["task"]
if "index" in complementary_data:
batch["index"] = complementary_data["index"]
if "task_index" in complementary_data:
batch["task_index"] = complementary_data["task_index"]
# Handle observation - flatten dict to observation.* keys if it's a dict
observation = transition.get(TransitionKey.OBSERVATION)
if isinstance(observation, dict):

View File

@@ -0,0 +1,210 @@
"""
Tokenizer processor for handling text tokenization in robot transitions.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
import torch
from transformers import AutoTokenizer
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import OBS_LANGUAGE
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
@dataclass
@ProcessorStepRegistry.register(name="tokenizer_processor")
class TokenizerProcessor:
"""Tokenizes text tasks in complementary data using a huggingface tokenizer.
This processor handles tokenization of task strings found in the complementary_data
using a specified pretrained tokenizer from Hugging Face. It adds tokenized versions
to the observation data for model processing while preserving the original task string.
The processor supports both single strings and lists of strings as task inputs.
Args:
tokenizer_name: Name of the pretrained tokenizer to load from Hugging Face Hub
(e.g., "bert-base-uncased", "microsoft/DialoGPT-medium"). This will be used
with AutoTokenizer.from_pretrained(). If tokenizer is provided, this is ignored.
tokenizer: A tokenizer object (e.g., from transformers library) that implements
the __call__ method. If provided, tokenizer_name is ignored. This parameter
is not serialized and must be provided via overrides when loading.
max_length: Maximum sequence length for tokenization. Defaults to 512.
task_key: Key in complementary_data containing the task text. Defaults to "task".
padding: Padding strategy for tokenization. Defaults to "max_length".
truncation: Whether to truncate sequences longer than max_length. Defaults to True.
Examples:
Using tokenizer name (auto-loaded):
```python
processor = TokenizerProcessor(tokenizer_name="bert-base-uncased", max_length=128)
```
Using custom tokenizer object:
```python
from transformers import AutoTokenizer
custom_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
processor = TokenizerProcessor(tokenizer=custom_tokenizer, max_length=128)
```
"""
tokenizer_name: str | None = None
tokenizer: AutoTokenizer | None = None
max_length: int = 512
task_key: str = "task"
padding_side: str = "right"
padding: str = "max_length"
truncation: bool = True
# Internal tokenizer instance (not serialized)
_tokenizer: Any = field(default=None, init=False, repr=False)
def __post_init__(self):
"""Initialize the tokenizer from the provided tokenizer or tokenizer name."""
if self.tokenizer is not None:
# Use provided tokenizer object directly
self._tokenizer = self.tokenizer
elif self.tokenizer_name is not None:
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
else:
raise ValueError(
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
"Pass a tokenizer object directly or a tokenizer name to auto-load."
)
def get_task(self, transition: EnvTransition) -> list[str] | None:
"""Extract and normalize task from complementary data.
Args:
transition: Input transition containing complementary_data.
Returns:
List of task strings if task is present, None otherwise.
"""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return None
if self.task_key not in complementary_data:
return None
task = complementary_data[self.task_key]
if task is None:
return None
# Convert to list of strings
if isinstance(task, str):
return [task]
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
return task
return None
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Process the transition by tokenizing the task text.
Args:
transition: Input transition containing complementary_data with task text.
Returns:
Modified transition with tokenized task added to observation.
Raises:
ValueError: If tokenizer initialization failed.
"""
task = self.get_task(transition)
if task is None:
return transition
# Tokenize the task
tokenized_prompt = self._tokenize_text(task)
# Get or create observation dict
if TransitionKey.OBSERVATION not in transition or transition[TransitionKey.OBSERVATION] is None:
transition[TransitionKey.OBSERVATION] = {}
observation = transition[TransitionKey.OBSERVATION]
# Add tokenized data to observation
observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"]
observation[f"{OBS_LANGUAGE}.attention_mask"] = tokenized_prompt["attention_mask"].to(
dtype=torch.bool
)
return transition
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
"""Tokenize text using the configured tokenizer.
Args:
text: Text string or list of strings to tokenize.
Returns:
Dictionary containing tokenized output with keys like 'input_ids', 'attention_mask'.
"""
return self._tokenizer(
text,
max_length=self.max_length,
truncation=self.truncation,
padding=self.padding,
padding_side=self.padding_side,
return_tensors="pt",
)
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization.
Note: Only tokenizer_name is saved, not the tokenizer object itself.
When loading, provide the tokenizer via overrides if needed.
"""
config = {
"max_length": self.max_length,
"task_key": self.task_key,
"padding_side": self.padding_side,
"padding": self.padding,
"truncation": self.truncation,
}
# Only include tokenizer_name if it was used (not when tokenizer object was provided)
if self.tokenizer_name is not None:
config["tokenizer_name"] = self.tokenizer_name
return config
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Add tokenized task features to the feature contract.
Args:
features: Input feature dictionary.
Returns:
Updated feature dictionary with tokenized task features added.
"""
# Add features for tokenized output if they don't exist
# Standard tokenizer output includes tokens and attention_mask
tokens_key = f"{OBS_LANGUAGE}.tokens"
attention_mask_key = f"{OBS_LANGUAGE}.attention_mask"
if tokens_key not in features:
features[tokens_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
if attention_mask_key not in features:
features[attention_mask_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
return features

View File

@@ -899,3 +899,231 @@ def test_task_preserves_other_keys():
assert processed_comp_data["motor_id"] == "motor_456"
assert processed_comp_data["config"] == {"speed": "slow", "precision": "high"}
assert processed_comp_data["metrics"] == [1.0, 2.0, 3.0]
# Index and task_index specific tests
def test_index_scalar_to_1d():
"""Test that 0D index tensor gets unsqueezed to 1D."""
processor = ToBatchProcessor()
# Create 0D index tensor (scalar)
index_0d = torch.tensor(42, dtype=torch.int64)
complementary_data = {"index": index_0d}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["index"].shape == (1,)
assert processed_comp_data["index"].dtype == torch.int64
assert processed_comp_data["index"][0] == 42
def test_task_index_scalar_to_1d():
"""Test that 0D task_index tensor gets unsqueezed to 1D."""
processor = ToBatchProcessor()
# Create 0D task_index tensor (scalar)
task_index_0d = torch.tensor(7, dtype=torch.int64)
complementary_data = {"task_index": task_index_0d}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task_index"].shape == (1,)
assert processed_comp_data["task_index"].dtype == torch.int64
assert processed_comp_data["task_index"][0] == 7
def test_index_and_task_index_together():
"""Test processing both index and task_index together."""
processor = ToBatchProcessor()
# Create 0D tensors for both
index_0d = torch.tensor(100, dtype=torch.int64)
task_index_0d = torch.tensor(3, dtype=torch.int64)
complementary_data = {
"index": index_0d,
"task_index": task_index_0d,
"task": "pick_object",
}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
# Check index
assert processed_comp_data["index"].shape == (1,)
assert processed_comp_data["index"][0] == 100
# Check task_index
assert processed_comp_data["task_index"].shape == (1,)
assert processed_comp_data["task_index"][0] == 3
# Check task is also processed
assert processed_comp_data["task"] == ["pick_object"]
def test_index_already_batched():
"""Test that already batched index tensors remain unchanged."""
processor = ToBatchProcessor()
# Create already batched tensors
index_1d = torch.tensor([42], dtype=torch.int64)
index_2d = torch.tensor([[42, 43]], dtype=torch.int64)
# Test 1D (already batched)
complementary_data = {"index": index_1d}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_1d)
# Test 2D
complementary_data = {"index": index_2d}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_2d)
def test_task_index_already_batched():
"""Test that already batched task_index tensors remain unchanged."""
processor = ToBatchProcessor()
# Create already batched tensors
task_index_1d = torch.tensor([7], dtype=torch.int64)
task_index_2d = torch.tensor([[7, 8]], dtype=torch.int64)
# Test 1D (already batched)
complementary_data = {"task_index": task_index_1d}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_1d)
# Test 2D
complementary_data = {"task_index": task_index_2d}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_2d)
def test_index_non_tensor_unchanged():
"""Test that non-tensor index values remain unchanged."""
processor = ToBatchProcessor()
complementary_data = {
"index": 42, # Plain int, not tensor
"task_index": [1, 2, 3], # List, not tensor
}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["index"] == 42
assert processed_comp_data["task_index"] == [1, 2, 3]
def test_index_dtype_preservation():
"""Test that index and task_index dtype is preserved during processing."""
processor = ToBatchProcessor()
# Test different dtypes
dtypes = [torch.int32, torch.int64, torch.long]
for dtype in dtypes:
index_0d = torch.tensor(42, dtype=dtype)
task_index_0d = torch.tensor(7, dtype=dtype)
complementary_data = {
"index": index_0d,
"task_index": task_index_0d,
}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["index"].dtype == dtype
assert processed_comp_data["task_index"].dtype == dtype
def test_index_with_full_transition():
"""Test index/task_index processing with full transition data."""
processor = ToBatchProcessor()
# Create full transition with all components
observation = {
OBS_STATE: torch.randn(7),
OBS_IMAGE: torch.randn(64, 64, 3),
}
action = torch.randn(4)
complementary_data = {
"task": "navigate_to_goal",
"index": torch.tensor(1000, dtype=torch.int64),
"task_index": torch.tensor(5, dtype=torch.int64),
"episode_id": 123,
}
transition = create_transition(
observation=observation,
action=action,
reward=0.5,
done=False,
complementary_data=complementary_data,
)
result = processor(transition)
# Check all components are processed correctly
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 64, 64, 3)
assert result[TransitionKey.ACTION].shape == (1, 4)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == ["navigate_to_goal"]
assert processed_comp_data["index"].shape == (1,)
assert processed_comp_data["index"][0] == 1000
assert processed_comp_data["task_index"].shape == (1,)
assert processed_comp_data["task_index"][0] == 5
assert processed_comp_data["episode_id"] == 123 # Non-tensor field unchanged
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_index_device_compatibility():
"""Test processor works with index/task_index tensors on different devices."""
processor = ToBatchProcessor()
# Create tensors on GPU
index_0d = torch.tensor(42, dtype=torch.int64, device="cuda")
task_index_0d = torch.tensor(7, dtype=torch.int64, device="cuda")
complementary_data = {
"index": index_0d,
"task_index": task_index_0d,
}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
# Check shapes and that tensors stayed on GPU
assert processed_comp_data["index"].shape == (1,)
assert processed_comp_data["task_index"].shape == (1,)
assert processed_comp_data["index"].device.type == "cuda"
assert processed_comp_data["task_index"].device.type == "cuda"
def test_empty_index_tensor():
"""Test handling of empty index tensors."""
processor = ToBatchProcessor()
# Empty 0D tensor doesn't make sense, but test empty 1D
index_empty = torch.tensor([], dtype=torch.int64)
complementary_data = {"index": index_empty}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
# Should remain unchanged (already 1D)
assert result[TransitionKey.COMPLEMENTARY_DATA]["index"].shape == (0,)

View File

@@ -0,0 +1,874 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import pytest
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.processor import DeviceProcessor, RobotProcessor
from lerobot.processor.pipeline import TransitionKey
def create_transition(
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
):
"""Helper function to create a transition dictionary."""
transition = {}
if observation is not None:
transition[TransitionKey.OBSERVATION] = observation
if action is not None:
transition[TransitionKey.ACTION] = action
if reward is not None:
transition[TransitionKey.REWARD] = reward
if done is not None:
transition[TransitionKey.DONE] = done
if truncated is not None:
transition[TransitionKey.TRUNCATED] = truncated
if info is not None:
transition[TransitionKey.INFO] = info
if complementary_data is not None:
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return transition
def test_basic_functionality():
"""Test basic device processor functionality on CPU."""
processor = DeviceProcessor(device="cpu")
# Create a transition with CPU tensors
observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)}
action = torch.randn(5)
reward = torch.tensor(1.0)
done = torch.tensor(False)
truncated = torch.tensor(False)
transition = create_transition(
observation=observation, action=action, reward=reward, done=done, truncated=truncated
)
result = processor(transition)
# Check that all tensors are on CPU
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu"
assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cpu"
assert result[TransitionKey.ACTION].device.type == "cpu"
assert result[TransitionKey.REWARD].device.type == "cpu"
assert result[TransitionKey.DONE].device.type == "cpu"
assert result[TransitionKey.TRUNCATED].device.type == "cpu"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_cuda_functionality():
"""Test device processor functionality on CUDA."""
processor = DeviceProcessor(device="cuda")
# Create a transition with CPU tensors
observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)}
action = torch.randn(5)
reward = torch.tensor(1.0)
done = torch.tensor(False)
truncated = torch.tensor(False)
transition = create_transition(
observation=observation, action=action, reward=reward, done=done, truncated=truncated
)
result = processor(transition)
# Check that all tensors are on CUDA
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda"
assert result[TransitionKey.ACTION].device.type == "cuda"
assert result[TransitionKey.REWARD].device.type == "cuda"
assert result[TransitionKey.DONE].device.type == "cuda"
assert result[TransitionKey.TRUNCATED].device.type == "cuda"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_specific_cuda_device():
"""Test device processor with specific CUDA device."""
processor = DeviceProcessor(device="cuda:0")
observation = {"observation.state": torch.randn(10)}
action = torch.randn(5)
transition = create_transition(observation=observation, action=action)
result = processor(transition)
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
assert result[TransitionKey.OBSERVATION]["observation.state"].device.index == 0
assert result[TransitionKey.ACTION].device.type == "cuda"
assert result[TransitionKey.ACTION].device.index == 0
def test_non_tensor_values():
"""Test that non-tensor values are preserved."""
processor = DeviceProcessor(device="cpu")
observation = {
"observation.state": torch.randn(10),
"observation.metadata": {"key": "value"}, # Non-tensor data
"observation.list": [1, 2, 3], # Non-tensor data
}
action = torch.randn(5)
info = {"episode": 1, "step": 42}
transition = create_transition(observation=observation, action=action, info=info)
result = processor(transition)
# Check tensors are processed
assert isinstance(result[TransitionKey.OBSERVATION]["observation.state"], torch.Tensor)
assert isinstance(result[TransitionKey.ACTION], torch.Tensor)
# Check non-tensor values are preserved
assert result[TransitionKey.OBSERVATION]["observation.metadata"] == {"key": "value"}
assert result[TransitionKey.OBSERVATION]["observation.list"] == [1, 2, 3]
assert result[TransitionKey.INFO] == {"episode": 1, "step": 42}
def test_none_values():
"""Test handling of None values."""
processor = DeviceProcessor(device="cpu")
# Test with None observation
transition = create_transition(observation=None, action=torch.randn(5))
result = processor(transition)
assert TransitionKey.OBSERVATION not in result
assert result[TransitionKey.ACTION].device.type == "cpu"
# Test with None action
transition = create_transition(observation={"observation.state": torch.randn(10)}, action=None)
result = processor(transition)
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu"
assert TransitionKey.ACTION not in result
def test_empty_observation():
"""Test handling of empty observation dictionary."""
processor = DeviceProcessor(device="cpu")
transition = create_transition(observation={}, action=torch.randn(5))
result = processor(transition)
assert result[TransitionKey.OBSERVATION] == {}
assert result[TransitionKey.ACTION].device.type == "cpu"
def test_scalar_tensors():
"""Test handling of scalar tensors."""
processor = DeviceProcessor(device="cpu")
observation = {"observation.scalar": torch.tensor(1.5)}
action = torch.tensor(2.0)
reward = torch.tensor(0.5)
transition = create_transition(observation=observation, action=action, reward=reward)
result = processor(transition)
assert result[TransitionKey.OBSERVATION]["observation.scalar"].item() == 1.5
assert result[TransitionKey.ACTION].item() == 2.0
assert result[TransitionKey.REWARD].item() == 0.5
def test_dtype_preservation():
"""Test that tensor dtypes are preserved."""
processor = DeviceProcessor(device="cpu")
observation = {
"observation.float32": torch.randn(5, dtype=torch.float32),
"observation.float64": torch.randn(5, dtype=torch.float64),
"observation.int32": torch.randint(0, 10, (5,), dtype=torch.int32),
"observation.bool": torch.tensor([True, False, True], dtype=torch.bool),
}
action = torch.randn(3, dtype=torch.float16)
transition = create_transition(observation=observation, action=action)
result = processor(transition)
assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float32
assert result[TransitionKey.OBSERVATION]["observation.float64"].dtype == torch.float64
assert result[TransitionKey.OBSERVATION]["observation.int32"].dtype == torch.int32
assert result[TransitionKey.OBSERVATION]["observation.bool"].dtype == torch.bool
assert result[TransitionKey.ACTION].dtype == torch.float16
def test_shape_preservation():
"""Test that tensor shapes are preserved."""
processor = DeviceProcessor(device="cpu")
observation = {
"observation.1d": torch.randn(10),
"observation.2d": torch.randn(5, 10),
"observation.3d": torch.randn(3, 224, 224),
"observation.4d": torch.randn(2, 3, 224, 224),
}
action = torch.randn(2, 5, 3)
transition = create_transition(observation=observation, action=action)
result = processor(transition)
assert result[TransitionKey.OBSERVATION]["observation.1d"].shape == (10,)
assert result[TransitionKey.OBSERVATION]["observation.2d"].shape == (5, 10)
assert result[TransitionKey.OBSERVATION]["observation.3d"].shape == (3, 224, 224)
assert result[TransitionKey.OBSERVATION]["observation.4d"].shape == (2, 3, 224, 224)
assert result[TransitionKey.ACTION].shape == (2, 5, 3)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_mixed_devices():
"""Test handling of tensors already on different devices."""
processor = DeviceProcessor(device="cuda")
# Create tensors on different devices
observation = {
"observation.cpu": torch.randn(5), # CPU
"observation.cuda": torch.randn(5).cuda(), # Already on CUDA
}
action = torch.randn(3).cuda() # Already on CUDA
transition = create_transition(observation=observation, action=action)
result = processor(transition)
# All should be on CUDA
assert result[TransitionKey.OBSERVATION]["observation.cpu"].device.type == "cuda"
assert result[TransitionKey.OBSERVATION]["observation.cuda"].device.type == "cuda"
assert result[TransitionKey.ACTION].device.type == "cuda"
def test_non_blocking_flag():
"""Test that non_blocking flag is set correctly."""
# CPU processor should have non_blocking=False
cpu_processor = DeviceProcessor(device="cpu")
assert cpu_processor.non_blocking is False
# CUDA processor should have non_blocking=True
cuda_processor = DeviceProcessor(device="cuda")
assert cuda_processor.non_blocking is True
cuda_0_processor = DeviceProcessor(device="cuda:0")
assert cuda_0_processor.non_blocking is True
def test_serialization_methods():
"""Test get_config, state_dict, and load_state_dict methods."""
processor = DeviceProcessor(device="cuda")
# Test get_config
config = processor.get_config()
assert config == {"device": "cuda", "float_dtype": None}
# Test state_dict (should be empty)
state = processor.state_dict()
assert state == {}
# Test load_state_dict (should be no-op)
processor.load_state_dict({})
assert processor.device == "cuda"
# Test reset (should be no-op)
processor.reset()
assert processor.device == "cuda"
def test_feature_contract():
"""Test that feature_contract returns features unchanged."""
processor = DeviceProcessor(device="cpu")
features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)),
"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,)),
}
result = processor.feature_contract(features)
assert result == features
assert result is features # Should return the same object
def test_integration_with_robot_processor():
"""Test integration with RobotProcessor."""
from lerobot.processor import ToBatchProcessor
# Create a pipeline with DeviceProcessor
device_processor = DeviceProcessor(device="cpu")
batch_processor = ToBatchProcessor()
processor = RobotProcessor(steps=[batch_processor, device_processor], name="test_pipeline")
# Create test data
observation = {"observation.state": torch.randn(10)}
action = torch.randn(5)
transition = create_transition(observation=observation, action=action)
result = processor(transition)
# Check that tensors are batched and on correct device
assert result[TransitionKey.OBSERVATION]["observation.state"].shape[0] == 1 # Batched
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu"
assert result[TransitionKey.ACTION].shape[0] == 1 # Batched
assert result[TransitionKey.ACTION].device.type == "cpu"
def test_save_and_load_pretrained():
"""Test saving and loading processor with DeviceProcessor."""
processor = DeviceProcessor(device="cuda:0", float_dtype="float16")
robot_processor = RobotProcessor(steps=[processor], name="device_test_processor")
with tempfile.TemporaryDirectory() as tmpdir:
# Save
robot_processor.save_pretrained(tmpdir)
# Load
loaded_processor = RobotProcessor.from_pretrained(tmpdir)
assert len(loaded_processor.steps) == 1
loaded_device_processor = loaded_processor.steps[0]
assert isinstance(loaded_device_processor, DeviceProcessor)
assert loaded_device_processor.device == "cuda:0"
assert loaded_device_processor.float_dtype == "float16"
def test_registry_functionality():
"""Test that DeviceProcessor is properly registered."""
from lerobot.processor.pipeline import ProcessorStepRegistry
# Check that DeviceProcessor is registered
registered_class = ProcessorStepRegistry.get("device_processor")
assert registered_class is DeviceProcessor
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_performance_with_large_tensors():
"""Test performance with large tensors and non_blocking flag."""
processor = DeviceProcessor(device="cuda")
# Create large tensors
observation = {
"observation.large_image": torch.randn(10, 3, 512, 512), # Large image batch
"observation.features": torch.randn(10, 2048), # Large feature vector
}
action = torch.randn(10, 100) # Large action space
transition = create_transition(observation=observation, action=action)
# Process should not raise any errors
result = processor(transition)
# Verify all tensors are on CUDA
assert result[TransitionKey.OBSERVATION]["observation.large_image"].device.type == "cuda"
assert result[TransitionKey.OBSERVATION]["observation.features"].device.type == "cuda"
assert result[TransitionKey.ACTION].device.type == "cuda"
def test_reward_done_truncated_types():
"""Test handling of different types for reward, done, and truncated."""
processor = DeviceProcessor(device="cpu")
# Test with scalar values (not tensors)
transition = create_transition(
observation={"observation.state": torch.randn(5)},
action=torch.randn(3),
reward=1.0, # float
done=False, # bool
truncated=True, # bool
)
result = processor(transition)
# Non-tensor values should be preserved as-is
assert result[TransitionKey.REWARD] == 1.0
assert result[TransitionKey.DONE] is False
assert result[TransitionKey.TRUNCATED] is True
# Test with tensor values
transition = create_transition(
observation={"observation.state": torch.randn(5)},
action=torch.randn(3),
reward=torch.tensor(1.0),
done=torch.tensor(False),
truncated=torch.tensor(True),
)
result = processor(transition)
# Tensor values should be moved to device
assert isinstance(result[TransitionKey.REWARD], torch.Tensor)
assert isinstance(result[TransitionKey.DONE], torch.Tensor)
assert isinstance(result[TransitionKey.TRUNCATED], torch.Tensor)
assert result[TransitionKey.REWARD].device.type == "cpu"
assert result[TransitionKey.DONE].device.type == "cpu"
assert result[TransitionKey.TRUNCATED].device.type == "cpu"
def test_complementary_data_preserved():
"""Test that complementary_data is preserved unchanged."""
processor = DeviceProcessor(device="cpu")
complementary_data = {
"task": "pick_object",
"episode_id": 42,
"metadata": {"sensor": "camera_1"},
"observation_is_pad": torch.tensor([False, False, True]), # This should be moved to device
}
transition = create_transition(
observation={"observation.state": torch.randn(5)}, complementary_data=complementary_data
)
result = processor(transition)
# Check that complementary_data is preserved
assert TransitionKey.COMPLEMENTARY_DATA in result
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "pick_object"
assert result[TransitionKey.COMPLEMENTARY_DATA]["episode_id"] == 42
assert result[TransitionKey.COMPLEMENTARY_DATA]["metadata"] == {"sensor": "camera_1"}
# Note: Currently DeviceProcessor doesn't process tensors in complementary_data
# This is intentional as complementary_data is typically metadata
def test_float_dtype_conversion():
"""Test float dtype conversion functionality."""
processor = DeviceProcessor(device="cpu", float_dtype="float16")
# Create tensors of different types
observation = {
"observation.float32": torch.randn(5, dtype=torch.float32),
"observation.float64": torch.randn(5, dtype=torch.float64),
"observation.int32": torch.randint(0, 10, (5,), dtype=torch.int32),
"observation.int64": torch.randint(0, 10, (5,), dtype=torch.int64),
"observation.bool": torch.tensor([True, False, True], dtype=torch.bool),
}
action = torch.randn(3, dtype=torch.float32)
reward = torch.tensor(1.0, dtype=torch.float32)
transition = create_transition(observation=observation, action=action, reward=reward)
result = processor(transition)
# Check that float tensors are converted to float16
assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float16
assert result[TransitionKey.OBSERVATION]["observation.float64"].dtype == torch.float16
assert result[TransitionKey.ACTION].dtype == torch.float16
assert result[TransitionKey.REWARD].dtype == torch.float16
# Check that non-float tensors are preserved
assert result[TransitionKey.OBSERVATION]["observation.int32"].dtype == torch.int32
assert result[TransitionKey.OBSERVATION]["observation.int64"].dtype == torch.int64
assert result[TransitionKey.OBSERVATION]["observation.bool"].dtype == torch.bool
def test_float_dtype_none():
"""Test that when float_dtype is None, no dtype conversion occurs."""
processor = DeviceProcessor(device="cpu", float_dtype=None)
observation = {
"observation.float32": torch.randn(5, dtype=torch.float32),
"observation.float64": torch.randn(5, dtype=torch.float64),
"observation.int32": torch.randint(0, 10, (5,), dtype=torch.int32),
}
action = torch.randn(3, dtype=torch.float64)
transition = create_transition(observation=observation, action=action)
result = processor(transition)
# Check that dtypes are preserved when float_dtype is None
assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float32
assert result[TransitionKey.OBSERVATION]["observation.float64"].dtype == torch.float64
assert result[TransitionKey.OBSERVATION]["observation.int32"].dtype == torch.int32
assert result[TransitionKey.ACTION].dtype == torch.float64
def test_float_dtype_bfloat16():
"""Test conversion to bfloat16."""
processor = DeviceProcessor(device="cpu", float_dtype="bfloat16")
observation = {"observation.state": torch.randn(5, dtype=torch.float32)}
action = torch.randn(3, dtype=torch.float64)
transition = create_transition(observation=observation, action=action)
result = processor(transition)
assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.bfloat16
assert result[TransitionKey.ACTION].dtype == torch.bfloat16
def test_float_dtype_float64():
"""Test conversion to float64."""
processor = DeviceProcessor(device="cpu", float_dtype="float64")
observation = {"observation.state": torch.randn(5, dtype=torch.float16)}
action = torch.randn(3, dtype=torch.float32)
transition = create_transition(observation=observation, action=action)
result = processor(transition)
assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float64
assert result[TransitionKey.ACTION].dtype == torch.float64
def test_float_dtype_invalid():
"""Test that invalid float_dtype raises ValueError."""
with pytest.raises(ValueError, match="Invalid float_dtype 'invalid_dtype'"):
DeviceProcessor(device="cpu", float_dtype="invalid_dtype")
def test_float_dtype_aliases():
"""Test that dtype aliases work correctly."""
# Test 'half' alias for float16
processor_half = DeviceProcessor(device="cpu", float_dtype="half")
assert processor_half._target_float_dtype == torch.float16
# Test 'float' alias for float32
processor_float = DeviceProcessor(device="cpu", float_dtype="float")
assert processor_float._target_float_dtype == torch.float32
# Test 'double' alias for float64
processor_double = DeviceProcessor(device="cpu", float_dtype="double")
assert processor_double._target_float_dtype == torch.float64
def test_float_dtype_with_mixed_tensors():
"""Test float dtype conversion with mixed tensor types."""
processor = DeviceProcessor(device="cpu", float_dtype="float32")
observation = {
"observation.image": torch.randint(0, 255, (3, 64, 64), dtype=torch.uint8), # Should not convert
"observation.state": torch.randn(10, dtype=torch.float64), # Should convert
"observation.mask": torch.tensor([True, False, True], dtype=torch.bool), # Should not convert
"observation.indices": torch.tensor([1, 2, 3], dtype=torch.long), # Should not convert
}
action = torch.randn(5, dtype=torch.float16) # Should convert
transition = create_transition(observation=observation, action=action)
result = processor(transition)
# Check conversions
assert result[TransitionKey.OBSERVATION]["observation.image"].dtype == torch.uint8 # Unchanged
assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float32 # Converted
assert result[TransitionKey.OBSERVATION]["observation.mask"].dtype == torch.bool # Unchanged
assert result[TransitionKey.OBSERVATION]["observation.indices"].dtype == torch.long # Unchanged
assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted
def test_float_dtype_serialization():
"""Test that float_dtype is properly serialized in get_config."""
processor = DeviceProcessor(device="cuda", float_dtype="float16")
config = processor.get_config()
assert config == {"device": "cuda", "float_dtype": "float16"}
# Test with None float_dtype
processor_none = DeviceProcessor(device="cpu", float_dtype=None)
config_none = processor_none.get_config()
assert config_none == {"device": "cpu", "float_dtype": None}
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_float_dtype_with_cuda():
"""Test float dtype conversion combined with CUDA device."""
processor = DeviceProcessor(device="cuda", float_dtype="float16")
# Create tensors on CPU with different dtypes
observation = {
"observation.float32": torch.randn(5, dtype=torch.float32),
"observation.int64": torch.tensor([1, 2, 3], dtype=torch.int64),
}
action = torch.randn(3, dtype=torch.float64)
transition = create_transition(observation=observation, action=action)
result = processor(transition)
# Check that tensors are on CUDA and float types are converted
assert result[TransitionKey.OBSERVATION]["observation.float32"].device.type == "cuda"
assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float16
assert result[TransitionKey.OBSERVATION]["observation.int64"].device.type == "cuda"
assert result[TransitionKey.OBSERVATION]["observation.int64"].dtype == torch.int64 # Unchanged
assert result[TransitionKey.ACTION].device.type == "cuda"
assert result[TransitionKey.ACTION].dtype == torch.float16
def test_complementary_data_index_fields():
"""Test processing of index and task_index fields in complementary_data."""
processor = DeviceProcessor(device="cpu")
# Create transition with index and task_index in complementary_data
complementary_data = {
"task": ["pick_cube"],
"index": torch.tensor([42], dtype=torch.int64),
"task_index": torch.tensor([3], dtype=torch.int64),
"episode_id": 123, # Non-tensor field
}
transition = create_transition(
observation={"observation.state": torch.randn(1, 7)},
action=torch.randn(1, 4),
complementary_data=complementary_data,
)
result = processor(transition)
# Check that tensors in complementary_data are processed
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
# Check index tensor
assert isinstance(processed_comp_data["index"], torch.Tensor)
assert processed_comp_data["index"].device.type == "cpu"
assert torch.equal(processed_comp_data["index"], complementary_data["index"])
# Check task_index tensor
assert isinstance(processed_comp_data["task_index"], torch.Tensor)
assert processed_comp_data["task_index"].device.type == "cpu"
assert torch.equal(processed_comp_data["task_index"], complementary_data["task_index"])
# Check non-tensor fields remain unchanged
assert processed_comp_data["task"] == ["pick_cube"]
assert processed_comp_data["episode_id"] == 123
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_complementary_data_index_fields_cuda():
"""Test moving index and task_index fields to CUDA."""
processor = DeviceProcessor(device="cuda:0")
# Create CPU tensors
complementary_data = {
"index": torch.tensor([100, 101], dtype=torch.int64),
"task_index": torch.tensor([5], dtype=torch.int64),
}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
# Check tensors moved to CUDA
assert processed_comp_data["index"].device.type == "cuda"
assert processed_comp_data["index"].device.index == 0
assert processed_comp_data["task_index"].device.type == "cuda"
assert processed_comp_data["task_index"].device.index == 0
def test_complementary_data_without_index_fields():
"""Test that complementary_data without index/task_index fields works correctly."""
processor = DeviceProcessor(device="cpu")
complementary_data = {
"task": ["navigate"],
"episode_id": 456,
}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
# Should process without errors and preserve non-tensor fields
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["task"] == ["navigate"]
assert processed_comp_data["episode_id"] == 456
def test_complementary_data_mixed_tensors():
"""Test complementary_data with mix of tensors and non-tensors."""
processor = DeviceProcessor(device="cpu")
complementary_data = {
"task": ["pick_and_place"],
"index": torch.tensor([42], dtype=torch.int64),
"task_index": torch.tensor([3], dtype=torch.int64),
"metrics": [1.0, 2.0, 3.0], # List, not tensor
"config": {"speed": "fast"}, # Dict
"episode_id": 789, # Int
}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
# Check tensors are processed
assert isinstance(processed_comp_data["index"], torch.Tensor)
assert isinstance(processed_comp_data["task_index"], torch.Tensor)
# Check non-tensors remain unchanged
assert processed_comp_data["task"] == ["pick_and_place"]
assert processed_comp_data["metrics"] == [1.0, 2.0, 3.0]
assert processed_comp_data["config"] == {"speed": "fast"}
assert processed_comp_data["episode_id"] == 789
def test_complementary_data_float_dtype_conversion():
"""Test that float dtype conversion doesn't affect int tensors in complementary_data."""
processor = DeviceProcessor(device="cpu", float_dtype="float16")
complementary_data = {
"index": torch.tensor([42], dtype=torch.int64),
"task_index": torch.tensor([3], dtype=torch.int64),
"float_tensor": torch.tensor([1.5, 2.5], dtype=torch.float32), # Should be converted
}
transition = create_transition(complementary_data=complementary_data)
result = processor(transition)
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
# Int tensors should keep their dtype
assert processed_comp_data["index"].dtype == torch.int64
assert processed_comp_data["task_index"].dtype == torch.int64
# Float tensor should be converted
assert processed_comp_data["float_tensor"].dtype == torch.float16
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_complementary_data_full_pipeline_cuda():
"""Test full transition with complementary_data on CUDA."""
processor = DeviceProcessor(device="cuda:0", float_dtype="float16")
# Create full transition with mixed CPU tensors
observation = {"observation.state": torch.randn(1, 7, dtype=torch.float32)}
action = torch.randn(1, 4, dtype=torch.float32)
reward = torch.tensor(1.5, dtype=torch.float32)
done = torch.tensor(False)
complementary_data = {
"task": ["reach_target"],
"index": torch.tensor([1000], dtype=torch.int64),
"task_index": torch.tensor([10], dtype=torch.int64),
}
transition = create_transition(
observation=observation,
action=action,
reward=reward,
done=done,
complementary_data=complementary_data,
)
result = processor(transition)
# Check all components moved to CUDA
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
assert result[TransitionKey.ACTION].device.type == "cuda"
assert result[TransitionKey.REWARD].device.type == "cuda"
assert result[TransitionKey.DONE].device.type == "cuda"
# Check complementary_data tensors
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
assert processed_comp_data["index"].device.type == "cuda"
assert processed_comp_data["task_index"].device.type == "cuda"
# Check float conversion happened for float tensors
assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float16
assert result[TransitionKey.ACTION].dtype == torch.float16
assert result[TransitionKey.REWARD].dtype == torch.float16
# Check int tensors kept their dtype
assert processed_comp_data["index"].dtype == torch.int64
assert processed_comp_data["task_index"].dtype == torch.int64
def test_complementary_data_empty():
"""Test empty complementary_data handling."""
processor = DeviceProcessor(device="cpu")
transition = create_transition(
observation={"observation.state": torch.randn(1, 7)},
complementary_data={},
)
result = processor(transition)
# Should have empty dict
assert result[TransitionKey.COMPLEMENTARY_DATA] == {}
def test_complementary_data_none():
"""Test None complementary_data handling."""
processor = DeviceProcessor(device="cpu")
transition = create_transition(
observation={"observation.state": torch.randn(1, 7)},
complementary_data=None,
)
result = processor(transition)
# Complementary data should not be in the result (same as input)
assert TransitionKey.COMPLEMENTARY_DATA not in result
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_policy_processor_integration():
"""Test integration with policy processors - input on GPU, output on CPU."""
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor import NormalizerProcessor, ToBatchProcessor, UnnormalizerProcessor
# Create features and stats
features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)),
"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,)),
}
stats = {
"observation.state": {"mean": torch.zeros(10), "std": torch.ones(10)},
"action": {"mean": torch.zeros(5), "std": torch.ones(5)},
}
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD, FeatureType.ACTION: NormalizationMode.MEAN_STD}
# Create input processor (preprocessor) that moves to GPU
input_processor = RobotProcessor(
steps=[
NormalizerProcessor(features=features, norm_map=norm_map, stats=stats),
ToBatchProcessor(),
DeviceProcessor(device="cuda"),
],
name="test_preprocessor",
)
# Create output processor (postprocessor) that moves to CPU
output_processor = RobotProcessor(
steps=[
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(features={"action": features["action"]}, norm_map=norm_map, stats=stats),
],
name="test_postprocessor",
)
# Test data on CPU
observation = {"observation.state": torch.randn(10)}
action = torch.randn(5)
transition = create_transition(observation=observation, action=action)
# Process through input processor
input_result = input_processor(transition)
# Verify tensors are on GPU and batched
assert input_result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
assert input_result[TransitionKey.OBSERVATION]["observation.state"].shape[0] == 1
assert input_result[TransitionKey.ACTION].device.type == "cuda"
assert input_result[TransitionKey.ACTION].shape[0] == 1
# Simulate model output on GPU
model_output = create_transition(action=torch.randn(1, 5).cuda())
# Process through output processor
output_result = output_processor(model_output)
# Verify action is back on CPU and unnormalized
assert output_result[TransitionKey.ACTION].device.type == "cpu"
assert output_result[TransitionKey.ACTION].shape == (1, 5)

View File

@@ -1260,6 +1260,273 @@ def test_hotswap_stats_with_different_data_types():
torch.testing.assert_close(tensor_stats["observation.image"]["max"], torch.tensor(1.0))
def test_normalization_info_tracking():
"""Test that normalization info is tracked in complementary_data."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.STATE: NormalizationMode.MIN_MAX,
FeatureType.ACTION: NormalizationMode.IDENTITY,
}
stats = {
"observation.image": {
"mean": np.array([0.5, 0.5, 0.5]),
"std": np.array([0.2, 0.2, 0.2]),
},
"observation.state": {
"min": np.array([0.0, -1.0]),
"max": np.array([1.0, 1.0]),
},
"action": {
"mean": np.array([0.0, 0.0]),
"std": np.array([1.0, 1.0]),
},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = create_transition(observation=observation, action=action)
# Process the transition
normalized_transition = normalizer(transition)
# Check that normalization info is added
comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
assert comp_data is not None
assert "normalized_keys" in comp_data
norm_info = comp_data["normalized_keys"]
assert norm_info["observation.image"] == "MEAN_STD"
assert norm_info["observation.state"] == "MIN_MAX"
assert norm_info["action"] == "IDENTITY"
def test_unnormalization_info_tracking():
"""Test that unnormalization info is tracked in complementary_data."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MIN_MAX,
}
stats = {
"observation.image": {
"mean": np.array([0.5, 0.5, 0.5]),
"std": np.array([0.2, 0.2, 0.2]),
},
"action": {
"min": np.array([-1.0, -1.0]),
"max": np.array([1.0, 1.0]),
},
}
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
action = torch.tensor([0.0, -0.5])
transition = create_transition(observation=observation, action=action)
# Process the transition
unnormalized_transition = unnormalizer(transition)
# Check that unnormalization info is added
comp_data = unnormalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
assert comp_data is not None
assert "unnormalized_keys" in comp_data
unnorm_info = comp_data["unnormalized_keys"]
assert unnorm_info["observation.image"] == "MEAN_STD"
assert unnorm_info["action"] == "MIN_MAX"
def test_normalization_info_with_missing_stats():
"""Test normalization info when stats are missing for some keys."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.STATE: NormalizationMode.MIN_MAX,
}
# Only provide stats for image, not state
stats = {
"observation.image": {
"mean": np.array([0.5, 0.5, 0.5]),
"std": np.array([0.2, 0.2, 0.2]),
},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = create_transition(observation=observation)
# Process the transition
normalized_transition = normalizer(transition)
# Check that only keys with stats are in normalization info
comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
assert comp_data is not None
assert "normalized_keys" in comp_data
norm_info = comp_data["normalized_keys"]
assert norm_info["observation.image"] == "MEAN_STD"
# State should not be in the normalization info since it has no stats
assert "observation.state" not in norm_info
def test_normalization_info_with_selective_keys():
"""Test normalization info with selective normalization."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.STATE: NormalizationMode.MIN_MAX,
}
stats = {
"observation.image": {
"mean": np.array([0.5, 0.5, 0.5]),
"std": np.array([0.2, 0.2, 0.2]),
},
"observation.state": {
"min": np.array([0.0, -1.0]),
"max": np.array([1.0, 1.0]),
},
}
# Only normalize image
normalizer = NormalizerProcessor(
features=features, norm_map=norm_map, stats=stats, normalize_keys={"observation.image"}
)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = create_transition(observation=observation)
# Process the transition
normalized_transition = normalizer(transition)
# Check that only selected keys are in normalization info
comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
assert comp_data is not None
assert "normalized_keys" in comp_data
norm_info = comp_data["normalized_keys"]
assert norm_info["observation.image"] == "MEAN_STD"
# State should not be in the normalization info since it wasn't in normalize_keys
assert "observation.state" not in norm_info
def test_normalization_info_preserved_in_pipeline():
"""Test that normalization info is preserved when using RobotProcessor pipeline."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MIN_MAX,
}
stats = {
"observation.image": {
"mean": np.array([0.5, 0.5, 0.5]),
"std": np.array([0.2, 0.2, 0.2]),
},
"action": {
"min": np.array([-1.0, -1.0]),
"max": np.array([1.0, 1.0]),
},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
# Create pipeline
pipeline = RobotProcessor([normalizer, unnormalizer])
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
action = torch.tensor([0.5, -0.5])
transition = create_transition(observation=observation, action=action)
# Process through pipeline
result = pipeline(transition)
# Check that both normalization and unnormalization info are present
comp_data = result.get(TransitionKey.COMPLEMENTARY_DATA)
assert comp_data is not None
assert "normalized_keys" in comp_data
assert "unnormalized_keys" in comp_data
# Check normalization info
norm_info = comp_data["normalized_keys"]
assert norm_info["observation.image"] == "MEAN_STD"
assert norm_info["action"] == "MIN_MAX"
# Check unnormalization info
unnorm_info = comp_data["unnormalized_keys"]
assert unnorm_info["observation.image"] == "MEAN_STD"
assert unnorm_info["action"] == "MIN_MAX"
def test_normalization_info_empty_transition():
"""Test that no normalization info is added for empty transitions."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MIN_MAX,
}
stats = {
"observation.image": {"mean": [0.5], "std": [0.2]},
"action": {"min": [-1.0], "max": [1.0]},
}
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
# Empty transition
transition = create_transition()
# Process the transition
normalized_transition = normalizer(transition)
# Check that no normalization info is added
comp_data = normalized_transition.get(TransitionKey.COMPLEMENTARY_DATA)
assert comp_data is None or "normalized_keys" not in comp_data
def test_hotswap_stats_functional_test():
"""Test that hotswapped processor actually works functionally."""
# Create test data

View File

@@ -1639,6 +1639,109 @@ def test_state_file_naming_with_multiple_processors():
assert loaded_post.steps[0].window_size == 10
def test_default_batch_to_transition_with_index_fields():
"""Test that _default_batch_to_transition handles index and task_index fields correctly."""
from lerobot.processor.pipeline import _default_batch_to_transition
# Create batch with index and task_index fields
batch = {
"observation.state": torch.randn(1, 7),
"action": torch.randn(1, 4),
"next.reward": 1.5,
"next.done": False,
"task": ["pick_cube"],
"index": torch.tensor([42], dtype=torch.int64),
"task_index": torch.tensor([3], dtype=torch.int64),
}
transition = _default_batch_to_transition(batch)
# Check basic transition structure
assert TransitionKey.OBSERVATION in transition
assert TransitionKey.ACTION in transition
assert TransitionKey.COMPLEMENTARY_DATA in transition
# Check that index and task_index are in complementary_data
comp_data = transition[TransitionKey.COMPLEMENTARY_DATA]
assert "index" in comp_data
assert "task_index" in comp_data
assert "task" in comp_data
# Verify values
assert torch.equal(comp_data["index"], batch["index"])
assert torch.equal(comp_data["task_index"], batch["task_index"])
assert comp_data["task"] == batch["task"]
def test_default_transition_to_batch_with_index_fields():
"""Test that _default_transition_to_batch handles index and task_index fields correctly."""
from lerobot.processor.pipeline import _default_transition_to_batch
# Create transition with index and task_index in complementary_data
transition = create_transition(
observation={"observation.state": torch.randn(1, 7)},
action=torch.randn(1, 4),
reward=1.5,
done=False,
complementary_data={
"task": ["navigate"],
"index": torch.tensor([100], dtype=torch.int64),
"task_index": torch.tensor([5], dtype=torch.int64),
},
)
batch = _default_transition_to_batch(transition)
# Check that index and task_index are in the batch
assert "index" in batch
assert "task_index" in batch
assert "task" in batch
# Verify values
assert torch.equal(batch["index"], transition[TransitionKey.COMPLEMENTARY_DATA]["index"])
assert torch.equal(batch["task_index"], transition[TransitionKey.COMPLEMENTARY_DATA]["task_index"])
assert batch["task"] == transition[TransitionKey.COMPLEMENTARY_DATA]["task"]
def test_batch_to_transition_without_index_fields():
"""Test that conversion works without index and task_index fields."""
from lerobot.processor.pipeline import _default_batch_to_transition
# Batch without index/task_index
batch = {
"observation.state": torch.randn(1, 7),
"action": torch.randn(1, 4),
"task": ["pick_cube"],
}
transition = _default_batch_to_transition(batch)
comp_data = transition[TransitionKey.COMPLEMENTARY_DATA]
# Should have task but not index/task_index
assert "task" in comp_data
assert "index" not in comp_data
assert "task_index" not in comp_data
def test_transition_to_batch_without_index_fields():
"""Test that conversion works without index and task_index fields."""
from lerobot.processor.pipeline import _default_transition_to_batch
# Transition without index/task_index
transition = create_transition(
observation={"observation.state": torch.randn(1, 7)},
action=torch.randn(1, 4),
complementary_data={"task": ["navigate"]},
)
batch = _default_transition_to_batch(transition)
# Should have task but not index/task_index
assert "task" in batch
assert "index" not in batch
assert "task_index" not in batch
def test_override_with_device_strings():
"""Test overriding device parameters with string values."""

View File

@@ -0,0 +1,699 @@
"""
Tests for the TokenizerProcessor class.
"""
import tempfile
from unittest.mock import patch
import pytest
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import OBS_LANGUAGE
from lerobot.processor.pipeline import RobotProcessor, TransitionKey
from lerobot.processor.tokenizer_processor import TokenizerProcessor
def create_transition(
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
):
"""Helper function to create test transitions."""
return {
TransitionKey.OBSERVATION: observation,
TransitionKey.ACTION: action,
TransitionKey.REWARD: reward,
TransitionKey.DONE: done,
TransitionKey.TRUNCATED: truncated,
TransitionKey.INFO: info,
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
}
class MockTokenizer:
"""Mock tokenizer for testing that mimics transformers tokenizer interface."""
def __init__(self, vocab_size: int = 1000):
self.vocab_size = vocab_size
def __call__(
self,
text: str | list[str],
max_length: int = 512,
truncation: bool = True,
padding: str = "max_length",
padding_side: str = "right",
return_tensors: str = "pt",
**kwargs,
) -> dict[str, torch.Tensor]:
"""Mock tokenization that returns deterministic tokens based on text."""
if isinstance(text, str):
texts = [text]
else:
texts = text
batch_size = len(texts)
# Create mock input_ids and attention_mask
input_ids = torch.zeros(batch_size, max_length, dtype=torch.long)
attention_mask = torch.zeros(batch_size, max_length, dtype=torch.long)
for i, txt in enumerate(texts):
# Simple mock: use hash of text to generate deterministic tokens
text_hash = hash(txt) % self.vocab_size
seq_len = min(len(txt.split()), max_length)
# Fill input_ids with simple pattern based on text
for j in range(seq_len):
input_ids[i, j] = (text_hash + j) % self.vocab_size
# Set attention mask for non-padded positions
attention_mask[i, :seq_len] = 1
result = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
# Return single sequence for single input to match transformers behavior
if len(texts) == 1:
result = {k: v.squeeze(0) for k, v in result.items()}
return result
@pytest.fixture
def mock_tokenizer():
"""Provide a mock tokenizer for testing."""
return MockTokenizer(vocab_size=100)
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_basic_tokenization(mock_auto_tokenizer):
"""Test basic string tokenization functionality."""
# Mock AutoTokenizer.from_pretrained to return our mock tokenizer
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10)
transition = create_transition(complementary_data={"task": "pick up the red cube"})
result = processor(transition)
# Check that original task is preserved
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "pick up the red cube"
# Check that tokens were added to observation
observation = result[TransitionKey.OBSERVATION]
assert f"{OBS_LANGUAGE}.tokens" in observation
assert f"{OBS_LANGUAGE}.attention_mask" in observation
# Check token structure
tokens = observation[f"{OBS_LANGUAGE}.tokens"]
attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"]
assert isinstance(tokens, torch.Tensor)
assert isinstance(attention_mask, torch.Tensor)
assert tokens.shape == (10,)
assert attention_mask.shape == (10,)
def test_basic_tokenization_with_tokenizer_object():
"""Test basic string tokenization functionality using tokenizer object directly."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(complementary_data={"task": "pick up the red cube"})
result = processor(transition)
# Check that original task is preserved
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "pick up the red cube"
# Check that tokens were added to observation
observation = result[TransitionKey.OBSERVATION]
assert f"{OBS_LANGUAGE}.tokens" in observation
assert f"{OBS_LANGUAGE}.attention_mask" in observation
# Check token structure
tokens = observation[f"{OBS_LANGUAGE}.tokens"]
attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"]
assert isinstance(tokens, torch.Tensor)
assert isinstance(attention_mask, torch.Tensor)
assert tokens.shape == (10,)
assert attention_mask.shape == (10,)
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_list_of_strings_tokenization(mock_auto_tokenizer):
"""Test tokenization of a list of strings."""
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=8)
transition = create_transition(complementary_data={"task": ["pick up cube", "place on table"]})
result = processor(transition)
# Check that original task is preserved
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == ["pick up cube", "place on table"]
# Check that tokens were added to observation
observation = result[TransitionKey.OBSERVATION]
tokens = observation[f"{OBS_LANGUAGE}.tokens"]
attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"]
assert tokens.shape == (2, 8) # batch_size=2, seq_len=8
assert attention_mask.shape == (2, 8)
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_custom_keys(mock_auto_tokenizer):
"""Test using custom task_key."""
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", task_key="instruction", max_length=5)
transition = create_transition(complementary_data={"instruction": "move forward"})
result = processor(transition)
# Check that tokens are stored in observation regardless of task_key
observation = result[TransitionKey.OBSERVATION]
assert f"{OBS_LANGUAGE}.tokens" in observation
assert f"{OBS_LANGUAGE}.attention_mask" in observation
tokens = observation[f"{OBS_LANGUAGE}.tokens"]
assert tokens.shape == (5,)
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_none_complementary_data(mock_auto_tokenizer):
"""Test handling of None complementary_data."""
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
transition = create_transition(complementary_data=None)
result = processor(transition)
assert result == transition # Should return unchanged
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_missing_task_key(mock_auto_tokenizer):
"""Test handling when task key is missing."""
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
transition = create_transition(complementary_data={"other_field": "some value"})
result = processor(transition)
assert result == transition # Should return unchanged
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_none_task_value(mock_auto_tokenizer):
"""Test handling when task value is None."""
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
transition = create_transition(complementary_data={"task": None})
result = processor(transition)
assert result == transition # Should return unchanged
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_unsupported_task_type(mock_auto_tokenizer):
"""Test handling of unsupported task types."""
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
# Test with integer task
transition = create_transition(complementary_data={"task": 123})
result = processor(transition)
assert result == transition # Should return unchanged
# Test with mixed list
transition = create_transition(complementary_data={"task": ["text", 123, "more text"]})
result = processor(transition)
assert result == transition # Should return unchanged
def test_no_tokenizer_error():
"""Test that ValueError is raised when neither tokenizer nor tokenizer_name is provided."""
with pytest.raises(ValueError, match="Either 'tokenizer' or 'tokenizer_name' must be provided"):
TokenizerProcessor()
def test_invalid_tokenizer_name_error():
"""Test that error is raised when invalid tokenizer_name is provided."""
with patch("lerobot.processor.tokenizer_processor.AutoTokenizer") as mock_auto_tokenizer:
# Mock import error
mock_auto_tokenizer.from_pretrained.side_effect = Exception("Model not found")
with pytest.raises(Exception, match="Model not found"):
TokenizerProcessor(tokenizer_name="invalid-tokenizer")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_get_config_with_tokenizer_name(mock_auto_tokenizer):
"""Test configuration serialization when using tokenizer_name."""
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(
tokenizer_name="test-tokenizer",
max_length=256,
task_key="instruction",
padding="longest",
truncation=False,
)
config = processor.get_config()
expected = {
"tokenizer_name": "test-tokenizer",
"max_length": 256,
"task_key": "instruction",
"padding_side": "right",
"padding": "longest",
"truncation": False,
}
assert config == expected
def test_get_config_with_tokenizer_object():
"""Test configuration serialization when using tokenizer object."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessor(
tokenizer=mock_tokenizer,
max_length=256,
task_key="instruction",
padding="longest",
truncation=False,
)
config = processor.get_config()
# tokenizer_name should not be in config when tokenizer object is used
expected = {
"max_length": 256,
"task_key": "instruction",
"padding_side": "right",
"padding": "longest",
"truncation": False,
}
assert config == expected
assert "tokenizer_name" not in config
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_state_dict_methods(mock_auto_tokenizer):
"""Test state_dict and load_state_dict methods."""
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
# Should return empty dict
state = processor.state_dict()
assert state == {}
# load_state_dict should not raise error
processor.load_state_dict({})
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_reset_method(mock_auto_tokenizer):
"""Test reset method."""
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
# Should not raise error
processor.reset()
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_integration_with_robot_processor(mock_auto_tokenizer):
"""Test integration with RobotProcessor."""
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
tokenizer_processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=6)
robot_processor = RobotProcessor([tokenizer_processor])
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "test task"},
)
result = robot_processor(transition)
# Check that observation exists and tokenization was applied
assert TransitionKey.OBSERVATION in result
observation = result[TransitionKey.OBSERVATION]
assert f"{OBS_LANGUAGE}.tokens" in observation
assert f"{OBS_LANGUAGE}.attention_mask" in observation
tokens = observation[f"{OBS_LANGUAGE}.tokens"]
attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"]
assert tokens.shape == (6,)
assert attention_mask.shape == (6,)
# Check that other data is preserved
assert torch.equal(
result[TransitionKey.OBSERVATION]["state"], transition[TransitionKey.OBSERVATION]["state"]
)
assert torch.equal(result[TransitionKey.ACTION], transition[TransitionKey.ACTION])
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_save_and_load_pretrained_with_tokenizer_name(mock_auto_tokenizer):
"""Test saving and loading processor with tokenizer_name."""
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
original_processor = TokenizerProcessor(
tokenizer_name="test-tokenizer", max_length=32, task_key="instruction"
)
robot_processor = RobotProcessor([original_processor])
with tempfile.TemporaryDirectory() as temp_dir:
# Save processor
robot_processor.save_pretrained(temp_dir)
# Load processor - tokenizer will be recreated from saved config
loaded_processor = RobotProcessor.from_pretrained(temp_dir)
# Test that loaded processor works
transition = create_transition(complementary_data={"instruction": "test instruction"})
result = loaded_processor(transition)
assert TransitionKey.OBSERVATION in result
assert f"{OBS_LANGUAGE}.tokens" in result[TransitionKey.OBSERVATION]
assert f"{OBS_LANGUAGE}.attention_mask" in result[TransitionKey.OBSERVATION]
def test_save_and_load_pretrained_with_tokenizer_object():
"""Test saving and loading processor with tokenizer object using overrides."""
mock_tokenizer = MockTokenizer(vocab_size=100)
original_processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=32, task_key="instruction")
robot_processor = RobotProcessor([original_processor])
with tempfile.TemporaryDirectory() as temp_dir:
# Save processor
robot_processor.save_pretrained(temp_dir)
# Load processor with tokenizer override (since tokenizer object wasn't saved)
loaded_processor = RobotProcessor.from_pretrained(
temp_dir, overrides={"tokenizer_processor": {"tokenizer": mock_tokenizer}}
)
# Test that loaded processor works
transition = create_transition(complementary_data={"instruction": "test instruction"})
result = loaded_processor(transition)
assert TransitionKey.OBSERVATION in result
assert f"{OBS_LANGUAGE}.tokens" in result[TransitionKey.OBSERVATION]
assert f"{OBS_LANGUAGE}.attention_mask" in result[TransitionKey.OBSERVATION]
def test_registry_functionality():
"""Test that the processor is properly registered."""
from lerobot.processor.pipeline import ProcessorStepRegistry
# Check that the processor is registered
assert "tokenizer_processor" in ProcessorStepRegistry.list()
# Check that we can retrieve it
retrieved_class = ProcessorStepRegistry.get("tokenizer_processor")
assert retrieved_class is TokenizerProcessor
def test_feature_contract_basic():
"""Test basic feature contract functionality."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=128)
input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)),
"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,)),
}
output_features = processor.feature_contract(input_features)
# Check that original features are preserved
assert "observation.state" in output_features
assert "action" in output_features
# Check that tokenized features are added
assert f"{OBS_LANGUAGE}.tokens" in output_features
assert f"{OBS_LANGUAGE}.attention_mask" in output_features
# Check feature properties
tokens_feature = output_features[f"{OBS_LANGUAGE}.tokens"]
attention_mask_feature = output_features[f"{OBS_LANGUAGE}.attention_mask"]
assert tokens_feature.type == FeatureType.LANGUAGE
assert tokens_feature.shape == (128,)
assert attention_mask_feature.type == FeatureType.LANGUAGE
assert attention_mask_feature.shape == (128,)
def test_feature_contract_with_custom_max_length():
"""Test feature contract with custom max_length."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=64)
input_features = {}
output_features = processor.feature_contract(input_features)
# Check that features use correct max_length
assert f"{OBS_LANGUAGE}.tokens" in output_features
assert f"{OBS_LANGUAGE}.attention_mask" in output_features
tokens_feature = output_features[f"{OBS_LANGUAGE}.tokens"]
attention_mask_feature = output_features[f"{OBS_LANGUAGE}.attention_mask"]
assert tokens_feature.shape == (64,)
assert attention_mask_feature.shape == (64,)
def test_feature_contract_existing_features():
"""Test feature contract when tokenized features already exist."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=256)
input_features = {
f"{OBS_LANGUAGE}.tokens": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)),
f"{OBS_LANGUAGE}.attention_mask": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)),
}
output_features = processor.feature_contract(input_features)
# Should not overwrite existing features
assert output_features[f"{OBS_LANGUAGE}.tokens"].shape == (100,) # Original shape preserved
assert output_features[f"{OBS_LANGUAGE}.attention_mask"].shape == (100,)
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_tokenization_parameters(mock_auto_tokenizer):
"""Test that tokenization parameters are correctly passed to tokenizer."""
# Create a custom mock that tracks calls
class TrackingMockTokenizer:
def __init__(self):
self.last_call_args = None
self.last_call_kwargs = None
def __call__(self, *args, **kwargs):
self.last_call_args = args
self.last_call_kwargs = kwargs
# Return minimal valid output
return {
"input_ids": torch.zeros(16, dtype=torch.long),
"attention_mask": torch.ones(16, dtype=torch.long),
}
tracking_tokenizer = TrackingMockTokenizer()
mock_auto_tokenizer.from_pretrained.return_value = tracking_tokenizer
processor = TokenizerProcessor(
tokenizer_name="test-tokenizer",
max_length=16,
padding="longest",
truncation=False,
padding_side="left",
)
transition = create_transition(complementary_data={"task": "test task"})
processor(transition)
# Check that parameters were passed correctly (task is converted to list)
assert tracking_tokenizer.last_call_args == (["test task"],)
assert tracking_tokenizer.last_call_kwargs["max_length"] == 16
assert tracking_tokenizer.last_call_kwargs["padding"] == "longest"
assert tracking_tokenizer.last_call_kwargs["padding_side"] == "left"
assert tracking_tokenizer.last_call_kwargs["truncation"] is False
assert tracking_tokenizer.last_call_kwargs["return_tensors"] == "pt"
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_preserves_other_complementary_data(mock_auto_tokenizer):
"""Test that other complementary data fields are preserved."""
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
transition = create_transition(
complementary_data={
"task": "test task",
"episode_id": 123,
"timestamp": 456.789,
"other_field": {"nested": "data"},
}
)
result = processor(transition)
comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
# Check that all original fields are preserved
assert comp_data["task"] == "test task"
assert comp_data["episode_id"] == 123
assert comp_data["timestamp"] == 456.789
assert comp_data["other_field"] == {"nested": "data"}
# Check that tokens were added to observation
observation = result[TransitionKey.OBSERVATION]
assert f"{OBS_LANGUAGE}.tokens" in observation
assert f"{OBS_LANGUAGE}.attention_mask" in observation
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_deterministic_tokenization(mock_auto_tokenizer):
"""Test that tokenization is deterministic for the same input."""
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10)
transition = create_transition(complementary_data={"task": "consistent test"})
result1 = processor(transition)
result2 = processor(transition)
tokens1 = result1[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
attention_mask1 = result1[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"]
tokens2 = result2[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"]
attention_mask2 = result2[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"]
# Results should be identical
assert torch.equal(tokens1, tokens2)
assert torch.equal(attention_mask1, attention_mask2)
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_empty_string_task(mock_auto_tokenizer):
"""Test handling of empty string task."""
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=8)
transition = create_transition(complementary_data={"task": ""})
result = processor(transition)
# Should still tokenize (mock tokenizer handles empty strings)
observation = result[TransitionKey.OBSERVATION]
assert f"{OBS_LANGUAGE}.tokens" in observation
tokens = observation[f"{OBS_LANGUAGE}.tokens"]
assert tokens.shape == (8,)
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_very_long_task(mock_auto_tokenizer):
"""Test handling of very long task strings."""
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=5, truncation=True)
long_task = " ".join(["word"] * 100) # Very long task
transition = create_transition(complementary_data={"task": long_task})
result = processor(transition)
# Should be truncated to max_length
observation = result[TransitionKey.OBSERVATION]
tokens = observation[f"{OBS_LANGUAGE}.tokens"]
attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"]
assert tokens.shape == (5,)
assert attention_mask.shape == (5,)
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_custom_padding_side(mock_auto_tokenizer):
"""Test using custom padding_side parameter."""
# Create a mock tokenizer that tracks padding_side calls
class PaddingSideTrackingTokenizer:
def __init__(self):
self.padding_side_calls = []
def __call__(
self,
text,
max_length=512,
truncation=True,
padding="max_length",
padding_side="right",
return_tensors="pt",
**kwargs,
):
self.padding_side_calls.append(padding_side)
# Return minimal valid output
return {
"input_ids": torch.zeros(max_length, dtype=torch.long),
"attention_mask": torch.ones(max_length, dtype=torch.long),
}
tracking_tokenizer = PaddingSideTrackingTokenizer()
mock_auto_tokenizer.from_pretrained.return_value = tracking_tokenizer
# Test left padding
processor_left = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10, padding_side="left")
transition = create_transition(complementary_data={"task": "test task"})
processor_left(transition)
assert tracking_tokenizer.padding_side_calls[-1] == "left"
# Test right padding (default)
processor_right = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10, padding_side="right")
processor_right(transition)
assert tracking_tokenizer.padding_side_calls[-1] == "right"