make fast work

This commit is contained in:
Jade Choghari
2025-12-25 20:59:32 +00:00
parent 9b5ac4387c
commit e682ef05f9
13 changed files with 4233 additions and 4326 deletions

View File

@@ -2,25 +2,26 @@
"repo_id": "local",
"vocab_size": 1024,
"scale": 10.0,
"encoded_dims": "0:15",
"encoded_dims": "0:7",
"encoded_dim_ranges": [
[
0,
15
7
]
],
"total_encoded_dims": 15,
"total_encoded_dims": 7,
"delta_dims": null,
"delta_dim_list": null,
"use_delta_transform": false,
"state_key": "observation.state",
"action_horizon": 50,
"num_training_chunks": 4900,
"normalization_mode": "MEAN_STD",
"action_horizon": 10,
"num_training_chunks": 25065,
"compression_stats": {
"compression_ratio": 15.85791309863622,
"mean_token_length": 47.295,
"p99_token_length": 90.0,
"compression_ratio": 2.8901734104046244,
"mean_token_length": 24.22,
"p99_token_length": 40.0,
"min_token_length": 9.0,
"max_token_length": 109.0
"max_token_length": 46.0
}
}

View File

@@ -1,11 +1,11 @@
{
"action_dim": 15,
"action_dim": 7,
"auto_map": {
"AutoProcessor": "processing_action_tokenizer.UniversalActionProcessor"
},
"min_token": -71,
"min_token": -203,
"processor_class": "UniversalActionProcessor",
"scale": 10.0,
"time_horizon": 50,
"time_horizon": 10,
"vocab_size": 1024
}

File diff suppressed because it is too large Load Diff

View File

@@ -537,17 +537,17 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
# FAST action token embedding and prediction head
self.fast_action_embedding = nn.Embedding(config.fast_vocab_size, paligemma_config.width)
self.fast_action_lm_head = nn.Linear(paligemma_config.width, config.fast_vocab_size)
# # FAST action token embedding and prediction head
# self.fast_action_embedding = nn.Embedding(config.fast_vocab_size, paligemma_config.width)
# self.fast_action_lm_head = nn.Linear(paligemma_config.width, config.fast_vocab_size)
# Apply dtype conversion to FAST layers to match model precision
if config.dtype == "bfloat16":
self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16)
self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.bfloat16)
elif config.dtype == "float32":
self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.float32)
self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.float32)
# # Apply dtype conversion to FAST layers to match model precision
# if config.dtype == "bfloat16":
# self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16)
# self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.bfloat16)
# elif config.dtype == "float32":
# self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.float32)
# self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.float32)
# Initialize gradient checkpointing flag
self.gradient_checkpointing_enabled = False
@@ -1280,7 +1280,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Process FAST action tokens (discrete token IDs)
if fast_action_tokens is not None:
def fast_action_embed_func(fast_action_tokens):
fast_emb = self.fast_action_embedding(fast_action_tokens)
fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens)
fast_emb_dim = fast_emb.shape[-1]
return fast_emb * math.sqrt(fast_emb_dim)
@@ -1411,7 +1411,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Get logits for FAST action tokens using the FAST LM head
# We only compute logits for the positions that predict FAST tokens
fast_logits = self.fast_action_lm_head(prefix_out) # (B, T-1, fast_vocab_size)
lm_head = self.paligemma_with_expert.paligemma.lm_head
# The FAST tokens start at position (total_T_images + num_lang_tokens)
# For next-token prediction:
@@ -1422,7 +1422,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Extract logits for FAST token prediction
# Input positions [fast_start-1 : fast_start-1+num_fast_embs] predict FAST tokens
fast_logits_for_pred = fast_logits[:, fast_start-1:fast_start-1+num_fast_embs, :] # (B, num_fast_embs, fast_vocab_size)
fast_hidden = prefix_out[:, fast_start-1:fast_start-1+num_fast_embs, :] # (B, num_fast_embs, hidden_dim)
fast_logits_for_pred = lm_head(fast_hidden) # (B, num_fast_embs, gemma_vocab_size)
# Targets are the FAST action tokens
fast_targets = fast_action_tokens # (B, num_fast_embs)
@@ -1438,7 +1439,140 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Apply mask and compute mean loss
masked_fast_loss = fast_loss_per_token * fast_action_masks.float()
fast_loss = masked_fast_loss.sum() / fast_action_masks.sum().clamp(min=1)
breakpoint()
# breakpoint()
# from transformers import AutoTokenizer, AutoProcessor
# _paligemma_tokenizer = AutoTokenizer.from_pretrained(
# "google/paligemma-3b-pt-224",
# trust_remote_code=True,
# add_eos_token=True,
# add_bos_token=False
# )
# # 257152
# # # Decode predicted output tokens
# # # fast_logits_for_pred.argmax(dim=-1)
# def _paligemma_tokens_to_act_tokens(tokens: torch.Tensor) -> torch.Tensor:
# """
# Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens).
# """
# return _paligemma_tokenizer.vocab_size - 1 - 128 - tokens
# # # target = _paligemma_tokens_to_act_tokens(fast_targets)
# decoded_tokens = _paligemma_tokenizer.batch_decode(fast_targets, skip_special_tokens=False)
# decoded_tokens = [
# _paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
# for seq in fast_logits_for_pred.argmax(dim=-1)
# ]
# cleaned_tokens = []
# for token_seq in decoded_tokens:
# if "|" in token_seq:
# token_seq = token_seq[:token_seq.index("|")]
# cleaned_tokens.append(token_seq)
# raw_action_tokens = [
# torch.tensor(
# _paligemma_tokenizer.convert_tokens_to_ids(token_seq),
# dtype=torch.long,
# device=fast_targets.device,
# )
# for token_seq in cleaned_tokens
# ]
# action_tokens = [
# _paligemma_tokens_to_act_tokens(raw_action_token)
# for raw_action_token in raw_action_tokens
# ]
# breakpoint()
# # Clean the decoded tokens by removing "Action:" prefix and extracting the relevant part
# cleaned_tokens = [
# tokens_sequence.strip().split("|")[0].strip()
# for tokens_sequence in decoded_tokens
# ]
# # Re-encode the cleaned text to get raw action tokens
# raw_action_tokens = [
# _paligemma_tokenizer.encode(sample_tokens, return_tensors="pt", padding=False).squeeze(0)
# for sample_tokens in cleaned_tokens
# ]
# # Convert PaliGemma tokens back to action tokens
# action_tokens = [
# _paligemma_tokens_to_act_tokens(raw_action_token)
# for raw_action_token in raw_action_tokens
# ]
# # # Decode each sample's tokens to continuous actions
# action_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
# # breakpoint()
# decoded_actions = action_tokenizer.decode(
# action_tokens,
# time_horizon=self.config.chunk_size,
# action_dim=6
# )
# breakpoint()
# def decode_actions_with_fast(
# token_ids: list[int],
# time_horizon: int,
# action_dim: int,
# relaxed_decoding: bool = False
# ) -> list:
# """
# Decodes action token IDs back to continuous action values using the FAST tokenizer.
# Args:
# token_ids: List of token IDs to decode.
# time_horizon: The number of timesteps for actions.
# action_dim: The dimensionality of each action.
# relaxed_decoding: Whether to use relaxed decoding (allows partial sequences).
# Returns:
# A list representing the decoded actions.
# """
# # Use the action tokenizer's decode method
# # The FAST tokenizer should have a decode method that converts tokens back to actions
# try:
# decoded_actions = action_tokenizer.decode(
# token_ids,
# time_horizon=time_horizon,
# action_dim=action_dim
# )
# return decoded_actions
# except Exception as e:
# if relaxed_decoding:
# # If relaxed decoding is enabled, try to decode as much as possible
# import logging
# logging.warning(f"Relaxed decoding: {e}. Returning partial decode.")
# try:
# # Try to decode with whatever tokens we have
# partial_decoded = action_tokenizer.decode(
# token_ids[:len(token_ids)],
# time_horizon=time_horizon,
# action_dim=action_dim
# )
# return partial_decoded
# except:
# # Return zeros if decoding completely fails
# return [[0.0] * action_dim for _ in range(time_horizon)]
# else:
# raise e
# valid = fast_logits_for_pred.argmax(dim=-1) <= (self._paligemma_tokenizer.vocab_size - 1 - 128)
# fast_region = fast_logits_for_pred.argmax(dim=-1).masked_fill(~valid, 0)
# fast_tokens = _paligemma_tokens_to_act_tokens(fast_region)
# actions = decode_actions_with_fast(fast_tokens.tolist(), time_horizon=self.config.chunk_size, action_dim=7, relaxed_decoding=True)[0]
# breakpoint()
# decoded_actions = [
# torch.tensor(
# decode_actions_with_fast(
# tok[0].tolist(),
# time_horizon=self.config.chunk_size,
# action_dim=7,
# relaxed_decoding=True,
# ),
# device=tokens.device,
# ).squeeze(0)
# for tok in action_tokens
# ]
# breakpoint()
# # Stack into a batch
# result = torch.stack(decoded_actions, dim=0)
# breakpoint()
return {
"fast_loss": fast_loss,
"loss": fast_loss,
@@ -1453,50 +1587,126 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
masks,
max_decoding_steps=None,
temperature=0.0,
) -> Tensor:
"""Sample actions using autoregressive decoding for FAST-only mode.
This implements the Pi0FAST inference: autoregressively decode action tokens.
Args:
images: List of image tensors
img_masks: List of image masks
tokens: Language instruction tokens
masks: Attention masks for tokens
max_decoding_steps: Maximum number of tokens to decode
temperature: Sampling temperature (0 = greedy)
Returns:
Decoded action tokens [B, max_decoding_steps]
) -> torch.Tensor:
"""
Inefficient but safe autoregressive decoding for FAST tokens.
Matches the pattern of _generate_subtask_tokens.
"""
if max_decoding_steps is None:
max_decoding_steps = self.config.max_action_tokens
bsize = tokens.shape[0]
device = tokens.device
lm_head = self.paligemma_with_expert.paligemma.lm_head
# Embed prefix (images + language) without FAST tokens
# 1. Initial Embedding (Matches Training Prefix)
# prefix_embs will include [Images, Language Prompt]
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _ = self.embed_prefix_fast(
images, img_masks, tokens, masks,
fast_action_tokens=None,
fast_action_masks=None
)
# Convert to bfloat16 if needed
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
if self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
# Initial forward pass to get KV cache
position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
att_2d_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
generated_action_tokens = torch.zeros((bsize, max_decoding_steps), dtype=torch.long, device=device)
# 2. Decoding Loop (Re-computes full sequence every step)
for t in range(max_decoding_steps):
# Always re-calculate position IDs from the current pad mask (matches training)
position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
att_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
# Full forward pass (No KV Cache)
(prefix_out, _), _ = self.paligemma_with_expert.forward(
attention_mask=att_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=False,
adarms_cond=[None, None],
)
# Predict next token from the very last sequence position
last_logits = lm_head(prefix_out[:, -1:, :]) # (B, 1, vocab_size)
if temperature > 0:
probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(last_logits[:, -1], dim=-1, keepdim=True)
generated_action_tokens[:, t] = next_token.squeeze(-1)
# 3. Update Sequence for next iteration (unless it's the last step)
if t < max_decoding_steps - 1:
# Embed the newly generated token
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
if prefix_embs.dtype == torch.bfloat16:
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
# Append to embeddings
prefix_embs = torch.cat([prefix_embs, next_token_emb], dim=1)
# Update padding mask (New token is always valid/1)
prefix_pad_masks = torch.cat([
prefix_pad_masks,
torch.ones((bsize, 1), dtype=torch.bool, device=device)
], dim=1)
# Update 2D attention mask: Grow the matrix
old_len = prefix_att_masks.shape[1]
new_len = old_len + 1
new_att_masks = torch.zeros((bsize, new_len, new_len), dtype=torch.bool, device=device)
new_att_masks[:, :old_len, :old_len] = prefix_att_masks
# New token attends to all non-padding tokens in the updated sequence
new_att_masks[:, -1, :] = prefix_pad_masks
prefix_att_masks = new_att_masks
return generated_action_tokens
@torch.no_grad()
def sample_actions_fast_kv_cache(
self,
images,
img_masks,
tokens,
masks,
max_decoding_steps=None,
temperature=0.0,
) -> torch.Tensor:
"""
Efficient autoregressive decoding for FAST tokens using KV-caching.
Only computes the prefix once, then incrementally generates tokens.
"""
if max_decoding_steps is None:
max_decoding_steps = self.config.max_action_tokens
bsize = tokens.shape[0]
device = tokens.device
lm_head = self.paligemma_with_expert.paligemma.lm_head
# 1. Initial Embedding (Matches Training Prefix)
# prefix_embs will include [Images, Language Prompt]
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _ = self.embed_prefix_fast(
images, img_masks, tokens, masks,
fast_action_tokens=None,
fast_action_masks=None
)
if self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
generated_action_tokens = torch.zeros((bsize, max_decoding_steps), dtype=torch.long, device=device)
# 2. Initial forward pass to populate KV cache
position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
att_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
# First forward pass with full prefix (caching enabled)
(prefix_out, _), past_key_values = self.paligemma_with_expert.forward(
attention_mask=att_2d_4d,
attention_mask=att_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
@@ -1504,62 +1714,71 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
adarms_cond=[None, None],
)
# Get initial logits from last position
last_hidden = prefix_out[:, -1:]
logits = self.fast_action_lm_head(last_hidden) # (B, 1, fast_vocab_size)
# Predict first token from the last sequence position
last_logits = lm_head(prefix_out[:, -1:, :]) # (B, 1, vocab_size)
# Autoregressive decoding
output_tokens = torch.zeros((bsize, max_decoding_steps), dtype=torch.long, device=device)
prefix_len = prefix_pad_masks.shape[1]
if temperature > 0:
probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(last_logits[:, -1], dim=-1, keepdim=True)
generated_action_tokens[:, 0] = next_token.squeeze(-1)
# Track current sequence length for position IDs and maintain the padding mask
current_seq_len = prefix_embs.shape[1]
# Keep track of valid positions: prefix_pad_masks tells us which positions are valid
current_pad_mask = prefix_pad_masks.clone() # (B, seq_len)
# 3. Incremental Decoding Loop (using KV cache)
for t in range(1, max_decoding_steps):
# Embed the newly generated token
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
if prefix_embs.dtype == torch.bfloat16:
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
# Update padding mask: new generated token is always valid
current_pad_mask = torch.cat([
current_pad_mask,
torch.ones((bsize, 1), dtype=torch.bool, device=device)
], dim=1) # (B, seq_len+1)
# Position ID for the new token (continues from where we left off)
new_position_id = torch.full((bsize, 1), current_seq_len, dtype=torch.long, device=device)
# For KV-cache: attention mask for the new token should only attend to valid positions
# Shape: (B, 1, past_len+1) where the new token attends to valid prefix + all generated tokens
new_att_mask_2d = current_pad_mask.unsqueeze(1) # (B, 1, seq_len+1)
att_4d_incremental = self._prepare_attention_masks_4d(new_att_mask_2d, dtype=next_token_emb.dtype)
# Forward pass with only the new token embedding (reusing cached KVs)
(new_out, _), past_key_values = self.paligemma_with_expert.forward(
attention_mask=att_4d_incremental,
position_ids=new_position_id,
past_key_values=past_key_values,
inputs_embeds=[next_token_emb, None],
use_cache=True,
adarms_cond=[None, None],
)
# Predict next token
last_logits = lm_head(new_out[:, -1:, :]) # (B, 1, vocab_size)
for step in range(max_decoding_steps):
# Sample next token
if temperature > 0:
probs = F.softmax(logits[:, -1] / temperature, dim=-1)
probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)
output_tokens[:, step] = next_token.squeeze(-1)
next_token = torch.argmax(last_logits[:, -1], dim=-1, keepdim=True)
# Check for EOS token (token ID 1 in many tokenizers)
# You may want to adjust this based on your FAST tokenizer
# For now, we decode all max_decoding_steps tokens
generated_action_tokens[:, t] = next_token.squeeze(-1)
if step < max_decoding_steps - 1:
# Embed the new token
def next_token_embed_func(next_token):
next_emb = self.fast_action_embedding(next_token)
return next_emb * math.sqrt(next_emb.shape[-1])
next_emb = next_token_embed_func(next_token)
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
next_emb = next_emb.to(dtype=torch.bfloat16)
# Update sequence length
current_seq_len += 1
# Update position ids
new_position_ids = torch.full((bsize, 1), prefix_len + step, dtype=torch.long, device=device)
# Create attention mask for the new token (attends to all previous)
new_att_mask = torch.ones(bsize, 1, prefix_len + step + 1, dtype=torch.bool, device=device)
new_att_4d = self._prepare_attention_masks_4d(new_att_mask, dtype=next_emb.dtype)
# Forward pass with KV cache
(next_out, _), past_key_values = self.paligemma_with_expert.forward(
attention_mask=new_att_4d,
position_ids=new_position_ids,
past_key_values=past_key_values,
inputs_embeds=[next_emb, None],
use_cache=True,
adarms_cond=[None, None],
)
logits = self.fast_action_lm_head(next_out)
return output_tokens
return generated_action_tokens
@torch.no_grad()
def _generate_subtask_tokens(
@@ -1903,7 +2122,6 @@ class PI05Policy(PreTrainedPolicy):
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
# Then add "model." prefix for all keys that don't already have it
remapped_state_dict = {}
remap_count = 0
@@ -2002,6 +2220,9 @@ class PI05Policy(PreTrainedPolicy):
# Some checkpoints might have this, but current model expects different structure
logging.warning(f"Vision embedding key might need handling: {key}")
if key == "model.paligemma_with_expert.paligemma.lm_head.weight":
fixed_state_dict["model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"] = value.clone()
fixed_state_dict[new_key] = value
return fixed_state_dict
@@ -2135,15 +2356,14 @@ class PI05Policy(PreTrainedPolicy):
# Get optional parameters
temperature = kwargs.get("temperature", 0.0)
max_decoding_steps = kwargs.get("max_decoding_steps", self.config.max_action_tokens)
max_decoding_steps = 256
# Sample action tokens autoregressively
action_tokens = self.model.sample_actions_fast(
action_tokens = self.model.sample_actions_fast_kv_cache(
images, img_masks, tokens, masks,
max_decoding_steps=max_decoding_steps,
temperature=temperature,
)
# Return the action tokens - these need to be decoded by the FAST tokenizer
# The caller is responsible for decoding tokens to continuous actions
return action_tokens

View File

@@ -1,18 +1,13 @@
rm -rf /fsx/jade_choghari/outputs/pi0_multi_training
lerobot-train \
--dataset.repo_id=local\
--dataset.root=/fsx/jade_choghari/outputs/collect-data-pgen \
--dataset.root=/fsx/jade_choghari/data/libero \
--output_dir=/fsx/jade_choghari/outputs/pi0_multi_training \
--job_name=pi0_multi_training \
--policy.repo_id=jadechoghari/pi0-base1 \
--policy.path=lerobot/pi05_base \
--policy.path=/fsx/jade_choghari/outputs/libero_training_fast_4/checkpoints/last/pretrained_model/ \
--policy.dtype=bfloat16 \
--steps=50000 \
--save_freq=5000 \
--rename_map='{
"observation.images.base": "observation.images.base_0_rgb",
"observation.images.left_wrist": "observation.images.left_wrist_0_rgb",
"observation.images.right_wrist": "observation.images.right_wrist_0_rgb",
}' \
--batch_size=32 \
--batch_size=4 \
--policy.device=cuda \

View File

@@ -1,6 +1,12 @@
python src/lerobot/policies/pi05/train_fast_tokenizer.py \
--repo_id "lerobot/libero" \
--action_horizon 50 \
--encoded_dims "0:6" \
--repo_id "local" \
--root /fsx/jade_choghari/data/libero \
--action_horizon 10 \
--encoded_dims "0:7" \
--vocab_size 1024 \
--output_dir "/fsx/jade_choghari/outputs/fast_tokenizer"
--push_to_hub \
--hub_repo_id jadechoghari/fast-libero-tokenizer-mean-std \
--normalization_mode MEAN_STD \
# python train_fast_tokenizer.py --repo_id my_dataset

View File

@@ -15,6 +15,8 @@ from pathlib import Path
from transformers import AutoProcessor
import torch
from huggingface_hub import HfApi
from lerobot.configs.types import NormalizationMode
from lerobot.datasets.lerobot_dataset import LeRobotDataset
@@ -39,6 +41,64 @@ def apply_delta_transform(state: np.ndarray, actions: np.ndarray, delta_dims: li
return delta_actions
def apply_normalization(
data: np.ndarray,
stats: dict[str, np.ndarray],
mode: NormalizationMode,
eps: float = 1e-8,
) -> np.ndarray:
"""Apply normalization to data based on the specified mode.
Args:
data: Data to normalize [N, H, D] or [D]
stats: Dictionary of statistics (mean, std, min, max, q01, q99, q10, q90)
mode: Normalization mode to apply
eps: Small epsilon for numerical stability
Returns:
Normalized data with the same shape as input
"""
if mode == NormalizationMode.IDENTITY:
return data
if mode == NormalizationMode.MEAN_STD:
mean = stats.get("mean")
std = stats.get("std")
if mean is None or std is None:
raise ValueError("MEAN_STD mode requires 'mean' and 'std' in stats")
return (data - mean) / np.maximum(std, eps)
if mode == NormalizationMode.MIN_MAX:
min_val = stats.get("min")
max_val = stats.get("max")
if min_val is None or max_val is None:
raise ValueError("MIN_MAX mode requires 'min' and 'max' in stats")
denom = np.maximum(max_val - min_val, eps)
return 2.0 * (data - min_val) / denom - 1.0
if mode == NormalizationMode.QUANTILES:
q01 = stats.get("q01")
q99 = stats.get("q99")
if q01 is None or q99 is None:
raise ValueError("QUANTILES mode requires 'q01' and 'q99' in stats")
denom = np.maximum(q99 - q01, eps)
# Clip to quantile range then normalize to [-1, 1]
clipped = np.clip(data, q01, q99)
return 2.0 * (clipped - q01) / denom - 1.0
if mode == NormalizationMode.QUANTILE10:
q10 = stats.get("q10")
q90 = stats.get("q90")
if q10 is None or q90 is None:
raise ValueError("QUANTILE10 mode requires 'q10' and 'q90' in stats")
denom = np.maximum(q90 - q10, eps)
# Clip to quantile range then normalize to [-1, 1]
clipped = np.clip(data, q10, q90)
return 2.0 * (clipped - q10) / denom - 1.0
raise ValueError(f"Unsupported normalization mode: {mode}")
def process_episode(args):
"""Process single episode and return action chunks."""
dataset, ep_idx, action_horizon, delta_dims, sample_fraction, state_key, use_delta_transform = args
@@ -237,9 +297,13 @@ def main(
delta_dims: str | None = None,
use_delta_transform: bool = False,
state_key: str = "observation.state",
normalization_mode: str = "QUANTILES",
vocab_size: int = 1024,
scale: float = 10.0,
output_dir: str | None = None,
push_to_hub: bool = False,
hub_repo_id: str | None = None,
hub_private: bool = False,
):
"""
Train FAST tokenizer for action encoding.
@@ -254,15 +318,29 @@ def main(
delta_dims: Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5")
use_delta_transform: Whether to apply delta transform (relative actions vs absolute actions)
state_key: Dataset key for state observations (default: "observation.state")
normalization_mode: Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY)
vocab_size: FAST vocabulary size (BPE vocab size)
scale: DCT scaling factor (default: 10.0)
output_dir: Directory to save tokenizer (default: ./fast_tokenizer_{repo_id})
push_to_hub: Whether to push the tokenizer to Hugging Face Hub
hub_repo_id: Hub repository ID (e.g., "username/tokenizer-name"). If None, uses output_dir name
hub_private: Whether to create a private repository on the Hub
"""
# Load dataset
print(f"Loading dataset: {repo_id}")
dataset = LeRobotDataset(repo_id=repo_id, root=root)
print(f"Dataset loaded: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
# Parse normalization mode
try:
norm_mode = NormalizationMode(normalization_mode)
except ValueError:
raise ValueError(
f"Invalid normalization_mode: {normalization_mode}. "
f"Must be one of: {', '.join([m.value for m in NormalizationMode])}"
)
print(f"Normalization mode: {norm_mode.value}")
# Parse encoded dimensions
encoded_dim_ranges = []
for range_str in encoded_dims.split(','):
@@ -317,13 +395,12 @@ def main(
encoded_chunks = np.concatenate(encoded_chunks, axis=-1) # [N, H, D_encoded]
print(f"Extracted {encoded_chunks.shape[-1]} encoded dimensions")
# Apply normalization to encoded dimensions only
# NOTE: For FAST, we ALWAYS use QUANTILE normalization (no per-timestamp)
# This clips outliers and provides consistent [-1, 1] range for DCT compression
# Apply normalization to encoded dimensions
print(f"\nBefore normalization - overall stats:")
print(f" Min: {np.min(encoded_chunks):.4f}, Max: {np.max(encoded_chunks):.4f}")
print(f" Mean: {np.mean(encoded_chunks):.4f}, Std: {np.std(encoded_chunks):.4f}")
# Get normalization stats from dataset
norm_stats = dataset.meta.stats
if norm_stats is not None and "action" in norm_stats:
action_stats = norm_stats["action"]
@@ -334,19 +411,31 @@ def main(
encoded_dim_indices.extend(range(start, end))
encoded_dim_indices = np.array(encoded_dim_indices)
# Use QUANTILE normalization: clip to [q01, q99] and map to [-1, 1]
if "q01" in action_stats and "q99" in action_stats:
q01 = np.array(action_stats["q01"])[encoded_dim_indices] # [D_encoded]
q99 = np.array(action_stats["q99"])[encoded_dim_indices] # [D_encoded]
# Extract stats for encoded dimensions only
encoded_stats = {}
for stat_name, stat_values in action_stats.items():
if isinstance(stat_values, (list, np.ndarray)):
stat_array = np.array(stat_values)
if len(stat_array) > max(encoded_dim_indices):
encoded_stats[stat_name] = stat_array[encoded_dim_indices]
if encoded_stats:
print(f"\nNormalization stats for encoded dimensions (mode: {norm_mode.value}):")
for stat_name, stat_values in encoded_stats.items():
print(f" {stat_name}: shape={stat_values.shape}, "
f"range=[{np.min(stat_values):.4f}, {np.max(stat_values):.4f}]")
print(f"\nNormalization stats (q01, q99) for encoded dimensions:")
for i, dim_idx in enumerate(encoded_dim_indices):
print(f" Orig dim {dim_idx}: q01={q01[i]:7.4f}, q99={q99[i]:7.4f}, range={q99[i]-q01[i]:7.4f}")
# Clip to quantile range and normalize to [-1, 1]
encoded_chunks = np.clip(encoded_chunks, q01, q99)
encoded_chunks = 2.0 * (encoded_chunks - q01) / np.maximum(q99 - q01, 1e-6) - 1.0
print(f"\nApplied quantile normalization [q01, q99] → [-1, 1]")
# Apply normalization based on mode
try:
encoded_chunks = apply_normalization(
encoded_chunks,
encoded_stats,
norm_mode,
eps=1e-8
)
print(f"\nApplied {norm_mode.value} normalization")
except ValueError as e:
print(f"Warning: {e}. Using raw actions without normalization.")
print(f"\nAfter normalization - overall stats:")
print(f" Min: {np.min(encoded_chunks):.4f}, Max: {np.max(encoded_chunks):.4f}")
@@ -358,9 +447,9 @@ def main(
print(f" Dim {d}: min={np.min(dim_data):7.4f}, max={np.max(dim_data):7.4f}, "
f"mean={np.mean(dim_data):7.4f}, std={np.std(dim_data):7.4f}")
else:
print("Warning: q01/q99 stats not found, using raw actions")
print("Warning: Could not extract stats for encoded dimensions, using raw actions")
else:
print("Warning: No normalization stats found, using raw actions")
print("Warning: No normalization stats found in dataset, using raw actions")
print(f"Encoded chunks shape: {encoded_chunks.shape}")
@@ -394,6 +483,7 @@ def main(
'delta_dim_list': delta_dim_list,
'use_delta_transform': use_delta_transform,
'state_key': state_key,
'normalization_mode': norm_mode.value,
'action_horizon': action_horizon,
'num_training_chunks': len(encoded_chunks),
'compression_stats': compression_stats,
@@ -402,8 +492,41 @@ def main(
with open(output_path / "metadata.json", 'w') as f:
json.dump(metadata, f, indent=2)
print(f"\nSaved FAST tokenizer to {output_path}")
print(f"\nSaved FAST tokenizer to {output_path}")
print(f"Metadata: {json.dumps(metadata, indent=2)}")
# Push to Hugging Face Hub if requested
if push_to_hub:
# Determine the hub repository ID
if hub_repo_id is None:
hub_repo_id = output_path.name
print(f"\nNo hub_repo_id provided, using: {hub_repo_id}")
print(f"\nPushing tokenizer to Hugging Face Hub: {hub_repo_id}")
print(f" Private: {hub_private}")
try:
# Use the tokenizer's push_to_hub method
tokenizer.push_to_hub(
repo_id=hub_repo_id,
private=hub_private,
commit_message=f"Upload FAST tokenizer trained on {repo_id}"
)
# Also upload the metadata.json file separately
api = HfApi()
api.upload_file(
path_or_fileobj=str(output_path / "metadata.json"),
path_in_repo="metadata.json",
repo_id=hub_repo_id,
repo_type="model",
commit_message="Upload tokenizer metadata"
)
print(f"Successfully pushed tokenizer to: https://huggingface.co/{hub_repo_id}")
except Exception as e:
print(f"Error pushing to hub: {e}")
print(" Make sure you're logged in with `huggingface-cli login`")
if __name__ == "__main__":

View File

@@ -0,0 +1,28 @@
#!/bin/bash
# FSDP training script for PI05 with aggressive memory optimization
# Use this for large models that OOM with standard DDP
accelerate launch --config_file /admin/home/jade_choghari/lerobot/fsdp_config.yaml \
$(which lerobot-train) \
--dataset.repo_id=local \
--dataset.root=/fsx/jade_choghari/data/libero \
--output_dir=/fsx/jade_choghari/outputs/libero_training_fsdp \
--job_name=libero_training_fsdp \
--policy.repo_id=jade_choghari/pi05-fast-libero-fsdp \
--policy.path=/fsx/jade_choghari/models/libero-pi-fast \
--policy.dtype=bfloat16 \
--steps=100000 \
--save_freq=10 \
--batch_size=8 \
--policy.device=cuda \
--policy.fast_only=true \
--policy.scheduler_warmup_steps=2000 \
--policy.scheduler_decay_steps=60000 \
--policy.scheduler_decay_lr=1e-5 \
--policy.gradient_checkpointing=false \
--wandb.enable=true \
--wandb.disable_artifact=true \
--wandb.project=pi05-libero-training-fsdp

View File

@@ -2,16 +2,19 @@ export CUDA_LAUNCH_BLOCKING=1
lerobot-train \
--dataset.repo_id=local \
--dataset.root=/fsx/jade_choghari/data/libero \
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_1 \
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_5 \
--job_name=libero_training_fast \
--policy.repo_id=jade_choghari/pi05-fast-libero \
--policy.path=/fsx/jade_choghari/models/libero-pi-fast \
--policy.dtype=bfloat16 \
--steps=200000 \
--save_freq=30000 \
--batch_size=16 \
--steps=100000 \
--save_freq=20000 \
--batch_size=4 \
--policy.device=cuda \
--policy.fast_only=true \
--policy.scheduler_warmup_steps=1000 \
--policy.scheduler_decay_steps=30000 \
--policy.scheduler_decay_lr=1e-5 \
--policy.gradient_checkpointing=true \
# --wandb.enable=true \
# --wandb.disable_artifact=true \

View File

@@ -0,0 +1,15 @@
#!/bin/bash
#SBATCH --job-name=pi05-train
#SBATCH --time=24:00:00
#SBATCH --qos=high
#SBATCH --gres=gpu:8
#SBATCH --mem=256G
#SBATCH --partition=hopper-prod
#SBATCH --output=/fsx/jade_choghari/logs/%x-%j.out
#SBATCH --error=/fsx/jade_choghari/logs/%x-%j.err
srun \
--container-image=/fsx/michel_aractingi/docker_images/huggingface+lerobot-gpu+dev.sqsh \
--container-mounts=/fsx/jade_choghari \
--container-workdir=$HOME/lerobot \
bash /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05/train_multi.sh

View File

@@ -1,16 +1,29 @@
accelerate launch --multi_gpu --num_processes=2 \
#!/bin/bash
set -euxo pipefail
# Source YOUR Miniforge conda (mounted from FSX)
source /fsx/jade_choghari/miniforge3/etc/profile.d/conda.sh
conda activate lerobot
accelerate launch --mixed_precision=bf16 --multi_gpu --num_processes=8 \
$(which lerobot-train) \
--dataset.repo_id=lerobot/libero \
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast \
--dataset.repo_id=local \
--dataset.root=/fsx/jade_choghari/data/libero \
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_4 \
--job_name=libero_training_fast \
--policy.repo_id=jade_choghari/pi05-fast-libero \
--policy.repo_id=jade_choghari/pi05-fast-libero-8 \
--policy.path=/fsx/jade_choghari/models/libero-pi-fast \
--policy.dtype=bfloat16 \
--steps=200000 \
--save_freq=30000 \
--batch_size=16 \
--steps=60000 \
--save_freq=10000 \
--batch_size=8 \
--policy.compile_model=false \
--policy.device=cuda \
--policy.fast_only=true \
--policy.scheduler_warmup_steps=2000 \
--policy.scheduler_decay_steps=60000 \
--policy.scheduler_decay_lr=1e-5 \
--policy.gradient_checkpointing=false \
--wandb.enable=true \
--wandb.disable_artifact=true \
--wandb.project=pi05-libero-training \
--wandb.project=pi05-libero-training \

View File

@@ -29,7 +29,9 @@ from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
from scipy.fft import idct
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import (
@@ -223,7 +225,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
task = self.get_task(self.transition)
if task is None:
raise ValueError("Task cannot be None")
# Tokenize the task (this will create CPU tensors)
tokenized_prompt = self._tokenize_text(task)
@@ -534,7 +535,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
Tokenizes the action tensor and creates a mask.
Args:
action: The input action tensor to tokenize. Shape: (B, action_dim) or (action_dim,)
action: The input action tensor to tokenize. Shape: (B, H, action_dim) or (H, action_dim,)
Returns:
A tuple of (tokens, mask) where:
@@ -576,7 +577,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
# Flatten to 1D if needed
if tokens.dim() > 1:
tokens = tokens.flatten()
tokens = torch.cat([self._act_tokens_to_paligemma_tokens(tokens), torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device)])
# Truncate or pad to max_action_tokens
if len(tokens) > self.max_action_tokens:
@@ -674,3 +675,510 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
}
return features
@dataclass
@ProcessorStepRegistry.register(name="action_detokenizer_processor_1")
class ActionDetokenizerProcessorStep1(ActionProcessorStep):
"""
Processor step to detokenize action tokens back to continuous actions.
This step takes tokenized actions (e.g., from model predictions), decodes them using
a Hugging Face `transformers` AutoProcessor (such as the Physical Intelligence "fast" tokenizer),
and returns the continuous action tensor.
This is the inverse operation of ActionTokenizerProcessorStep and is typically used
during inference to convert predicted tokens back to executable actions.
Requires the `transformers` library to be installed.
Attributes:
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
action_horizon: The number of timesteps for actions.
action_dim: The dimensionality of each action.
relaxed_decoding: Whether to use relaxed decoding for actions (allows graceful handling of partial sequences).
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
"""
tokenizer_name: str | None = None
tokenizer: Any | None = None
trust_remote_code: bool = True
action_horizon: int = 1
action_dim: int = 7
relaxed_decoding: bool = False
# Internal tokenizer instance (not part of the config)
action_tokenizer: Any = field(default=None, init=False, repr=False)
_paligemma_tokenizer: Any = field(default=None, init=False, repr=False)
_fast_skip_tokens: int = field(default=128, init=False, repr=False)
def __post_init__(self):
"""
Initializes the action detokenizer after the dataclass is created.
It checks for the availability of the `transformers` library and loads the tokenizer
either from a provided object or by name from the Hugging Face Hub.
Raises:
ImportError: If the `transformers` library is not installed.
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
"""
if not _transformers_available:
raise ImportError(
"The 'transformers' library is not installed. "
"Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionDetokenizerProcessorStep."
)
if self.tokenizer is not None:
# Use provided tokenizer object directly
self.action_tokenizer = self.tokenizer
elif self.tokenizer_name is not None:
if AutoProcessor is None:
raise ImportError("AutoProcessor is not available")
self.action_tokenizer = AutoProcessor.from_pretrained(
self.tokenizer_name, trust_remote_code=self.trust_remote_code
)
else:
raise ValueError(
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
"Pass a tokenizer object directly or a tokenizer name to auto-load."
)
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
"google/paligemma-3b-pt-224",
trust_remote_code=True,
add_eos_token=True,
add_bos_token=False
)
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
def _paligemma_tokens_to_act_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
"""
Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens).
"""
return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens
def decode_actions_with_fast(
self,
token_ids: list[int],
time_horizon: int,
action_dim: int,
relaxed_decoding: bool = False
) -> list:
"""
Decodes action token IDs back to continuous action values using the FAST tokenizer.
Args:
token_ids: List of token IDs to decode.
time_horizon: The number of timesteps for actions.
action_dim: The dimensionality of each action.
relaxed_decoding: Whether to use relaxed decoding (allows partial sequences).
Returns:
A list representing the decoded actions.
"""
# Use the action tokenizer's decode method
# The FAST tokenizer should have a decode method that converts tokens back to actions
try:
decoded_actions = self.action_tokenizer.decode(
token_ids,
time_horizon=time_horizon,
action_dim=action_dim
)
return decoded_actions
except Exception as e:
if relaxed_decoding:
# If relaxed decoding is enabled, try to decode as much as possible
import logging
logging.warning(f"Relaxed decoding: {e}. Returning partial decode.")
try:
# Try to decode with whatever tokens we have
partial_decoded = self.action_tokenizer.decode(
token_ids[:len(token_ids)],
time_horizon=time_horizon,
action_dim=action_dim
)
return partial_decoded
except:
# Return zeros if decoding completely fails
return [[0.0] * action_dim for _ in range(time_horizon)]
else:
raise e
def extract_actions(self, tokens: torch.Tensor) -> torch.Tensor:
"""
Extracts actions from predicted output tokens using the FAST model.
Args:
tokens: The input tensor of tokenized outputs. Shape: (B, seq_len) or (seq_len,)
Returns:
The extracted actions as a tensor of shape (B, action_horizon, action_dim) or (action_horizon, action_dim).
"""
# Handle single sample (add batch dimension)
single_sample = tokens.dim() == 1
if single_sample:
tokens = tokens.unsqueeze(0)
# Decode predicted output tokens
decoded_tokens = self._paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
# Clean the decoded tokens by removing "Action:" prefix and extracting the relevant part
cleaned_tokens = [
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
for tokens_sequence in decoded_tokens
]
# Re-encode the cleaned text to get raw action tokens
raw_action_tokens = [
self._paligemma_tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
for sample_tokens in cleaned_tokens
]
# Convert PaliGemma tokens back to action tokens
action_tokens = [
self._paligemma_tokens_to_act_tokens(raw_action_token)
for raw_action_token in raw_action_tokens
]
tokens = [t.flatten().tolist() for t in action_tokens]
breakpoint()
# Decode each sample's tokens to continuous actions
decoded_actions = [
torch.tensor(
self.decode_actions_with_fast(
tok.tolist(),
time_horizon=self.action_horizon,
action_dim=self.action_dim,
relaxed_decoding=self.relaxed_decoding,
),
device=tokens.device,
).squeeze(0)
for tok in action_tokens
]
breakpoint()
# Stack into a batch
result = torch.stack(decoded_actions, dim=0)
# Remove batch dimension if input was single sample
if single_sample:
result = result.squeeze(0)
return result
def action(self, action: torch.Tensor) -> torch.Tensor:
"""
Detokenizes action tokens back to continuous actions.
Args:
action: The tokenized action tensor. Shape: (B, max_action_tokens) or (max_action_tokens,)
Returns:
The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim)
"""
return self.extract_actions(action)
def get_config(self) -> dict[str, Any]:
"""
Returns the serializable configuration of the processor.
Note: The tokenizer object itself is not serialized. If the processor was initialized
with a tokenizer name, that name will be included in the config.
Returns:
A dictionary with the processor's configuration parameters.
"""
config = {
"trust_remote_code": self.trust_remote_code,
"action_horizon": self.action_horizon,
"action_dim": self.action_dim,
"relaxed_decoding": self.relaxed_decoding,
}
# Only save tokenizer_name if it was used to create the tokenizer
if self.tokenizer_name is not None and self.tokenizer is None:
config["tokenizer_name"] = self.tokenizer_name
return config
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Updates feature definitions to reflect detokenized actions.
This updates the policy features dictionary to indicate that the action
has been detokenized from token IDs back to continuous values.
Args:
features: The dictionary of existing policy features.
Returns:
The updated dictionary of policy features.
"""
# Update the action feature to reflect the continuous action shape
if PipelineFeatureType.ACTION in features:
# Replace the action feature with the detokenized version
features[PipelineFeatureType.ACTION] = {
"action": PolicyFeature(
type=FeatureType.STATE, # Continuous action
shape=(self.action_horizon, self.action_dim)
)
}
return features
@dataclass
@ProcessorStepRegistry.register(name="action_detokenizer_processor")
class ActionDetokenizerProcessorStep(ActionProcessorStep):
"""
Processor step to detokenize action tokens back to continuous actions.
This step takes tokenized actions (e.g., from model predictions), decodes them using
a Hugging Face `transformers` AutoProcessor (such as the Physical Intelligence "fast" tokenizer),
and returns the continuous action tensor.
This is the inverse operation of ActionTokenizerProcessorStep and is typically used
during inference to convert predicted tokens back to executable actions.
Requires the `transformers` library to be installed.
Attributes:
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
action_horizon: The number of timesteps for actions.
action_dim: The dimensionality of each action.
relaxed_decoding: Whether to use relaxed decoding for actions (allows graceful handling of partial sequences).
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
"""
tokenizer_name: str | None = None
tokenizer: Any | None = None
trust_remote_code: bool = True
action_horizon: int = 1
action_dim: int = 7
relaxed_decoding: bool = False
# Internal tokenizer instance (not part of the config)
action_tokenizer: Any = field(default=None, init=False, repr=False)
_paligemma_tokenizer: Any = field(default=None, init=False, repr=False)
_fast_skip_tokens: int = field(default=128, init=False, repr=False)
def __post_init__(self):
"""
Initializes the action detokenizer after the dataclass is created.
It checks for the availability of the `transformers` library and loads the tokenizer
either from a provided object or by name from the Hugging Face Hub.
Raises:
ImportError: If the `transformers` library is not installed.
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
"""
if not _transformers_available:
raise ImportError(
"The 'transformers' library is not installed. "
"Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionDetokenizerProcessorStep."
)
if self.tokenizer is not None:
# Use provided tokenizer object directly
self.action_tokenizer = self.tokenizer
elif self.tokenizer_name is not None:
if AutoProcessor is None:
raise ImportError("AutoProcessor is not available")
self.action_tokenizer = AutoProcessor.from_pretrained(
self.tokenizer_name, trust_remote_code=self.trust_remote_code
)
else:
raise ValueError(
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
"Pass a tokenizer object directly or a tokenizer name to auto-load."
)
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
"google/paligemma-3b-pt-224",
trust_remote_code=True,
add_eos_token=True,
add_bos_token=False
)
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
def _paligemma_tokens_to_act_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
"""
Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens).
"""
return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens
def decode_actions_with_fast(
self,
token_ids: list[int],
time_horizon: int,
action_dim: int,
relaxed_decoding: bool = True
) -> list:
"""
Decodes action token IDs back to continuous action values using the FAST tokenizer.
Args:
token_ids: List of token IDs to decode.
time_horizon: The number of timesteps for actions.
action_dim: The dimensionality of each action.
relaxed_decoding: Whether to use relaxed decoding (allows partial sequences).
Returns:
A list representing the decoded actions.
"""
decoded_actions = []
for token in token_ids:
try:
decoded_tokens = self.action_tokenizer.bpe_tokenizer.decode(token)
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.action_tokenizer.min_token
if relaxed_decoding:
# expected sequence length
expected_seq_len = time_horizon * action_dim
diff = expected_seq_len - decoded_dct_coeff.shape[0]
# apply truncation if too long
if diff < 0:
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # tsruncate on the right
# apply padding if too short
elif diff > 0:
decoded_dct_coeff = np.pad(
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
)
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, action_dim)
assert decoded_dct_coeff.shape == (
time_horizon,
action_dim,
), (
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({time_horizon}, {action_dim})"
)
except Exception as e:
print(f"Error decoding tokens: {e}")
print(f"Tokens: {token}")
decoded_dct_coeff = np.zeros((time_horizon, action_dim))
decoded_actions.append(idct(decoded_dct_coeff / self.action_tokenizer.scale, axis=0, norm="ortho"))
return np.stack(decoded_actions)
def extract_actions(self, tokens: torch.Tensor) -> torch.Tensor:
"""
Extracts actions from predicted output tokens using the FAST model.
Args:
tokens: The input tensor of tokenized outputs. Shape: (B, seq_len) or (seq_len,)
Returns:
The extracted actions as a tensor of shape (B, action_horizon, action_dim) or (action_horizon, action_dim).
"""
# Handle single sample (add batch dimension)
single_sample = tokens.dim() == 1
if single_sample:
tokens = tokens.unsqueeze(0)
# valid = tokens <= (self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens)
# fast_region = tokens.masked_fill(~valid, 0)
# fast_tokens = self._paligemma_tokens_to_act_tokens(fast_region)
# actions = self.decode_actions_with_fast(fast_tokens.tolist(), time_horizon=self.action_horizon, action_dim=self.action_dim, relaxed_decoding=self.relaxed_decoding)[0]
decoded_tokens = [
self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
for seq in tokens
]
cleaned_tokens = []
for token_seq in decoded_tokens:
if "|" in token_seq:
token_seq = token_seq[:token_seq.index("|")]
cleaned_tokens.append(token_seq)
raw_action_tokens = [
torch.tensor(
self._paligemma_tokenizer.convert_tokens_to_ids(token_seq),
dtype=torch.long,
device=tokens.device,
)
for token_seq in cleaned_tokens
]
action_tokens = [
self._paligemma_tokens_to_act_tokens(raw_action_token)
for raw_action_token in raw_action_tokens
]
actions = self.decode_actions_with_fast(
action_tokens,
time_horizon=self.action_horizon,
action_dim=self.action_dim
)
return torch.tensor(actions, device=tokens.device)
def action(self, action: torch.Tensor) -> torch.Tensor:
"""
Detokenizes action tokens back to continuous actions.
Args:
action: The tokenized action tensor. Shape: (B, max_action_tokens) or (max_action_tokens,)
Returns:
The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim)
"""
return self.extract_actions(action)
def get_config(self) -> dict[str, Any]:
"""
Returns the serializable configuration of the processor.
Note: The tokenizer object itself is not serialized. If the processor was initialized
with a tokenizer name, that name will be included in the config.
Returns:
A dictionary with the processor's configuration parameters.
"""
config = {
"trust_remote_code": self.trust_remote_code,
"action_horizon": self.action_horizon,
"action_dim": self.action_dim,
"relaxed_decoding": self.relaxed_decoding,
}
# Only save tokenizer_name if it was used to create the tokenizer
if self.tokenizer_name is not None and self.tokenizer is None:
config["tokenizer_name"] = self.tokenizer_name
return config
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Updates feature definitions to reflect detokenized actions.
This updates the policy features dictionary to indicate that the action
has been detokenized from token IDs back to continuous values.
Args:
features: The dictionary of existing policy features.
Returns:
The updated dictionary of policy features.
"""
# Update the action feature to reflect the continuous action shape
if PipelineFeatureType.ACTION in features:
# Replace the action feature with the detokenized version
features[PipelineFeatureType.ACTION] = {
"action": PolicyFeature(
type=FeatureType.STATE, # Continuous action
shape=(self.action_horizon, self.action_dim)
)
}
return features

View File

@@ -62,6 +62,7 @@ def update_policy(
accelerator: Accelerator,
lr_scheduler=None,
lock=None,
postprocessor = None,
) -> tuple[MetricsTracker, dict]:
"""
Performs a single training step to update the policy's weights.
@@ -90,6 +91,10 @@ def update_policy(
# Let accelerator handle mixed precision
with accelerator.autocast():
loss, output_dict = policy.forward(batch)
action = policy.predict_action_chunk(batch)
if postprocessor is not None:
action = postprocessor(action)
breakpoint()
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
# Use accelerator's backward method
@@ -151,7 +156,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
from accelerate.utils import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
accelerator = Accelerator(step_scheduler_with_optimizer=False, gradient_accumulation_steps=4, kwargs_handlers=[ddp_kwargs])
init_logging(accelerator=accelerator)
@@ -245,6 +250,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
**postprocessor_kwargs,
)
if is_main_process:
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
@@ -344,6 +350,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
cfg.optimizer.grad_clip_norm,
accelerator=accelerator,
lr_scheduler=lr_scheduler,
postprocessor=postprocessor,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we