diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index bb94ebb33..e808f7ce6 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -2373,6 +2373,8 @@ class PI05Policy(PreTrainedPolicy): Returns: The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim) """ + from transformers import AutoTokenizer + self._paligemma_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224", trust_remote_code=True) if self.action_tokenizer is None or self._paligemma_tokenizer is None: raise ValueError( "Action tokenizer not initialized. Make sure fast_only=True in config and tokenizers loaded successfully." @@ -2388,13 +2390,28 @@ class PI05Policy(PreTrainedPolicy): self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist()) for seq in tokens ] - breakpoint() + # Get the token sequence for "Action: " to remove it + action_prefix_ids = self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False) + action_prefix_tokens = self._paligemma_tokenizer.convert_ids_to_tokens(action_prefix_ids) + action_prefix_len = len(action_prefix_tokens) + # Clean tokens by removing everything after the first "|" (end-of-action marker) + # and removing all occurrences of "Action: " token sequence cleaned_tokens = [] for token_seq in decoded_tokens: - # also remove the "Action:" prefix + # Remove everything after "|" if "|" in token_seq: token_seq = token_seq[:token_seq.index("|")] + + # Remove all occurrences of "Action: " token sequence + i = 0 + while i <= len(token_seq) - action_prefix_len: + if token_seq[i:i+action_prefix_len] == action_prefix_tokens: + # Found a match, remove it + token_seq = token_seq[:i] + token_seq[i+action_prefix_len:] + else: + i += 1 + cleaned_tokens.append(token_seq) # Convert token strings back to IDs diff --git a/src/lerobot/policies/pi05/train_multi.sh b/src/lerobot/policies/pi05/train_multi.sh index b4f5f5d7d..04f3e3bcc 100644 --- a/src/lerobot/policies/pi05/train_multi.sh +++ b/src/lerobot/policies/pi05/train_multi.sh @@ -9,7 +9,7 @@ accelerate launch --mixed_precision=bf16 --multi_gpu --num_processes=8 \ $(which lerobot-train) \ --dataset.repo_id=local \ --dataset.root=/fsx/jade_choghari/data/libero \ - --output_dir=/fsx/jade_choghari/outputs/libero_training_fast_mean \ + --output_dir=/fsx/jade_choghari/outputs/libero_training_fast_mean_1 \ --job_name=libero_training_fast \ --policy.repo_id=jade_choghari/pi05-fast-libero \ --policy.path=/fsx/jade_choghari/models/pi05-base \ @@ -23,6 +23,9 @@ accelerate launch --mixed_precision=bf16 --multi_gpu --num_processes=8 \ --policy.scheduler_decay_steps=100000 \ --policy.scheduler_decay_lr=1e-5 \ --policy.gradient_checkpointing=true \ + --policy.chunk_size=10 \ + --policy.n_action_steps=10 \ + --policy.max_action_tokens=256 \ --rename_map='{ "observation.images.image1": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb",