fix(pi052): supervise only FAST action-code tokens

Mask the FAST auxiliary loss to discrete action-code tokens so wrapper formatting tokens do not affect action co-training.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-18 17:38:34 +00:00
parent 474c5478d9
commit 0e2dc1b76f
4 changed files with 134 additions and 30 deletions

View File

@@ -32,6 +32,7 @@ import torch
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.types import EnvTransition, RobotObservation, TransitionKey
from lerobot.utils.constants import (
ACTION_CODE_TOKEN_MASK,
ACTION_TOKEN_MASK,
ACTION_TOKENS,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -412,14 +413,15 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
# During inference, no action is available, skip tokenization
return new_transition
# Tokenize and get both tokens and mask
tokens, mask = self._tokenize_action(action)
# Tokenize and get masks for the full formatted sequence and the discrete action codes.
tokens, mask, code_mask = self._tokenize_action(action)
# Store mask in complementary data
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if complementary_data is None:
complementary_data = {}
complementary_data[ACTION_TOKEN_MASK] = mask
complementary_data[ACTION_CODE_TOKEN_MASK] = code_mask
complementary_data[ACTION_TOKENS] = tokens
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition
@@ -430,7 +432,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
"""
return self._paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Tokenizes the action tensor and creates a mask.
@@ -459,6 +461,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
# The fast tokenizer expects action data and returns token IDs
tokens_list = []
masks_list = []
code_masks_list = []
for i in range(batch_size):
# Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy)
@@ -476,19 +479,26 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
if tokens.dim() > 1:
tokens = tokens.flatten()
action_code_tokens = self._act_tokens_to_paligemma_tokens(tokens)
bos_id = self._paligemma_tokenizer.bos_token_id
# add bos
prompt_tokens = torch.tensor(
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
device=action.device,
)
end_tokens = torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device)
code_start = 1 + len(prompt_tokens)
code_end = code_start + len(action_code_tokens)
tokens = torch.cat(
[
torch.tensor([bos_id], device=action.device),
torch.tensor(
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
device=action.device,
),
self._act_tokens_to_paligemma_tokens(tokens),
torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),
prompt_tokens,
action_code_tokens,
end_tokens,
]
)
code_mask = torch.zeros(len(tokens), dtype=torch.bool, device=action.device)
code_mask[code_start:code_end] = True
# Truncate or pad to max_action_tokens
if len(tokens) > self.max_action_tokens:
@@ -497,44 +507,49 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
"Consider increasing the `max_action_tokens` in your model config if this happens frequently."
)
tokens = tokens[: self.max_action_tokens]
code_mask = code_mask[: self.max_action_tokens]
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
else:
pad_len = self.max_action_tokens - len(tokens)
mask = torch.cat(
[
torch.ones(len(tokens), dtype=torch.bool, device=action.device),
torch.zeros(
self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device
),
torch.zeros(pad_len, dtype=torch.bool, device=action.device),
]
)
code_mask = torch.nn.functional.pad(code_mask, (0, pad_len), value=False)
# Pad tokens with zeros
tokens = torch.nn.functional.pad(tokens, (0, self.max_action_tokens - len(tokens)), value=0)
tokens = torch.nn.functional.pad(tokens, (0, pad_len), value=0)
tokens_list.append(tokens)
masks_list.append(mask)
code_masks_list.append(code_mask)
# Stack into batched tensors
tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens)
masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens)
code_masks_batch = torch.stack(code_masks_list, dim=0) # (B, max_action_tokens)
# Remove batch dimension if input was single sample
if single_sample:
tokens_batch = tokens_batch.squeeze(0)
masks_batch = masks_batch.squeeze(0)
code_masks_batch = code_masks_batch.squeeze(0)
# Move to the same device as the input
if device is not None:
tokens_batch = tokens_batch.to(device)
masks_batch = masks_batch.to(device)
code_masks_batch = code_masks_batch.to(device)
return tokens_batch, masks_batch
return tokens_batch, masks_batch, code_masks_batch
def action(self, action: torch.Tensor) -> torch.Tensor:
"""
This method is not used since we override __call__.
Required by ActionProcessorStep ABC.
"""
tokens, _ = self._tokenize_action(action)
tokens, _, _ = self._tokenize_action(action)
return tokens
def get_config(self) -> dict[str, Any]: