diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 164484cae..b39a26596 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -1394,9 +1394,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` # For next-token prediction, we input tokens [0:T-1] to predict tokens [1:T] # So we remove the last token from input - input_embs = prefix_embs[:, :-1] - input_pad_masks = prefix_pad_masks[:, :-1] - input_att_masks = prefix_att_masks[:, :-1, :-1] + # input_embs = prefix_embs[:, :-1] + # input_pad_masks = prefix_pad_masks[:, :-1] + # input_att_masks = prefix_att_masks[:, :-1, :-1] + + input_embs = prefix_embs + input_pad_masks = prefix_pad_masks + input_att_masks = prefix_att_masks position_ids = torch.cumsum(input_pad_masks, dim=1) - 1 att_2d_4d = self._prepare_attention_masks_4d(input_att_masks, dtype=input_embs.dtype) @@ -1438,6 +1442,24 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` fast_targets = fast_targets[:, 1:] # Shift targets right fast_action_masks = fast_action_masks[:, 1:] # Shift masks to match targets + # from transformers import AutoTokenizer + # self._paligemma_tokenizer = AutoTokenizer.from_pretrained( + # "google/paligemma-3b-pt-224", + # trust_remote_code=True, + # add_eos_token=True, + # add_bos_token=False + # ) + # # remove + # decoded_tokens = [ + # self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist()) + # for seq in fast_targets + # ] + # corrected_tokens = [ + # self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist()) + # for seq in fast_logits_for_pred.argmax(dim=-1) + # ] + # breakpoint() + # Compute cross-entropy loss loss_fct = torch.nn.CrossEntropyLoss(reduction='none') fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1)) @@ -1591,8 +1613,49 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` adarms_cond=[None, None], ) - # Predict first token from the last sequence position - last_logits = lm_head(prefix_out[:, -1:, :]) # (B, 1, vocab_size) + # Get BOS token and add it as the first token in action sequence + bos_id = self._paligemma_tokenizer.bos_token_id + bos_token = torch.full((bsize, 1), bos_id, dtype=torch.long, device=device) + + # Embed BOS token + bos_token_emb = self.paligemma_with_expert.embed_language_tokens(bos_token) + bos_token_emb = bos_token_emb * math.sqrt(bos_token_emb.shape[-1]) + if prefix_embs.dtype == torch.bfloat16: + bos_token_emb = bos_token_emb.to(dtype=torch.bfloat16) + + # Track current sequence length for position IDs and maintain the padding mask + current_seq_len = prefix_embs.shape[1] + # Keep track of valid positions: prefix_pad_masks tells us which positions are valid + current_pad_mask = prefix_pad_masks.clone() # (B, seq_len) + + # Update padding mask for BOS token: it's always valid + current_pad_mask = torch.cat([ + current_pad_mask, + torch.ones((bsize, 1), dtype=torch.bool, device=device) + ], dim=1) # (B, seq_len+1) + + # Position ID for BOS token (continues from where prefix ended) + bos_position_id = torch.full((bsize, 1), current_seq_len, dtype=torch.long, device=device) + + # Attention mask for BOS token: attends to all valid prefix positions + bos_att_mask_2d = current_pad_mask.unsqueeze(1) # (B, 1, seq_len+1) + bos_att_4d = self._prepare_attention_masks_4d(bos_att_mask_2d, dtype=bos_token_emb.dtype) + + # Forward pass with BOS token (reusing cached KVs from prefix) + (bos_out, _), past_key_values = self.paligemma_with_expert.forward( + attention_mask=bos_att_4d, + position_ids=bos_position_id, + past_key_values=past_key_values, + inputs_embeds=[bos_token_emb, None], + use_cache=True, + adarms_cond=[None, None], + ) + + # Update sequence length to account for BOS token + current_seq_len += 1 + + # Predict first action token from BOS token output + last_logits = lm_head(bos_out[:, -1:, :]) # (B, 1, vocab_size) if temperature > 0: probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1) @@ -1602,11 +1665,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` generated_action_tokens[:, 0] = next_token.squeeze(-1) - # Track current sequence length for position IDs and maintain the padding mask - current_seq_len = prefix_embs.shape[1] - # Keep track of valid positions: prefix_pad_masks tells us which positions are valid - current_pad_mask = prefix_pad_masks.clone() # (B, seq_len) - # 3. Incremental Decoding Loop (using KV cache) for t in range(1, max_decoding_steps): # Embed the newly generated token @@ -2330,10 +2388,11 @@ class PI05Policy(PreTrainedPolicy): self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist()) for seq in tokens ] - + breakpoint() # Clean tokens by removing everything after the first "|" (end-of-action marker) cleaned_tokens = [] for token_seq in decoded_tokens: + # also remove the "Action:" prefix if "|" in token_seq: token_seq = token_seq[:token_seq.index("|")] cleaned_tokens.append(token_seq) diff --git a/src/lerobot/policies/pi05/train2.sh b/src/lerobot/policies/pi05/train2.sh index 879618d20..df7330870 100644 --- a/src/lerobot/policies/pi05/train2.sh +++ b/src/lerobot/policies/pi05/train2.sh @@ -5,7 +5,7 @@ lerobot-train \ --output_dir=/fsx/jade_choghari/outputs/pi0_multi_training \ --job_name=pi0_multi_training \ --policy.repo_id=jadechoghari/pi0-base1 \ - --policy.path=/fsx/jade_choghari/outputs/libero_training_fast_4/checkpoints/last/pretrained_model/ \ + --policy.path=/fsx/jade_choghari/outputs/libero_training_fast_6/checkpoints/last/pretrained_model/ \ --policy.dtype=bfloat16 \ --steps=50000 \ --save_freq=5000 \ diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index d1a1893f4..8f9489918 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -577,7 +577,10 @@ class ActionTokenizerProcessorStep(ActionProcessorStep): if tokens.dim() > 1: tokens = tokens.flatten() + bos_id = self._paligemma_tokenizer.bos_token_id + # add bos tokens = torch.cat([ + torch.tensor([bos_id], device=action.device), torch.tensor(self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False), device=action.device), self._act_tokens_to_paligemma_tokens(tokens), torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device),