mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
detoknize action at policy level
This commit is contained in:
@@ -21,8 +21,10 @@ from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from scipy.fftpack import idct
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
@@ -2046,6 +2048,34 @@ class PI05Policy(PreTrainedPolicy):
|
||||
except Exception as e:
|
||||
logging.warning(f"Could not load tokenizer for subtask decoding: {e}")
|
||||
self.tokenizer = None
|
||||
|
||||
# Load FAST tokenizer for action detokenization (only if fast_only mode)
|
||||
self.action_tokenizer = None
|
||||
self._paligemma_tokenizer = None
|
||||
self._fast_skip_tokens = 128
|
||||
|
||||
if config.fast_only:
|
||||
try:
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
|
||||
# Load FAST tokenizer
|
||||
self.action_tokenizer = AutoProcessor.from_pretrained(
|
||||
"jadechoghari/fast-libero-tokenizer-mean-std",
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Load PaliGemma tokenizer for token conversion
|
||||
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/paligemma-3b-pt-224",
|
||||
trust_remote_code=True,
|
||||
add_eos_token=True,
|
||||
add_bos_token=False
|
||||
)
|
||||
|
||||
logging.info("Loaded FAST tokenizer for action detokenization")
|
||||
except Exception as e:
|
||||
logging.warning(f"Could not load FAST tokenizer for action detokenization: {e}")
|
||||
logging.warning("Action tokens will be returned without detokenization")
|
||||
|
||||
self.reset()
|
||||
|
||||
@@ -2323,6 +2353,148 @@ class PI05Policy(PreTrainedPolicy):
|
||||
"""Pad action"""
|
||||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||
return actions
|
||||
|
||||
def _paligemma_tokens_to_act_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens).
|
||||
|
||||
Args:
|
||||
tokens: PaliGemma token IDs
|
||||
|
||||
Returns:
|
||||
Action token IDs
|
||||
"""
|
||||
return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens
|
||||
|
||||
def decode_actions_with_fast(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
time_horizon: int,
|
||||
action_dim: int,
|
||||
relaxed_decoding: bool = True
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Decodes action token IDs back to continuous action values using the FAST tokenizer.
|
||||
|
||||
Args:
|
||||
token_ids: List of token IDs to decode.
|
||||
time_horizon: The number of timesteps for actions.
|
||||
action_dim: The dimensionality of each action.
|
||||
relaxed_decoding: Whether to use relaxed decoding (allows partial sequences).
|
||||
|
||||
Returns:
|
||||
A numpy array representing the decoded actions.
|
||||
"""
|
||||
decoded_actions = []
|
||||
|
||||
for token in token_ids:
|
||||
try:
|
||||
decoded_tokens = self.action_tokenizer.bpe_tokenizer.decode(token)
|
||||
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.action_tokenizer.min_token
|
||||
|
||||
if relaxed_decoding:
|
||||
# expected sequence length
|
||||
expected_seq_len = time_horizon * action_dim
|
||||
diff = expected_seq_len - decoded_dct_coeff.shape[0]
|
||||
|
||||
# apply truncation if too long
|
||||
if diff < 0:
|
||||
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # truncate on the right
|
||||
|
||||
# apply padding if too short
|
||||
elif diff > 0:
|
||||
decoded_dct_coeff = np.pad(
|
||||
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
|
||||
)
|
||||
|
||||
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, action_dim)
|
||||
assert decoded_dct_coeff.shape == (
|
||||
time_horizon,
|
||||
action_dim,
|
||||
), (
|
||||
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({time_horizon}, {action_dim})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Error decoding tokens: {e}")
|
||||
logging.warning(f"Tokens: {token}")
|
||||
decoded_dct_coeff = np.zeros((time_horizon, action_dim))
|
||||
|
||||
decoded_actions.append(idct(decoded_dct_coeff / self.action_tokenizer.scale, axis=0, norm="ortho"))
|
||||
|
||||
return np.stack(decoded_actions)
|
||||
|
||||
def detokenize_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor:
|
||||
"""
|
||||
Detokenizes action tokens back to continuous actions.
|
||||
|
||||
This method converts predicted action tokens from the model back to continuous action values
|
||||
using the FAST tokenizer. It handles the conversion from PaliGemma token space to action token
|
||||
space, then decodes the action tokens to continuous values using DCT decoding.
|
||||
|
||||
Args:
|
||||
tokens: The input tensor of tokenized outputs. Shape: (B, seq_len) or (seq_len,)
|
||||
action_horizon: The number of timesteps for actions.
|
||||
action_dim: The dimensionality of each action.
|
||||
|
||||
Returns:
|
||||
The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim)
|
||||
"""
|
||||
if self.action_tokenizer is None or self._paligemma_tokenizer is None:
|
||||
raise ValueError(
|
||||
"Action tokenizer not initialized. Make sure fast_only=True in config and tokenizers loaded successfully."
|
||||
)
|
||||
|
||||
# Handle single sample (add batch dimension)
|
||||
single_sample = tokens.dim() == 1
|
||||
if single_sample:
|
||||
tokens = tokens.unsqueeze(0)
|
||||
|
||||
# Convert token IDs to token strings
|
||||
decoded_tokens = [
|
||||
self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
|
||||
for seq in tokens
|
||||
]
|
||||
|
||||
# Clean tokens by removing everything after the first "|" (end-of-action marker)
|
||||
cleaned_tokens = []
|
||||
for token_seq in decoded_tokens:
|
||||
if "|" in token_seq:
|
||||
token_seq = token_seq[:token_seq.index("|")]
|
||||
cleaned_tokens.append(token_seq)
|
||||
|
||||
# Convert token strings back to IDs
|
||||
raw_action_tokens = [
|
||||
torch.tensor(
|
||||
self._paligemma_tokenizer.convert_tokens_to_ids(token_seq),
|
||||
dtype=torch.long,
|
||||
device=tokens.device,
|
||||
)
|
||||
for token_seq in cleaned_tokens
|
||||
]
|
||||
|
||||
# Convert PaliGemma tokens to action tokens
|
||||
action_tokens = [
|
||||
self._paligemma_tokens_to_act_tokens(raw_action_token)
|
||||
for raw_action_token in raw_action_tokens
|
||||
]
|
||||
|
||||
# Decode action tokens to continuous actions
|
||||
actions = self.decode_actions_with_fast(
|
||||
action_tokens,
|
||||
time_horizon=action_horizon,
|
||||
action_dim=action_dim
|
||||
)
|
||||
|
||||
# Convert to tensor and return
|
||||
actions_tensor = torch.tensor(actions, dtype=torch.float32, device=tokens.device)
|
||||
|
||||
# Remove batch dimension if input was single sample
|
||||
if single_sample:
|
||||
actions_tensor = actions_tensor.squeeze(0)
|
||||
|
||||
breakpoint()
|
||||
return actions_tensor
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -2364,9 +2536,18 @@ class PI05Policy(PreTrainedPolicy):
|
||||
max_decoding_steps=max_decoding_steps,
|
||||
temperature=temperature,
|
||||
)
|
||||
# Return the action tokens - these need to be decoded by the FAST tokenizer
|
||||
# The caller is responsible for decoding tokens to continuous actions
|
||||
return action_tokens
|
||||
|
||||
# Detokenize action tokens to continuous actions
|
||||
action_horizon = self.config.n_action_steps
|
||||
action_dim = 7
|
||||
|
||||
continuous_actions = self.detokenize_actions(
|
||||
action_tokens,
|
||||
action_horizon=action_horizon,
|
||||
action_dim=action_dim
|
||||
)
|
||||
|
||||
return continuous_actions
|
||||
|
||||
# Full mode: use flow matching with optional subtask generation
|
||||
# Use high_level_task tokens (WITHOUT subtask) for inference - we'll generate the subtask
|
||||
|
||||
@@ -2,7 +2,7 @@ export CUDA_LAUNCH_BLOCKING=1
|
||||
lerobot-train \
|
||||
--dataset.repo_id=local \
|
||||
--dataset.root=/fsx/jade_choghari/data/libero \
|
||||
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_5 \
|
||||
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_4 \
|
||||
--job_name=libero_training_fast \
|
||||
--policy.repo_id=jade_choghari/pi05-fast-libero \
|
||||
--policy.path=/fsx/jade_choghari/models/libero-pi-fast \
|
||||
|
||||
Reference in New Issue
Block a user