mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 12:21:27 +00:00
add more changges
This commit is contained in:
@@ -2373,6 +2373,8 @@ class PI05Policy(PreTrainedPolicy):
|
||||
Returns:
|
||||
The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim)
|
||||
"""
|
||||
from transformers import AutoTokenizer
|
||||
self._paligemma_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224", trust_remote_code=True)
|
||||
if self.action_tokenizer is None or self._paligemma_tokenizer is None:
|
||||
raise ValueError(
|
||||
"Action tokenizer not initialized. Make sure fast_only=True in config and tokenizers loaded successfully."
|
||||
@@ -2388,13 +2390,28 @@ class PI05Policy(PreTrainedPolicy):
|
||||
self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
|
||||
for seq in tokens
|
||||
]
|
||||
breakpoint()
|
||||
# Get the token sequence for "Action: " to remove it
|
||||
action_prefix_ids = self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False)
|
||||
action_prefix_tokens = self._paligemma_tokenizer.convert_ids_to_tokens(action_prefix_ids)
|
||||
action_prefix_len = len(action_prefix_tokens)
|
||||
|
||||
# Clean tokens by removing everything after the first "|" (end-of-action marker)
|
||||
# and removing all occurrences of "Action: " token sequence
|
||||
cleaned_tokens = []
|
||||
for token_seq in decoded_tokens:
|
||||
# also remove the "Action:" prefix
|
||||
# Remove everything after "|"
|
||||
if "|" in token_seq:
|
||||
token_seq = token_seq[:token_seq.index("|")]
|
||||
|
||||
# Remove all occurrences of "Action: " token sequence
|
||||
i = 0
|
||||
while i <= len(token_seq) - action_prefix_len:
|
||||
if token_seq[i:i+action_prefix_len] == action_prefix_tokens:
|
||||
# Found a match, remove it
|
||||
token_seq = token_seq[:i] + token_seq[i+action_prefix_len:]
|
||||
else:
|
||||
i += 1
|
||||
|
||||
cleaned_tokens.append(token_seq)
|
||||
|
||||
# Convert token strings back to IDs
|
||||
|
||||
@@ -9,7 +9,7 @@ accelerate launch --mixed_precision=bf16 --multi_gpu --num_processes=8 \
|
||||
$(which lerobot-train) \
|
||||
--dataset.repo_id=local \
|
||||
--dataset.root=/fsx/jade_choghari/data/libero \
|
||||
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_mean \
|
||||
--output_dir=/fsx/jade_choghari/outputs/libero_training_fast_mean_1 \
|
||||
--job_name=libero_training_fast \
|
||||
--policy.repo_id=jade_choghari/pi05-fast-libero \
|
||||
--policy.path=/fsx/jade_choghari/models/pi05-base \
|
||||
@@ -23,6 +23,9 @@ accelerate launch --mixed_precision=bf16 --multi_gpu --num_processes=8 \
|
||||
--policy.scheduler_decay_steps=100000 \
|
||||
--policy.scheduler_decay_lr=1e-5 \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--policy.chunk_size=10 \
|
||||
--policy.n_action_steps=10 \
|
||||
--policy.max_action_tokens=256 \
|
||||
--rename_map='{
|
||||
"observation.images.image1": "observation.images.base_0_rgb",
|
||||
"observation.images.image2": "observation.images.left_wrist_0_rgb",
|
||||
|
||||
Reference in New Issue
Block a user