add more changges

This commit is contained in:
Jade Choghari
2025-12-27 21:15:30 +00:00
parent 7556c7fd70
commit 7d897daeb2
2 changed files with 23 additions and 3 deletions

View File

@@ -2373,6 +2373,8 @@ class PI05Policy(PreTrainedPolicy):
Returns:
The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim)
"""
from transformers import AutoTokenizer
self._paligemma_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224", trust_remote_code=True)
if self.action_tokenizer is None or self._paligemma_tokenizer is None:
raise ValueError(
"Action tokenizer not initialized. Make sure fast_only=True in config and tokenizers loaded successfully."
@@ -2388,13 +2390,28 @@ class PI05Policy(PreTrainedPolicy):
self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
for seq in tokens
]
breakpoint()
# Get the token sequence for "Action: " to remove it
action_prefix_ids = self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False)
action_prefix_tokens = self._paligemma_tokenizer.convert_ids_to_tokens(action_prefix_ids)
action_prefix_len = len(action_prefix_tokens)
# Clean tokens by removing everything after the first "|" (end-of-action marker)
# and removing all occurrences of "Action: " token sequence
cleaned_tokens = []
for token_seq in decoded_tokens:
# also remove the "Action:" prefix
# Remove everything after "|"
if "|" in token_seq:
token_seq = token_seq[:token_seq.index("|")]
# Remove all occurrences of "Action: " token sequence
i = 0
while i <= len(token_seq) - action_prefix_len:
if token_seq[i:i+action_prefix_len] == action_prefix_tokens:
# Found a match, remove it
token_seq = token_seq[:i] + token_seq[i+action_prefix_len:]
else:
i += 1
cleaned_tokens.append(token_seq)
# Convert token strings back to IDs

View File

@@ -9,7 +9,7 @@ accelerate launch --mixed_precision=bf16 --multi_gpu --num_processes=8 \
$(which lerobot-train) \
--dataset.repo_id=local \
--dataset.root=/fsx/jade_choghari/data/libero \
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_mean \
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_mean_1 \
--job_name=libero_training_fast \
--policy.repo_id=jade_choghari/pi05-fast-libero \
--policy.path=/fsx/jade_choghari/models/pi05-base \
@@ -23,6 +23,9 @@ accelerate launch --mixed_precision=bf16 --multi_gpu --num_processes=8 \
--policy.scheduler_decay_steps=100000 \
--policy.scheduler_decay_lr=1e-5 \
--policy.gradient_checkpointing=true \
--policy.chunk_size=10 \
--policy.n_action_steps=10 \
--policy.max_action_tokens=256 \
--rename_map='{
"observation.images.image1": "observation.images.base_0_rgb",
"observation.images.image2": "observation.images.left_wrist_0_rgb",