feat(policies): add autoregressive VLAs with tokenization PiFast (#2734)

This commit is contained in:
Jade Choghari
2026-01-09 23:08:37 +01:00
committed by GitHub
parent ba3d2148a3
commit 1d86c9b7f2
15 changed files with 3214 additions and 5 deletions

View File

@@ -23,22 +23,29 @@ token IDs and attention masks, which are then added to the observation dictionar
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import torch
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
from lerobot.utils.constants import (
ACTION_TOKEN_MASK,
ACTION_TOKENS,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
)
from lerobot.utils.import_utils import _transformers_available
from .core import EnvTransition, TransitionKey
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers import AutoTokenizer
from transformers import AutoProcessor, AutoTokenizer
else:
AutoProcessor = None
AutoTokenizer = None
@@ -268,3 +275,256 @@ class TokenizerProcessorStep(ObservationProcessorStep):
)
return features
@dataclass
@ProcessorStepRegistry.register(name="action_tokenizer_processor")
class ActionTokenizerProcessorStep(ActionProcessorStep):
"""
Processor step to tokenize action data using a fast action tokenizer.
This step takes action tensors from an `EnvTransition`, tokenizes them using
a Hugging Face `transformers` AutoProcessor (such as the Physical Intelligence "fast" tokenizer),
and returns the tokenized action.
Requires the `transformers` library to be installed.
Attributes:
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
paligemma_tokenizer_name: The name of a pretrained PaliGemma tokenizer from the Hugging Face Hub (e.g., "google/paligemma-3b-pt-224").
"""
action_tokenizer_name: str | None = None
action_tokenizer_input_object: Any | None = None
trust_remote_code: bool = True
max_action_tokens: int = 256
fast_skip_tokens: int = 128
paligemma_tokenizer_name: str = "google/paligemma-3b-pt-224"
# Internal tokenizer instance (not part of the config)
action_tokenizer: Any = field(default=None, init=False, repr=False)
_paligemma_tokenizer: Any = field(default=None, init=False, repr=False)
def __post_init__(self):
"""
Initializes the action tokenizer after the dataclass is created.
It checks for the availability of the `transformers` library and loads the tokenizer
either from a provided object or by name from the Hugging Face Hub.
Raises:
ImportError: If the `transformers` library is not installed.
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
"""
if not _transformers_available:
raise ImportError(
"The 'transformers' library is not installed. "
"Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionTokenizerProcessorStep."
)
if self.action_tokenizer_input_object is not None:
self.action_tokenizer = self.action_tokenizer_input_object
elif self.action_tokenizer_name is not None:
if AutoProcessor is None:
raise ImportError("AutoProcessor is not available")
self.action_tokenizer = AutoProcessor.from_pretrained(
self.action_tokenizer_name, trust_remote_code=self.trust_remote_code
)
else:
raise ValueError(
"Either 'action_tokenizer' or 'action_tokenizer_name' must be provided. "
"Pass a tokenizer object directly or a tokenizer name to auto-load."
)
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
self.paligemma_tokenizer_name,
trust_remote_code=self.trust_remote_code,
add_eos_token=True,
add_bos_token=False,
)
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
Applies action tokenization to the transition.
This overrides the base class to handle both tokens and mask.
Args:
transition: The input transition with action data.
Returns:
The processed transition with tokenized actions and mask in complementary data.
"""
self._current_transition = transition.copy()
new_transition = self._current_transition
action = new_transition.get(TransitionKey.ACTION)
if action is None:
# During inference, no action is available, skip tokenization
return new_transition
# Tokenize and get both tokens and mask
tokens, mask = self._tokenize_action(action)
# Store mask in complementary data
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if complementary_data is None:
complementary_data = {}
complementary_data[ACTION_TOKEN_MASK] = mask
complementary_data[ACTION_TOKENS] = tokens
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition
def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
"""
Converts action tokens to PaliGemma tokens.
"""
return self._paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Tokenizes the action tensor and creates a mask.
Args:
action: The input action tensor to tokenize. Shape: (B, H, action_dim) or (H, action_dim,)
Returns:
A tuple of (tokens, mask) where:
- tokens: Tensor of token IDs with shape (B, max_action_tokens)
- mask: Boolean mask with shape (B, max_action_tokens), True for real tokens, False for padding
"""
if action is None:
raise ValueError("Action cannot be None")
# Get the device and dtype of the input action
device = action.device if isinstance(action, torch.Tensor) else None
# Handle single sample (add batch dimension)
single_sample = action.dim() == 1
if single_sample:
action = action.unsqueeze(0)
batch_size = action.shape[0]
# Tokenize the action batch
# The fast tokenizer expects action data and returns token IDs
tokens_list = []
masks_list = []
for i in range(batch_size):
# Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy)
action_cpu = action[i : i + 1].cpu()
tokens = self.action_tokenizer(action_cpu)
# Convert to numpy array if it's a list
if isinstance(tokens, list) or not isinstance(tokens, torch.Tensor):
tokens = torch.tensor(tokens, dtype=torch.long, device=action.device)
else:
# Move tokens back to the same device as input action
tokens = tokens.to(device=action.device)
# Flatten to 1D if needed
if tokens.dim() > 1:
tokens = tokens.flatten()
bos_id = self._paligemma_tokenizer.bos_token_id
# add bos
tokens = torch.cat(
[
torch.tensor([bos_id], device=action.device),
torch.tensor(
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
device=action.device,
),
self._act_tokens_to_paligemma_tokens(tokens),
torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),
]
)
# Truncate or pad to max_action_tokens
if len(tokens) > self.max_action_tokens:
logging.warning(
f"Token length ({len(tokens)}) exceeds max length ({self.max_action_tokens}), truncating. "
"Consider increasing the `max_action_tokens` in your model config if this happens frequently."
)
tokens = tokens[: self.max_action_tokens]
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
else:
mask = torch.cat(
[
torch.ones(len(tokens), dtype=torch.bool, device=action.device),
torch.zeros(
self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device
),
]
)
# Pad tokens with zeros
tokens = torch.nn.functional.pad(tokens, (0, self.max_action_tokens - len(tokens)), value=0)
tokens_list.append(tokens)
masks_list.append(mask)
# Stack into batched tensors
tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens)
masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens)
# Remove batch dimension if input was single sample
if single_sample:
tokens_batch = tokens_batch.squeeze(0)
masks_batch = masks_batch.squeeze(0)
# Move to the same device as the input
if device is not None:
tokens_batch = tokens_batch.to(device)
masks_batch = masks_batch.to(device)
return tokens_batch, masks_batch
def action(self, action: torch.Tensor) -> torch.Tensor:
"""
This method is not used since we override __call__.
Required by ActionProcessorStep ABC.
"""
tokens, _ = self._tokenize_action(action)
return tokens
def get_config(self) -> dict[str, Any]:
"""
Returns the serializable configuration of the processor.
Note: The tokenizer object itself is not serialized. If the processor was initialized
with a tokenizer name, that name will be included in the config.
Returns:
A dictionary with the processor's configuration parameters.
"""
config = {
"trust_remote_code": self.trust_remote_code,
"max_action_tokens": self.max_action_tokens,
}
# Only save tokenizer_name if it was used to create the tokenizer
if self.action_tokenizer_name is not None and self.action_tokenizer_input_object is None:
config["action_tokenizer_name"] = self.action_tokenizer_name
return config
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Updates feature definitions to reflect tokenized actions.
This updates the policy features dictionary to indicate that the action
has been tokenized into a sequence of token IDs with shape (max_action_tokens,).
Args:
features: The dictionary of existing policy features.
Returns:
The updated dictionary of policy features.
"""
return features