mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
fix training
This commit is contained in:
@@ -1394,9 +1394,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
# For next-token prediction, we input tokens [0:T-1] to predict tokens [1:T]
|
||||
# So we remove the last token from input
|
||||
input_embs = prefix_embs[:, :-1]
|
||||
input_pad_masks = prefix_pad_masks[:, :-1]
|
||||
input_att_masks = prefix_att_masks[:, :-1, :-1]
|
||||
# input_embs = prefix_embs[:, :-1]
|
||||
# input_pad_masks = prefix_pad_masks[:, :-1]
|
||||
# input_att_masks = prefix_att_masks[:, :-1, :-1]
|
||||
|
||||
input_embs = prefix_embs
|
||||
input_pad_masks = prefix_pad_masks
|
||||
input_att_masks = prefix_att_masks
|
||||
|
||||
position_ids = torch.cumsum(input_pad_masks, dim=1) - 1
|
||||
att_2d_4d = self._prepare_attention_masks_4d(input_att_masks, dtype=input_embs.dtype)
|
||||
@@ -1438,6 +1442,24 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
fast_targets = fast_targets[:, 1:] # Shift targets right
|
||||
fast_action_masks = fast_action_masks[:, 1:] # Shift masks to match targets
|
||||
|
||||
# from transformers import AutoTokenizer
|
||||
# self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
|
||||
# "google/paligemma-3b-pt-224",
|
||||
# trust_remote_code=True,
|
||||
# add_eos_token=True,
|
||||
# add_bos_token=False
|
||||
# )
|
||||
# # remove
|
||||
# decoded_tokens = [
|
||||
# self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
|
||||
# for seq in fast_targets
|
||||
# ]
|
||||
# corrected_tokens = [
|
||||
# self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
|
||||
# for seq in fast_logits_for_pred.argmax(dim=-1)
|
||||
# ]
|
||||
# breakpoint()
|
||||
|
||||
# Compute cross-entropy loss
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
||||
fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1))
|
||||
@@ -1591,8 +1613,49 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
|
||||
# Predict first token from the last sequence position
|
||||
last_logits = lm_head(prefix_out[:, -1:, :]) # (B, 1, vocab_size)
|
||||
# Get BOS token and add it as the first token in action sequence
|
||||
bos_id = self._paligemma_tokenizer.bos_token_id
|
||||
bos_token = torch.full((bsize, 1), bos_id, dtype=torch.long, device=device)
|
||||
|
||||
# Embed BOS token
|
||||
bos_token_emb = self.paligemma_with_expert.embed_language_tokens(bos_token)
|
||||
bos_token_emb = bos_token_emb * math.sqrt(bos_token_emb.shape[-1])
|
||||
if prefix_embs.dtype == torch.bfloat16:
|
||||
bos_token_emb = bos_token_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
# Track current sequence length for position IDs and maintain the padding mask
|
||||
current_seq_len = prefix_embs.shape[1]
|
||||
# Keep track of valid positions: prefix_pad_masks tells us which positions are valid
|
||||
current_pad_mask = prefix_pad_masks.clone() # (B, seq_len)
|
||||
|
||||
# Update padding mask for BOS token: it's always valid
|
||||
current_pad_mask = torch.cat([
|
||||
current_pad_mask,
|
||||
torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
||||
], dim=1) # (B, seq_len+1)
|
||||
|
||||
# Position ID for BOS token (continues from where prefix ended)
|
||||
bos_position_id = torch.full((bsize, 1), current_seq_len, dtype=torch.long, device=device)
|
||||
|
||||
# Attention mask for BOS token: attends to all valid prefix positions
|
||||
bos_att_mask_2d = current_pad_mask.unsqueeze(1) # (B, 1, seq_len+1)
|
||||
bos_att_4d = self._prepare_attention_masks_4d(bos_att_mask_2d, dtype=bos_token_emb.dtype)
|
||||
|
||||
# Forward pass with BOS token (reusing cached KVs from prefix)
|
||||
(bos_out, _), past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=bos_att_4d,
|
||||
position_ids=bos_position_id,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=[bos_token_emb, None],
|
||||
use_cache=True,
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
|
||||
# Update sequence length to account for BOS token
|
||||
current_seq_len += 1
|
||||
|
||||
# Predict first action token from BOS token output
|
||||
last_logits = lm_head(bos_out[:, -1:, :]) # (B, 1, vocab_size)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1)
|
||||
@@ -1602,11 +1665,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
generated_action_tokens[:, 0] = next_token.squeeze(-1)
|
||||
|
||||
# Track current sequence length for position IDs and maintain the padding mask
|
||||
current_seq_len = prefix_embs.shape[1]
|
||||
# Keep track of valid positions: prefix_pad_masks tells us which positions are valid
|
||||
current_pad_mask = prefix_pad_masks.clone() # (B, seq_len)
|
||||
|
||||
# 3. Incremental Decoding Loop (using KV cache)
|
||||
for t in range(1, max_decoding_steps):
|
||||
# Embed the newly generated token
|
||||
@@ -2330,10 +2388,11 @@ class PI05Policy(PreTrainedPolicy):
|
||||
self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
|
||||
for seq in tokens
|
||||
]
|
||||
|
||||
breakpoint()
|
||||
# Clean tokens by removing everything after the first "|" (end-of-action marker)
|
||||
cleaned_tokens = []
|
||||
for token_seq in decoded_tokens:
|
||||
# also remove the "Action:" prefix
|
||||
if "|" in token_seq:
|
||||
token_seq = token_seq[:token_seq.index("|")]
|
||||
cleaned_tokens.append(token_seq)
|
||||
|
||||
@@ -5,7 +5,7 @@ lerobot-train \
|
||||
--output_dir=/fsx/jade_choghari/outputs/pi0_multi_training \
|
||||
--job_name=pi0_multi_training \
|
||||
--policy.repo_id=jadechoghari/pi0-base1 \
|
||||
--policy.path=/fsx/jade_choghari/outputs/libero_training_fast_4/checkpoints/last/pretrained_model/ \
|
||||
--policy.path=/fsx/jade_choghari/outputs/libero_training_fast_6/checkpoints/last/pretrained_model/ \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=50000 \
|
||||
--save_freq=5000 \
|
||||
|
||||
@@ -577,7 +577,10 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
if tokens.dim() > 1:
|
||||
tokens = tokens.flatten()
|
||||
|
||||
bos_id = self._paligemma_tokenizer.bos_token_id
|
||||
# add bos
|
||||
tokens = torch.cat([
|
||||
torch.tensor([bos_id], device=action.device),
|
||||
torch.tensor(self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False), device=action.device),
|
||||
self._act_tokens_to_paligemma_tokens(tokens),
|
||||
torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),
|
||||
|
||||
Reference in New Issue
Block a user