fix training

This commit is contained in:
Jade Choghari
2025-12-27 10:43:00 +00:00
parent 4b40153c32
commit 4434c863b4
3 changed files with 74 additions and 12 deletions

View File

@@ -1394,9 +1394,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# For next-token prediction, we input tokens [0:T-1] to predict tokens [1:T]
# So we remove the last token from input
input_embs = prefix_embs[:, :-1]
input_pad_masks = prefix_pad_masks[:, :-1]
input_att_masks = prefix_att_masks[:, :-1, :-1]
# input_embs = prefix_embs[:, :-1]
# input_pad_masks = prefix_pad_masks[:, :-1]
# input_att_masks = prefix_att_masks[:, :-1, :-1]
input_embs = prefix_embs
input_pad_masks = prefix_pad_masks
input_att_masks = prefix_att_masks
position_ids = torch.cumsum(input_pad_masks, dim=1) - 1
att_2d_4d = self._prepare_attention_masks_4d(input_att_masks, dtype=input_embs.dtype)
@@ -1438,6 +1442,24 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
fast_targets = fast_targets[:, 1:] # Shift targets right
fast_action_masks = fast_action_masks[:, 1:] # Shift masks to match targets
# from transformers import AutoTokenizer
# self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
# "google/paligemma-3b-pt-224",
# trust_remote_code=True,
# add_eos_token=True,
# add_bos_token=False
# )
# # remove
# decoded_tokens = [
# self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
# for seq in fast_targets
# ]
# corrected_tokens = [
# self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
# for seq in fast_logits_for_pred.argmax(dim=-1)
# ]
# breakpoint()
# Compute cross-entropy loss
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1))
@@ -1591,8 +1613,49 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
adarms_cond=[None, None],
)
# Predict first token from the last sequence position
last_logits = lm_head(prefix_out[:, -1:, :]) # (B, 1, vocab_size)
# Get BOS token and add it as the first token in action sequence
bos_id = self._paligemma_tokenizer.bos_token_id
bos_token = torch.full((bsize, 1), bos_id, dtype=torch.long, device=device)
# Embed BOS token
bos_token_emb = self.paligemma_with_expert.embed_language_tokens(bos_token)
bos_token_emb = bos_token_emb * math.sqrt(bos_token_emb.shape[-1])
if prefix_embs.dtype == torch.bfloat16:
bos_token_emb = bos_token_emb.to(dtype=torch.bfloat16)
# Track current sequence length for position IDs and maintain the padding mask
current_seq_len = prefix_embs.shape[1]
# Keep track of valid positions: prefix_pad_masks tells us which positions are valid
current_pad_mask = prefix_pad_masks.clone() # (B, seq_len)
# Update padding mask for BOS token: it's always valid
current_pad_mask = torch.cat([
current_pad_mask,
torch.ones((bsize, 1), dtype=torch.bool, device=device)
], dim=1) # (B, seq_len+1)
# Position ID for BOS token (continues from where prefix ended)
bos_position_id = torch.full((bsize, 1), current_seq_len, dtype=torch.long, device=device)
# Attention mask for BOS token: attends to all valid prefix positions
bos_att_mask_2d = current_pad_mask.unsqueeze(1) # (B, 1, seq_len+1)
bos_att_4d = self._prepare_attention_masks_4d(bos_att_mask_2d, dtype=bos_token_emb.dtype)
# Forward pass with BOS token (reusing cached KVs from prefix)
(bos_out, _), past_key_values = self.paligemma_with_expert.forward(
attention_mask=bos_att_4d,
position_ids=bos_position_id,
past_key_values=past_key_values,
inputs_embeds=[bos_token_emb, None],
use_cache=True,
adarms_cond=[None, None],
)
# Update sequence length to account for BOS token
current_seq_len += 1
# Predict first action token from BOS token output
last_logits = lm_head(bos_out[:, -1:, :]) # (B, 1, vocab_size)
if temperature > 0:
probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1)
@@ -1602,11 +1665,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
generated_action_tokens[:, 0] = next_token.squeeze(-1)
# Track current sequence length for position IDs and maintain the padding mask
current_seq_len = prefix_embs.shape[1]
# Keep track of valid positions: prefix_pad_masks tells us which positions are valid
current_pad_mask = prefix_pad_masks.clone() # (B, seq_len)
# 3. Incremental Decoding Loop (using KV cache)
for t in range(1, max_decoding_steps):
# Embed the newly generated token
@@ -2330,10 +2388,11 @@ class PI05Policy(PreTrainedPolicy):
self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
for seq in tokens
]
breakpoint()
# Clean tokens by removing everything after the first "|" (end-of-action marker)
cleaned_tokens = []
for token_seq in decoded_tokens:
# also remove the "Action:" prefix
if "|" in token_seq:
token_seq = token_seq[:token_seq.index("|")]
cleaned_tokens.append(token_seq)

View File

@@ -5,7 +5,7 @@ lerobot-train \
--output_dir=/fsx/jade_choghari/outputs/pi0_multi_training \
--job_name=pi0_multi_training \
--policy.repo_id=jadechoghari/pi0-base1 \
--policy.path=/fsx/jade_choghari/outputs/libero_training_fast_4/checkpoints/last/pretrained_model/ \
--policy.path=/fsx/jade_choghari/outputs/libero_training_fast_6/checkpoints/last/pretrained_model/ \
--policy.dtype=bfloat16 \
--steps=50000 \
--save_freq=5000 \

View File

@@ -577,7 +577,10 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
if tokens.dim() > 1:
tokens = tokens.flatten()
bos_id = self._paligemma_tokenizer.bos_token_id
# add bos
tokens = torch.cat([
torch.tensor([bos_id], device=action.device),
torch.tensor(self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False), device=action.device),
self._act_tokens_to_paligemma_tokens(tokens),
torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),