mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
fix(pi052): supervise only FAST action-code tokens
Mask the FAST auxiliary loss to discrete action-code tokens so wrapper formatting tokens do not affect action co-training. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -32,6 +32,7 @@ import torch
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.types import EnvTransition, RobotObservation, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
ACTION_CODE_TOKEN_MASK,
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -412,14 +413,15 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
# During inference, no action is available, skip tokenization
|
||||
return new_transition
|
||||
|
||||
# Tokenize and get both tokens and mask
|
||||
tokens, mask = self._tokenize_action(action)
|
||||
# Tokenize and get masks for the full formatted sequence and the discrete action codes.
|
||||
tokens, mask, code_mask = self._tokenize_action(action)
|
||||
|
||||
# Store mask in complementary data
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if complementary_data is None:
|
||||
complementary_data = {}
|
||||
complementary_data[ACTION_TOKEN_MASK] = mask
|
||||
complementary_data[ACTION_CODE_TOKEN_MASK] = code_mask
|
||||
complementary_data[ACTION_TOKENS] = tokens
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
return new_transition
|
||||
@@ -430,7 +432,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
"""
|
||||
return self._paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
|
||||
|
||||
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Tokenizes the action tensor and creates a mask.
|
||||
|
||||
@@ -459,6 +461,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
# The fast tokenizer expects action data and returns token IDs
|
||||
tokens_list = []
|
||||
masks_list = []
|
||||
code_masks_list = []
|
||||
|
||||
for i in range(batch_size):
|
||||
# Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy)
|
||||
@@ -476,19 +479,26 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
if tokens.dim() > 1:
|
||||
tokens = tokens.flatten()
|
||||
|
||||
action_code_tokens = self._act_tokens_to_paligemma_tokens(tokens)
|
||||
bos_id = self._paligemma_tokenizer.bos_token_id
|
||||
# add bos
|
||||
prompt_tokens = torch.tensor(
|
||||
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
|
||||
device=action.device,
|
||||
)
|
||||
end_tokens = torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device)
|
||||
|
||||
code_start = 1 + len(prompt_tokens)
|
||||
code_end = code_start + len(action_code_tokens)
|
||||
tokens = torch.cat(
|
||||
[
|
||||
torch.tensor([bos_id], device=action.device),
|
||||
torch.tensor(
|
||||
self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False),
|
||||
device=action.device,
|
||||
),
|
||||
self._act_tokens_to_paligemma_tokens(tokens),
|
||||
torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),
|
||||
prompt_tokens,
|
||||
action_code_tokens,
|
||||
end_tokens,
|
||||
]
|
||||
)
|
||||
code_mask = torch.zeros(len(tokens), dtype=torch.bool, device=action.device)
|
||||
code_mask[code_start:code_end] = True
|
||||
|
||||
# Truncate or pad to max_action_tokens
|
||||
if len(tokens) > self.max_action_tokens:
|
||||
@@ -497,44 +507,49 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
"Consider increasing the `max_action_tokens` in your model config if this happens frequently."
|
||||
)
|
||||
tokens = tokens[: self.max_action_tokens]
|
||||
code_mask = code_mask[: self.max_action_tokens]
|
||||
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
|
||||
else:
|
||||
pad_len = self.max_action_tokens - len(tokens)
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.ones(len(tokens), dtype=torch.bool, device=action.device),
|
||||
torch.zeros(
|
||||
self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device
|
||||
),
|
||||
torch.zeros(pad_len, dtype=torch.bool, device=action.device),
|
||||
]
|
||||
)
|
||||
code_mask = torch.nn.functional.pad(code_mask, (0, pad_len), value=False)
|
||||
# Pad tokens with zeros
|
||||
tokens = torch.nn.functional.pad(tokens, (0, self.max_action_tokens - len(tokens)), value=0)
|
||||
tokens = torch.nn.functional.pad(tokens, (0, pad_len), value=0)
|
||||
|
||||
tokens_list.append(tokens)
|
||||
masks_list.append(mask)
|
||||
code_masks_list.append(code_mask)
|
||||
|
||||
# Stack into batched tensors
|
||||
tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens)
|
||||
masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens)
|
||||
code_masks_batch = torch.stack(code_masks_list, dim=0) # (B, max_action_tokens)
|
||||
|
||||
# Remove batch dimension if input was single sample
|
||||
if single_sample:
|
||||
tokens_batch = tokens_batch.squeeze(0)
|
||||
masks_batch = masks_batch.squeeze(0)
|
||||
code_masks_batch = code_masks_batch.squeeze(0)
|
||||
|
||||
# Move to the same device as the input
|
||||
if device is not None:
|
||||
tokens_batch = tokens_batch.to(device)
|
||||
masks_batch = masks_batch.to(device)
|
||||
code_masks_batch = code_masks_batch.to(device)
|
||||
|
||||
return tokens_batch, masks_batch
|
||||
return tokens_batch, masks_batch, code_masks_batch
|
||||
|
||||
def action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This method is not used since we override __call__.
|
||||
Required by ActionProcessorStep ABC.
|
||||
"""
|
||||
tokens, _ = self._tokenize_action(action)
|
||||
tokens, _, _ = self._tokenize_action(action)
|
||||
return tokens
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
|
||||
Reference in New Issue
Block a user