mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
make fast work
This commit is contained in:
@@ -2,25 +2,26 @@
|
||||
"repo_id": "local",
|
||||
"vocab_size": 1024,
|
||||
"scale": 10.0,
|
||||
"encoded_dims": "0:15",
|
||||
"encoded_dims": "0:7",
|
||||
"encoded_dim_ranges": [
|
||||
[
|
||||
0,
|
||||
15
|
||||
7
|
||||
]
|
||||
],
|
||||
"total_encoded_dims": 15,
|
||||
"total_encoded_dims": 7,
|
||||
"delta_dims": null,
|
||||
"delta_dim_list": null,
|
||||
"use_delta_transform": false,
|
||||
"state_key": "observation.state",
|
||||
"action_horizon": 50,
|
||||
"num_training_chunks": 4900,
|
||||
"normalization_mode": "MEAN_STD",
|
||||
"action_horizon": 10,
|
||||
"num_training_chunks": 25065,
|
||||
"compression_stats": {
|
||||
"compression_ratio": 15.85791309863622,
|
||||
"mean_token_length": 47.295,
|
||||
"p99_token_length": 90.0,
|
||||
"compression_ratio": 2.8901734104046244,
|
||||
"mean_token_length": 24.22,
|
||||
"p99_token_length": 40.0,
|
||||
"min_token_length": 9.0,
|
||||
"max_token_length": 109.0
|
||||
"max_token_length": 46.0
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,11 @@
|
||||
{
|
||||
"action_dim": 15,
|
||||
"action_dim": 7,
|
||||
"auto_map": {
|
||||
"AutoProcessor": "processing_action_tokenizer.UniversalActionProcessor"
|
||||
},
|
||||
"min_token": -71,
|
||||
"min_token": -203,
|
||||
"processor_class": "UniversalActionProcessor",
|
||||
"scale": 10.0,
|
||||
"time_horizon": 50,
|
||||
"time_horizon": 10,
|
||||
"vocab_size": 1024
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -537,17 +537,17 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
|
||||
# FAST action token embedding and prediction head
|
||||
self.fast_action_embedding = nn.Embedding(config.fast_vocab_size, paligemma_config.width)
|
||||
self.fast_action_lm_head = nn.Linear(paligemma_config.width, config.fast_vocab_size)
|
||||
# # FAST action token embedding and prediction head
|
||||
# self.fast_action_embedding = nn.Embedding(config.fast_vocab_size, paligemma_config.width)
|
||||
# self.fast_action_lm_head = nn.Linear(paligemma_config.width, config.fast_vocab_size)
|
||||
|
||||
# Apply dtype conversion to FAST layers to match model precision
|
||||
if config.dtype == "bfloat16":
|
||||
self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16)
|
||||
self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.bfloat16)
|
||||
elif config.dtype == "float32":
|
||||
self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.float32)
|
||||
self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.float32)
|
||||
# # Apply dtype conversion to FAST layers to match model precision
|
||||
# if config.dtype == "bfloat16":
|
||||
# self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16)
|
||||
# self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.bfloat16)
|
||||
# elif config.dtype == "float32":
|
||||
# self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.float32)
|
||||
# self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.float32)
|
||||
|
||||
# Initialize gradient checkpointing flag
|
||||
self.gradient_checkpointing_enabled = False
|
||||
@@ -1280,7 +1280,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Process FAST action tokens (discrete token IDs)
|
||||
if fast_action_tokens is not None:
|
||||
def fast_action_embed_func(fast_action_tokens):
|
||||
fast_emb = self.fast_action_embedding(fast_action_tokens)
|
||||
fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens)
|
||||
fast_emb_dim = fast_emb.shape[-1]
|
||||
return fast_emb * math.sqrt(fast_emb_dim)
|
||||
|
||||
@@ -1411,7 +1411,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
# Get logits for FAST action tokens using the FAST LM head
|
||||
# We only compute logits for the positions that predict FAST tokens
|
||||
fast_logits = self.fast_action_lm_head(prefix_out) # (B, T-1, fast_vocab_size)
|
||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||
|
||||
# The FAST tokens start at position (total_T_images + num_lang_tokens)
|
||||
# For next-token prediction:
|
||||
@@ -1422,7 +1422,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
# Extract logits for FAST token prediction
|
||||
# Input positions [fast_start-1 : fast_start-1+num_fast_embs] predict FAST tokens
|
||||
fast_logits_for_pred = fast_logits[:, fast_start-1:fast_start-1+num_fast_embs, :] # (B, num_fast_embs, fast_vocab_size)
|
||||
fast_hidden = prefix_out[:, fast_start-1:fast_start-1+num_fast_embs, :] # (B, num_fast_embs, hidden_dim)
|
||||
fast_logits_for_pred = lm_head(fast_hidden) # (B, num_fast_embs, gemma_vocab_size)
|
||||
|
||||
# Targets are the FAST action tokens
|
||||
fast_targets = fast_action_tokens # (B, num_fast_embs)
|
||||
@@ -1438,7 +1439,140 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Apply mask and compute mean loss
|
||||
masked_fast_loss = fast_loss_per_token * fast_action_masks.float()
|
||||
fast_loss = masked_fast_loss.sum() / fast_action_masks.sum().clamp(min=1)
|
||||
breakpoint()
|
||||
|
||||
# breakpoint()
|
||||
# from transformers import AutoTokenizer, AutoProcessor
|
||||
# _paligemma_tokenizer = AutoTokenizer.from_pretrained(
|
||||
# "google/paligemma-3b-pt-224",
|
||||
# trust_remote_code=True,
|
||||
# add_eos_token=True,
|
||||
# add_bos_token=False
|
||||
# )
|
||||
# # 257152
|
||||
# # # Decode predicted output tokens
|
||||
# # # fast_logits_for_pred.argmax(dim=-1)
|
||||
# def _paligemma_tokens_to_act_tokens(tokens: torch.Tensor) -> torch.Tensor:
|
||||
# """
|
||||
# Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens).
|
||||
# """
|
||||
# return _paligemma_tokenizer.vocab_size - 1 - 128 - tokens
|
||||
# # # target = _paligemma_tokens_to_act_tokens(fast_targets)
|
||||
# decoded_tokens = _paligemma_tokenizer.batch_decode(fast_targets, skip_special_tokens=False)
|
||||
# decoded_tokens = [
|
||||
# _paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
|
||||
# for seq in fast_logits_for_pred.argmax(dim=-1)
|
||||
# ]
|
||||
# cleaned_tokens = []
|
||||
# for token_seq in decoded_tokens:
|
||||
# if "|" in token_seq:
|
||||
# token_seq = token_seq[:token_seq.index("|")]
|
||||
# cleaned_tokens.append(token_seq)
|
||||
# raw_action_tokens = [
|
||||
# torch.tensor(
|
||||
# _paligemma_tokenizer.convert_tokens_to_ids(token_seq),
|
||||
# dtype=torch.long,
|
||||
# device=fast_targets.device,
|
||||
# )
|
||||
# for token_seq in cleaned_tokens
|
||||
# ]
|
||||
|
||||
# action_tokens = [
|
||||
# _paligemma_tokens_to_act_tokens(raw_action_token)
|
||||
# for raw_action_token in raw_action_tokens
|
||||
# ]
|
||||
# breakpoint()
|
||||
# # Clean the decoded tokens by removing "Action:" prefix and extracting the relevant part
|
||||
# cleaned_tokens = [
|
||||
# tokens_sequence.strip().split("|")[0].strip()
|
||||
# for tokens_sequence in decoded_tokens
|
||||
# ]
|
||||
|
||||
# # Re-encode the cleaned text to get raw action tokens
|
||||
# raw_action_tokens = [
|
||||
# _paligemma_tokenizer.encode(sample_tokens, return_tensors="pt", padding=False).squeeze(0)
|
||||
# for sample_tokens in cleaned_tokens
|
||||
# ]
|
||||
# # Convert PaliGemma tokens back to action tokens
|
||||
# action_tokens = [
|
||||
# _paligemma_tokens_to_act_tokens(raw_action_token)
|
||||
# for raw_action_token in raw_action_tokens
|
||||
# ]
|
||||
# # # Decode each sample's tokens to continuous actions
|
||||
# action_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
|
||||
# # breakpoint()
|
||||
# decoded_actions = action_tokenizer.decode(
|
||||
# action_tokens,
|
||||
# time_horizon=self.config.chunk_size,
|
||||
# action_dim=6
|
||||
# )
|
||||
# breakpoint()
|
||||
# def decode_actions_with_fast(
|
||||
# token_ids: list[int],
|
||||
# time_horizon: int,
|
||||
# action_dim: int,
|
||||
# relaxed_decoding: bool = False
|
||||
# ) -> list:
|
||||
# """
|
||||
# Decodes action token IDs back to continuous action values using the FAST tokenizer.
|
||||
|
||||
# Args:
|
||||
# token_ids: List of token IDs to decode.
|
||||
# time_horizon: The number of timesteps for actions.
|
||||
# action_dim: The dimensionality of each action.
|
||||
# relaxed_decoding: Whether to use relaxed decoding (allows partial sequences).
|
||||
|
||||
# Returns:
|
||||
# A list representing the decoded actions.
|
||||
# """
|
||||
# # Use the action tokenizer's decode method
|
||||
# # The FAST tokenizer should have a decode method that converts tokens back to actions
|
||||
# try:
|
||||
# decoded_actions = action_tokenizer.decode(
|
||||
# token_ids,
|
||||
# time_horizon=time_horizon,
|
||||
# action_dim=action_dim
|
||||
# )
|
||||
# return decoded_actions
|
||||
# except Exception as e:
|
||||
# if relaxed_decoding:
|
||||
# # If relaxed decoding is enabled, try to decode as much as possible
|
||||
# import logging
|
||||
# logging.warning(f"Relaxed decoding: {e}. Returning partial decode.")
|
||||
# try:
|
||||
# # Try to decode with whatever tokens we have
|
||||
# partial_decoded = action_tokenizer.decode(
|
||||
# token_ids[:len(token_ids)],
|
||||
# time_horizon=time_horizon,
|
||||
# action_dim=action_dim
|
||||
# )
|
||||
# return partial_decoded
|
||||
# except:
|
||||
# # Return zeros if decoding completely fails
|
||||
# return [[0.0] * action_dim for _ in range(time_horizon)]
|
||||
# else:
|
||||
# raise e
|
||||
|
||||
# valid = fast_logits_for_pred.argmax(dim=-1) <= (self._paligemma_tokenizer.vocab_size - 1 - 128)
|
||||
# fast_region = fast_logits_for_pred.argmax(dim=-1).masked_fill(~valid, 0)
|
||||
# fast_tokens = _paligemma_tokens_to_act_tokens(fast_region)
|
||||
# actions = decode_actions_with_fast(fast_tokens.tolist(), time_horizon=self.config.chunk_size, action_dim=7, relaxed_decoding=True)[0]
|
||||
# breakpoint()
|
||||
# decoded_actions = [
|
||||
# torch.tensor(
|
||||
# decode_actions_with_fast(
|
||||
# tok[0].tolist(),
|
||||
# time_horizon=self.config.chunk_size,
|
||||
# action_dim=7,
|
||||
# relaxed_decoding=True,
|
||||
# ),
|
||||
# device=tokens.device,
|
||||
# ).squeeze(0)
|
||||
# for tok in action_tokens
|
||||
# ]
|
||||
# breakpoint()
|
||||
# # Stack into a batch
|
||||
# result = torch.stack(decoded_actions, dim=0)
|
||||
# breakpoint()
|
||||
return {
|
||||
"fast_loss": fast_loss,
|
||||
"loss": fast_loss,
|
||||
@@ -1453,50 +1587,126 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
masks,
|
||||
max_decoding_steps=None,
|
||||
temperature=0.0,
|
||||
) -> Tensor:
|
||||
"""Sample actions using autoregressive decoding for FAST-only mode.
|
||||
|
||||
This implements the Pi0FAST inference: autoregressively decode action tokens.
|
||||
|
||||
Args:
|
||||
images: List of image tensors
|
||||
img_masks: List of image masks
|
||||
tokens: Language instruction tokens
|
||||
masks: Attention masks for tokens
|
||||
max_decoding_steps: Maximum number of tokens to decode
|
||||
temperature: Sampling temperature (0 = greedy)
|
||||
|
||||
Returns:
|
||||
Decoded action tokens [B, max_decoding_steps]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Inefficient but safe autoregressive decoding for FAST tokens.
|
||||
Matches the pattern of _generate_subtask_tokens.
|
||||
"""
|
||||
if max_decoding_steps is None:
|
||||
max_decoding_steps = self.config.max_action_tokens
|
||||
|
||||
bsize = tokens.shape[0]
|
||||
device = tokens.device
|
||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||
|
||||
# Embed prefix (images + language) without FAST tokens
|
||||
# 1. Initial Embedding (Matches Training Prefix)
|
||||
# prefix_embs will include [Images, Language Prompt]
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _ = self.embed_prefix_fast(
|
||||
images, img_masks, tokens, masks,
|
||||
fast_action_tokens=None,
|
||||
fast_action_masks=None
|
||||
)
|
||||
|
||||
# Convert to bfloat16 if needed
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
if self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
|
||||
# Initial forward pass to get KV cache
|
||||
position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
att_2d_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
|
||||
|
||||
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
generated_action_tokens = torch.zeros((bsize, max_decoding_steps), dtype=torch.long, device=device)
|
||||
|
||||
# 2. Decoding Loop (Re-computes full sequence every step)
|
||||
for t in range(max_decoding_steps):
|
||||
# Always re-calculate position IDs from the current pad mask (matches training)
|
||||
position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
att_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
|
||||
|
||||
# Full forward pass (No KV Cache)
|
||||
(prefix_out, _), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
|
||||
# Predict next token from the very last sequence position
|
||||
last_logits = lm_head(prefix_out[:, -1:, :]) # (B, 1, vocab_size)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
next_token = torch.argmax(last_logits[:, -1], dim=-1, keepdim=True)
|
||||
|
||||
generated_action_tokens[:, t] = next_token.squeeze(-1)
|
||||
|
||||
# 3. Update Sequence for next iteration (unless it's the last step)
|
||||
if t < max_decoding_steps - 1:
|
||||
# Embed the newly generated token
|
||||
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
||||
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
|
||||
if prefix_embs.dtype == torch.bfloat16:
|
||||
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
# Append to embeddings
|
||||
prefix_embs = torch.cat([prefix_embs, next_token_emb], dim=1)
|
||||
|
||||
# Update padding mask (New token is always valid/1)
|
||||
prefix_pad_masks = torch.cat([
|
||||
prefix_pad_masks,
|
||||
torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
||||
], dim=1)
|
||||
|
||||
# Update 2D attention mask: Grow the matrix
|
||||
old_len = prefix_att_masks.shape[1]
|
||||
new_len = old_len + 1
|
||||
new_att_masks = torch.zeros((bsize, new_len, new_len), dtype=torch.bool, device=device)
|
||||
new_att_masks[:, :old_len, :old_len] = prefix_att_masks
|
||||
# New token attends to all non-padding tokens in the updated sequence
|
||||
new_att_masks[:, -1, :] = prefix_pad_masks
|
||||
prefix_att_masks = new_att_masks
|
||||
return generated_action_tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_actions_fast_kv_cache(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
tokens,
|
||||
masks,
|
||||
max_decoding_steps=None,
|
||||
temperature=0.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Efficient autoregressive decoding for FAST tokens using KV-caching.
|
||||
Only computes the prefix once, then incrementally generates tokens.
|
||||
"""
|
||||
if max_decoding_steps is None:
|
||||
max_decoding_steps = self.config.max_action_tokens
|
||||
|
||||
bsize = tokens.shape[0]
|
||||
device = tokens.device
|
||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||
|
||||
# 1. Initial Embedding (Matches Training Prefix)
|
||||
# prefix_embs will include [Images, Language Prompt]
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _ = self.embed_prefix_fast(
|
||||
images, img_masks, tokens, masks,
|
||||
fast_action_tokens=None,
|
||||
fast_action_masks=None
|
||||
)
|
||||
|
||||
if self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
|
||||
generated_action_tokens = torch.zeros((bsize, max_decoding_steps), dtype=torch.long, device=device)
|
||||
|
||||
# 2. Initial forward pass to populate KV cache
|
||||
position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
att_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
|
||||
|
||||
# First forward pass with full prefix (caching enabled)
|
||||
(prefix_out, _), past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_4d,
|
||||
attention_mask=att_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
@@ -1504,62 +1714,71 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
|
||||
# Get initial logits from last position
|
||||
last_hidden = prefix_out[:, -1:]
|
||||
logits = self.fast_action_lm_head(last_hidden) # (B, 1, fast_vocab_size)
|
||||
# Predict first token from the last sequence position
|
||||
last_logits = lm_head(prefix_out[:, -1:, :]) # (B, 1, vocab_size)
|
||||
|
||||
# Autoregressive decoding
|
||||
output_tokens = torch.zeros((bsize, max_decoding_steps), dtype=torch.long, device=device)
|
||||
prefix_len = prefix_pad_masks.shape[1]
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
next_token = torch.argmax(last_logits[:, -1], dim=-1, keepdim=True)
|
||||
|
||||
generated_action_tokens[:, 0] = next_token.squeeze(-1)
|
||||
|
||||
# Track current sequence length for position IDs and maintain the padding mask
|
||||
current_seq_len = prefix_embs.shape[1]
|
||||
# Keep track of valid positions: prefix_pad_masks tells us which positions are valid
|
||||
current_pad_mask = prefix_pad_masks.clone() # (B, seq_len)
|
||||
|
||||
# 3. Incremental Decoding Loop (using KV cache)
|
||||
for t in range(1, max_decoding_steps):
|
||||
# Embed the newly generated token
|
||||
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
||||
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
|
||||
if prefix_embs.dtype == torch.bfloat16:
|
||||
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
# Update padding mask: new generated token is always valid
|
||||
current_pad_mask = torch.cat([
|
||||
current_pad_mask,
|
||||
torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
||||
], dim=1) # (B, seq_len+1)
|
||||
|
||||
# Position ID for the new token (continues from where we left off)
|
||||
new_position_id = torch.full((bsize, 1), current_seq_len, dtype=torch.long, device=device)
|
||||
|
||||
# For KV-cache: attention mask for the new token should only attend to valid positions
|
||||
# Shape: (B, 1, past_len+1) where the new token attends to valid prefix + all generated tokens
|
||||
new_att_mask_2d = current_pad_mask.unsqueeze(1) # (B, 1, seq_len+1)
|
||||
att_4d_incremental = self._prepare_attention_masks_4d(new_att_mask_2d, dtype=next_token_emb.dtype)
|
||||
|
||||
# Forward pass with only the new token embedding (reusing cached KVs)
|
||||
(new_out, _), past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_4d_incremental,
|
||||
position_ids=new_position_id,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=[next_token_emb, None],
|
||||
use_cache=True,
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
|
||||
# Predict next token
|
||||
last_logits = lm_head(new_out[:, -1:, :]) # (B, 1, vocab_size)
|
||||
|
||||
for step in range(max_decoding_steps):
|
||||
# Sample next token
|
||||
if temperature > 0:
|
||||
probs = F.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)
|
||||
|
||||
output_tokens[:, step] = next_token.squeeze(-1)
|
||||
next_token = torch.argmax(last_logits[:, -1], dim=-1, keepdim=True)
|
||||
|
||||
# Check for EOS token (token ID 1 in many tokenizers)
|
||||
# You may want to adjust this based on your FAST tokenizer
|
||||
# For now, we decode all max_decoding_steps tokens
|
||||
generated_action_tokens[:, t] = next_token.squeeze(-1)
|
||||
|
||||
if step < max_decoding_steps - 1:
|
||||
# Embed the new token
|
||||
def next_token_embed_func(next_token):
|
||||
next_emb = self.fast_action_embedding(next_token)
|
||||
return next_emb * math.sqrt(next_emb.shape[-1])
|
||||
|
||||
next_emb = next_token_embed_func(next_token)
|
||||
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
next_emb = next_emb.to(dtype=torch.bfloat16)
|
||||
# Update sequence length
|
||||
current_seq_len += 1
|
||||
|
||||
# Update position ids
|
||||
new_position_ids = torch.full((bsize, 1), prefix_len + step, dtype=torch.long, device=device)
|
||||
|
||||
# Create attention mask for the new token (attends to all previous)
|
||||
new_att_mask = torch.ones(bsize, 1, prefix_len + step + 1, dtype=torch.bool, device=device)
|
||||
new_att_4d = self._prepare_attention_masks_4d(new_att_mask, dtype=next_emb.dtype)
|
||||
|
||||
# Forward pass with KV cache
|
||||
(next_out, _), past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=new_att_4d,
|
||||
position_ids=new_position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=[next_emb, None],
|
||||
use_cache=True,
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
|
||||
logits = self.fast_action_lm_head(next_out)
|
||||
|
||||
return output_tokens
|
||||
return generated_action_tokens
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _generate_subtask_tokens(
|
||||
@@ -1903,7 +2122,6 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
|
||||
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
|
||||
|
||||
# Then add "model." prefix for all keys that don't already have it
|
||||
remapped_state_dict = {}
|
||||
remap_count = 0
|
||||
@@ -2002,6 +2220,9 @@ class PI05Policy(PreTrainedPolicy):
|
||||
# Some checkpoints might have this, but current model expects different structure
|
||||
logging.warning(f"Vision embedding key might need handling: {key}")
|
||||
|
||||
if key == "model.paligemma_with_expert.paligemma.lm_head.weight":
|
||||
fixed_state_dict["model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"] = value.clone()
|
||||
|
||||
fixed_state_dict[new_key] = value
|
||||
|
||||
return fixed_state_dict
|
||||
@@ -2135,15 +2356,14 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
# Get optional parameters
|
||||
temperature = kwargs.get("temperature", 0.0)
|
||||
max_decoding_steps = kwargs.get("max_decoding_steps", self.config.max_action_tokens)
|
||||
max_decoding_steps = 256
|
||||
|
||||
# Sample action tokens autoregressively
|
||||
action_tokens = self.model.sample_actions_fast(
|
||||
action_tokens = self.model.sample_actions_fast_kv_cache(
|
||||
images, img_masks, tokens, masks,
|
||||
max_decoding_steps=max_decoding_steps,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
# Return the action tokens - these need to be decoded by the FAST tokenizer
|
||||
# The caller is responsible for decoding tokens to continuous actions
|
||||
return action_tokens
|
||||
|
||||
@@ -1,18 +1,13 @@
|
||||
rm -rf /fsx/jade_choghari/outputs/pi0_multi_training
|
||||
lerobot-train \
|
||||
--dataset.repo_id=local\
|
||||
--dataset.root=/fsx/jade_choghari/outputs/collect-data-pgen \
|
||||
--dataset.root=/fsx/jade_choghari/data/libero \
|
||||
--output_dir=/fsx/jade_choghari/outputs/pi0_multi_training \
|
||||
--job_name=pi0_multi_training \
|
||||
--policy.repo_id=jadechoghari/pi0-base1 \
|
||||
--policy.path=lerobot/pi05_base \
|
||||
--policy.path=/fsx/jade_choghari/outputs/libero_training_fast_4/checkpoints/last/pretrained_model/ \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=50000 \
|
||||
--save_freq=5000 \
|
||||
--rename_map='{
|
||||
"observation.images.base": "observation.images.base_0_rgb",
|
||||
"observation.images.left_wrist": "observation.images.left_wrist_0_rgb",
|
||||
"observation.images.right_wrist": "observation.images.right_wrist_0_rgb",
|
||||
}' \
|
||||
--batch_size=32 \
|
||||
--batch_size=4 \
|
||||
--policy.device=cuda \
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
python src/lerobot/policies/pi05/train_fast_tokenizer.py \
|
||||
--repo_id "lerobot/libero" \
|
||||
--action_horizon 50 \
|
||||
--encoded_dims "0:6" \
|
||||
--repo_id "local" \
|
||||
--root /fsx/jade_choghari/data/libero \
|
||||
--action_horizon 10 \
|
||||
--encoded_dims "0:7" \
|
||||
--vocab_size 1024 \
|
||||
--output_dir "/fsx/jade_choghari/outputs/fast_tokenizer"
|
||||
--push_to_hub \
|
||||
--hub_repo_id jadechoghari/fast-libero-tokenizer-mean-std \
|
||||
--normalization_mode MEAN_STD \
|
||||
|
||||
|
||||
# python train_fast_tokenizer.py --repo_id my_dataset
|
||||
@@ -15,6 +15,8 @@ from pathlib import Path
|
||||
from transformers import AutoProcessor
|
||||
import torch
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
@@ -39,6 +41,64 @@ def apply_delta_transform(state: np.ndarray, actions: np.ndarray, delta_dims: li
|
||||
return delta_actions
|
||||
|
||||
|
||||
def apply_normalization(
|
||||
data: np.ndarray,
|
||||
stats: dict[str, np.ndarray],
|
||||
mode: NormalizationMode,
|
||||
eps: float = 1e-8,
|
||||
) -> np.ndarray:
|
||||
"""Apply normalization to data based on the specified mode.
|
||||
|
||||
Args:
|
||||
data: Data to normalize [N, H, D] or [D]
|
||||
stats: Dictionary of statistics (mean, std, min, max, q01, q99, q10, q90)
|
||||
mode: Normalization mode to apply
|
||||
eps: Small epsilon for numerical stability
|
||||
|
||||
Returns:
|
||||
Normalized data with the same shape as input
|
||||
"""
|
||||
if mode == NormalizationMode.IDENTITY:
|
||||
return data
|
||||
|
||||
if mode == NormalizationMode.MEAN_STD:
|
||||
mean = stats.get("mean")
|
||||
std = stats.get("std")
|
||||
if mean is None or std is None:
|
||||
raise ValueError("MEAN_STD mode requires 'mean' and 'std' in stats")
|
||||
return (data - mean) / np.maximum(std, eps)
|
||||
|
||||
if mode == NormalizationMode.MIN_MAX:
|
||||
min_val = stats.get("min")
|
||||
max_val = stats.get("max")
|
||||
if min_val is None or max_val is None:
|
||||
raise ValueError("MIN_MAX mode requires 'min' and 'max' in stats")
|
||||
denom = np.maximum(max_val - min_val, eps)
|
||||
return 2.0 * (data - min_val) / denom - 1.0
|
||||
|
||||
if mode == NormalizationMode.QUANTILES:
|
||||
q01 = stats.get("q01")
|
||||
q99 = stats.get("q99")
|
||||
if q01 is None or q99 is None:
|
||||
raise ValueError("QUANTILES mode requires 'q01' and 'q99' in stats")
|
||||
denom = np.maximum(q99 - q01, eps)
|
||||
# Clip to quantile range then normalize to [-1, 1]
|
||||
clipped = np.clip(data, q01, q99)
|
||||
return 2.0 * (clipped - q01) / denom - 1.0
|
||||
|
||||
if mode == NormalizationMode.QUANTILE10:
|
||||
q10 = stats.get("q10")
|
||||
q90 = stats.get("q90")
|
||||
if q10 is None or q90 is None:
|
||||
raise ValueError("QUANTILE10 mode requires 'q10' and 'q90' in stats")
|
||||
denom = np.maximum(q90 - q10, eps)
|
||||
# Clip to quantile range then normalize to [-1, 1]
|
||||
clipped = np.clip(data, q10, q90)
|
||||
return 2.0 * (clipped - q10) / denom - 1.0
|
||||
|
||||
raise ValueError(f"Unsupported normalization mode: {mode}")
|
||||
|
||||
|
||||
def process_episode(args):
|
||||
"""Process single episode and return action chunks."""
|
||||
dataset, ep_idx, action_horizon, delta_dims, sample_fraction, state_key, use_delta_transform = args
|
||||
@@ -237,9 +297,13 @@ def main(
|
||||
delta_dims: str | None = None,
|
||||
use_delta_transform: bool = False,
|
||||
state_key: str = "observation.state",
|
||||
normalization_mode: str = "QUANTILES",
|
||||
vocab_size: int = 1024,
|
||||
scale: float = 10.0,
|
||||
output_dir: str | None = None,
|
||||
push_to_hub: bool = False,
|
||||
hub_repo_id: str | None = None,
|
||||
hub_private: bool = False,
|
||||
):
|
||||
"""
|
||||
Train FAST tokenizer for action encoding.
|
||||
@@ -254,15 +318,29 @@ def main(
|
||||
delta_dims: Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5")
|
||||
use_delta_transform: Whether to apply delta transform (relative actions vs absolute actions)
|
||||
state_key: Dataset key for state observations (default: "observation.state")
|
||||
normalization_mode: Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY)
|
||||
vocab_size: FAST vocabulary size (BPE vocab size)
|
||||
scale: DCT scaling factor (default: 10.0)
|
||||
output_dir: Directory to save tokenizer (default: ./fast_tokenizer_{repo_id})
|
||||
push_to_hub: Whether to push the tokenizer to Hugging Face Hub
|
||||
hub_repo_id: Hub repository ID (e.g., "username/tokenizer-name"). If None, uses output_dir name
|
||||
hub_private: Whether to create a private repository on the Hub
|
||||
"""
|
||||
# Load dataset
|
||||
print(f"Loading dataset: {repo_id}")
|
||||
dataset = LeRobotDataset(repo_id=repo_id, root=root)
|
||||
print(f"Dataset loaded: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||||
|
||||
# Parse normalization mode
|
||||
try:
|
||||
norm_mode = NormalizationMode(normalization_mode)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid normalization_mode: {normalization_mode}. "
|
||||
f"Must be one of: {', '.join([m.value for m in NormalizationMode])}"
|
||||
)
|
||||
print(f"Normalization mode: {norm_mode.value}")
|
||||
|
||||
# Parse encoded dimensions
|
||||
encoded_dim_ranges = []
|
||||
for range_str in encoded_dims.split(','):
|
||||
@@ -317,13 +395,12 @@ def main(
|
||||
encoded_chunks = np.concatenate(encoded_chunks, axis=-1) # [N, H, D_encoded]
|
||||
print(f"Extracted {encoded_chunks.shape[-1]} encoded dimensions")
|
||||
|
||||
# Apply normalization to encoded dimensions only
|
||||
# NOTE: For FAST, we ALWAYS use QUANTILE normalization (no per-timestamp)
|
||||
# This clips outliers and provides consistent [-1, 1] range for DCT compression
|
||||
# Apply normalization to encoded dimensions
|
||||
print(f"\nBefore normalization - overall stats:")
|
||||
print(f" Min: {np.min(encoded_chunks):.4f}, Max: {np.max(encoded_chunks):.4f}")
|
||||
print(f" Mean: {np.mean(encoded_chunks):.4f}, Std: {np.std(encoded_chunks):.4f}")
|
||||
|
||||
# Get normalization stats from dataset
|
||||
norm_stats = dataset.meta.stats
|
||||
if norm_stats is not None and "action" in norm_stats:
|
||||
action_stats = norm_stats["action"]
|
||||
@@ -334,19 +411,31 @@ def main(
|
||||
encoded_dim_indices.extend(range(start, end))
|
||||
encoded_dim_indices = np.array(encoded_dim_indices)
|
||||
|
||||
# Use QUANTILE normalization: clip to [q01, q99] and map to [-1, 1]
|
||||
if "q01" in action_stats and "q99" in action_stats:
|
||||
q01 = np.array(action_stats["q01"])[encoded_dim_indices] # [D_encoded]
|
||||
q99 = np.array(action_stats["q99"])[encoded_dim_indices] # [D_encoded]
|
||||
# Extract stats for encoded dimensions only
|
||||
encoded_stats = {}
|
||||
for stat_name, stat_values in action_stats.items():
|
||||
if isinstance(stat_values, (list, np.ndarray)):
|
||||
stat_array = np.array(stat_values)
|
||||
if len(stat_array) > max(encoded_dim_indices):
|
||||
encoded_stats[stat_name] = stat_array[encoded_dim_indices]
|
||||
|
||||
if encoded_stats:
|
||||
print(f"\nNormalization stats for encoded dimensions (mode: {norm_mode.value}):")
|
||||
for stat_name, stat_values in encoded_stats.items():
|
||||
print(f" {stat_name}: shape={stat_values.shape}, "
|
||||
f"range=[{np.min(stat_values):.4f}, {np.max(stat_values):.4f}]")
|
||||
|
||||
print(f"\nNormalization stats (q01, q99) for encoded dimensions:")
|
||||
for i, dim_idx in enumerate(encoded_dim_indices):
|
||||
print(f" Orig dim {dim_idx}: q01={q01[i]:7.4f}, q99={q99[i]:7.4f}, range={q99[i]-q01[i]:7.4f}")
|
||||
|
||||
# Clip to quantile range and normalize to [-1, 1]
|
||||
encoded_chunks = np.clip(encoded_chunks, q01, q99)
|
||||
encoded_chunks = 2.0 * (encoded_chunks - q01) / np.maximum(q99 - q01, 1e-6) - 1.0
|
||||
print(f"\nApplied quantile normalization [q01, q99] → [-1, 1]")
|
||||
# Apply normalization based on mode
|
||||
try:
|
||||
encoded_chunks = apply_normalization(
|
||||
encoded_chunks,
|
||||
encoded_stats,
|
||||
norm_mode,
|
||||
eps=1e-8
|
||||
)
|
||||
print(f"\nApplied {norm_mode.value} normalization")
|
||||
except ValueError as e:
|
||||
print(f"Warning: {e}. Using raw actions without normalization.")
|
||||
|
||||
print(f"\nAfter normalization - overall stats:")
|
||||
print(f" Min: {np.min(encoded_chunks):.4f}, Max: {np.max(encoded_chunks):.4f}")
|
||||
@@ -358,9 +447,9 @@ def main(
|
||||
print(f" Dim {d}: min={np.min(dim_data):7.4f}, max={np.max(dim_data):7.4f}, "
|
||||
f"mean={np.mean(dim_data):7.4f}, std={np.std(dim_data):7.4f}")
|
||||
else:
|
||||
print("Warning: q01/q99 stats not found, using raw actions")
|
||||
print("Warning: Could not extract stats for encoded dimensions, using raw actions")
|
||||
else:
|
||||
print("Warning: No normalization stats found, using raw actions")
|
||||
print("Warning: No normalization stats found in dataset, using raw actions")
|
||||
|
||||
print(f"Encoded chunks shape: {encoded_chunks.shape}")
|
||||
|
||||
@@ -394,6 +483,7 @@ def main(
|
||||
'delta_dim_list': delta_dim_list,
|
||||
'use_delta_transform': use_delta_transform,
|
||||
'state_key': state_key,
|
||||
'normalization_mode': norm_mode.value,
|
||||
'action_horizon': action_horizon,
|
||||
'num_training_chunks': len(encoded_chunks),
|
||||
'compression_stats': compression_stats,
|
||||
@@ -402,8 +492,41 @@ def main(
|
||||
with open(output_path / "metadata.json", 'w') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
print(f"\n✅ Saved FAST tokenizer to {output_path}")
|
||||
print(f"\nSaved FAST tokenizer to {output_path}")
|
||||
print(f"Metadata: {json.dumps(metadata, indent=2)}")
|
||||
|
||||
# Push to Hugging Face Hub if requested
|
||||
if push_to_hub:
|
||||
# Determine the hub repository ID
|
||||
if hub_repo_id is None:
|
||||
hub_repo_id = output_path.name
|
||||
print(f"\nNo hub_repo_id provided, using: {hub_repo_id}")
|
||||
|
||||
print(f"\nPushing tokenizer to Hugging Face Hub: {hub_repo_id}")
|
||||
print(f" Private: {hub_private}")
|
||||
|
||||
try:
|
||||
# Use the tokenizer's push_to_hub method
|
||||
tokenizer.push_to_hub(
|
||||
repo_id=hub_repo_id,
|
||||
private=hub_private,
|
||||
commit_message=f"Upload FAST tokenizer trained on {repo_id}"
|
||||
)
|
||||
|
||||
# Also upload the metadata.json file separately
|
||||
api = HfApi()
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(output_path / "metadata.json"),
|
||||
path_in_repo="metadata.json",
|
||||
repo_id=hub_repo_id,
|
||||
repo_type="model",
|
||||
commit_message="Upload tokenizer metadata"
|
||||
)
|
||||
|
||||
print(f"Successfully pushed tokenizer to: https://huggingface.co/{hub_repo_id}")
|
||||
except Exception as e:
|
||||
print(f"Error pushing to hub: {e}")
|
||||
print(" Make sure you're logged in with `huggingface-cli login`")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
28
src/lerobot/policies/pi05/train_fsdp.sh
Executable file
28
src/lerobot/policies/pi05/train_fsdp.sh
Executable file
@@ -0,0 +1,28 @@
|
||||
#!/bin/bash
|
||||
|
||||
# FSDP training script for PI05 with aggressive memory optimization
|
||||
# Use this for large models that OOM with standard DDP
|
||||
|
||||
accelerate launch --config_file /admin/home/jade_choghari/lerobot/fsdp_config.yaml \
|
||||
$(which lerobot-train) \
|
||||
--dataset.repo_id=local \
|
||||
--dataset.root=/fsx/jade_choghari/data/libero \
|
||||
--output_dir=/fsx/jade_choghari/outputs/libero_training_fsdp \
|
||||
--job_name=libero_training_fsdp \
|
||||
--policy.repo_id=jade_choghari/pi05-fast-libero-fsdp \
|
||||
--policy.path=/fsx/jade_choghari/models/libero-pi-fast \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=100000 \
|
||||
--save_freq=10 \
|
||||
--batch_size=8 \
|
||||
--policy.device=cuda \
|
||||
--policy.fast_only=true \
|
||||
--policy.scheduler_warmup_steps=2000 \
|
||||
--policy.scheduler_decay_steps=60000 \
|
||||
--policy.scheduler_decay_lr=1e-5 \
|
||||
--policy.gradient_checkpointing=false \
|
||||
--wandb.enable=true \
|
||||
--wandb.disable_artifact=true \
|
||||
--wandb.project=pi05-libero-training-fsdp
|
||||
|
||||
|
||||
@@ -2,16 +2,19 @@ export CUDA_LAUNCH_BLOCKING=1
|
||||
lerobot-train \
|
||||
--dataset.repo_id=local \
|
||||
--dataset.root=/fsx/jade_choghari/data/libero \
|
||||
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_1 \
|
||||
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_5 \
|
||||
--job_name=libero_training_fast \
|
||||
--policy.repo_id=jade_choghari/pi05-fast-libero \
|
||||
--policy.path=/fsx/jade_choghari/models/libero-pi-fast \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=200000 \
|
||||
--save_freq=30000 \
|
||||
--batch_size=16 \
|
||||
--steps=100000 \
|
||||
--save_freq=20000 \
|
||||
--batch_size=4 \
|
||||
--policy.device=cuda \
|
||||
--policy.fast_only=true \
|
||||
--policy.scheduler_warmup_steps=1000 \
|
||||
--policy.scheduler_decay_steps=30000 \
|
||||
--policy.scheduler_decay_lr=1e-5 \
|
||||
--policy.gradient_checkpointing=true \
|
||||
# --wandb.enable=true \
|
||||
# --wandb.disable_artifact=true \
|
||||
|
||||
15
src/lerobot/policies/pi05/train_multi.sbatch
Normal file
15
src/lerobot/policies/pi05/train_multi.sbatch
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=pi05-train
|
||||
#SBATCH --time=24:00:00
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --mem=256G
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --output=/fsx/jade_choghari/logs/%x-%j.out
|
||||
#SBATCH --error=/fsx/jade_choghari/logs/%x-%j.err
|
||||
|
||||
srun \
|
||||
--container-image=/fsx/michel_aractingi/docker_images/huggingface+lerobot-gpu+dev.sqsh \
|
||||
--container-mounts=/fsx/jade_choghari \
|
||||
--container-workdir=$HOME/lerobot \
|
||||
bash /admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05/train_multi.sh
|
||||
@@ -1,16 +1,29 @@
|
||||
accelerate launch --multi_gpu --num_processes=2 \
|
||||
#!/bin/bash
|
||||
set -euxo pipefail
|
||||
|
||||
# Source YOUR Miniforge conda (mounted from FSX)
|
||||
source /fsx/jade_choghari/miniforge3/etc/profile.d/conda.sh
|
||||
|
||||
conda activate lerobot
|
||||
accelerate launch --mixed_precision=bf16 --multi_gpu --num_processes=8 \
|
||||
$(which lerobot-train) \
|
||||
--dataset.repo_id=lerobot/libero \
|
||||
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast \
|
||||
--dataset.repo_id=local \
|
||||
--dataset.root=/fsx/jade_choghari/data/libero \
|
||||
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_4 \
|
||||
--job_name=libero_training_fast \
|
||||
--policy.repo_id=jade_choghari/pi05-fast-libero \
|
||||
--policy.repo_id=jade_choghari/pi05-fast-libero-8 \
|
||||
--policy.path=/fsx/jade_choghari/models/libero-pi-fast \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=200000 \
|
||||
--save_freq=30000 \
|
||||
--batch_size=16 \
|
||||
--steps=60000 \
|
||||
--save_freq=10000 \
|
||||
--batch_size=8 \
|
||||
--policy.compile_model=false \
|
||||
--policy.device=cuda \
|
||||
--policy.fast_only=true \
|
||||
--policy.scheduler_warmup_steps=2000 \
|
||||
--policy.scheduler_decay_steps=60000 \
|
||||
--policy.scheduler_decay_lr=1e-5 \
|
||||
--policy.gradient_checkpointing=false \
|
||||
--wandb.enable=true \
|
||||
--wandb.disable_artifact=true \
|
||||
--wandb.project=pi05-libero-training \
|
||||
--wandb.project=pi05-libero-training \
|
||||
@@ -29,7 +29,9 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.fft import idct
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import (
|
||||
@@ -223,7 +225,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
task = self.get_task(self.transition)
|
||||
if task is None:
|
||||
raise ValueError("Task cannot be None")
|
||||
|
||||
# Tokenize the task (this will create CPU tensors)
|
||||
tokenized_prompt = self._tokenize_text(task)
|
||||
|
||||
@@ -534,7 +535,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
Tokenizes the action tensor and creates a mask.
|
||||
|
||||
Args:
|
||||
action: The input action tensor to tokenize. Shape: (B, action_dim) or (action_dim,)
|
||||
action: The input action tensor to tokenize. Shape: (B, H, action_dim) or (H, action_dim,)
|
||||
|
||||
Returns:
|
||||
A tuple of (tokens, mask) where:
|
||||
@@ -576,7 +577,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
# Flatten to 1D if needed
|
||||
if tokens.dim() > 1:
|
||||
tokens = tokens.flatten()
|
||||
|
||||
|
||||
tokens = torch.cat([self._act_tokens_to_paligemma_tokens(tokens), torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device)])
|
||||
# Truncate or pad to max_action_tokens
|
||||
if len(tokens) > self.max_action_tokens:
|
||||
@@ -674,3 +675,510 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="action_detokenizer_processor_1")
|
||||
class ActionDetokenizerProcessorStep1(ActionProcessorStep):
|
||||
"""
|
||||
Processor step to detokenize action tokens back to continuous actions.
|
||||
|
||||
This step takes tokenized actions (e.g., from model predictions), decodes them using
|
||||
a Hugging Face `transformers` AutoProcessor (such as the Physical Intelligence "fast" tokenizer),
|
||||
and returns the continuous action tensor.
|
||||
|
||||
This is the inverse operation of ActionTokenizerProcessorStep and is typically used
|
||||
during inference to convert predicted tokens back to executable actions.
|
||||
|
||||
Requires the `transformers` library to be installed.
|
||||
|
||||
Attributes:
|
||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
|
||||
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
|
||||
action_horizon: The number of timesteps for actions.
|
||||
action_dim: The dimensionality of each action.
|
||||
relaxed_decoding: Whether to use relaxed decoding for actions (allows graceful handling of partial sequences).
|
||||
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
|
||||
"""
|
||||
|
||||
tokenizer_name: str | None = None
|
||||
tokenizer: Any | None = None
|
||||
trust_remote_code: bool = True
|
||||
action_horizon: int = 1
|
||||
action_dim: int = 7
|
||||
relaxed_decoding: bool = False
|
||||
|
||||
# Internal tokenizer instance (not part of the config)
|
||||
action_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
_paligemma_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
_fast_skip_tokens: int = field(default=128, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Initializes the action detokenizer after the dataclass is created.
|
||||
|
||||
It checks for the availability of the `transformers` library and loads the tokenizer
|
||||
either from a provided object or by name from the Hugging Face Hub.
|
||||
|
||||
Raises:
|
||||
ImportError: If the `transformers` library is not installed.
|
||||
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
|
||||
"""
|
||||
if not _transformers_available:
|
||||
raise ImportError(
|
||||
"The 'transformers' library is not installed. "
|
||||
"Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionDetokenizerProcessorStep."
|
||||
)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
# Use provided tokenizer object directly
|
||||
self.action_tokenizer = self.tokenizer
|
||||
elif self.tokenizer_name is not None:
|
||||
if AutoProcessor is None:
|
||||
raise ImportError("AutoProcessor is not available")
|
||||
self.action_tokenizer = AutoProcessor.from_pretrained(
|
||||
self.tokenizer_name, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
|
||||
"Pass a tokenizer object directly or a tokenizer name to auto-load."
|
||||
)
|
||||
|
||||
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/paligemma-3b-pt-224",
|
||||
trust_remote_code=True,
|
||||
add_eos_token=True,
|
||||
add_bos_token=False
|
||||
)
|
||||
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||
|
||||
def _paligemma_tokens_to_act_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens).
|
||||
"""
|
||||
return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens
|
||||
|
||||
def decode_actions_with_fast(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
time_horizon: int,
|
||||
action_dim: int,
|
||||
relaxed_decoding: bool = False
|
||||
) -> list:
|
||||
"""
|
||||
Decodes action token IDs back to continuous action values using the FAST tokenizer.
|
||||
|
||||
Args:
|
||||
token_ids: List of token IDs to decode.
|
||||
time_horizon: The number of timesteps for actions.
|
||||
action_dim: The dimensionality of each action.
|
||||
relaxed_decoding: Whether to use relaxed decoding (allows partial sequences).
|
||||
|
||||
Returns:
|
||||
A list representing the decoded actions.
|
||||
"""
|
||||
# Use the action tokenizer's decode method
|
||||
# The FAST tokenizer should have a decode method that converts tokens back to actions
|
||||
try:
|
||||
decoded_actions = self.action_tokenizer.decode(
|
||||
token_ids,
|
||||
time_horizon=time_horizon,
|
||||
action_dim=action_dim
|
||||
)
|
||||
return decoded_actions
|
||||
except Exception as e:
|
||||
if relaxed_decoding:
|
||||
# If relaxed decoding is enabled, try to decode as much as possible
|
||||
import logging
|
||||
logging.warning(f"Relaxed decoding: {e}. Returning partial decode.")
|
||||
try:
|
||||
# Try to decode with whatever tokens we have
|
||||
partial_decoded = self.action_tokenizer.decode(
|
||||
token_ids[:len(token_ids)],
|
||||
time_horizon=time_horizon,
|
||||
action_dim=action_dim
|
||||
)
|
||||
return partial_decoded
|
||||
except:
|
||||
# Return zeros if decoding completely fails
|
||||
return [[0.0] * action_dim for _ in range(time_horizon)]
|
||||
else:
|
||||
raise e
|
||||
|
||||
def extract_actions(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Extracts actions from predicted output tokens using the FAST model.
|
||||
|
||||
Args:
|
||||
tokens: The input tensor of tokenized outputs. Shape: (B, seq_len) or (seq_len,)
|
||||
|
||||
Returns:
|
||||
The extracted actions as a tensor of shape (B, action_horizon, action_dim) or (action_horizon, action_dim).
|
||||
"""
|
||||
# Handle single sample (add batch dimension)
|
||||
single_sample = tokens.dim() == 1
|
||||
if single_sample:
|
||||
tokens = tokens.unsqueeze(0)
|
||||
|
||||
# Decode predicted output tokens
|
||||
decoded_tokens = self._paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True)
|
||||
|
||||
# Clean the decoded tokens by removing "Action:" prefix and extracting the relevant part
|
||||
cleaned_tokens = [
|
||||
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip()
|
||||
for tokens_sequence in decoded_tokens
|
||||
]
|
||||
|
||||
# Re-encode the cleaned text to get raw action tokens
|
||||
raw_action_tokens = [
|
||||
self._paligemma_tokenizer.encode(sample_tokens, return_tensors="pt", padding=False)
|
||||
for sample_tokens in cleaned_tokens
|
||||
]
|
||||
|
||||
# Convert PaliGemma tokens back to action tokens
|
||||
action_tokens = [
|
||||
self._paligemma_tokens_to_act_tokens(raw_action_token)
|
||||
for raw_action_token in raw_action_tokens
|
||||
]
|
||||
tokens = [t.flatten().tolist() for t in action_tokens]
|
||||
breakpoint()
|
||||
# Decode each sample's tokens to continuous actions
|
||||
decoded_actions = [
|
||||
torch.tensor(
|
||||
self.decode_actions_with_fast(
|
||||
tok.tolist(),
|
||||
time_horizon=self.action_horizon,
|
||||
action_dim=self.action_dim,
|
||||
relaxed_decoding=self.relaxed_decoding,
|
||||
),
|
||||
device=tokens.device,
|
||||
).squeeze(0)
|
||||
for tok in action_tokens
|
||||
]
|
||||
breakpoint()
|
||||
# Stack into a batch
|
||||
result = torch.stack(decoded_actions, dim=0)
|
||||
|
||||
# Remove batch dimension if input was single sample
|
||||
if single_sample:
|
||||
result = result.squeeze(0)
|
||||
|
||||
return result
|
||||
|
||||
def action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Detokenizes action tokens back to continuous actions.
|
||||
|
||||
Args:
|
||||
action: The tokenized action tensor. Shape: (B, max_action_tokens) or (max_action_tokens,)
|
||||
|
||||
Returns:
|
||||
The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim)
|
||||
"""
|
||||
return self.extract_actions(action)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the serializable configuration of the processor.
|
||||
|
||||
Note: The tokenizer object itself is not serialized. If the processor was initialized
|
||||
with a tokenizer name, that name will be included in the config.
|
||||
|
||||
Returns:
|
||||
A dictionary with the processor's configuration parameters.
|
||||
"""
|
||||
config = {
|
||||
"trust_remote_code": self.trust_remote_code,
|
||||
"action_horizon": self.action_horizon,
|
||||
"action_dim": self.action_dim,
|
||||
"relaxed_decoding": self.relaxed_decoding,
|
||||
}
|
||||
|
||||
# Only save tokenizer_name if it was used to create the tokenizer
|
||||
if self.tokenizer_name is not None and self.tokenizer is None:
|
||||
config["tokenizer_name"] = self.tokenizer_name
|
||||
|
||||
return config
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Updates feature definitions to reflect detokenized actions.
|
||||
|
||||
This updates the policy features dictionary to indicate that the action
|
||||
has been detokenized from token IDs back to continuous values.
|
||||
|
||||
Args:
|
||||
features: The dictionary of existing policy features.
|
||||
|
||||
Returns:
|
||||
The updated dictionary of policy features.
|
||||
"""
|
||||
# Update the action feature to reflect the continuous action shape
|
||||
if PipelineFeatureType.ACTION in features:
|
||||
# Replace the action feature with the detokenized version
|
||||
features[PipelineFeatureType.ACTION] = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.STATE, # Continuous action
|
||||
shape=(self.action_horizon, self.action_dim)
|
||||
)
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="action_detokenizer_processor")
|
||||
class ActionDetokenizerProcessorStep(ActionProcessorStep):
|
||||
"""
|
||||
Processor step to detokenize action tokens back to continuous actions.
|
||||
|
||||
This step takes tokenized actions (e.g., from model predictions), decodes them using
|
||||
a Hugging Face `transformers` AutoProcessor (such as the Physical Intelligence "fast" tokenizer),
|
||||
and returns the continuous action tensor.
|
||||
|
||||
This is the inverse operation of ActionTokenizerProcessorStep and is typically used
|
||||
during inference to convert predicted tokens back to executable actions.
|
||||
|
||||
Requires the `transformers` library to be installed.
|
||||
|
||||
Attributes:
|
||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
|
||||
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
|
||||
action_horizon: The number of timesteps for actions.
|
||||
action_dim: The dimensionality of each action.
|
||||
relaxed_decoding: Whether to use relaxed decoding for actions (allows graceful handling of partial sequences).
|
||||
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
|
||||
"""
|
||||
|
||||
tokenizer_name: str | None = None
|
||||
tokenizer: Any | None = None
|
||||
trust_remote_code: bool = True
|
||||
action_horizon: int = 1
|
||||
action_dim: int = 7
|
||||
relaxed_decoding: bool = False
|
||||
|
||||
# Internal tokenizer instance (not part of the config)
|
||||
action_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
_paligemma_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
_fast_skip_tokens: int = field(default=128, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Initializes the action detokenizer after the dataclass is created.
|
||||
|
||||
It checks for the availability of the `transformers` library and loads the tokenizer
|
||||
either from a provided object or by name from the Hugging Face Hub.
|
||||
|
||||
Raises:
|
||||
ImportError: If the `transformers` library is not installed.
|
||||
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
|
||||
"""
|
||||
if not _transformers_available:
|
||||
raise ImportError(
|
||||
"The 'transformers' library is not installed. "
|
||||
"Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionDetokenizerProcessorStep."
|
||||
)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
# Use provided tokenizer object directly
|
||||
self.action_tokenizer = self.tokenizer
|
||||
elif self.tokenizer_name is not None:
|
||||
if AutoProcessor is None:
|
||||
raise ImportError("AutoProcessor is not available")
|
||||
self.action_tokenizer = AutoProcessor.from_pretrained(
|
||||
self.tokenizer_name, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
|
||||
"Pass a tokenizer object directly or a tokenizer name to auto-load."
|
||||
)
|
||||
|
||||
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/paligemma-3b-pt-224",
|
||||
trust_remote_code=True,
|
||||
add_eos_token=True,
|
||||
add_bos_token=False
|
||||
)
|
||||
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||
|
||||
def _paligemma_tokens_to_act_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens).
|
||||
"""
|
||||
return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens
|
||||
|
||||
|
||||
def decode_actions_with_fast(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
time_horizon: int,
|
||||
action_dim: int,
|
||||
relaxed_decoding: bool = True
|
||||
) -> list:
|
||||
"""
|
||||
Decodes action token IDs back to continuous action values using the FAST tokenizer.
|
||||
|
||||
Args:
|
||||
token_ids: List of token IDs to decode.
|
||||
time_horizon: The number of timesteps for actions.
|
||||
action_dim: The dimensionality of each action.
|
||||
relaxed_decoding: Whether to use relaxed decoding (allows partial sequences).
|
||||
|
||||
Returns:
|
||||
A list representing the decoded actions.
|
||||
"""
|
||||
decoded_actions = []
|
||||
|
||||
for token in token_ids:
|
||||
try:
|
||||
decoded_tokens = self.action_tokenizer.bpe_tokenizer.decode(token)
|
||||
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.action_tokenizer.min_token
|
||||
|
||||
if relaxed_decoding:
|
||||
# expected sequence length
|
||||
expected_seq_len = time_horizon * action_dim
|
||||
diff = expected_seq_len - decoded_dct_coeff.shape[0]
|
||||
|
||||
# apply truncation if too long
|
||||
if diff < 0:
|
||||
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # tsruncate on the right
|
||||
|
||||
# apply padding if too short
|
||||
elif diff > 0:
|
||||
decoded_dct_coeff = np.pad(
|
||||
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0
|
||||
)
|
||||
|
||||
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, action_dim)
|
||||
assert decoded_dct_coeff.shape == (
|
||||
time_horizon,
|
||||
action_dim,
|
||||
), (
|
||||
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({time_horizon}, {action_dim})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error decoding tokens: {e}")
|
||||
print(f"Tokens: {token}")
|
||||
decoded_dct_coeff = np.zeros((time_horizon, action_dim))
|
||||
|
||||
decoded_actions.append(idct(decoded_dct_coeff / self.action_tokenizer.scale, axis=0, norm="ortho"))
|
||||
|
||||
return np.stack(decoded_actions)
|
||||
|
||||
def extract_actions(self, tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Extracts actions from predicted output tokens using the FAST model.
|
||||
|
||||
Args:
|
||||
tokens: The input tensor of tokenized outputs. Shape: (B, seq_len) or (seq_len,)
|
||||
|
||||
Returns:
|
||||
The extracted actions as a tensor of shape (B, action_horizon, action_dim) or (action_horizon, action_dim).
|
||||
"""
|
||||
# Handle single sample (add batch dimension)
|
||||
single_sample = tokens.dim() == 1
|
||||
if single_sample:
|
||||
tokens = tokens.unsqueeze(0)
|
||||
|
||||
# valid = tokens <= (self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens)
|
||||
# fast_region = tokens.masked_fill(~valid, 0)
|
||||
# fast_tokens = self._paligemma_tokens_to_act_tokens(fast_region)
|
||||
# actions = self.decode_actions_with_fast(fast_tokens.tolist(), time_horizon=self.action_horizon, action_dim=self.action_dim, relaxed_decoding=self.relaxed_decoding)[0]
|
||||
decoded_tokens = [
|
||||
self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
|
||||
for seq in tokens
|
||||
]
|
||||
cleaned_tokens = []
|
||||
for token_seq in decoded_tokens:
|
||||
if "|" in token_seq:
|
||||
token_seq = token_seq[:token_seq.index("|")]
|
||||
cleaned_tokens.append(token_seq)
|
||||
raw_action_tokens = [
|
||||
torch.tensor(
|
||||
self._paligemma_tokenizer.convert_tokens_to_ids(token_seq),
|
||||
dtype=torch.long,
|
||||
device=tokens.device,
|
||||
)
|
||||
for token_seq in cleaned_tokens
|
||||
]
|
||||
|
||||
action_tokens = [
|
||||
self._paligemma_tokens_to_act_tokens(raw_action_token)
|
||||
for raw_action_token in raw_action_tokens
|
||||
]
|
||||
actions = self.decode_actions_with_fast(
|
||||
action_tokens,
|
||||
time_horizon=self.action_horizon,
|
||||
action_dim=self.action_dim
|
||||
)
|
||||
|
||||
return torch.tensor(actions, device=tokens.device)
|
||||
|
||||
def action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Detokenizes action tokens back to continuous actions.
|
||||
|
||||
Args:
|
||||
action: The tokenized action tensor. Shape: (B, max_action_tokens) or (max_action_tokens,)
|
||||
|
||||
Returns:
|
||||
The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim)
|
||||
"""
|
||||
return self.extract_actions(action)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the serializable configuration of the processor.
|
||||
|
||||
Note: The tokenizer object itself is not serialized. If the processor was initialized
|
||||
with a tokenizer name, that name will be included in the config.
|
||||
|
||||
Returns:
|
||||
A dictionary with the processor's configuration parameters.
|
||||
"""
|
||||
config = {
|
||||
"trust_remote_code": self.trust_remote_code,
|
||||
"action_horizon": self.action_horizon,
|
||||
"action_dim": self.action_dim,
|
||||
"relaxed_decoding": self.relaxed_decoding,
|
||||
}
|
||||
|
||||
# Only save tokenizer_name if it was used to create the tokenizer
|
||||
if self.tokenizer_name is not None and self.tokenizer is None:
|
||||
config["tokenizer_name"] = self.tokenizer_name
|
||||
|
||||
return config
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Updates feature definitions to reflect detokenized actions.
|
||||
|
||||
This updates the policy features dictionary to indicate that the action
|
||||
has been detokenized from token IDs back to continuous values.
|
||||
|
||||
Args:
|
||||
features: The dictionary of existing policy features.
|
||||
|
||||
Returns:
|
||||
The updated dictionary of policy features.
|
||||
"""
|
||||
# Update the action feature to reflect the continuous action shape
|
||||
if PipelineFeatureType.ACTION in features:
|
||||
# Replace the action feature with the detokenized version
|
||||
features[PipelineFeatureType.ACTION] = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.STATE, # Continuous action
|
||||
shape=(self.action_horizon, self.action_dim)
|
||||
)
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
@@ -62,6 +62,7 @@ def update_policy(
|
||||
accelerator: Accelerator,
|
||||
lr_scheduler=None,
|
||||
lock=None,
|
||||
postprocessor = None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
"""
|
||||
Performs a single training step to update the policy's weights.
|
||||
@@ -90,6 +91,10 @@ def update_policy(
|
||||
# Let accelerator handle mixed precision
|
||||
with accelerator.autocast():
|
||||
loss, output_dict = policy.forward(batch)
|
||||
action = policy.predict_action_chunk(batch)
|
||||
if postprocessor is not None:
|
||||
action = postprocessor(action)
|
||||
breakpoint()
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
|
||||
# Use accelerator's backward method
|
||||
@@ -151,7 +156,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
|
||||
accelerator = Accelerator(step_scheduler_with_optimizer=False, gradient_accumulation_steps=4, kwargs_handlers=[ddp_kwargs])
|
||||
|
||||
init_logging(accelerator=accelerator)
|
||||
|
||||
@@ -245,6 +250,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
|
||||
|
||||
if is_main_process:
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
@@ -344,6 +350,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
accelerator=accelerator,
|
||||
lr_scheduler=lr_scheduler,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
|
||||
Reference in New Issue
Block a user