detoknize action at policy level

This commit is contained in:
Jade Choghari
2025-12-26 06:45:38 +00:00
parent e682ef05f9
commit 8edd544bbe
2 changed files with 185 additions and 4 deletions

View File

@@ -21,8 +21,10 @@ from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal, TypedDict
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from scipy.fftpack import idct
from torch import Tensor, nn
from typing_extensions import Unpack
@@ -2046,6 +2048,34 @@ class PI05Policy(PreTrainedPolicy):
except Exception as e:
logging.warning(f"Could not load tokenizer for subtask decoding: {e}")
self.tokenizer = None
# Load FAST tokenizer for action detokenization (only if fast_only mode)
self.action_tokenizer = None
self._paligemma_tokenizer = None
self._fast_skip_tokens = 128
if config.fast_only:
try:
from transformers import AutoProcessor, AutoTokenizer
# Load FAST tokenizer
self.action_tokenizer = AutoProcessor.from_pretrained(
"jadechoghari/fast-libero-tokenizer-mean-std",
trust_remote_code=True
)
# Load PaliGemma tokenizer for token conversion
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
"google/paligemma-3b-pt-224",
trust_remote_code=True,
add_eos_token=True,
add_bos_token=False
)
logging.info("Loaded FAST tokenizer for action detokenization")
except Exception as e:
logging.warning(f"Could not load FAST tokenizer for action detokenization: {e}")
logging.warning("Action tokens will be returned without detokenization")
self.reset()
@@ -2323,6 +2353,148 @@ class PI05Policy(PreTrainedPolicy):
"""Pad action"""
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
return actions
def _paligemma_tokens_to_act_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
"""
Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens).
Args:
tokens: PaliGemma token IDs
Returns:
Action token IDs
"""
return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens
def decode_actions_with_fast(
self,
token_ids: list[int],
time_horizon: int,
action_dim: int,
relaxed_decoding: bool = True
) -> np.ndarray:
"""
Decodes action token IDs back to continuous action values using the FAST tokenizer.
Args:
token_ids: List of token IDs to decode.
time_horizon: The number of timesteps for actions.
action_dim: The dimensionality of each action.
relaxed_decoding: Whether to use relaxed decoding (allows partial sequences).
Returns:
A numpy array representing the decoded actions.
"""
decoded_actions = []
for token in token_ids:
try:
decoded_tokens = self.action_tokenizer.bpe_tokenizer.decode(token)
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.action_tokenizer.min_token
if relaxed_decoding:
# expected sequence length
expected_seq_len = time_horizon * action_dim
diff = expected_seq_len - decoded_dct_coeff.shape[0]
# apply truncation if too long
if diff < 0:
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # truncate on the right
# apply padding if too short
elif diff > 0:
decoded_dct_coeff = np.pad(
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
)
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, action_dim)
assert decoded_dct_coeff.shape == (
time_horizon,
action_dim,
), (
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({time_horizon}, {action_dim})"
)
except Exception as e:
logging.warning(f"Error decoding tokens: {e}")
logging.warning(f"Tokens: {token}")
decoded_dct_coeff = np.zeros((time_horizon, action_dim))
decoded_actions.append(idct(decoded_dct_coeff / self.action_tokenizer.scale, axis=0, norm="ortho"))
return np.stack(decoded_actions)
def detokenize_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor:
"""
Detokenizes action tokens back to continuous actions.
This method converts predicted action tokens from the model back to continuous action values
using the FAST tokenizer. It handles the conversion from PaliGemma token space to action token
space, then decodes the action tokens to continuous values using DCT decoding.
Args:
tokens: The input tensor of tokenized outputs. Shape: (B, seq_len) or (seq_len,)
action_horizon: The number of timesteps for actions.
action_dim: The dimensionality of each action.
Returns:
The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim)
"""
if self.action_tokenizer is None or self._paligemma_tokenizer is None:
raise ValueError(
"Action tokenizer not initialized. Make sure fast_only=True in config and tokenizers loaded successfully."
)
# Handle single sample (add batch dimension)
single_sample = tokens.dim() == 1
if single_sample:
tokens = tokens.unsqueeze(0)
# Convert token IDs to token strings
decoded_tokens = [
self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
for seq in tokens
]
# Clean tokens by removing everything after the first "|" (end-of-action marker)
cleaned_tokens = []
for token_seq in decoded_tokens:
if "|" in token_seq:
token_seq = token_seq[:token_seq.index("|")]
cleaned_tokens.append(token_seq)
# Convert token strings back to IDs
raw_action_tokens = [
torch.tensor(
self._paligemma_tokenizer.convert_tokens_to_ids(token_seq),
dtype=torch.long,
device=tokens.device,
)
for token_seq in cleaned_tokens
]
# Convert PaliGemma tokens to action tokens
action_tokens = [
self._paligemma_tokens_to_act_tokens(raw_action_token)
for raw_action_token in raw_action_tokens
]
# Decode action tokens to continuous actions
actions = self.decode_actions_with_fast(
action_tokens,
time_horizon=action_horizon,
action_dim=action_dim
)
# Convert to tensor and return
actions_tensor = torch.tensor(actions, dtype=torch.float32, device=tokens.device)
# Remove batch dimension if input was single sample
if single_sample:
actions_tensor = actions_tensor.squeeze(0)
breakpoint()
return actions_tensor
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@@ -2364,9 +2536,18 @@ class PI05Policy(PreTrainedPolicy):
max_decoding_steps=max_decoding_steps,
temperature=temperature,
)
# Return the action tokens - these need to be decoded by the FAST tokenizer
# The caller is responsible for decoding tokens to continuous actions
return action_tokens
# Detokenize action tokens to continuous actions
action_horizon = self.config.n_action_steps
action_dim = 7
continuous_actions = self.detokenize_actions(
action_tokens,
action_horizon=action_horizon,
action_dim=action_dim
)
return continuous_actions
# Full mode: use flow matching with optional subtask generation
# Use high_level_task tokens (WITHOUT subtask) for inference - we'll generate the subtask

View File

@@ -2,7 +2,7 @@ export CUDA_LAUNCH_BLOCKING=1
lerobot-train \
--dataset.repo_id=local \
--dataset.root=/fsx/jade_choghari/data/libero \
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_5 \
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_4 \
--job_name=libero_training_fast \
--policy.repo_id=jade_choghari/pi05-fast-libero \
--policy.path=/fsx/jade_choghari/models/libero-pi-fast \