diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 0504aa22e..5fd7fcf56 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -21,8 +21,10 @@ from collections import deque from pathlib import Path from typing import TYPE_CHECKING, Literal, TypedDict +import numpy as np import torch import torch.nn.functional as F # noqa: N812 +from scipy.fftpack import idct from torch import Tensor, nn from typing_extensions import Unpack @@ -2046,6 +2048,34 @@ class PI05Policy(PreTrainedPolicy): except Exception as e: logging.warning(f"Could not load tokenizer for subtask decoding: {e}") self.tokenizer = None + + # Load FAST tokenizer for action detokenization (only if fast_only mode) + self.action_tokenizer = None + self._paligemma_tokenizer = None + self._fast_skip_tokens = 128 + + if config.fast_only: + try: + from transformers import AutoProcessor, AutoTokenizer + + # Load FAST tokenizer + self.action_tokenizer = AutoProcessor.from_pretrained( + "jadechoghari/fast-libero-tokenizer-mean-std", + trust_remote_code=True + ) + + # Load PaliGemma tokenizer for token conversion + self._paligemma_tokenizer = AutoTokenizer.from_pretrained( + "google/paligemma-3b-pt-224", + trust_remote_code=True, + add_eos_token=True, + add_bos_token=False + ) + + logging.info("Loaded FAST tokenizer for action detokenization") + except Exception as e: + logging.warning(f"Could not load FAST tokenizer for action detokenization: {e}") + logging.warning("Action tokens will be returned without detokenization") self.reset() @@ -2323,6 +2353,148 @@ class PI05Policy(PreTrainedPolicy): """Pad action""" actions = pad_vector(batch[ACTION], self.config.max_action_dim) return actions + + def _paligemma_tokens_to_act_tokens(self, tokens: torch.Tensor) -> torch.Tensor: + """ + Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens). + + Args: + tokens: PaliGemma token IDs + + Returns: + Action token IDs + """ + return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens + + def decode_actions_with_fast( + self, + token_ids: list[int], + time_horizon: int, + action_dim: int, + relaxed_decoding: bool = True + ) -> np.ndarray: + """ + Decodes action token IDs back to continuous action values using the FAST tokenizer. + + Args: + token_ids: List of token IDs to decode. + time_horizon: The number of timesteps for actions. + action_dim: The dimensionality of each action. + relaxed_decoding: Whether to use relaxed decoding (allows partial sequences). + + Returns: + A numpy array representing the decoded actions. + """ + decoded_actions = [] + + for token in token_ids: + try: + decoded_tokens = self.action_tokenizer.bpe_tokenizer.decode(token) + decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.action_tokenizer.min_token + + if relaxed_decoding: + # expected sequence length + expected_seq_len = time_horizon * action_dim + diff = expected_seq_len - decoded_dct_coeff.shape[0] + + # apply truncation if too long + if diff < 0: + decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # truncate on the right + + # apply padding if too short + elif diff > 0: + decoded_dct_coeff = np.pad( + decoded_dct_coeff, (0, diff), mode="constant", constant_values=0 + ) + + decoded_dct_coeff = decoded_dct_coeff.reshape(-1, action_dim) + assert decoded_dct_coeff.shape == ( + time_horizon, + action_dim, + ), ( + f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({time_horizon}, {action_dim})" + ) + + except Exception as e: + logging.warning(f"Error decoding tokens: {e}") + logging.warning(f"Tokens: {token}") + decoded_dct_coeff = np.zeros((time_horizon, action_dim)) + + decoded_actions.append(idct(decoded_dct_coeff / self.action_tokenizer.scale, axis=0, norm="ortho")) + + return np.stack(decoded_actions) + + def detokenize_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor: + """ + Detokenizes action tokens back to continuous actions. + + This method converts predicted action tokens from the model back to continuous action values + using the FAST tokenizer. It handles the conversion from PaliGemma token space to action token + space, then decodes the action tokens to continuous values using DCT decoding. + + Args: + tokens: The input tensor of tokenized outputs. Shape: (B, seq_len) or (seq_len,) + action_horizon: The number of timesteps for actions. + action_dim: The dimensionality of each action. + + Returns: + The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim) + """ + if self.action_tokenizer is None or self._paligemma_tokenizer is None: + raise ValueError( + "Action tokenizer not initialized. Make sure fast_only=True in config and tokenizers loaded successfully." + ) + + # Handle single sample (add batch dimension) + single_sample = tokens.dim() == 1 + if single_sample: + tokens = tokens.unsqueeze(0) + + # Convert token IDs to token strings + decoded_tokens = [ + self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist()) + for seq in tokens + ] + + # Clean tokens by removing everything after the first "|" (end-of-action marker) + cleaned_tokens = [] + for token_seq in decoded_tokens: + if "|" in token_seq: + token_seq = token_seq[:token_seq.index("|")] + cleaned_tokens.append(token_seq) + + # Convert token strings back to IDs + raw_action_tokens = [ + torch.tensor( + self._paligemma_tokenizer.convert_tokens_to_ids(token_seq), + dtype=torch.long, + device=tokens.device, + ) + for token_seq in cleaned_tokens + ] + + # Convert PaliGemma tokens to action tokens + action_tokens = [ + self._paligemma_tokens_to_act_tokens(raw_action_token) + for raw_action_token in raw_action_tokens + ] + + # Decode action tokens to continuous actions + actions = self.decode_actions_with_fast( + action_tokens, + time_horizon=action_horizon, + action_dim=action_dim + ) + + # Convert to tensor and return + actions_tensor = torch.tensor(actions, dtype=torch.float32, device=tokens.device) + + # Remove batch dimension if input was single sample + if single_sample: + actions_tensor = actions_tensor.squeeze(0) + + breakpoint() + return actions_tensor @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: @@ -2364,9 +2536,18 @@ class PI05Policy(PreTrainedPolicy): max_decoding_steps=max_decoding_steps, temperature=temperature, ) - # Return the action tokens - these need to be decoded by the FAST tokenizer - # The caller is responsible for decoding tokens to continuous actions - return action_tokens + + # Detokenize action tokens to continuous actions + action_horizon = self.config.n_action_steps + action_dim = 7 + + continuous_actions = self.detokenize_actions( + action_tokens, + action_horizon=action_horizon, + action_dim=action_dim + ) + + return continuous_actions # Full mode: use flow matching with optional subtask generation # Use high_level_task tokens (WITHOUT subtask) for inference - we'll generate the subtask diff --git a/src/lerobot/policies/pi05/train_libero.sh b/src/lerobot/policies/pi05/train_libero.sh index dd74d22c1..9e2520dd9 100644 --- a/src/lerobot/policies/pi05/train_libero.sh +++ b/src/lerobot/policies/pi05/train_libero.sh @@ -2,7 +2,7 @@ export CUDA_LAUNCH_BLOCKING=1 lerobot-train \ --dataset.repo_id=local \ --dataset.root=/fsx/jade_choghari/data/libero \ - --output_dir=/fsx/jade_choghari/outputs/libero_training_fast_5 \ + --output_dir=/fsx/jade_choghari/outputs/libero_training_fast_4 \ --job_name=libero_training_fast \ --policy.repo_id=jade_choghari/pi05-fast-libero \ --policy.path=/fsx/jade_choghari/models/libero-pi-fast \