mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
Compare commits
1 Commits
feat/add-h
...
feat/test_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f8aa7d03b |
@@ -1,138 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
Example demonstrating how to use the ActionTokenizerProcessorStep to tokenize actions.
|
||||
|
||||
This example shows how to:
|
||||
1. Load a dataset with action data
|
||||
2. Apply the action tokenizer processor to tokenize actions with proper padding/truncation
|
||||
3. Access both the tokenized actions and the attention mask
|
||||
4. Decode tokenized actions back to their original form
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.processor.tokenizer_processor import ActionTokenizerProcessorStep
|
||||
from lerobot.utils.constants import ACTION_TOKEN_MASK
|
||||
|
||||
# Define delta timestamps for the dataset
|
||||
delta_timestamps = {
|
||||
'action': [
|
||||
0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333,
|
||||
0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3,
|
||||
0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335,
|
||||
0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6,
|
||||
0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333,
|
||||
0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9,
|
||||
0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334,
|
||||
1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2,
|
||||
1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333,
|
||||
1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5,
|
||||
1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333
|
||||
]
|
||||
}
|
||||
|
||||
# Load the dataset
|
||||
print("Loading dataset...")
|
||||
dataset = LeRobotDataset(
|
||||
repo_id="local",
|
||||
root="/fsx/jade_choghari/outputs/pgen_annotations1",
|
||||
delta_timestamps=delta_timestamps
|
||||
)
|
||||
|
||||
# Create a dataloader
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=4,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
# Get a batch of data
|
||||
batch = next(iter(dataloader))
|
||||
action_data = batch["action"] # Shape: (batch_size, action_horizon, action_dim)
|
||||
|
||||
print(f"\nOriginal action shape: {action_data.shape}")
|
||||
print(f"Original action data (first sample, first timestep):\n{action_data[0, 0]}")
|
||||
|
||||
# Method 1: Using the tokenizer directly (as in fast_tokenize.py)
|
||||
print("\n" + "="*80)
|
||||
print("Method 1: Direct tokenizer usage")
|
||||
print("="*80)
|
||||
|
||||
tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
|
||||
|
||||
# Tokenize directly
|
||||
tokens = tokenizer(action_data)
|
||||
print(f"\nDirect tokenization result type: {type(tokens)}")
|
||||
print(f"Tokens shape/length: {tokens.shape if isinstance(tokens, torch.Tensor) else len(tokens)}")
|
||||
|
||||
# Decode
|
||||
decoded_actions = tokenizer.decode(tokens)
|
||||
print(f"Decoded actions shape: {decoded_actions.shape}")
|
||||
reconstruction_error = torch.abs(action_data - decoded_actions).mean()
|
||||
print(f"Mean absolute reconstruction error: {reconstruction_error.item():.6f}")
|
||||
|
||||
# Method 2: Using the ActionTokenizerProcessorStep with proper padding/truncation
|
||||
print("\n" + "="*80)
|
||||
print("Method 2: Using ActionTokenizerProcessorStep (with padding & mask)")
|
||||
print("="*80)
|
||||
|
||||
# Create the action tokenizer processor step
|
||||
action_tokenizer_processor = ActionTokenizerProcessorStep(
|
||||
tokenizer_name="physical-intelligence/fast",
|
||||
trust_remote_code=True,
|
||||
max_action_tokens=32, # Maximum number of tokens per action
|
||||
)
|
||||
|
||||
# Create a transition with the action data
|
||||
transition = {
|
||||
TransitionKey.ACTION: action_data,
|
||||
TransitionKey.OBSERVATION: {}, # Empty for this example
|
||||
}
|
||||
|
||||
# Apply the processor
|
||||
processed_transition = action_tokenizer_processor(transition)
|
||||
|
||||
# Extract tokenized actions and mask
|
||||
tokenized_actions = processed_transition[TransitionKey.ACTION]
|
||||
complementary_data = processed_transition[TransitionKey.COMPLEMENTARY_DATA]
|
||||
action_mask = complementary_data[ACTION_TOKEN_MASK]
|
||||
|
||||
print(f"\nTokenized actions shape: {tokenized_actions.shape}") # (batch_size, max_action_tokens)
|
||||
print(f"Action mask shape: {action_mask.shape}") # (batch_size, max_action_tokens)
|
||||
print(f"Tokenized actions dtype: {tokenized_actions.dtype}")
|
||||
print(f"Action mask dtype: {action_mask.dtype}")
|
||||
|
||||
# Show token statistics
|
||||
print(f"\nFirst sample tokens: {tokenized_actions[0]}")
|
||||
print(f"First sample mask: {action_mask[0]}")
|
||||
num_real_tokens = action_mask[0].sum().item()
|
||||
print(f"Number of real tokens (non-padding): {num_real_tokens}")
|
||||
print(f"Number of padding tokens: {action_mask.shape[1] - num_real_tokens}")
|
||||
|
||||
# Decode using the mask
|
||||
print("\nDecoding tokenized actions...")
|
||||
decoded_with_processor = tokenizer.decode(tokenized_actions)
|
||||
print(f"Decoded actions shape: {decoded_with_processor.shape}")
|
||||
|
||||
# Calculate reconstruction error
|
||||
reconstruction_error_processor = torch.abs(action_data - decoded_with_processor).mean()
|
||||
print(f"Mean absolute reconstruction error: {reconstruction_error_processor.item():.6f}")
|
||||
|
||||
# Show that masking works correctly
|
||||
print("\n" + "="*80)
|
||||
print("Mask demonstration")
|
||||
print("="*80)
|
||||
for i in range(min(4, tokenized_actions.shape[0])):
|
||||
mask_i = action_mask[i]
|
||||
num_real = mask_i.sum().item()
|
||||
print(f"Sample {i}: {num_real} real tokens, {len(mask_i) - num_real} padding tokens")
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("Action tokenization example completed successfully!")
|
||||
print("="*80)
|
||||
|
||||
@@ -1402,13 +1402,6 @@ def main():
|
||||
action="store_true",
|
||||
help="Push modified dataset to HuggingFace Hub",
|
||||
)
|
||||
# add image key
|
||||
parser.add_argument(
|
||||
"--image-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Image observation key to use for image mode (default: None)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
console = Console()
|
||||
@@ -1450,10 +1443,7 @@ def main():
|
||||
)
|
||||
|
||||
# Get image keys (for image mode)
|
||||
if args.image_key:
|
||||
image_keys = [args.image_key]
|
||||
else:
|
||||
image_keys = dataset.meta.camera_keys[:args.num_image_views_per_sample]
|
||||
image_keys = dataset.meta.camera_keys[:args.num_image_views_per_sample]
|
||||
if not args.video_mode:
|
||||
console.print(f"[cyan]Using image keys: {image_keys}[/cyan]")
|
||||
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
import numpy as np
|
||||
from transformers import AutoProcessor
|
||||
import torch
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
|
||||
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
|
||||
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1", delta_timestamps=delta_timestamps)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=4,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
# Load the tokenizer from the Hugging Face hub
|
||||
tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
|
||||
|
||||
# Tokenize & decode action chunks (we use dummy data here)
|
||||
action_data = batch["action"] # one batch of action chunks
|
||||
tokens = tokenizer(action_data) # tokens = list[int]
|
||||
decoded_actions = tokenizer.decode(tokens)
|
||||
print("tokenized actions: ", tokens)
|
||||
@@ -10,19 +10,17 @@ from lerobot.policies.factory import make_policy, make_policy_config
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path="/fsx/jade_choghari/outputs/pi0_training/checkpoints/last/pretrained_model",
|
||||
pretrained_name_or_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model",
|
||||
)
|
||||
cfg.dtype = "bfloat16"
|
||||
|
||||
pre_processor, post_processor = make_pre_post_processors(
|
||||
policy_cfg=cfg,
|
||||
pretrained_path="/fsx/jade_choghari/outputs/pi0_training/checkpoints/last/pretrained_model",
|
||||
pretrained_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model",
|
||||
)
|
||||
|
||||
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
|
||||
|
||||
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1", delta_timestamps=delta_timestamps)
|
||||
|
||||
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1")
|
||||
# rename map --rename_map='{
|
||||
# "observation.images.side": "observation.images.base_0_rgb",
|
||||
# "observation.images.up": "observation.images.left_wrist_0_rgb"
|
||||
@@ -45,47 +43,16 @@ dataloader = torch.utils.data.DataLoader(
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
batch = pre_processor(batch)
|
||||
|
||||
# Test training forward pass
|
||||
policy.train()
|
||||
# run inference
|
||||
# action = policy.select_action(batch)
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
breakpoint()
|
||||
# import requests
|
||||
# from PIL import Image
|
||||
# from transformers import AutoProcessor
|
||||
# model = policy.model.paligemma_with_expert.paligemma
|
||||
# model = model.to(device="cuda", dtype=torch.bfloat16)
|
||||
# model.eval()
|
||||
# prompt = "Describe this image."
|
||||
# url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
# image = Image.open(requests.get(url, stream=True).raw)
|
||||
# processor = AutoProcessor.from_pretrained(
|
||||
# "google/paligemma-3b-pt-224",
|
||||
# )
|
||||
# inputs = processor(image, prompt, return_tensors="pt").to(model.device)
|
||||
# print("generating...")
|
||||
# output = model.generate(
|
||||
# **inputs,
|
||||
# max_new_tokens=50,
|
||||
# use_cache=True, # default dynamic cache
|
||||
# )
|
||||
# print(processor.decode(output[0], skip_special_tokens=True))
|
||||
print(f"Training loss: {loss_dict}")
|
||||
|
||||
|
||||
# # other model
|
||||
# from transformers import PaliGemmaForConditionalGeneration
|
||||
# model = PaliGemmaForConditionalGeneration.from_pretrained(
|
||||
# "google/paligemma2-3b-pt-224",
|
||||
# torch_dtype=torch.bfloat16,
|
||||
# device_map="auto",
|
||||
# )
|
||||
# model.eval()
|
||||
# print("generating...")
|
||||
# output = model.generate(
|
||||
# **inputs,
|
||||
# max_new_tokens=100,
|
||||
# use_cache=True, # default dynamic cache
|
||||
# )
|
||||
# print("Model 2 output:")
|
||||
# print(processor.decode(output[0], skip_special_tokens=True))
|
||||
# Test inference
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
actions = policy.predict_action_chunk(batch)
|
||||
print(f"Predicted actions shape: {actions.shape}")
|
||||
@@ -1,159 +0,0 @@
|
||||
## One-sentence answer
|
||||
|
||||
> `make_att_2d_masks(prefix_pad_masks, prefix_att_masks)` builds the **actual 2D attention mask** `[B, L, L]` that tells the transformer **which token positions may attend to which others**, combining **padding** and **causality**.
|
||||
|
||||
Everything else you’ve seen so far was just metadata.
|
||||
|
||||
---
|
||||
|
||||
## What goes in
|
||||
|
||||
### Inputs
|
||||
|
||||
```python
|
||||
prefix_pad_masks # shape [B, L]
|
||||
prefix_att_masks # shape [B, L]
|
||||
```
|
||||
|
||||
Where:
|
||||
|
||||
* `prefix_pad_masks[b, i] = True`
|
||||
→ token `i` exists (not padding)
|
||||
|
||||
* `prefix_att_masks[b, i] = False`
|
||||
→ token `i` is **bidirectional**
|
||||
|
||||
* `prefix_att_masks[b, i] = True`
|
||||
→ token `i` is **causal (autoregressive)**
|
||||
|
||||
---
|
||||
|
||||
## What comes out
|
||||
|
||||
```python
|
||||
att_2d_prefix # shape [B, L, L]
|
||||
```
|
||||
|
||||
Each entry:
|
||||
|
||||
```text
|
||||
att_2d_prefix[b, i, j] = True
|
||||
```
|
||||
|
||||
means:
|
||||
|
||||
> “In batch `b`, **token i (query)** is allowed to attend to **token j (key)**.”
|
||||
|
||||
---
|
||||
|
||||
## How it is constructed (conceptually)
|
||||
|
||||
For **each batch b**, **each query position i**, **each key position j**:
|
||||
|
||||
```python
|
||||
if not prefix_pad_masks[b, j]:
|
||||
att[b, i, j] = False # cannot attend to padding
|
||||
else if not prefix_att_masks[b, i]:
|
||||
att[b, i, j] = True # bidirectional token → can see all real tokens
|
||||
else:
|
||||
att[b, i, j] = (j <= i) # causal token → can see only past + itself
|
||||
```
|
||||
|
||||
That’s it.
|
||||
|
||||
---
|
||||
|
||||
## Tiny concrete example (exactly matching your code)
|
||||
|
||||
Suppose:
|
||||
|
||||
```python
|
||||
prefix_pad_masks[0] = [T, T, T, T, T, F]
|
||||
prefix_att_masks[0] = [F, F, F, T, T, T]
|
||||
```
|
||||
|
||||
Tokens:
|
||||
|
||||
```
|
||||
0: IMG
|
||||
1: IMG
|
||||
2: LANG
|
||||
3: SUB0
|
||||
4: SUB1
|
||||
5: PAD
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Resulting `att_2d_prefix[0]`
|
||||
|
||||
`✓ = True, ✗ = False`
|
||||
|
||||
| Q \ K | 0 | 1 | 2 | 3 | 4 | 5 |
|
||||
| ---------- | - | - | - | - | - | - |
|
||||
| 0 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||||
| 1 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||||
| 2 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||||
| 3 (causal) | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ |
|
||||
| 4 (causal) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
|
||||
| 5 (pad) | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ |
|
||||
|
||||
---
|
||||
|
||||
## Why this matters for your training code
|
||||
|
||||
This line:
|
||||
|
||||
```python
|
||||
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix)
|
||||
```
|
||||
|
||||
Converts `[B, L, L] → [B, 1, L, L]` and possibly flips True/False to `0/-inf`.
|
||||
|
||||
This is **exactly what Paligemma uses inside self-attention**.
|
||||
|
||||
---
|
||||
|
||||
## Key implications (VERY important)
|
||||
|
||||
### 1️⃣ This mask does **not isolate token groups**
|
||||
|
||||
* Bidirectional tokens can attend to **everything**
|
||||
* Causal tokens only restrict *their own row*
|
||||
|
||||
So **flow/action tokens must be blocked separately**.
|
||||
|
||||
---
|
||||
|
||||
### 2️⃣ This is why your AR subtask prediction works
|
||||
|
||||
* Subtask tokens are causal
|
||||
* Output at position `i` predicts token `i+1`
|
||||
* Padding is fully ignored
|
||||
|
||||
---
|
||||
|
||||
### 3️⃣ Inference behavior
|
||||
|
||||
When `subtask_tokens = None`:
|
||||
|
||||
* `prefix_att_masks` contains only `False`
|
||||
* `att_2d_prefix` becomes **fully bidirectional**
|
||||
* No AR behavior remains
|
||||
|
||||
Exactly what you want.
|
||||
|
||||
---
|
||||
|
||||
## One-sentence takeaway (commit this)
|
||||
|
||||
> `make_att_2d_masks` fuses **padding** and **causality** into a concrete `[B, L, L]` attention matrix that the transformer actually uses.
|
||||
|
||||
If you want next, I can:
|
||||
|
||||
* inspect `make_att_2d_masks()` source with you
|
||||
* show how to block **flow → subtask** attention
|
||||
* explain how this changes when suffix tokens are added
|
||||
* help you refactor this into a cleaner “grouped attention” API
|
||||
|
||||
You’re now at the point where the model’s behavior should feel *predictable*, not magical.
|
||||
@@ -1,11 +1,10 @@
|
||||
python examples/dataset/annotate.py \
|
||||
--repo-id jadechoghari/collect-data \
|
||||
--video-key observation.images.base \
|
||||
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
|
||||
--episodes 16 22
|
||||
|
||||
# python examples/dataset/annotate.py \
|
||||
# --repo-id lerobot/svla_so101_pickplace \
|
||||
# --video-key observation.images.side \
|
||||
# --model Qwen/Qwen3-VL-30B-A3B-Instruct \
|
||||
# --episodes 5
|
||||
|
||||
python examples/dataset/annotate.py \
|
||||
--repo-id lerobot/svla_so101_pickplace \
|
||||
--video-key observation.images.side \
|
||||
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
|
||||
--episodes 3 5 7 44
|
||||
@@ -4,12 +4,12 @@
|
||||
# This generates user prompts and robot utterances for hierarchical policy training
|
||||
|
||||
# Configuration
|
||||
REPO_ID="jadechoghari/collect-data"
|
||||
REPO_ID="lerobot/svla_so101_pickplace"
|
||||
MODEL="Qwen/Qwen3-VL-30B-A3B-Instruct"
|
||||
# Alternative: MODEL="Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
|
||||
OUTPUT_DIR="/fsx/jade_choghari/outputs/collect-data-pgen"
|
||||
OUTPUT_DIR="/fsx/jade_choghari/outputs/pgen_annotations1"
|
||||
BATCH_SIZE=32
|
||||
TEMPERATURE=0.9
|
||||
SAMPLE_INTERVAL=5.0 # Generate dialogue every 1 second (all episodes processed)
|
||||
@@ -22,7 +22,6 @@ python examples/dataset/annotate_pgen.py \
|
||||
--temperature "$TEMPERATURE" \
|
||||
--batch-size "$BATCH_SIZE" \
|
||||
--sample-interval "$SAMPLE_INTERVAL" \
|
||||
--image-key observation.images.base \
|
||||
--num-image-views-per-sample 1
|
||||
|
||||
# For faster testing, increase sample interval:
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
srun --time 12:00:00 --qos=high --gres=gpu:1 --mem=24G --partition=hopper-prod --container-image /fsx/michel_aractingi/docker_images/huggingface+lerobot-gpu+dev.sqsh --container-mounts /fsx/jade_choghari
|
||||
@@ -1,47 +0,0 @@
|
||||
# Voice Assistant Examples
|
||||
|
||||
Voice-enabled robot assistant examples using speech-to-text (STT), and text-to-speech (TTS).
|
||||
|
||||
## Overview
|
||||
|
||||
These examples demonstrate how to build a voice interface for robot control:
|
||||
|
||||
1. **Hold SPACE** → Push-to-talk recording starts
|
||||
2. **Release SPACE** → Recording stops
|
||||
3. **STT (Whisper)** → Converts speech to text (high-level task prompt)
|
||||
4. **Pi0.5** → Generates robot response/utterance
|
||||
5. **TTS (Kokoro)** → Speaks the response back
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install torch transformers sounddevice numpy pynput kokoro>=0.9.2
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### With Pi0.5 Model
|
||||
|
||||
```bash
|
||||
python examples/voice_assistant/voice_assistant_pi05.py \
|
||||
--pretrained_path path/to/pi05/checkpoint
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### Pi0.5 Voice Integration
|
||||
|
||||
Pi0.5 can generate robot utterances as part of its subtask prediction. The flow:
|
||||
|
||||
1. **High-level prompt**: User voice command is transcribed and formatted as a task prompt
|
||||
2. **Subtask generation**: Pi0.5 autoregressively generates a response
|
||||
3. **Utterance extraction**: If the response contains `<utterance>...</utterance>` tags, the content is extracted
|
||||
4. **TTS output**: The response is spoken back to the user
|
||||
|
||||
## Configuration Options
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `--pretrained_path` | None | Path to Pi0.5 checkpoint |
|
||||
| `--record_seconds` | 5.0 | Audio recording duration |
|
||||
| `--max_response_tokens` | 100 | Max tokens in generated response |
|
||||
@@ -1,336 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Voice Assistant with Pi0.5: Microphone → STT → Pi0.5 → TTS → Speaker
|
||||
|
||||
This example demonstrates how to use Pi0.5 as a conversational robot assistant:
|
||||
1. Hold SPACE to record your voice command
|
||||
2. Speech-to-text (Whisper) converts speech to text
|
||||
3. Text is fed as a high-level prompt to Pi0.5
|
||||
4. Pi0.5 generates a response (robot utterance)
|
||||
5. Text-to-speech (Kokoro) speaks the response back
|
||||
|
||||
Requirements:
|
||||
pip install torch transformers sounddevice numpy pynput kokoro>=0.9.2
|
||||
|
||||
Usage:
|
||||
python examples/voice_assistant/voice_assistant_pi05.py \
|
||||
--pretrained_path lerobot/pi0.5-base
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import sounddevice as sd
|
||||
import torch
|
||||
from pynput import keyboard
|
||||
from transformers import AutoTokenizer, WhisperForConditionalGeneration, WhisperProcessor
|
||||
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05.modeling_pi05 import PI05Pytorch
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
|
||||
def get_device():
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
class Pi05VoiceAssistant:
|
||||
"""Voice assistant using Pi0.5 for generating robot utterances."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pretrained_path: str | None = None,
|
||||
max_response_tokens: int = 100,
|
||||
max_record_seconds: float = 30.0,
|
||||
):
|
||||
self.device = get_device()
|
||||
self.dtype = torch.float32 if self.device.type == "mps" else torch.bfloat16
|
||||
self.max_response_tokens = max_response_tokens
|
||||
self.max_record_seconds = max_record_seconds
|
||||
|
||||
# Push-to-talk state
|
||||
self._recording = False
|
||||
self._audio_chunks: list[np.ndarray] = []
|
||||
self._stream: sd.InputStream | None = None
|
||||
|
||||
print(f"Using device: {self.device}")
|
||||
self._load_models(pretrained_path)
|
||||
|
||||
def _load_models(self, pretrained_path: str | None):
|
||||
print("Loading STT (Whisper tiny)...")
|
||||
self.stt_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
self.stt_model = WhisperForConditionalGeneration.from_pretrained(
|
||||
"openai/whisper-tiny.en", torch_dtype=self.dtype
|
||||
).to(self.device)
|
||||
|
||||
print("Loading Pi0.5 model...")
|
||||
self._load_pi05(pretrained_path)
|
||||
|
||||
print("Loading tokenizer...")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
|
||||
self._load_tts()
|
||||
print("Ready!\n")
|
||||
|
||||
def _load_pi05(self, pretrained_path: str | None):
|
||||
"""Load Pi0.5 model for utterance generation."""
|
||||
config = PI05Config()
|
||||
config.dtype = "float32" if self.device.type == "mps" else "bfloat16"
|
||||
|
||||
self.pi05_model = PI05Pytorch(config)
|
||||
|
||||
if pretrained_path:
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
state_dict = load_file(f"{pretrained_path}/model.safetensors")
|
||||
self.pi05_model.load_state_dict(state_dict, strict=False)
|
||||
print(f"✓ Loaded Pi0.5 weights from {pretrained_path}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not load pretrained weights: {e}")
|
||||
print("Using randomly initialized model for demo purposes")
|
||||
|
||||
self.pi05_model = self.pi05_model.to(self.device)
|
||||
self.pi05_model.eval()
|
||||
|
||||
def _load_tts(self):
|
||||
try:
|
||||
print("Loading TTS (Kokoro 82M)...")
|
||||
from kokoro import KPipeline
|
||||
|
||||
self.tts_pipeline = KPipeline(lang_code="a") # American English
|
||||
self.tts_voice = "af_heart"
|
||||
self.tts_type = "kokoro"
|
||||
print("Kokoro loaded!")
|
||||
except Exception as e:
|
||||
print(f"Kokoro not available ({e})")
|
||||
print("Using macOS `say` for TTS")
|
||||
self.tts_pipeline = None
|
||||
self.tts_type = "system"
|
||||
|
||||
def _audio_callback(self, indata, frames, time_info, status):
|
||||
"""Callback for audio stream - collects chunks while recording."""
|
||||
if self._recording:
|
||||
self._audio_chunks.append(indata.copy())
|
||||
|
||||
def _start_recording(self):
|
||||
"""Start recording audio."""
|
||||
if self._recording:
|
||||
return
|
||||
self._recording = True
|
||||
self._audio_chunks = []
|
||||
print("🎤 Recording... (release SPACE to stop)")
|
||||
|
||||
def _stop_recording(self) -> np.ndarray | None:
|
||||
"""Stop recording and return the audio."""
|
||||
if not self._recording:
|
||||
return None
|
||||
self._recording = False
|
||||
|
||||
if not self._audio_chunks:
|
||||
return None
|
||||
|
||||
audio = np.concatenate(self._audio_chunks, axis=0).flatten()
|
||||
duration = len(audio) / SAMPLE_RATE
|
||||
volume = np.abs(audio).max()
|
||||
print(f"Recorded {duration:.1f}s, volume: {volume:.4f}")
|
||||
|
||||
if volume < 0.001:
|
||||
print("⚠️ Very low audio - check microphone permissions!")
|
||||
return None
|
||||
|
||||
return audio
|
||||
|
||||
def wait_for_spacebar(self) -> np.ndarray | None:
|
||||
"""Wait for spacebar press, record while held, return audio on release."""
|
||||
audio_result = None
|
||||
recording_done = threading.Event()
|
||||
|
||||
def on_press(key):
|
||||
if key == keyboard.Key.space:
|
||||
self._start_recording()
|
||||
|
||||
def on_release(key):
|
||||
nonlocal audio_result
|
||||
if key == keyboard.Key.space and self._recording:
|
||||
audio_result = self._stop_recording()
|
||||
recording_done.set()
|
||||
return False # Stop listener
|
||||
|
||||
# Start audio stream
|
||||
self._stream = sd.InputStream(
|
||||
samplerate=SAMPLE_RATE,
|
||||
channels=1,
|
||||
dtype="float32",
|
||||
callback=self._audio_callback,
|
||||
blocksize=int(SAMPLE_RATE * 0.1), # 100ms blocks
|
||||
)
|
||||
|
||||
with self._stream:
|
||||
print("\n⏳ Press and hold SPACE to speak...")
|
||||
with keyboard.Listener(on_press=on_press, on_release=on_release) as listener:
|
||||
# Wait for recording to complete or timeout
|
||||
recording_done.wait(timeout=self.max_record_seconds)
|
||||
if self._recording:
|
||||
audio_result = self._stop_recording()
|
||||
|
||||
return audio_result
|
||||
|
||||
def transcribe(self, audio: np.ndarray) -> str:
|
||||
start = time.perf_counter()
|
||||
inputs = self.stt_processor(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt")
|
||||
input_features = inputs.input_features.to(self.device, dtype=self.dtype)
|
||||
tokens = self.stt_model.generate(input_features)
|
||||
text = self.stt_processor.batch_decode(tokens, skip_special_tokens=True)[0]
|
||||
print(f"STT: {time.perf_counter() - start:.2f}s")
|
||||
return text.strip()
|
||||
|
||||
def _create_dummy_images(self, batch_size: int = 1) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||
"""Create placeholder images for Pi0.5 when no camera is available."""
|
||||
image_shape = (batch_size, 3, 224, 224)
|
||||
dummy_image = torch.zeros(image_shape, dtype=torch.float32, device=self.device)
|
||||
dummy_mask = torch.ones(batch_size, dtype=torch.bool, device=self.device)
|
||||
return [dummy_image], [dummy_mask]
|
||||
|
||||
def _tokenize_prompt(self, text: str) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Tokenize the user prompt for Pi0.5."""
|
||||
prompt = f"User request: {text}\nRobot response:"
|
||||
tokenized = self.tokenizer(
|
||||
[prompt],
|
||||
max_length=200,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = tokenized["input_ids"].to(self.device)
|
||||
masks = tokenized["attention_mask"].to(self.device, dtype=torch.bool)
|
||||
return tokens, masks
|
||||
|
||||
def generate_response(self, user_text: str) -> str:
|
||||
"""Generate robot utterance using Pi0.5's language generation."""
|
||||
start = time.perf_counter()
|
||||
|
||||
images, img_masks = self._create_dummy_images()
|
||||
tokens, masks = self._tokenize_prompt(user_text)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_tokens = self.pi05_model._generate_subtask_tokens(
|
||||
images=images,
|
||||
img_masks=img_masks,
|
||||
tokens=tokens,
|
||||
masks=masks,
|
||||
tokenizer=self.tokenizer,
|
||||
max_length=self.max_response_tokens,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Decode generated tokens
|
||||
valid_tokens = generated_tokens[0][generated_tokens[0] != 0]
|
||||
response = self.tokenizer.decode(valid_tokens, skip_special_tokens=True)
|
||||
|
||||
# Extract utterance if marked with special tokens
|
||||
response = self._extract_utterance(response)
|
||||
|
||||
print(f"Pi0.5: {time.perf_counter() - start:.2f}s")
|
||||
return response.strip()
|
||||
|
||||
def _extract_utterance(self, text: str) -> str:
|
||||
"""Extract utterance from between <utterance> tokens if present."""
|
||||
pattern = r"<utterance>(.*?)</utterance>"
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
return text
|
||||
|
||||
def speak(self, text: str):
|
||||
start = time.perf_counter()
|
||||
if self.tts_type == "kokoro":
|
||||
generator = self.tts_pipeline(text, voice=self.tts_voice)
|
||||
audio_chunks = [audio for _, _, audio in generator]
|
||||
if audio_chunks:
|
||||
audio = np.concatenate(audio_chunks)
|
||||
sd.play(audio, 24000)
|
||||
sd.wait()
|
||||
else:
|
||||
subprocess.run(["say", text], check=True)
|
||||
print(f"TTS: {time.perf_counter() - start:.2f}s")
|
||||
|
||||
def run(self):
|
||||
print("=" * 50)
|
||||
print("Pi0.5 Voice Assistant")
|
||||
print("=" * 50)
|
||||
print("• Hold SPACE to record your voice command")
|
||||
print("• Release SPACE when done speaking")
|
||||
print("• Press Ctrl+C to exit")
|
||||
print("=" * 50)
|
||||
|
||||
while True:
|
||||
try:
|
||||
audio = self.wait_for_spacebar()
|
||||
|
||||
if audio is None:
|
||||
print("(no audio captured)\n")
|
||||
continue
|
||||
|
||||
user_text = self.transcribe(audio)
|
||||
|
||||
if not user_text:
|
||||
print("(no speech detected)\n")
|
||||
continue
|
||||
|
||||
print(f"You: {user_text}")
|
||||
|
||||
response = self.generate_response(user_text)
|
||||
print(f"Robot: {response}\n")
|
||||
|
||||
self.speak(response)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye!")
|
||||
break
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Pi0.5 Voice Assistant")
|
||||
parser.add_argument(
|
||||
"--pretrained_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to pretrained Pi0.5 model (optional)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_response_tokens",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum tokens in generated response",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_record_seconds",
|
||||
type=float,
|
||||
default=30.0,
|
||||
help="Maximum recording duration in seconds",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assistant = Pi05VoiceAssistant(
|
||||
pretrained_path=args.pretrained_path,
|
||||
max_response_tokens=args.max_response_tokens,
|
||||
max_record_seconds=args.max_record_seconds,
|
||||
)
|
||||
assistant.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,26 +0,0 @@
|
||||
{
|
||||
"repo_id": "local",
|
||||
"vocab_size": 1024,
|
||||
"scale": 10.0,
|
||||
"encoded_dims": "0:15",
|
||||
"encoded_dim_ranges": [
|
||||
[
|
||||
0,
|
||||
15
|
||||
]
|
||||
],
|
||||
"total_encoded_dims": 15,
|
||||
"delta_dims": null,
|
||||
"delta_dim_list": null,
|
||||
"use_delta_transform": false,
|
||||
"state_key": "observation.state",
|
||||
"action_horizon": 50,
|
||||
"num_training_chunks": 4900,
|
||||
"compression_stats": {
|
||||
"compression_ratio": 15.85791309863622,
|
||||
"mean_token_length": 47.295,
|
||||
"p99_token_length": 90.0,
|
||||
"min_token_length": 9.0,
|
||||
"max_token_length": 109.0
|
||||
}
|
||||
}
|
||||
@@ -1,158 +0,0 @@
|
||||
import logging
|
||||
from typing import ClassVar
|
||||
|
||||
import numpy as np
|
||||
from scipy.fft import dct
|
||||
from scipy.fft import idct
|
||||
from tokenizers import ByteLevelBPETokenizer
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
|
||||
class UniversalActionProcessor(ProcessorMixin):
|
||||
attributes: ClassVar[list[str]] = ["bpe_tokenizer"]
|
||||
bpe_tokenizer_class: str = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bpe_tokenizer: PreTrainedTokenizerFast,
|
||||
scale: float = 10,
|
||||
vocab_size: int = 1024,
|
||||
min_token: int = 0,
|
||||
*,
|
||||
action_dim: int | None = None,
|
||||
time_horizon: int | None = None,
|
||||
):
|
||||
self.scale = scale
|
||||
self.vocab_size = vocab_size
|
||||
self.min_token = min_token
|
||||
|
||||
# Action horizon and dimension needed during decoding. These can be specified
|
||||
# in three ways (in order of priority):
|
||||
# 1. passed in as kwargs to decode()
|
||||
# 2. in the constructor
|
||||
# 3. cached from the last time decode() was called
|
||||
self.time_horizon = time_horizon
|
||||
self.action_dim = action_dim
|
||||
self.called_time_horizon = time_horizon
|
||||
self.called_action_dim = action_dim
|
||||
|
||||
super().__init__(bpe_tokenizer)
|
||||
|
||||
def __call__(self, action_chunk: np.array) -> np.array:
|
||||
assert action_chunk.ndim <= 3, "Only 3 dimensions supported: [batch, timesteps, action_dim]"
|
||||
if action_chunk.ndim == 2:
|
||||
action_chunk = action_chunk[None, ...]
|
||||
|
||||
# Cache the time horizon and action dimension for decoding
|
||||
self.called_time_horizon = action_chunk.shape[-2]
|
||||
self.called_action_dim = action_chunk.shape[-1]
|
||||
|
||||
dct_coeff = dct(action_chunk, axis=1, norm="ortho")
|
||||
dct_coeff = np.around(dct_coeff * self.scale)
|
||||
tokens = []
|
||||
for elem in dct_coeff:
|
||||
token_str = "".join(map(chr, np.maximum(elem.flatten() - self.min_token, 0).astype(int)))
|
||||
tokens.append(self.bpe_tokenizer(token_str)["input_ids"])
|
||||
return tokens
|
||||
|
||||
def decode(
|
||||
self,
|
||||
tokens: list[list[int]],
|
||||
*,
|
||||
time_horizon: int | None = None,
|
||||
action_dim: int | None = None,
|
||||
) -> np.array:
|
||||
self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon
|
||||
self.action_dim = action_dim or self.action_dim or self.called_action_dim
|
||||
|
||||
# Cache the time horizon and action dimension for the next call
|
||||
self.called_time_horizon = self.time_horizon
|
||||
self.called_action_dim = self.action_dim
|
||||
|
||||
assert (
|
||||
self.time_horizon is not None and self.action_dim is not None
|
||||
), "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
|
||||
|
||||
decoded_actions = []
|
||||
for token in tokens:
|
||||
try:
|
||||
decoded_tokens = self.bpe_tokenizer.decode(token)
|
||||
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.min_token
|
||||
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
|
||||
assert (
|
||||
decoded_dct_coeff.shape
|
||||
== (
|
||||
self.time_horizon,
|
||||
self.action_dim,
|
||||
)
|
||||
), f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
|
||||
except Exception as e:
|
||||
print(f"Error decoding tokens: {e}")
|
||||
print(f"Tokens: {token}")
|
||||
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
|
||||
decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho"))
|
||||
return np.stack(decoded_actions)
|
||||
|
||||
@classmethod
|
||||
def fit(
|
||||
cls,
|
||||
action_data: list[np.array],
|
||||
scale: float = 10,
|
||||
vocab_size: int = 1024,
|
||||
*,
|
||||
time_horizon: int | None = None,
|
||||
action_dim: int | None = None,
|
||||
) -> "UniversalActionProcessor":
|
||||
# Run DCT over all inputs
|
||||
dct_tokens = [dct(a, axis=0, norm="ortho").flatten() for a in action_data]
|
||||
|
||||
# Quantize and find min token
|
||||
max_token = int(np.around(np.concatenate(dct_tokens) * scale).max())
|
||||
min_token = int(np.around(np.concatenate(dct_tokens) * scale).min())
|
||||
min_vocab_size = max_token - min_token
|
||||
|
||||
assert (
|
||||
min_vocab_size <= vocab_size
|
||||
), f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}"
|
||||
if min_vocab_size + 100 > vocab_size:
|
||||
logging.warning(
|
||||
f"Initial alphabet size {min_vocab_size} is almost as large as the vocab"
|
||||
f"size {vocab_size}, consider increasing vocab size"
|
||||
)
|
||||
|
||||
# Make token iterator for BPE training
|
||||
def _token_iter():
|
||||
for tokens in dct_tokens:
|
||||
rounded_tokens = np.around(tokens * scale) - min_token
|
||||
rounded_tokens = rounded_tokens.astype(int)
|
||||
string = "".join(map(chr, rounded_tokens))
|
||||
yield string
|
||||
|
||||
# Train BPE tokenizer
|
||||
bpe = ByteLevelBPETokenizer()
|
||||
|
||||
# Set up the entire range of possible tokens as the initial alphabet
|
||||
alphabet = [chr(i) for i in range(max_token - min_token + 1)]
|
||||
trainer = BpeTrainer(
|
||||
vocab_size=vocab_size,
|
||||
min_frequency=2,
|
||||
show_progress=True,
|
||||
special_tokens=[],
|
||||
initial_alphabet=alphabet,
|
||||
max_token_length=10000,
|
||||
)
|
||||
|
||||
# Train the inner tokenizer (don't use ByteLevelBPETokenizer.train_from_iterator()
|
||||
# because it doesn't support custom alphabets)
|
||||
bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer)
|
||||
|
||||
return cls(
|
||||
PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False),
|
||||
scale=scale,
|
||||
vocab_size=vocab_size,
|
||||
min_token=min_token,
|
||||
time_horizon=time_horizon,
|
||||
action_dim=action_dim,
|
||||
)
|
||||
@@ -1,11 +0,0 @@
|
||||
{
|
||||
"action_dim": 15,
|
||||
"auto_map": {
|
||||
"AutoProcessor": "processing_action_tokenizer.UniversalActionProcessor"
|
||||
},
|
||||
"min_token": -71,
|
||||
"processor_class": "UniversalActionProcessor",
|
||||
"scale": 10.0,
|
||||
"time_horizon": 50,
|
||||
"vocab_size": 1024
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
{}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,11 +0,0 @@
|
||||
{
|
||||
"added_tokens_decoder": {},
|
||||
"auto_map": {
|
||||
"AutoProcessor": "processing_action_tokenizer.UniversalActionProcessor"
|
||||
},
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"extra_special_tokens": {},
|
||||
"model_max_length": 1000000000000000019884624838656,
|
||||
"processor_class": "UniversalActionProcessor",
|
||||
"tokenizer_class": "PreTrainedTokenizerFast"
|
||||
}
|
||||
@@ -1,196 +0,0 @@
|
||||
# FAST Tokenizer Training for LeRobotDataset
|
||||
|
||||
This directory contains tools for training a FAST (Factorized Action Sequence Tokenizer) on LeRobot datasets.
|
||||
|
||||
## Files
|
||||
|
||||
- **`train_fast_tokenizer.py`**: Main training script (refactored for LeRobotDataset)
|
||||
- **`train_fast_tokenizer_example.md`**: Usage examples and parameter documentation
|
||||
- **`MIGRATION_NOTES.md`**: Migration guide from B1K to LeRobotDataset
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Basic usage
|
||||
python train_fast_tokenizer.py \
|
||||
--repo_id "lerobot/aloha_sim_insertion_human" \
|
||||
--action_horizon 10 \
|
||||
--encoded_dims "0:14"
|
||||
|
||||
# With delta transform
|
||||
python train_fast_tokenizer.py \
|
||||
--repo_id "lerobot/aloha_sim_insertion_human" \
|
||||
--action_horizon 10 \
|
||||
--encoded_dims "0:14" \
|
||||
--delta_dims "0,1,2,3,4,5,6,7,8,9,10,11,12,13" \
|
||||
--state_key "observation.state" \
|
||||
--vocab_size 1024
|
||||
```
|
||||
|
||||
## What is FAST?
|
||||
|
||||
FAST is a tokenizer for robotic action sequences that:
|
||||
1. Applies DCT (Discrete Cosine Transform) to action chunks
|
||||
2. Quantizes DCT coefficients
|
||||
3. Uses BPE (Byte-Pair Encoding) to compress the quantized sequence
|
||||
4. Achieves high compression ratios (e.g., 10-20x) while maintaining accuracy
|
||||
|
||||
This enables efficient storage and processing of long action sequences in vision-language-action models.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python 3.10+
|
||||
- LeRobot dataset (either local or from HuggingFace Hub)
|
||||
- transformers (for AutoProcessor)
|
||||
- numpy
|
||||
- torch
|
||||
- tyro
|
||||
|
||||
## Workflow
|
||||
|
||||
```
|
||||
LeRobotDataset → Extract Episodes → Apply Delta Transform
|
||||
↓
|
||||
Select Dimensions → Normalize (q01, q99) → Create Chunks
|
||||
↓
|
||||
Train FAST Tokenizer → Compute Stats → Save
|
||||
```
|
||||
|
||||
## Parameters Guide
|
||||
|
||||
### Essential Parameters
|
||||
|
||||
- **`repo_id`**: HuggingFace dataset repository ID
|
||||
- Example: `"lerobot/aloha_sim_insertion_human"`
|
||||
|
||||
- **`action_horizon`**: Length of action sequences to tokenize
|
||||
- Typical: 10-16 steps
|
||||
|
||||
- **`encoded_dims`**: Which action dimensions to encode
|
||||
- Format: `"start:end,start:end"`
|
||||
- Example: `"0:7"` = dimensions 0-6
|
||||
- Example: `"0:3,7:10"` = dimensions 0-2 and 7-9
|
||||
|
||||
### Optional Parameters
|
||||
|
||||
- **`delta_dims`**: Apply delta transform (action - state) to these dimensions
|
||||
- Format: `"0,1,2,3,4,5"`
|
||||
- Use for position-based actions
|
||||
|
||||
- **`state_key`**: Dataset key containing state observations
|
||||
- Default: `"observation.state"`
|
||||
|
||||
- **`vocab_size`**: BPE vocabulary size
|
||||
- Default: 1024
|
||||
- Larger = better compression but more memory
|
||||
|
||||
- **`scale`**: DCT quantization scale
|
||||
- Default: 10.0
|
||||
- Smaller = finer quantization, larger = coarser
|
||||
|
||||
- **`sample_fraction`**: Fraction of action chunks to use per episode
|
||||
- Default: 0.1 (10%)
|
||||
- Increase for small datasets, decrease for large datasets
|
||||
|
||||
## Output
|
||||
|
||||
The script creates a directory (default: `./fast_tokenizer_{repo_id}`) containing:
|
||||
|
||||
1. **Tokenizer files**: Can be loaded with `AutoProcessor.from_pretrained()`
|
||||
2. **`metadata.json`**: Contains:
|
||||
- Training configuration
|
||||
- Compression statistics
|
||||
- Dataset information
|
||||
|
||||
## Example Output
|
||||
|
||||
```
|
||||
Loading dataset: lerobot/aloha_sim_insertion_human
|
||||
Dataset loaded: 50 episodes, 5000 frames
|
||||
Encoding 14 dimensions: 0:14
|
||||
Delta dimensions: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
||||
Action horizon: 10
|
||||
Processing 50 episodes...
|
||||
Collected 4500 action chunks
|
||||
Extracted 14 encoded dimensions
|
||||
|
||||
Before normalization - overall stats:
|
||||
Min: -2.3451, Max: 3.1234, Mean: 0.0234, Std: 0.8765
|
||||
|
||||
Applied quantile normalization [q01, q99] → [-1, 1]
|
||||
|
||||
After normalization - overall stats:
|
||||
Min: -1.0000, Max: 1.0000, Mean: 0.0156, Std: 0.4321
|
||||
|
||||
Training FAST tokenizer on 4500 action chunks...
|
||||
Action chunk shape: (4500, 10, 14)
|
||||
Vocab size: 1024
|
||||
DCT scale: 10.0
|
||||
✓ Tokenizer training complete!
|
||||
|
||||
Compression Statistics:
|
||||
Average compression ratio: 14.23x
|
||||
Mean token length: 9.8
|
||||
P99 token length: 15
|
||||
Min token length: 6
|
||||
Max token length: 18
|
||||
|
||||
✅ Saved FAST tokenizer to ./fast_tokenizer_lerobot_aloha_sim_insertion_human
|
||||
```
|
||||
|
||||
## Using the Trained Tokenizer
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoProcessor.from_pretrained(
|
||||
"./fast_tokenizer_lerobot_aloha_sim_insertion_human",
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Encode action chunk [horizon, action_dim]
|
||||
action_chunk = np.random.randn(10, 14) # Example
|
||||
tokens = tokenizer(action_chunk[None])[0] # Returns token IDs
|
||||
|
||||
# Decode tokens back to actions
|
||||
reconstructed = tokenizer.decode(tokens)
|
||||
```
|
||||
|
||||
## Tips
|
||||
|
||||
1. **Start Small**: Use `--max_episodes 10` for initial testing
|
||||
2. **Check Dimensions**: Verify encoded dimensions match your robot's action space
|
||||
3. **Delta Transform**: Use for position-based actions, not velocity-based
|
||||
4. **Normalization**: Ensure dataset has proper statistics computed
|
||||
5. **Compression Ratio**: Aim for 10-20x for good balance of compression and accuracy
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Issue**: "No normalization stats found"
|
||||
- **Solution**: Compute dataset statistics first, or use raw actions
|
||||
|
||||
**Issue**: "Episode too short for action horizon"
|
||||
- **Solution**: Reduce `--action_horizon` or filter short episodes
|
||||
|
||||
**Issue**: "State key not found"
|
||||
- **Solution**: Check dataset features and use correct `--state_key`
|
||||
|
||||
**Issue**: Memory error with large datasets
|
||||
- **Solution**: Reduce `--sample_fraction` or `--max_episodes`
|
||||
|
||||
## Citation
|
||||
|
||||
If you use FAST in your research, please cite:
|
||||
|
||||
```bibtex
|
||||
@article{black2023fast,
|
||||
title={FAST: Factorized Action Sequence Tokenizer for Vision-Language-Action Models},
|
||||
author={Black, Kevin and others},
|
||||
journal={arXiv preprint},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -37,9 +37,6 @@ class PI05Config(PreTrainedConfig):
|
||||
# Shorter state and action vectors will be padded to these dimensions
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
max_action_tokens: int = 32
|
||||
fast_vocab_size: int = 2048
|
||||
|
||||
|
||||
# Flow matching parameters: see openpi `PI0Pytorch`
|
||||
num_inference_steps: int = 10
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
lerobot-train \
|
||||
--dataset.repo_id=lerobot \
|
||||
--dataset.root=/fsx/jade_choghari/outputs/collect-data-pgen \
|
||||
--output_dir=/fsx/jade_choghari/outputs/pi0test1 \
|
||||
--job_name=pi0_training \
|
||||
--policy.repo_id=jade_choghari/pi0-base \
|
||||
--policy.path=/fsx/jade_choghari/outputs/pi0_fast_fruit1/checkpoints/last/pretrained_model \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=3000 \
|
||||
--save_freq=1000 \
|
||||
--rename_map='{
|
||||
"observation.images.base": "observation.images.base_0_rgb",
|
||||
"observation.images.left_wrist": "observation.images.left_wrist_0_rgb",
|
||||
"observation.images.right_wrist": "observation.images.right_wrist_0_rgb",
|
||||
}' \
|
||||
--batch_size=4 \
|
||||
--policy.device=cuda \
|
||||
# --wandb.enable=true \
|
||||
# --wandb.disable_artifact=true \
|
||||
# --wandb.project=pi05hi-training \
|
||||
|
||||
@@ -48,10 +48,10 @@ from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS,
|
||||
OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_SUBTASK_ONLY_TOKENS,
|
||||
OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_PROMPT_TOKENS,
|
||||
OBS_LANGUAGE_PROMPT_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TARGET_TOKENS,
|
||||
OBS_LANGUAGE_TARGET_ATTENTION_MASK,
|
||||
OPENPI_ATTENTION_MASK_VALUE,
|
||||
)
|
||||
|
||||
@@ -537,18 +537,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
|
||||
# FAST action token embedding and prediction head
|
||||
self.fast_action_embedding = nn.Embedding(config.fast_vocab_size, paligemma_config.width)
|
||||
self.fast_action_lm_head = nn.Linear(paligemma_config.width, config.fast_vocab_size)
|
||||
|
||||
# Apply dtype conversion to FAST layers to match model precision
|
||||
if config.dtype == "bfloat16":
|
||||
self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16)
|
||||
self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.bfloat16)
|
||||
elif config.dtype == "float32":
|
||||
self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.float32)
|
||||
self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.float32)
|
||||
|
||||
# Initialize gradient checkpointing flag
|
||||
self.gradient_checkpointing_enabled = False
|
||||
|
||||
@@ -604,194 +592,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
result = result.to(dtype=dtype)
|
||||
return result
|
||||
|
||||
def _create_custom_attention_mask(self, att_mask_segments, pad_masks, bsize):
|
||||
"""Create custom 2D attention mask for the new attention pattern.
|
||||
|
||||
Attention rules:
|
||||
- Images + Language: bidirectional among themselves, don't attend to subtask or FAST
|
||||
- Subtask: attend to images + language, causal among themselves, don't attend to FAST
|
||||
- FAST: attend to images + language + subtask, causal among themselves
|
||||
|
||||
Args:
|
||||
att_mask_segments: List of (type, length) tuples
|
||||
pad_masks: Padding masks [B, total_seq_len]
|
||||
bsize: Batch size
|
||||
|
||||
Returns:
|
||||
att_2d_masks: 2D attention mask [B, total_seq_len, total_seq_len]
|
||||
"""
|
||||
total_len = sum(length for _, length in att_mask_segments)
|
||||
device = pad_masks.device
|
||||
|
||||
# Initialize attention mask as False (cannot attend)
|
||||
att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device)
|
||||
|
||||
# Track positions for each segment
|
||||
positions = []
|
||||
current_pos = 0
|
||||
for seg_type, seg_len in att_mask_segments:
|
||||
positions.append((seg_type, current_pos, current_pos + seg_len))
|
||||
current_pos += seg_len
|
||||
|
||||
# Apply attention rules
|
||||
for i, (query_type, query_start, query_end) in enumerate(positions):
|
||||
for j, (key_type, key_start, key_end) in enumerate(positions):
|
||||
# Images and Language can attend to each other bidirectionally
|
||||
if query_type in ['image', 'language'] and key_type in ['image', 'language']:
|
||||
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
|
||||
|
||||
# Subtask tokens attend to images + language
|
||||
elif query_type == 'subtask' and key_type in ['image', 'language']:
|
||||
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
|
||||
|
||||
# Subtask tokens attend causally to themselves
|
||||
elif query_type == 'subtask' and key_type == 'subtask':
|
||||
# Create causal mask for subtask tokens
|
||||
subtask_len = query_end - query_start
|
||||
causal_mask = torch.tril(torch.ones(subtask_len, subtask_len, dtype=torch.bool, device=device))
|
||||
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
|
||||
|
||||
# FAST tokens attend to images + language + subtask
|
||||
elif query_type == 'fast' and key_type in ['image', 'language', 'subtask']:
|
||||
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
|
||||
|
||||
# FAST tokens attend causally to themselves
|
||||
elif query_type == 'fast' and key_type == 'fast':
|
||||
fast_len = query_end - query_start
|
||||
causal_mask = torch.tril(torch.ones(fast_len, fast_len, dtype=torch.bool, device=device))
|
||||
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
|
||||
|
||||
# Apply padding masks
|
||||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||
att_2d_masks = att_2d_masks & pad_2d_masks
|
||||
|
||||
return att_2d_masks
|
||||
|
||||
def visualize_attention_mask(
|
||||
self,
|
||||
att_mask_segments,
|
||||
att_2d_masks,
|
||||
save_path,
|
||||
batch_idx=0,
|
||||
dpi=150,
|
||||
max_display_tokens=None
|
||||
):
|
||||
"""Visualize the attention mask with labeled segments.
|
||||
|
||||
Args:
|
||||
att_mask_segments: List of (type, length) tuples defining the segments
|
||||
att_2d_masks: 2D attention mask tensor [B, total_seq_len, total_seq_len]
|
||||
save_path: Path where to save the visualization image
|
||||
batch_idx: Which batch item to visualize (default: 0)
|
||||
dpi: DPI for the saved image (default: 150)
|
||||
max_display_tokens: Maximum number of tokens to display (for very long sequences)
|
||||
"""
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
except ImportError:
|
||||
logging.warning("matplotlib not available, skipping attention mask visualization")
|
||||
return
|
||||
|
||||
# Extract the mask for the specified batch
|
||||
mask = att_2d_masks[batch_idx].cpu().float().numpy()
|
||||
|
||||
# If sequence is too long, downsample for visualization
|
||||
if max_display_tokens is not None and mask.shape[0] > max_display_tokens:
|
||||
# Simple downsampling by taking every Nth token
|
||||
step = mask.shape[0] // max_display_tokens
|
||||
mask = mask[::step, ::step]
|
||||
# Adjust segments accordingly
|
||||
att_mask_segments = [(seg_type, max(1, seg_len // step)) for seg_type, seg_len in att_mask_segments]
|
||||
|
||||
# Calculate positions for each segment
|
||||
positions = []
|
||||
current_pos = 0
|
||||
for seg_type, seg_len in att_mask_segments:
|
||||
positions.append((seg_type, current_pos, current_pos + seg_len))
|
||||
current_pos += seg_len
|
||||
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(12, 10))
|
||||
|
||||
# Create custom colormap: white for False (no attention), blue for True (attention)
|
||||
colors = ['white', '#2E86AB']
|
||||
n_bins = 2
|
||||
cmap = LinearSegmentedColormap.from_list('attention', colors, N=n_bins)
|
||||
|
||||
# Display the mask
|
||||
im = ax.imshow(mask, cmap=cmap, aspect='auto', interpolation='nearest', vmin=0, vmax=1)
|
||||
|
||||
# Add colorbar
|
||||
cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
||||
cbar.set_label('Attention Enabled', rotation=270, labelpad=20)
|
||||
cbar.set_ticks([0.25, 0.75])
|
||||
cbar.set_ticklabels(['No', 'Yes'])
|
||||
|
||||
# Define colors for each segment type
|
||||
segment_colors = {
|
||||
'image': '#A23B72',
|
||||
'language': '#F18F01',
|
||||
'subtask': '#C73E1D',
|
||||
'fast': '#6A994E'
|
||||
}
|
||||
|
||||
# Draw segment boundaries and labels
|
||||
for seg_type, start, end in positions:
|
||||
color = segment_colors.get(seg_type, '#666666')
|
||||
|
||||
# Draw vertical lines for columns (keys)
|
||||
ax.axvline(x=start - 0.5, color=color, linewidth=2, alpha=0.7)
|
||||
ax.axvline(x=end - 0.5, color=color, linewidth=2, alpha=0.7)
|
||||
|
||||
# Draw horizontal lines for rows (queries)
|
||||
ax.axhline(y=start - 0.5, color=color, linewidth=2, alpha=0.7)
|
||||
ax.axhline(y=end - 0.5, color=color, linewidth=2, alpha=0.7)
|
||||
|
||||
# Add labels at the top
|
||||
mid_pos = (start + end) / 2
|
||||
ax.text(mid_pos, -mask.shape[0] * 0.02, f"{seg_type.upper()}\n({end - start})",
|
||||
ha='center', va='top', fontsize=10, fontweight='bold', color=color)
|
||||
|
||||
# Add labels on the left
|
||||
ax.text(-mask.shape[1] * 0.02, mid_pos, f"{seg_type.upper()}\n({end - start})",
|
||||
ha='right', va='center', fontsize=10, fontweight='bold', color=color, rotation=0)
|
||||
|
||||
# Set axis labels
|
||||
ax.set_xlabel('Key Position (tokens being attended to)', fontsize=12, fontweight='bold')
|
||||
ax.set_ylabel('Query Position (tokens attending)', fontsize=12, fontweight='bold')
|
||||
ax.set_title('Attention Mask Pattern\n(White = No Attention, Blue = Attention Allowed)',
|
||||
fontsize=14, fontweight='bold', pad=20)
|
||||
|
||||
# Create legend for segment types
|
||||
legend_patches = []
|
||||
attention_rules = {
|
||||
'image': 'Bidirectional with lang',
|
||||
'language': 'Bidirectional with images',
|
||||
'subtask': 'Attends to img+lang, causal self',
|
||||
'fast': 'Attends to all, causal self'
|
||||
}
|
||||
for seg_type, color in segment_colors.items():
|
||||
if any(seg[0] == seg_type for seg in att_mask_segments):
|
||||
rule = attention_rules.get(seg_type, '')
|
||||
legend_patches.append(mpatches.Patch(color=color, label=f'{seg_type.upper()}: {rule}'))
|
||||
|
||||
ax.legend(handles=legend_patches, loc='upper right', bbox_to_anchor=(1.15, 1.0),
|
||||
framealpha=0.9, fontsize=9)
|
||||
|
||||
# Adjust layout and save
|
||||
plt.tight_layout()
|
||||
|
||||
# Ensure the directory exists
|
||||
save_path = Path(save_path)
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
logging.info(f"Attention mask visualization saved to: {save_path}")
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
return torch.normal(
|
||||
mean=0.0,
|
||||
@@ -807,43 +607,30 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
)
|
||||
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
|
||||
def embed_prefix(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
tokens,
|
||||
subtask_tokens,
|
||||
masks,
|
||||
subtask_masks,
|
||||
fast_action_tokens=None,
|
||||
fast_action_masks=None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
|
||||
"""Embed images with SigLIP, tokens, and optionally subtask tokens with embedding layer.
|
||||
self, images, img_masks, prompt_tokens, target_tokens, prompt_masks, target_masks=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
||||
"""Embed images with SigLIP, prompt tokens, and optionally target tokens with embedding layer.
|
||||
|
||||
Args:
|
||||
images: List of image tensors
|
||||
img_masks: List of image masks
|
||||
tokens: Language instruction tokens
|
||||
subtask_tokens: Subtask tokens to predict (can be None for inference)
|
||||
masks: Attention masks for tokens
|
||||
fast_action_tokens: FAST action tokens for auxiliary prediction (can be None) - discrete token IDs
|
||||
fast_action_masks: Padding masks for FAST action tokens (can be None)
|
||||
prompt_tokens: Prompt tokens (input for generation)
|
||||
target_tokens: Target tokens to predict (can be None for inference)
|
||||
prompt_masks: Attention masks for prompt tokens
|
||||
target_masks: Attention masks for target tokens
|
||||
|
||||
Returns:
|
||||
embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided), (fast_action_tokens if provided)]
|
||||
embs: Concatenated embeddings [images, prompt_tokens, (target_tokens if provided)]
|
||||
pad_masks: Padding masks
|
||||
att_masks: Custom 2D attention mask implementing the required pattern
|
||||
att_masks: Attention masks (with causal masking for target prediction if target_tokens provided)
|
||||
total_T_images: Total number of image tokens
|
||||
num_subtask_embs: Number of subtask token embeddings
|
||||
num_fast_embs: Number of FAST action token embeddings
|
||||
"""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_mask_segments = [] # Store info about each segment for custom mask creation
|
||||
att_masks = []
|
||||
total_T_images = 0
|
||||
num_subtask_embs = 0
|
||||
num_fast_embs = 0
|
||||
|
||||
# Process images
|
||||
for img, img_mask in zip(images, img_masks, strict=True):
|
||||
@@ -856,79 +643,48 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
embs.append(img_emb)
|
||||
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
||||
att_mask_segments.append(('image', num_img_embs))
|
||||
att_masks += [0] * num_img_embs # Images can attend to all previous tokens
|
||||
total_T_images += num_img_embs
|
||||
|
||||
# Process language instruction tokens
|
||||
def lang_embed_func(tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
return lang_emb * math.sqrt(lang_emb_dim)
|
||||
# Process prompt tokens
|
||||
def prompt_embed_func(prompt_tokens):
|
||||
prompt_emb = self.paligemma_with_expert.embed_language_tokens(prompt_tokens)
|
||||
prompt_emb_dim = prompt_emb.shape[-1]
|
||||
return prompt_emb * math.sqrt(prompt_emb_dim)
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
||||
embs.append(lang_emb)
|
||||
pad_masks.append(masks)
|
||||
prompt_emb = self._apply_checkpoint(prompt_embed_func, prompt_tokens)
|
||||
embs.append(prompt_emb)
|
||||
pad_masks.append(prompt_masks)
|
||||
|
||||
num_lang_embs = lang_emb.shape[1]
|
||||
att_mask_segments.append(('language', num_lang_embs))
|
||||
num_prompt_embs = prompt_emb.shape[1]
|
||||
att_masks += [0] * num_prompt_embs # Prompt tokens can attend to all previous tokens (images + prompt)
|
||||
|
||||
# Process subtask tokens if provided (these are predicted, so use causal masking)
|
||||
if subtask_tokens is not None:
|
||||
def subtask_embed_func(subtask_tokens):
|
||||
subtask_emb = self.paligemma_with_expert.embed_language_tokens(subtask_tokens)
|
||||
subtask_emb_dim = subtask_emb.shape[-1]
|
||||
return subtask_emb * math.sqrt(subtask_emb_dim)
|
||||
# Process target tokens if provided (these are predicted, so use causal masking)
|
||||
if target_tokens is not None:
|
||||
def target_embed_func(target_tokens):
|
||||
target_emb = self.paligemma_with_expert.embed_language_tokens(target_tokens)
|
||||
target_emb_dim = target_emb.shape[-1]
|
||||
return target_emb * math.sqrt(target_emb_dim)
|
||||
|
||||
subtask_emb = self._apply_checkpoint(subtask_embed_func, subtask_tokens)
|
||||
embs.append(subtask_emb)
|
||||
target_emb = self._apply_checkpoint(target_embed_func, target_tokens)
|
||||
embs.append(target_emb)
|
||||
|
||||
# Create subtask pad masks (non-zero tokens are valid)
|
||||
pad_masks.append(subtask_masks)
|
||||
# Create target pad masks (non-zero tokens are valid)
|
||||
pad_masks.append(target_masks)
|
||||
|
||||
num_subtask_embs = subtask_emb.shape[1]
|
||||
att_mask_segments.append(('subtask', num_subtask_embs))
|
||||
# Process FAST action tokens if provided (these are discrete token IDs)
|
||||
if fast_action_tokens is not None:
|
||||
def fast_action_embed_func(fast_action_tokens):
|
||||
fast_emb = self.fast_action_embedding(fast_action_tokens)
|
||||
fast_emb_dim = fast_emb.shape[-1]
|
||||
return fast_emb * math.sqrt(fast_emb_dim)
|
||||
|
||||
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
|
||||
embs.append(fast_action_emb)
|
||||
|
||||
# Use provided mask or create default (all valid)
|
||||
if fast_action_masks is not None:
|
||||
fast_pad_mask = fast_action_masks
|
||||
else:
|
||||
bsize = fast_action_tokens.shape[0]
|
||||
num_fast_embs = fast_action_tokens.shape[1]
|
||||
fast_pad_mask = torch.ones(bsize, num_fast_embs, dtype=torch.bool, device=fast_action_tokens.device)
|
||||
|
||||
num_fast_embs = fast_action_tokens.shape[1]
|
||||
pad_masks.append(fast_pad_mask)
|
||||
att_mask_segments.append(('fast', num_fast_embs))
|
||||
num_target_embs = target_emb.shape[1]
|
||||
# Causal masking for target tokens: each target token can attend to images, all prompt tokens,
|
||||
# and previous target tokens
|
||||
att_masks += [1] * num_target_embs # Use 1 for causal attention on target tokens
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
|
||||
# Create custom 2D attention mask
|
||||
# Attention rules:
|
||||
# - Images + Language: bidirectional among themselves, don't attend to subtask or FAST
|
||||
# - Subtask: attend to images + language, causal among themselves, don't attend to FAST
|
||||
# - FAST: attend to images + language + subtask, causal among themselves
|
||||
att_masks = self._create_custom_attention_mask(att_mask_segments, pad_masks, bsize)
|
||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||
|
||||
# # Optionally visualize the attention mask
|
||||
# self.visualize_attention_mask(
|
||||
# att_mask_segments=att_mask_segments,
|
||||
# att_2d_masks=att_masks,
|
||||
# save_path="/admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05/attention_mask_visualization.png",
|
||||
# batch_idx=0,
|
||||
# max_display_tokens=512 # Limit display for very long sequences
|
||||
# )
|
||||
bsize = pad_masks.shape[0]
|
||||
att_masks = att_masks[None, :].expand(bsize, att_masks.shape[0])
|
||||
|
||||
return embs, pad_masks, att_masks, total_T_images, num_subtask_embs, num_fast_embs
|
||||
return embs, pad_masks, att_masks, total_T_images
|
||||
|
||||
def embed_suffix(self, noisy_actions, timestep):
|
||||
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
@@ -977,20 +733,17 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
return embs, pad_masks, att_masks, adarms_cond
|
||||
|
||||
# loss_dict = self.model.forward(images, img_masks, high_level_task, tokens, masks, subtask_tokens, subtask_masks, actions, fast_action_tokens, fast_action_masks)
|
||||
def forward(self, images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions, fast_action_tokens=None, fast_action_masks=None, noise=None, time=None) -> Tensor:
|
||||
def forward(self, images, img_masks, prompt_tokens, prompt_masks, target_tokens, target_masks, actions, noise=None, time=None) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss.
|
||||
|
||||
Args:
|
||||
images: List of image tensors
|
||||
img_masks: List of image masks
|
||||
high_level_task: Instruction tokens WITHOUT subtask (e.g., "High level task: X; State: Y; Subtask:")
|
||||
high_level_task_masks: Attention masks for high_level_task
|
||||
subtask_tokens: Subtask tokens to predict (e.g., tokens for "pick up the cup")
|
||||
subtask_masks: Attention masks for subtask_tokens
|
||||
actions: Ground truth actions [B, chunk_size, action_dim]
|
||||
fast_action_tokens: Discrete action token IDs [B, max_action_tokens]
|
||||
fast_action_masks: Padding masks for fast action tokens [B, max_action_tokens]
|
||||
prompt_tokens: Prompt tokens WITHOUT target (e.g., "High level task: X; State: Y; Subtask:")
|
||||
prompt_masks: Attention masks for prompt_tokens
|
||||
target_tokens: Target tokens to predict (e.g., tokens for "pick up the cup")
|
||||
target_masks: Attention masks for target_tokens
|
||||
actions: Ground truth actions
|
||||
noise: Optional noise for flow matching
|
||||
time: Optional time for flow matching
|
||||
"""
|
||||
@@ -1003,184 +756,72 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
# Initialize FAST loss to 0 (will be computed only if FAST tokens are provided)
|
||||
fast_loss = torch.tensor(0.0, device=actions.device, dtype=actions.dtype)
|
||||
|
||||
# ========== PASS 1: Prefix with FAST tokens for subtask + FAST prediction ==========
|
||||
# Only run this pass if FAST action tokens are provided
|
||||
if fast_action_tokens is not None and fast_action_masks is not None:
|
||||
# Embed prefix (images + high_level_task + subtask_tokens + FAST tokens)
|
||||
# FAST tokens are provided as discrete token IDs
|
||||
prefix_with_fast_embs, prefix_with_fast_pad_masks, prefix_with_fast_att_masks, total_T_images, num_subtask_embs, num_fast_embs = self.embed_prefix(
|
||||
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks,
|
||||
fast_action_tokens=fast_action_tokens, fast_action_masks=fast_action_masks
|
||||
)
|
||||
|
||||
# Convert embeddings to bfloat16 if needed for the model
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
prefix_with_fast_embs = prefix_with_fast_embs.to(dtype=torch.bfloat16)
|
||||
|
||||
# Prepare attention masks for prefix pass with FAST tokens
|
||||
position_ids_prefix_with_fast = torch.cumsum(prefix_with_fast_pad_masks, dim=1) - 1
|
||||
att_2d_prefix_with_fast_4d = self._prepare_attention_masks_4d(prefix_with_fast_att_masks, dtype=prefix_with_fast_embs.dtype)
|
||||
|
||||
# Forward pass through paligemma for subtask + FAST prediction
|
||||
(prefix_with_fast_out, _), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_prefix_with_fast_4d,
|
||||
position_ids=position_ids_prefix_with_fast,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_with_fast_embs, None], # SUFFIX = None
|
||||
use_cache=False,
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
|
||||
# LM HEAD → SUBTASK LOGITS
|
||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||
logits = lm_head(prefix_with_fast_out) # (B, T_prefix_with_fast, vocab)
|
||||
|
||||
# Extract logits for subtask token prediction
|
||||
T_high_level_task = high_level_task.size(1)
|
||||
T_subtask = subtask_tokens.size(1)
|
||||
start_index = total_T_images + T_high_level_task
|
||||
end_index = start_index + T_subtask
|
||||
logits_subtask = logits[:, start_index-1:end_index-1, :] # (B, T_subtask, vocab)
|
||||
|
||||
targets = subtask_tokens # (B, T_subtask)
|
||||
# Compute cross-entropy loss for subtask
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
||||
logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1))
|
||||
targets_flat = targets.reshape(-1)
|
||||
loss_per_token = loss_fct(logits_flat, targets_flat)
|
||||
loss_per_token = loss_per_token.reshape(targets.shape)
|
||||
masked_loss = loss_per_token * subtask_masks.float()
|
||||
subtask_loss = masked_loss.sum() / subtask_masks.sum().clamp(min=1)
|
||||
|
||||
# Extract outputs for FAST action token prediction and compute auxiliary loss
|
||||
# FAST outputs start after subtask tokens
|
||||
# Similar to subtask, we use autoregressive prediction where position i predicts token i+1
|
||||
fast_start_index = end_index
|
||||
fast_end_index = fast_start_index + num_fast_embs
|
||||
|
||||
# Get logits for FAST action tokens using the FAST LM head
|
||||
fast_logits = self.fast_action_lm_head(prefix_with_fast_out) # (B, T_prefix_with_fast, fast_vocab_size)
|
||||
|
||||
# Extract logits for FAST token prediction (autoregressive: position i predicts token i+1)
|
||||
# - Position (fast_start_index-1) predicts fast_action_tokens[0]
|
||||
# - Position (fast_start_index) predicts fast_action_tokens[1], etc.
|
||||
fast_logits_for_pred = fast_logits[:, fast_start_index-1:fast_end_index-1, :] # (B, max_action_tokens, fast_vocab_size)
|
||||
|
||||
# Compute cross-entropy loss for FAST action tokens
|
||||
fast_targets = fast_action_tokens # (B, max_action_tokens)
|
||||
loss_fct_fast = torch.nn.CrossEntropyLoss(reduction='none')
|
||||
fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1)) # (B*max_action_tokens, fast_vocab_size)
|
||||
fast_targets_flat = fast_targets.reshape(-1) # (B*max_action_tokens)
|
||||
|
||||
fast_loss_per_token = loss_fct_fast(fast_logits_flat, fast_targets_flat) # (B*max_action_tokens)
|
||||
fast_loss_per_token = fast_loss_per_token.reshape(fast_targets.shape) # (B, max_action_tokens)
|
||||
|
||||
# Apply mask and compute mean loss over valid tokens
|
||||
masked_fast_loss = fast_loss_per_token * fast_action_masks.float()
|
||||
fast_loss = masked_fast_loss.sum() / fast_action_masks.sum().clamp(min=1)
|
||||
else:
|
||||
# If no FAST tokens provided, compute subtask loss without FAST tokens
|
||||
# This is the fallback for backward compatibility
|
||||
prefix_embs_for_subtask, prefix_pad_masks_for_subtask, prefix_att_masks_for_subtask, total_T_images, _, _ = self.embed_prefix(
|
||||
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks,
|
||||
fast_action_tokens=None, fast_action_masks=None
|
||||
)
|
||||
|
||||
# Convert embeddings to bfloat16 if needed for the model
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
prefix_embs_for_subtask = prefix_embs_for_subtask.to(dtype=torch.bfloat16)
|
||||
|
||||
position_ids_prefix = torch.cumsum(prefix_pad_masks_for_subtask, dim=1) - 1
|
||||
att_2d_prefix_4d = self._prepare_attention_masks_4d(prefix_att_masks_for_subtask, dtype=prefix_embs_for_subtask.dtype)
|
||||
|
||||
(prefix_out, _), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_prefix_4d,
|
||||
position_ids=position_ids_prefix,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs_for_subtask, None],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
|
||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||
logits = lm_head(prefix_out)
|
||||
|
||||
T_high_level_task = high_level_task.size(1)
|
||||
T_subtask = subtask_tokens.size(1)
|
||||
start_index = total_T_images + T_high_level_task
|
||||
end_index = start_index + T_subtask
|
||||
logits_subtask = logits[:, start_index-1:end_index-1, :]
|
||||
|
||||
targets = subtask_tokens
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
||||
logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1))
|
||||
targets_flat = targets.reshape(-1)
|
||||
loss_per_token = loss_fct(logits_flat, targets_flat)
|
||||
loss_per_token = loss_per_token.reshape(targets.shape)
|
||||
masked_loss = loss_per_token * subtask_masks.float()
|
||||
subtask_loss = masked_loss.sum() / subtask_masks.sum().clamp(min=1)
|
||||
|
||||
# ========== PASS 2: Full forward WITHOUT FAST tokens for flow matching ==========
|
||||
# Embed prefix WITHOUT FAST tokens (images + high_level_task + subtask_tokens)
|
||||
prefix_embs_no_fast, prefix_pad_masks_no_fast, prefix_att_masks_no_fast, _, _, _ = self.embed_prefix(
|
||||
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks,
|
||||
fast_action_tokens=None, fast_action_masks=None
|
||||
|
||||
# Embed prefix (images + prompt_tokens + target_tokens)
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix(
|
||||
images, img_masks, prompt_tokens, target_tokens, prompt_masks, target_masks
|
||||
)
|
||||
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
|
||||
|
||||
# Prepare attention masks for prefix-only pass (for target token prediction)
|
||||
att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype)
|
||||
|
||||
# prefix-only transformer run for target token prediction
|
||||
(prefix_out, _), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_prefix_4d,
|
||||
position_ids=position_ids_prefix,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None], # SUFFIX = None
|
||||
use_cache=False,
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
|
||||
# LM HEAD → TARGET LOGITS
|
||||
# prefix_out: (B, T_prefix, H) where T_prefix = total_T_images + T_prompt + T_target
|
||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||
logits = lm_head(prefix_out) # (B, T_prefix, vocab)
|
||||
|
||||
# Extract logits for target token prediction (shifted by 1 for autoregressive training)
|
||||
# Position i predicts token i+1, so we take logits from positions before target tokens:
|
||||
# - Position (start_index-1) (last prompt token) predicts target_tokens[0]
|
||||
# - Position (start_index) (first target token) predicts target_tokens[1], etc.
|
||||
T_prompt = prompt_tokens.size(1)
|
||||
T_target = target_tokens.size(1)
|
||||
start_index = total_T_images + T_prompt
|
||||
end_index = start_index + T_target
|
||||
logits_target = logits[:, start_index-1:end_index-1, :] # (B, T_target, vocab)
|
||||
|
||||
# Compute cross-entropy loss
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
|
||||
# Reshape for loss computation
|
||||
logits_flat = logits_target.reshape(-1, logits_target.size(-1)) # (B*T_target, vocab)
|
||||
targets_flat = target_tokens.reshape(-1) # (B*T_target)
|
||||
|
||||
loss_per_token = loss_fct(logits_flat, targets_flat) # (B*T_target)
|
||||
loss_per_token = loss_per_token.reshape(target_tokens.shape) # (B, T_target)
|
||||
|
||||
# Apply mask and compute mean loss over valid tokens
|
||||
masked_loss = loss_per_token * target_masks.float()
|
||||
target_loss = masked_loss.sum() / target_masks.sum().clamp(min=1)
|
||||
# Convert embeddings to bfloat16 if needed for the model
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||
prefix_embs_no_fast = prefix_embs_no_fast.to(dtype=torch.bfloat16)
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
|
||||
# For the flow matching pass, we need custom attention where:
|
||||
# - prefix follows the custom pattern (images+lang bidirectional, subtask causal, no cross-attention)
|
||||
# - suffix attends to all prefix + causal to itself
|
||||
# We'll construct this by extending prefix_att_masks_no_fast to include suffix
|
||||
|
||||
# prefix_att_masks_no_fast is already a 2D boolean mask [B, prefix_len, prefix_len]
|
||||
# We need to extend it to [B, prefix_len + suffix_len, prefix_len + suffix_len]
|
||||
|
||||
bsize = prefix_pad_masks_no_fast.shape[0]
|
||||
prefix_len = prefix_pad_masks_no_fast.shape[1]
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
total_len = prefix_len + suffix_len
|
||||
device = prefix_pad_masks_no_fast.device
|
||||
|
||||
# Create full attention mask
|
||||
full_att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device)
|
||||
|
||||
# Copy prefix attention pattern
|
||||
full_att_2d_masks[:, :prefix_len, :prefix_len] = prefix_att_masks_no_fast
|
||||
|
||||
# Suffix attends to all prefix
|
||||
full_att_2d_masks[:, prefix_len:, :prefix_len] = True
|
||||
|
||||
# Suffix has causal attention among itself
|
||||
suffix_causal_mask = torch.tril(torch.ones(suffix_len, suffix_len, dtype=torch.bool, device=device))
|
||||
full_att_2d_masks[:, prefix_len:, prefix_len:] = suffix_causal_mask[None, :, :]
|
||||
|
||||
# Apply padding masks
|
||||
pad_masks = torch.cat([prefix_pad_masks_no_fast, suffix_pad_masks], dim=1)
|
||||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||
full_att_2d_masks = full_att_2d_masks & pad_2d_masks
|
||||
|
||||
# Concatenate prefix (images + prompt_tokens + target_tokens) and suffix (actions) masks
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
|
||||
# Prepare attention masks for full forward pass (prefix + suffix)
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks, dtype=prefix_embs_no_fast.dtype)
|
||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype)
|
||||
|
||||
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||
@@ -1191,10 +832,11 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
# prefix_out to be used for the language head
|
||||
return suffix_out
|
||||
|
||||
suffix_out = self._apply_checkpoint(
|
||||
forward_func, prefix_embs_no_fast, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
||||
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
||||
)
|
||||
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
@@ -1209,82 +851,79 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
return {
|
||||
"flow_loss": fm_loss,
|
||||
"subtask_loss": subtask_loss,
|
||||
"fast_loss": fast_loss,
|
||||
"loss": fm_loss.mean() + 0.1 * subtask_loss + 0.05 * fast_loss, # ref: b1k winner
|
||||
"target_loss": target_loss,
|
||||
"loss": 10 * fm_loss.mean() + target_loss,
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _generate_subtask_tokens(
|
||||
self, images, img_masks, tokens, masks, tokenizer, max_length, device
|
||||
def _generate_target_tokens(
|
||||
self, images, img_masks, prompt_tokens, prompt_masks, tokenizer, max_length, device
|
||||
):
|
||||
bsize = tokens.shape[0]
|
||||
"""Generate target tokens autoregressively using next token prediction."""
|
||||
bsize = prompt_tokens.shape[0]
|
||||
|
||||
# Get lm_head for token generation
|
||||
lm_head = self.paligemma_with_expert.paligemma.lm_head
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _, _ = self.embed_prefix(
|
||||
images, img_masks, tokens, subtask_tokens=None, masks=masks, subtask_masks=None,
|
||||
fast_action_tokens=None, fast_action_masks=None
|
||||
# Embed prefix without target tokens first
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix(
|
||||
images, img_masks, prompt_tokens, target_tokens=None, prompt_masks=prompt_masks, target_masks=None
|
||||
)
|
||||
|
||||
|
||||
# Initialize generated tokens list
|
||||
generated_tokens = torch.zeros((bsize, max_length), dtype=torch.long, device=device)
|
||||
|
||||
# tracking mask: False = still generating, True = finished
|
||||
finished = torch.zeros(bsize, dtype=torch.bool, device=device)
|
||||
|
||||
for t in range(max_length):
|
||||
# Prepare attention masks for current prefix
|
||||
att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
att_2d_prefix_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
|
||||
|
||||
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype)
|
||||
|
||||
# Forward pass through model to get logits
|
||||
(prefix_out, _), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_prefix_4d,
|
||||
attention_mask=att_2d_prefix_4d,
|
||||
position_ids=position_ids_prefix,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
# ...
|
||||
use_cache=False,
|
||||
adarms_cond=[None, None],
|
||||
)
|
||||
|
||||
logits = lm_head(prefix_out)
|
||||
next_token_logits = logits[:, -1, :]
|
||||
next_token = torch.argmax(next_token_logits, dim=-1) # (B,)
|
||||
# Get logits from the last position
|
||||
logits = lm_head(prefix_out) # (B, T_prefix, vocab)
|
||||
next_token_logits = logits[:, -1, :] # (B, vocab)
|
||||
|
||||
# 1. if a row was already finished, force the next token to be PAD (0)
|
||||
next_token = torch.where(finished, torch.tensor(0, device=device), next_token)
|
||||
# Greedy decoding - take the most likely token
|
||||
next_token = torch.argmax(next_token_logits, dim=-1) # (B,)
|
||||
|
||||
# 2. store the token
|
||||
# Store generated token
|
||||
generated_tokens[:, t] = next_token
|
||||
|
||||
# 3. update the finished mask
|
||||
# Check for EOS token - if all batches have generated EOS, stop
|
||||
if tokenizer.eos_token_id is not None:
|
||||
finished |= (next_token == tokenizer.eos_token_id)
|
||||
if (next_token == tokenizer.eos_token_id).all():
|
||||
break
|
||||
|
||||
# 4. break only if everyone is finished
|
||||
if finished.all():
|
||||
break
|
||||
|
||||
next_token_unsqueezed = next_token.unsqueeze(1)
|
||||
# Embed the generated token and append to prefix
|
||||
next_token_unsqueezed = next_token.unsqueeze(1) # (B, 1)
|
||||
|
||||
def next_token_embed_func(next_token_unsqueezed):
|
||||
next_emb = self.paligemma_with_expert.embed_language_tokens(next_token_unsqueezed)
|
||||
return next_emb * math.sqrt(next_emb.shape[-1])
|
||||
next_emb_dim = next_emb.shape[-1]
|
||||
return next_emb * math.sqrt(next_emb_dim)
|
||||
|
||||
next_emb = self._apply_checkpoint(next_token_embed_func, next_token_unsqueezed)
|
||||
|
||||
# update embeddings
|
||||
# Append to prefix embeddings
|
||||
prefix_embs = torch.cat([prefix_embs, next_emb], dim=1)
|
||||
|
||||
# update padding masks
|
||||
# Update masks - new token is valid and uses causal attention
|
||||
prefix_pad_masks = torch.cat([
|
||||
prefix_pad_masks,
|
||||
torch.ones((bsize, 1), dtype=torch.bool, device=device)
|
||||
], dim=1)
|
||||
|
||||
# update attention masks
|
||||
old_seq_len = prefix_att_masks.shape[1]
|
||||
new_seq_len = old_seq_len + 1
|
||||
new_att_masks = torch.zeros((bsize, new_seq_len, new_seq_len), dtype=torch.bool, device=device)
|
||||
new_att_masks[:, :old_seq_len, :old_seq_len] = prefix_att_masks
|
||||
new_att_masks[:, -1, :] = prefix_pad_masks
|
||||
prefix_att_masks = new_att_masks
|
||||
|
||||
prefix_att_masks = torch.cat([prefix_att_masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1)
|
||||
|
||||
return generated_tokens
|
||||
|
||||
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
||||
@@ -1292,20 +931,20 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
tokens,
|
||||
masks,
|
||||
prompt_tokens,
|
||||
prompt_masks,
|
||||
noise=None,
|
||||
num_steps=None,
|
||||
tokenizer=None,
|
||||
max_subtask_tokens=50,
|
||||
max_target_tokens=50,
|
||||
**kwargs: Unpack[ActionSelectKwargs],
|
||||
) -> Tensor:
|
||||
"""Do a full inference forward and compute the action."""
|
||||
if num_steps is None:
|
||||
num_steps = self.config.num_inference_steps
|
||||
|
||||
bsize = tokens.shape[0]
|
||||
device = tokens.device
|
||||
bsize = prompt_tokens.shape[0]
|
||||
device = prompt_tokens.device
|
||||
|
||||
if noise is None:
|
||||
# Sample noise with padded dimension as expected by action_in_proj
|
||||
@@ -1316,41 +955,33 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
) # Use config max_action_dim for internal processing
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
# Generate subtask tokens autoregressively during inference
|
||||
generated_subtask_tokens = None
|
||||
# Generate target tokens autoregressively during inference (if tokenizer provided)
|
||||
generated_target_tokens = None
|
||||
target_masks = None
|
||||
if tokenizer is not None:
|
||||
generated_subtask_tokens = self._generate_subtask_tokens(
|
||||
images, img_masks, tokens, masks, tokenizer, max_subtask_tokens, device
|
||||
generated_target_tokens = self._generate_target_tokens(
|
||||
images, img_masks, prompt_tokens, prompt_masks, tokenizer, max_target_tokens, device
|
||||
)
|
||||
|
||||
# Decode and print the generated subtask tokens
|
||||
# Decode and print the generated target tokens
|
||||
for i in range(bsize):
|
||||
# Remove padding tokens (0) and special tokens
|
||||
valid_tokens = generated_subtask_tokens[i][generated_subtask_tokens[i] != 0]
|
||||
valid_tokens = generated_target_tokens[i][generated_target_tokens[i] != 0]
|
||||
decoded_text = tokenizer.decode(valid_tokens, skip_special_tokens=True)
|
||||
print(f"[Inference] Generated subtask {i}: {decoded_text}")
|
||||
print(f"[Inference] Generated target {i}: {decoded_text}")
|
||||
|
||||
# Create mask for generated tokens (all valid)
|
||||
subtask_masks = torch.ones_like(generated_subtask_tokens, dtype=torch.bool)
|
||||
# Create mask for generated tokens (all valid where token != 0)
|
||||
target_masks = generated_target_tokens != 0
|
||||
|
||||
# During inference, we don't have subtask_tokens yet, so pass None
|
||||
# Also no FAST tokens during inference
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, _, _, _ = self.embed_prefix(
|
||||
images, img_masks, tokens, subtask_tokens=generated_subtask_tokens, masks=masks, subtask_masks=subtask_masks,
|
||||
fast_action_tokens=None, fast_action_masks=None
|
||||
# Embed prefix with prompt and optionally generated target tokens
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks, _ = self.embed_prefix(
|
||||
images, img_masks, prompt_tokens, target_tokens=generated_target_tokens,
|
||||
prompt_masks=prompt_masks, target_masks=target_masks
|
||||
)
|
||||
|
||||
# Convert embeddings to bfloat16 if needed for the model
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
|
||||
# prefix_att_masks is already a 2D attention mask from embed_prefix
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks, dtype=prefix_embs.dtype)
|
||||
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
@@ -1575,7 +1206,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
print(f"Remapped {remap_count} state dict keys")
|
||||
|
||||
# Load the remapped state dict into the model
|
||||
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=False)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
|
||||
|
||||
if missing_keys:
|
||||
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
|
||||
@@ -1780,13 +1411,13 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
# Use high_level_task tokens (WITHOUT subtask) for inference - we'll generate the subtask
|
||||
high_level_task = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS}"]
|
||||
high_level_task_masks = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK}"]
|
||||
# Use prompt tokens (WITHOUT target) for inference - we'll generate the target
|
||||
prompt_tokens = batch[f"{OBS_LANGUAGE_PROMPT_TOKENS}"]
|
||||
prompt_masks = batch[f"{OBS_LANGUAGE_PROMPT_ATTENTION_MASK}"]
|
||||
|
||||
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, high_level_task, high_level_task_masks,
|
||||
images, img_masks, prompt_tokens, prompt_masks,
|
||||
tokenizer=self.tokenizer,
|
||||
**kwargs
|
||||
)
|
||||
@@ -1802,32 +1433,15 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
high_level_task = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS}"]
|
||||
high_level_task_masks = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK}"]
|
||||
subtask_tokens, subtask_masks = batch[f"{OBS_LANGUAGE_SUBTASK_ONLY_TOKENS}"], batch[f"{OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK}"]
|
||||
prompt_tokens = batch[f"{OBS_LANGUAGE_PROMPT_TOKENS}"]
|
||||
prompt_masks = batch[f"{OBS_LANGUAGE_PROMPT_ATTENTION_MASK}"]
|
||||
target_tokens, target_masks = batch[f"{OBS_LANGUAGE_TARGET_TOKENS}"], batch[f"{OBS_LANGUAGE_TARGET_ATTENTION_MASK}"]
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
# Decode and print ground truth subtask tokens during training
|
||||
# if self.tokenizer is not None and self.training:
|
||||
# bsize = subtask_tokens.shape[0]
|
||||
# for i in range(bsize):
|
||||
# # Remove padding tokens (0) and special tokens
|
||||
# valid_tokens = subtask_tokens[i][subtask_masks[i].bool()]
|
||||
# # if len(valid_tokens) > 0:
|
||||
# # decoded_text = self.tokenizer.decode(valid_tokens, skip_special_tokens=True)
|
||||
# # print(f"[Training] Ground truth subtask {i}: {decoded_text}")
|
||||
|
||||
# Get FAST action tokens from batch
|
||||
fast_action_tokens = batch.get("action.tokens", None) # (B, max_action_tokens)
|
||||
fast_action_masks = batch.get("action.token_mask", None) # (B, max_action_tokens)
|
||||
# Compute loss (no separate state needed for PI05)
|
||||
# high_level_task = instruction tokens WITHOUT subtask (e.g., "High level task: X; State: Y; Subtask:")
|
||||
# subtask_tokens = subtask tokens to predict (e.g., "pick up the cup")
|
||||
# fast_action_tokens = discrete action token IDs to predict
|
||||
loss_dict = self.model.forward(
|
||||
images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions,
|
||||
fast_action_tokens=fast_action_tokens, fast_action_masks=fast_action_masks
|
||||
)
|
||||
# Compute loss
|
||||
# prompt_tokens = instruction tokens WITHOUT target (e.g., "High level task: X; State: Y; Subtask:")
|
||||
# target_tokens = target tokens to predict (e.g., "pick up the cup")
|
||||
loss_dict = self.model.forward(images, img_masks, prompt_tokens, prompt_masks, target_tokens, target_masks, actions)
|
||||
|
||||
# Extract the total loss
|
||||
loss = loss_dict["loss"]
|
||||
@@ -1836,8 +1450,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
detailed_loss_dict = {
|
||||
"loss": loss.item(),
|
||||
"flow_loss": loss_dict["flow_loss"].mean().item(),
|
||||
"subtask_loss": loss_dict["subtask_loss"].item(),
|
||||
"fast_loss": loss_dict["fast_loss"].item(),
|
||||
"target_loss": loss_dict["target_loss"].item(),
|
||||
}
|
||||
|
||||
return loss, detailed_loss_dict
|
||||
|
||||
@@ -33,7 +33,6 @@ from lerobot.processor import (
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
ActionTokenizerProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
@@ -55,8 +54,8 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
|
||||
|
||||
max_state_dim: int = 32
|
||||
task_key: str = "task"
|
||||
high_level_task_key: str = "user_prompt"
|
||||
subtask_only_key: str = "subtask"
|
||||
prompt_key: str = "prompt"
|
||||
target_key: str = "target"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
@@ -68,7 +67,7 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
|
||||
if tasks is None:
|
||||
raise ValueError("No task found in complementary data")
|
||||
|
||||
high_level_tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.high_level_task_key)
|
||||
high_level_tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get("user_prompt")
|
||||
|
||||
# TODO: check if this necessary
|
||||
state = deepcopy(state)
|
||||
@@ -87,36 +86,27 @@ class Pi05PrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
|
||||
for high_level_task in high_level_tasks:
|
||||
cleaned_high_level_tasks.append(high_level_task.strip().replace("_", " ").replace("\n", " "))
|
||||
|
||||
# Process low level tasks with state information
|
||||
low_level_prompts = []
|
||||
subtask_only_prompts = [] # Store only the subtask text for prediction
|
||||
# Process tasks to create prompts (input) and targets (what to predict)
|
||||
prompts = [] # Input prompts ending with "Subtask:"
|
||||
targets = [] # Target text to predict (the subtask)
|
||||
for i, task in enumerate(tasks):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
|
||||
# Store only the subtask text (used as prediction target)
|
||||
subtask_only_prompts.append(cleaned_text)
|
||||
# Store the subtask text as target for prediction
|
||||
targets.append(cleaned_text)
|
||||
|
||||
if cleaned_high_level_tasks:
|
||||
cleaned_high_level_task = cleaned_high_level_tasks[i]
|
||||
full_prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask: {cleaned_text}"
|
||||
# Prompt ends with "Subtask:" - model will predict the target
|
||||
prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask:"
|
||||
else:
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
raise ValueError("No high level tasks found")
|
||||
|
||||
low_level_prompts.append(full_prompt)
|
||||
prompts.append(prompt)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = low_level_prompts
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.subtask_only_key] = subtask_only_prompts
|
||||
|
||||
# Process high level tasks without state information (if available)
|
||||
if high_level_tasks is not None:
|
||||
high_level_prompts = []
|
||||
for i, cleaned_high_level_task in enumerate(cleaned_high_level_tasks):
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
full_prompt = f"High level task: {cleaned_high_level_task}; State: {state_str}; Subtask:"
|
||||
high_level_prompts.append(full_prompt)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.high_level_task_key] = high_level_prompts
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.prompt_key] = prompts
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.target_key] = targets
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
@@ -159,6 +149,7 @@ def make_pi05_pre_post_processors(
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
@@ -177,9 +168,6 @@ def make_pi05_pre_post_processors(
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
ActionTokenizerProcessorStep(
|
||||
tokenizer_name="/fsx/jade_choghari/outputs/fast_tokenizer", # TODO: jade put the PI
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
@@ -189,7 +177,7 @@ def make_pi05_pre_post_processors(
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
export CUDA_LAUNCH_BLOCKING=1
|
||||
lerobot-train \
|
||||
--dataset.repo_id=local \
|
||||
--dataset.root=/fsx/jade_choghari/outputs/collect-data-pgen \
|
||||
--output_dir=/fsx/jade_choghari/outputs/pi0_fast_fruit1 \
|
||||
--job_name=pi0_training \
|
||||
--policy.repo_id=jade_choghari/pi0-base1 \
|
||||
--policy.path=lerobot/pi05_base \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=200000 \
|
||||
--save_freq=5000 \
|
||||
--rename_map='{
|
||||
"observation.images.base": "observation.images.base_0_rgb",
|
||||
"observation.images.left_wrist": "observation.images.left_wrist_0_rgb",
|
||||
"observation.images.right_wrist": "observation.images.right_wrist_0_rgb",
|
||||
}' \
|
||||
--batch_size=4 \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true \
|
||||
--wandb.disable_artifact=true \
|
||||
--wandb.project=pi05hi-training \
|
||||
# /fsx/jade_choghari/.cache/huggingface/lerobot/jadechoghari/collect-data
|
||||
@@ -1,18 +0,0 @@
|
||||
rm -rf /fsx/jade_choghari/outputs/pi0_multi_training
|
||||
lerobot-train \
|
||||
--dataset.repo_id=local\
|
||||
--dataset.root=/fsx/jade_choghari/outputs/collect-data-pgen \
|
||||
--output_dir=/fsx/jade_choghari/outputs/pi0_multi_training \
|
||||
--job_name=pi0_multi_training \
|
||||
--policy.repo_id=jadechoghari/pi0-base1 \
|
||||
--policy.path=lerobot/pi05_base \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=50000 \
|
||||
--save_freq=5000 \
|
||||
--rename_map='{
|
||||
"observation.images.base": "observation.images.base_0_rgb",
|
||||
"observation.images.left_wrist": "observation.images.left_wrist_0_rgb",
|
||||
"observation.images.right_wrist": "observation.images.right_wrist_0_rgb",
|
||||
}' \
|
||||
--batch_size=32 \
|
||||
--policy.device=cuda \
|
||||
@@ -1,9 +0,0 @@
|
||||
python src/lerobot/policies/pi05/train_fast_tokenizer.py \
|
||||
--repo_id "local" \
|
||||
--root "/fsx/jade_choghari/outputs/collect-data-pgen" \
|
||||
--action_horizon 16 \
|
||||
--encoded_dims "0:15" \
|
||||
--action_horizon 50 \
|
||||
--vocab_size 1024 \
|
||||
--scale 10.0 \
|
||||
--output_dir "/fsx/jade_choghari/outputs/fast_tokenizer"
|
||||
@@ -1,410 +0,0 @@
|
||||
"""Train FAST tokenizer for action encoding.
|
||||
|
||||
This script:
|
||||
1. Loads action chunks from LeRobotDataset (with sampling)
|
||||
2. Applies delta transforms and per-timestamp normalization
|
||||
3. Trains FAST tokenizer on specified action dimensions
|
||||
4. Saves tokenizer to assets directory
|
||||
5. Reports compression statistics
|
||||
"""
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import tyro
|
||||
from pathlib import Path
|
||||
from transformers import AutoProcessor
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def apply_delta_transform(state: np.ndarray, actions: np.ndarray, delta_dims: list[int] | None) -> np.ndarray:
|
||||
"""Apply delta transform to specified dimensions.
|
||||
|
||||
Args:
|
||||
state: Current state [D]
|
||||
actions: Future actions [D]
|
||||
delta_dims: List of dimension indices to apply delta transform to
|
||||
|
||||
Returns:
|
||||
Transformed actions [D]
|
||||
"""
|
||||
if delta_dims is None or len(delta_dims) == 0:
|
||||
return actions
|
||||
|
||||
delta_actions = actions.copy()
|
||||
for dim in delta_dims:
|
||||
delta_actions[dim] = actions[dim] - state[dim]
|
||||
|
||||
return delta_actions
|
||||
|
||||
|
||||
def process_episode(args):
|
||||
"""Process single episode and return action chunks."""
|
||||
dataset, ep_idx, action_horizon, delta_dims, sample_fraction, state_key, use_delta_transform = args
|
||||
|
||||
try:
|
||||
# Get episode info
|
||||
ep_info = dataset.meta.episodes[ep_idx]
|
||||
from_idx = ep_info["dataset_from_index"]
|
||||
to_idx = ep_info["dataset_to_index"]
|
||||
ep_length = to_idx - from_idx
|
||||
|
||||
if ep_length < action_horizon:
|
||||
return None
|
||||
|
||||
# Load all frames in episode
|
||||
# If dataset has episode filtering, we need to use the mapping
|
||||
states = []
|
||||
actions = []
|
||||
|
||||
for abs_idx in range(from_idx, to_idx):
|
||||
# Map absolute index to relative index if needed
|
||||
if dataset._absolute_to_relative_idx is not None:
|
||||
if abs_idx not in dataset._absolute_to_relative_idx:
|
||||
# This episode's frames aren't in the filtered dataset
|
||||
return None
|
||||
rel_idx = dataset._absolute_to_relative_idx[abs_idx]
|
||||
else:
|
||||
rel_idx = abs_idx
|
||||
|
||||
frame = dataset.hf_dataset[rel_idx]
|
||||
|
||||
# Get state (could be from observation.state or other state key)
|
||||
if state_key in frame:
|
||||
state = frame[state_key].numpy() if torch.is_tensor(frame[state_key]) else np.array(frame[state_key])
|
||||
else:
|
||||
# If no state key, use zeros (no delta transform)
|
||||
state = np.zeros_like(frame["action"].numpy() if torch.is_tensor(frame["action"]) else np.array(frame["action"]))
|
||||
|
||||
action = frame["action"].numpy() if torch.is_tensor(frame["action"]) else np.array(frame["action"])
|
||||
|
||||
states.append(state)
|
||||
actions.append(action)
|
||||
|
||||
states = np.array(states)
|
||||
actions = np.array(actions)
|
||||
|
||||
# Create action chunks (sliding window)
|
||||
# All actions in a chunk are relative to the FIRST state in that chunk
|
||||
action_chunks = []
|
||||
|
||||
for i in range(len(states) - action_horizon + 1):
|
||||
current_state = states[i] # First state in chunk
|
||||
future_absolute_actions = actions[i:i + action_horizon]
|
||||
|
||||
if use_delta_transform:
|
||||
# Relative actions
|
||||
delta_chunk = np.zeros_like(future_absolute_actions)
|
||||
for t in range(action_horizon):
|
||||
delta_chunk[t] = apply_delta_transform(
|
||||
current_state,
|
||||
future_absolute_actions[t],
|
||||
delta_dims,
|
||||
)
|
||||
action_chunks.append(delta_chunk)
|
||||
else:
|
||||
# Absolute actions (NO delta)
|
||||
action_chunks.append(future_absolute_actions)
|
||||
|
||||
if len(action_chunks) == 0:
|
||||
return None
|
||||
|
||||
action_chunks = np.array(action_chunks)
|
||||
|
||||
# Sample chunks
|
||||
if sample_fraction < 1.0:
|
||||
n_chunks = len(action_chunks)
|
||||
n_samples = max(1, int(n_chunks * sample_fraction))
|
||||
episode_seed = hash(ep_idx) % (2**31)
|
||||
rng = np.random.RandomState(episode_seed)
|
||||
indices = rng.choice(n_chunks, size=n_samples, replace=False)
|
||||
action_chunks = action_chunks[indices]
|
||||
|
||||
return action_chunks
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing episode {ep_idx}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
def train_fast_tokenizer(
|
||||
action_chunks: np.ndarray,
|
||||
vocab_size: int = 1024,
|
||||
scale: float = 10.0,
|
||||
) -> AutoProcessor:
|
||||
"""
|
||||
Train FAST tokenizer (BPE on DCT coefficients) on action chunks.
|
||||
|
||||
Uses the .fit() method to train a new tokenizer on the provided data.
|
||||
|
||||
Args:
|
||||
action_chunks: Array of action chunks [N, H, D] where N=num_chunks, H=horizon, D=action_dim
|
||||
vocab_size: BPE vocabulary size
|
||||
scale: DCT scaling factor for quantization
|
||||
|
||||
Returns:
|
||||
Trained FAST tokenizer
|
||||
"""
|
||||
print(f"Training FAST tokenizer on {len(action_chunks)} action chunks...")
|
||||
print(f"Action chunk shape: {action_chunks.shape}")
|
||||
print(f"Vocab size: {vocab_size}")
|
||||
print(f"DCT scale: {scale}")
|
||||
|
||||
# Download the tokenizer source code (not pretrained weights)
|
||||
# We'll train a new tokenizer on our own data
|
||||
base_tokenizer = AutoProcessor.from_pretrained(
|
||||
"physical-intelligence/fast",
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# Convert action_chunks array to list of arrays (expected by .fit())
|
||||
action_data_list = [action_chunks[i] for i in range(len(action_chunks))]
|
||||
|
||||
# Train the new tokenizer on our action data using .fit()
|
||||
# This trains the BPE tokenizer on DCT coefficients
|
||||
print("Training new tokenizer (this may take a few minutes)...")
|
||||
tokenizer = base_tokenizer.fit(
|
||||
action_data_list,
|
||||
scale=scale,
|
||||
vocab_size=vocab_size,
|
||||
time_horizon=action_chunks.shape[1], # action_horizon
|
||||
action_dim=action_chunks.shape[2], # encoded dimensions
|
||||
)
|
||||
print("✓ Tokenizer training complete!")
|
||||
|
||||
# Validate it works
|
||||
sample_chunk = action_chunks[0]
|
||||
encoded = tokenizer(sample_chunk[None])[0]
|
||||
if isinstance(encoded, list):
|
||||
encoded = np.array(encoded)
|
||||
print(f"Sample encoding: {len(encoded)} tokens for chunk shape {sample_chunk.shape}")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def compute_compression_stats(tokenizer, action_chunks: np.ndarray):
|
||||
"""Compute compression statistics."""
|
||||
print("\nComputing compression statistics...")
|
||||
|
||||
# Sample for stats (use max 1000 chunks for speed)
|
||||
sample_size = min(1000, len(action_chunks))
|
||||
sample_indices = np.random.RandomState(42).choice(len(action_chunks), size=sample_size, replace=False)
|
||||
sample_chunks = action_chunks[sample_indices]
|
||||
|
||||
token_lengths = []
|
||||
for chunk in sample_chunks:
|
||||
encoded = tokenizer(chunk[None])[0]
|
||||
if isinstance(encoded, list):
|
||||
token_lengths.append(len(encoded))
|
||||
else:
|
||||
token_lengths.append(encoded.shape[0] if hasattr(encoded, 'shape') else len(encoded))
|
||||
|
||||
token_lengths = np.array(token_lengths)
|
||||
|
||||
# Compression ratio: (H * D) / avg_tokens
|
||||
input_size = action_chunks.shape[1] * action_chunks.shape[2]
|
||||
avg_tokens = np.mean(token_lengths)
|
||||
compression_ratio = input_size / avg_tokens
|
||||
|
||||
stats = {
|
||||
'compression_ratio': float(compression_ratio),
|
||||
'mean_token_length': float(np.mean(token_lengths)),
|
||||
'p99_token_length': float(np.percentile(token_lengths, 99)),
|
||||
'min_token_length': float(np.min(token_lengths)),
|
||||
'max_token_length': float(np.max(token_lengths)),
|
||||
}
|
||||
|
||||
print(f"Compression Statistics:")
|
||||
print(f" Average compression ratio: {stats['compression_ratio']:.2f}x")
|
||||
print(f" Mean token length: {stats['mean_token_length']:.1f}")
|
||||
print(f" P99 token length: {stats['p99_token_length']:.0f}")
|
||||
print(f" Min token length: {stats['min_token_length']:.0f}")
|
||||
print(f" Max token length: {stats['max_token_length']:.0f}")
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def main(
|
||||
repo_id: str,
|
||||
root: str | None = None,
|
||||
action_horizon: int = 10,
|
||||
max_episodes: int | None = None,
|
||||
sample_fraction: float = 0.1,
|
||||
encoded_dims: str = "0:6,7:23",
|
||||
delta_dims: str | None = None,
|
||||
use_delta_transform: bool = False,
|
||||
state_key: str = "observation.state",
|
||||
vocab_size: int = 1024,
|
||||
scale: float = 10.0,
|
||||
output_dir: str | None = None,
|
||||
):
|
||||
"""
|
||||
Train FAST tokenizer for action encoding.
|
||||
|
||||
Args:
|
||||
repo_id: LeRobot dataset repository ID
|
||||
root: Root directory for dataset (default: ~/.cache/huggingface/lerobot)
|
||||
action_horizon: Number of future actions in each chunk
|
||||
max_episodes: Max episodes to use (None = all episodes in dataset)
|
||||
sample_fraction: Fraction of chunks to sample per episode
|
||||
encoded_dims: Comma-separated dimension ranges to encode (e.g., "0:6,7:23")
|
||||
delta_dims: Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5")
|
||||
use_delta_transform: Whether to apply delta transform (relative actions vs absolute actions)
|
||||
state_key: Dataset key for state observations (default: "observation.state")
|
||||
vocab_size: FAST vocabulary size (BPE vocab size)
|
||||
scale: DCT scaling factor (default: 10.0)
|
||||
output_dir: Directory to save tokenizer (default: ./fast_tokenizer_{repo_id})
|
||||
"""
|
||||
# Load dataset
|
||||
print(f"Loading dataset: {repo_id}")
|
||||
dataset = LeRobotDataset(repo_id=repo_id, root=root)
|
||||
print(f"Dataset loaded: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||||
|
||||
# Parse encoded dimensions
|
||||
encoded_dim_ranges = []
|
||||
for range_str in encoded_dims.split(','):
|
||||
start, end = map(int, range_str.strip().split(':'))
|
||||
encoded_dim_ranges.append((start, end))
|
||||
|
||||
total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges)
|
||||
print(f"Encoding {total_encoded_dims} dimensions: {encoded_dims}")
|
||||
|
||||
# Parse delta dimensions
|
||||
delta_dim_list = None
|
||||
if delta_dims is not None and delta_dims.strip():
|
||||
delta_dim_list = [int(d.strip()) for d in delta_dims.split(',')]
|
||||
print(f"Delta dimensions: {delta_dim_list}")
|
||||
else:
|
||||
print("No delta dimensions specified")
|
||||
|
||||
print(f"Use delta transform: {use_delta_transform}")
|
||||
if use_delta_transform and (delta_dim_list is None or len(delta_dim_list) == 0):
|
||||
print("Warning: use_delta_transform=True but no delta_dims specified. No delta will be applied.")
|
||||
|
||||
print(f"Action horizon: {action_horizon}")
|
||||
print(f"State key: {state_key}")
|
||||
|
||||
# Determine episodes to process
|
||||
num_episodes = dataset.num_episodes
|
||||
if max_episodes is not None:
|
||||
num_episodes = min(max_episodes, num_episodes)
|
||||
|
||||
print(f"Processing {num_episodes} episodes...")
|
||||
|
||||
# Process episodes sequentially (to avoid pickling issues with dataset)
|
||||
all_chunks = []
|
||||
for ep_idx in range(num_episodes):
|
||||
if ep_idx % 10 == 0:
|
||||
print(f" Processing episode {ep_idx}/{num_episodes}...")
|
||||
|
||||
chunks = process_episode(
|
||||
(dataset, ep_idx, action_horizon, delta_dim_list, sample_fraction, state_key, use_delta_transform)
|
||||
)
|
||||
if chunks is not None:
|
||||
all_chunks.append(chunks)
|
||||
|
||||
# Concatenate all chunks
|
||||
all_chunks = np.concatenate(all_chunks, axis=0)
|
||||
print(f"Collected {len(all_chunks)} action chunks")
|
||||
|
||||
# Extract only encoded dimensions FIRST (before normalization)
|
||||
encoded_chunks = []
|
||||
for start, end in encoded_dim_ranges:
|
||||
encoded_chunks.append(all_chunks[:, :, start:end])
|
||||
encoded_chunks = np.concatenate(encoded_chunks, axis=-1) # [N, H, D_encoded]
|
||||
print(f"Extracted {encoded_chunks.shape[-1]} encoded dimensions")
|
||||
|
||||
# Apply normalization to encoded dimensions only
|
||||
# NOTE: For FAST, we ALWAYS use QUANTILE normalization (no per-timestamp)
|
||||
# This clips outliers and provides consistent [-1, 1] range for DCT compression
|
||||
print(f"\nBefore normalization - overall stats:")
|
||||
print(f" Min: {np.min(encoded_chunks):.4f}, Max: {np.max(encoded_chunks):.4f}")
|
||||
print(f" Mean: {np.mean(encoded_chunks):.4f}, Std: {np.std(encoded_chunks):.4f}")
|
||||
|
||||
norm_stats = dataset.meta.stats
|
||||
if norm_stats is not None and "action" in norm_stats:
|
||||
action_stats = norm_stats["action"]
|
||||
|
||||
# Build encoded dimension indices
|
||||
encoded_dim_indices = []
|
||||
for start, end in encoded_dim_ranges:
|
||||
encoded_dim_indices.extend(range(start, end))
|
||||
encoded_dim_indices = np.array(encoded_dim_indices)
|
||||
|
||||
# Use QUANTILE normalization: clip to [q01, q99] and map to [-1, 1]
|
||||
if "q01" in action_stats and "q99" in action_stats:
|
||||
q01 = np.array(action_stats["q01"])[encoded_dim_indices] # [D_encoded]
|
||||
q99 = np.array(action_stats["q99"])[encoded_dim_indices] # [D_encoded]
|
||||
|
||||
print(f"\nNormalization stats (q01, q99) for encoded dimensions:")
|
||||
for i, dim_idx in enumerate(encoded_dim_indices):
|
||||
print(f" Orig dim {dim_idx}: q01={q01[i]:7.4f}, q99={q99[i]:7.4f}, range={q99[i]-q01[i]:7.4f}")
|
||||
|
||||
# Clip to quantile range and normalize to [-1, 1]
|
||||
encoded_chunks = np.clip(encoded_chunks, q01, q99)
|
||||
encoded_chunks = 2.0 * (encoded_chunks - q01) / np.maximum(q99 - q01, 1e-6) - 1.0
|
||||
print(f"\nApplied quantile normalization [q01, q99] → [-1, 1]")
|
||||
|
||||
print(f"\nAfter normalization - overall stats:")
|
||||
print(f" Min: {np.min(encoded_chunks):.4f}, Max: {np.max(encoded_chunks):.4f}")
|
||||
print(f" Mean: {np.mean(encoded_chunks):.4f}, Std: {np.std(encoded_chunks):.4f}")
|
||||
|
||||
print(f"\nPer-dimension stats (after normalization):")
|
||||
for d in range(encoded_chunks.shape[-1]):
|
||||
dim_data = encoded_chunks[:, :, d]
|
||||
print(f" Dim {d}: min={np.min(dim_data):7.4f}, max={np.max(dim_data):7.4f}, "
|
||||
f"mean={np.mean(dim_data):7.4f}, std={np.std(dim_data):7.4f}")
|
||||
else:
|
||||
print("Warning: q01/q99 stats not found, using raw actions")
|
||||
else:
|
||||
print("Warning: No normalization stats found, using raw actions")
|
||||
|
||||
print(f"Encoded chunks shape: {encoded_chunks.shape}")
|
||||
|
||||
# Train FAST tokenizer
|
||||
tokenizer = train_fast_tokenizer(
|
||||
encoded_chunks,
|
||||
vocab_size=vocab_size,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
# Compute compression statistics
|
||||
compression_stats = compute_compression_stats(tokenizer, encoded_chunks)
|
||||
|
||||
# Save tokenizer
|
||||
if output_dir is None:
|
||||
output_dir = f"fast_tokenizer_{repo_id.replace('/', '_')}"
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tokenizer.save_pretrained(output_path)
|
||||
|
||||
# Save metadata
|
||||
metadata = {
|
||||
'repo_id': repo_id,
|
||||
'vocab_size': vocab_size,
|
||||
'scale': scale,
|
||||
'encoded_dims': encoded_dims,
|
||||
'encoded_dim_ranges': encoded_dim_ranges,
|
||||
'total_encoded_dims': total_encoded_dims,
|
||||
'delta_dims': delta_dims,
|
||||
'delta_dim_list': delta_dim_list,
|
||||
'use_delta_transform': use_delta_transform,
|
||||
'state_key': state_key,
|
||||
'action_horizon': action_horizon,
|
||||
'num_training_chunks': len(encoded_chunks),
|
||||
'compression_stats': compression_stats,
|
||||
}
|
||||
|
||||
with open(output_path / "metadata.json", 'w') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
print(f"\n✅ Saved FAST tokenizer to {output_path}")
|
||||
print(f"Metadata: {json.dumps(metadata, indent=2)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(main)
|
||||
@@ -1,101 +0,0 @@
|
||||
# Train FAST Tokenizer - Usage Examples
|
||||
|
||||
This script trains a FAST (Factorized Action Sequence Tokenizer) on LeRobotDataset action data.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/pi05/train_fast_tokenizer.py \
|
||||
--repo_id "lerobot/aloha_sim_insertion_human" \
|
||||
--action_horizon 10 \
|
||||
--encoded_dims "0:7" \
|
||||
--vocab_size 1024 \
|
||||
--scale 10.0
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
### Required
|
||||
- `--repo_id`: LeRobot dataset repository ID (e.g., "lerobot/aloha_sim_insertion_human")
|
||||
|
||||
### Optional
|
||||
- `--root`: Root directory for dataset (default: ~/.cache/huggingface/lerobot)
|
||||
- `--action_horizon`: Number of future actions in each chunk (default: 10)
|
||||
- `--max_episodes`: Maximum number of episodes to use (default: None = all)
|
||||
- `--sample_fraction`: Fraction of chunks to sample per episode (default: 0.1)
|
||||
- `--encoded_dims`: Comma-separated dimension ranges to encode (default: "0:6,7:23")
|
||||
- Example: "0:7" encodes dimensions 0-6
|
||||
- Example: "0:3,6:9" encodes dimensions 0-2 and 6-8
|
||||
- `--delta_dims`: Comma-separated dimension indices for delta transform (default: None)
|
||||
- Example: "0,1,2,3,4,5" applies delta transform to first 6 dimensions
|
||||
- Delta transform: action[i] - state[i] for specified dimensions
|
||||
- `--state_key`: Dataset key for state observations (default: "observation.state")
|
||||
- `--vocab_size`: FAST vocabulary size / BPE vocab size (default: 1024)
|
||||
- `--scale`: DCT scaling factor (default: 10.0)
|
||||
- `--output_dir`: Directory to save tokenizer (default: ./fast_tokenizer_{repo_id})
|
||||
|
||||
## Examples
|
||||
|
||||
### Example 1: Train on full action space
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/pi05/train_fast_tokenizer.py \
|
||||
--repo_id "lerobot/pusht" \
|
||||
--action_horizon 16 \
|
||||
--encoded_dims "0:2" \
|
||||
--vocab_size 512 \
|
||||
--max_episodes 100
|
||||
```
|
||||
|
||||
### Example 2: Train with delta transform
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/pi05/train_fast_tokenizer.py \
|
||||
--repo_id "lerobot/aloha_sim_insertion_human" \
|
||||
--action_horizon 10 \
|
||||
--encoded_dims "0:14" \
|
||||
--delta_dims "0,1,2,3,4,5,6,7,8,9,10,11,12,13" \
|
||||
--state_key "observation.state" \
|
||||
--vocab_size 1024 \
|
||||
--scale 10.0 \
|
||||
--sample_fraction 0.2
|
||||
```
|
||||
|
||||
### Example 3: Train on subset of dimensions
|
||||
|
||||
```bash
|
||||
python src/lerobot/policies/pi05/train_fast_tokenizer.py \
|
||||
--repo_id "lerobot/aloha_sim_insertion_human" \
|
||||
--action_horizon 10 \
|
||||
--encoded_dims "0:7" \
|
||||
--vocab_size 1024 \
|
||||
--output_dir "./my_tokenizer"
|
||||
```
|
||||
|
||||
## Output
|
||||
|
||||
The script saves:
|
||||
1. **Tokenizer files**: Trained FAST tokenizer (can be loaded with `AutoProcessor.from_pretrained()`)
|
||||
2. **metadata.json**: Contains:
|
||||
- Configuration parameters
|
||||
- Compression statistics (compression ratio, token lengths)
|
||||
- Training dataset information
|
||||
|
||||
## Understanding the Process
|
||||
|
||||
1. **Load Dataset**: Loads the LeRobotDataset from HuggingFace
|
||||
2. **Extract Action Chunks**: Creates sliding windows of actions with specified horizon
|
||||
3. **Apply Delta Transform**: (Optional) Computes action deltas relative to current state
|
||||
4. **Select Encoded Dimensions**: Extracts only the dimensions to be encoded
|
||||
5. **Normalize**: Applies quantile normalization ([q01, q99] → [-1, 1])
|
||||
6. **Train Tokenizer**: Trains BPE tokenizer on DCT coefficients
|
||||
7. **Compute Stats**: Reports compression ratio and token length statistics
|
||||
8. **Save**: Saves tokenizer and metadata
|
||||
|
||||
## Notes
|
||||
|
||||
- **Normalization**: The script uses quantile normalization (q01, q99) from the dataset's statistics
|
||||
- **Sampling**: To speed up training, you can sample a fraction of chunks per episode
|
||||
- **Delta Transform**: Applied per-dimension to make actions relative to current state
|
||||
- **Compression**: FAST uses DCT + BPE to compress action sequences efficiently
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
rm -rf /fsx/jade_choghari/outputs/pi0_multi_training
|
||||
accelerate launch --multi_gpu --num_processes=2 \
|
||||
$(which lerobot-train) \
|
||||
--dataset.repo_id=local \
|
||||
--dataset.root=/fsx/jade_choghari/outputs/collect-data-pgen \
|
||||
--output_dir=/fsx/jade_choghari/outputs/pi0_multi_training \
|
||||
--job_name=pi0_multi_training \
|
||||
--policy.repo_id=jadechoghari/pi0-base1 \
|
||||
--policy.path=lerobot/pi05_base \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=50000 \
|
||||
--save_freq=5000 \
|
||||
--rename_map='{
|
||||
"observation.images.base": "observation.images.base_0_rgb",
|
||||
"observation.images.left_wrist": "observation.images.left_wrist_0_rgb",
|
||||
"observation.images.right_wrist": "observation.images.right_wrist_0_rgb",
|
||||
}' \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--batch_size=1 \
|
||||
--policy.device=cpu
|
||||
# --wandb.enable=true \
|
||||
# --wandb.disable_artifact=true \
|
||||
# --wandb.project=pi05hi-training \
|
||||
@@ -75,7 +75,7 @@ from .policy_robot_bridge import (
|
||||
RobotActionToPolicyActionProcessorStep,
|
||||
)
|
||||
from .rename_processor import RenameObservationsProcessorStep
|
||||
from .tokenizer_processor import TokenizerProcessorStep, ActionTokenizerProcessorStep
|
||||
from .tokenizer_processor import TokenizerProcessorStep
|
||||
|
||||
__all__ = [
|
||||
"ActionProcessorStep",
|
||||
|
||||
@@ -15,13 +15,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script defines processors for tokenizing data from an environment transition.
|
||||
This script defines a processor for tokenizing natural language instructions from an environment transition.
|
||||
|
||||
It includes:
|
||||
- TokenizerProcessorStep: Uses a tokenizer from the Hugging Face `transformers` library to convert
|
||||
task descriptions (text) into token IDs and attention masks, which are then added to the observation dictionary.
|
||||
- ActionTokenizerProcessorStep: Uses a processor/tokenizer (e.g., the Physical Intelligence "fast" tokenizer)
|
||||
to tokenize action tensors into discrete token IDs for action modeling.
|
||||
It uses a tokenizer from the Hugging Face `transformers` library to convert task descriptions (text) into
|
||||
token IDs and attention masks, which are then added to the observation dictionary.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -33,25 +30,22 @@ import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import (
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS,
|
||||
OBS_LANGUAGE_PROMPT_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_PROMPT_TOKENS,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_LANGUAGE_SUBTASK_ONLY_TOKENS,
|
||||
OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TARGET_TOKENS,
|
||||
OBS_LANGUAGE_TARGET_ATTENTION_MASK,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
from transformers import AutoTokenizer
|
||||
else:
|
||||
AutoProcessor = None
|
||||
AutoTokenizer = None
|
||||
|
||||
|
||||
@@ -65,8 +59,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
tokenizes it using a Hugging Face `transformers` tokenizer, and adds the resulting
|
||||
token IDs and attention mask to the `observation` dictionary.
|
||||
|
||||
Optionally, this step can also tokenize a high-level task (e.g., user prompt) and/or
|
||||
a subtask if present in the complementary data, creating separate tokenized observations.
|
||||
Optionally, this step can also tokenize a prompt (input for generation) and/or
|
||||
a target (text to predict) if present in the complementary data, creating separate tokenized observations.
|
||||
|
||||
Requires the `transformers` library to be installed.
|
||||
|
||||
@@ -75,8 +69,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
tokenizer: A pre-initialized tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||
max_length: The maximum length to pad or truncate sequences to.
|
||||
task_key: The key in `complementary_data` where the task string is stored.
|
||||
high_level_task_key: The key in `complementary_data` where the high-level task (user prompt) is stored.
|
||||
subtask_key: The key in `complementary_data` where the subtask string is stored.
|
||||
prompt_key: The key in `complementary_data` where the prompt (input for generation) is stored.
|
||||
target_key: The key in `complementary_data` where the target (text to predict) is stored.
|
||||
padding_side: The side to pad on ('left' or 'right').
|
||||
padding: The padding strategy ('max_length', 'longest', etc.).
|
||||
truncation: Whether to truncate sequences longer than `max_length`.
|
||||
@@ -87,8 +81,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
tokenizer: Any | None = None # Use `Any` for compatibility without a hard dependency
|
||||
max_length: int = 512
|
||||
task_key: str = "task"
|
||||
high_level_task_key: str = "user_prompt"
|
||||
subtask_key: str = "subtask"
|
||||
prompt_key: str = "prompt"
|
||||
target_key: str = "target"
|
||||
padding_side: str = "right"
|
||||
padding: str = "max_length"
|
||||
truncation: bool = True
|
||||
@@ -153,57 +147,57 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
|
||||
return None
|
||||
|
||||
def get_high_level_task(self, transition: EnvTransition) -> list[str] | None:
|
||||
def get_prompt(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""
|
||||
Extracts the high-level task description(s) from the transition's complementary data.
|
||||
Extracts the prompt (input for generation) from the transition's complementary data.
|
||||
|
||||
Args:
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
A list of high-level task strings, or None if the high-level task key is not found or the value is None.
|
||||
A list of prompt strings, or None if the prompt key is not found or the value is None.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
high_level_task = complementary_data.get(self.high_level_task_key)
|
||||
prompt = complementary_data.get(self.prompt_key)
|
||||
|
||||
if high_level_task is None:
|
||||
if prompt is None:
|
||||
return None
|
||||
|
||||
# Standardize to a list of strings for the tokenizer
|
||||
if isinstance(high_level_task, str):
|
||||
return [high_level_task]
|
||||
elif isinstance(high_level_task, list) and all(isinstance(t, str) for t in high_level_task):
|
||||
return high_level_task
|
||||
if isinstance(prompt, str):
|
||||
return [prompt]
|
||||
elif isinstance(prompt, list) and all(isinstance(t, str) for t in prompt):
|
||||
return prompt
|
||||
|
||||
return None
|
||||
|
||||
def get_subtask(self, transition: EnvTransition) -> list[str] | None:
|
||||
def get_target(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""
|
||||
Extracts the subtask description(s) from the transition's complementary data.
|
||||
Extracts the target (text to predict) from the transition's complementary data.
|
||||
|
||||
Args:
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
A list of subtask strings, or None if the subtask key is not found or the value is None.
|
||||
A list of target strings, or None if the target key is not found or the value is None.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
subtask = complementary_data.get(self.subtask_key)
|
||||
target = complementary_data.get(self.target_key)
|
||||
|
||||
if subtask is None:
|
||||
if target is None:
|
||||
return None
|
||||
|
||||
# Standardize to a list of strings for the tokenizer
|
||||
if isinstance(subtask, str):
|
||||
return [subtask]
|
||||
elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask):
|
||||
return subtask
|
||||
if isinstance(target, str):
|
||||
return [target]
|
||||
elif isinstance(target, list) and all(isinstance(t, str) for t in target):
|
||||
return target
|
||||
|
||||
return None
|
||||
|
||||
@@ -244,39 +238,37 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
# Also tokenize high-level task if available
|
||||
high_level_task = self.get_high_level_task(self.transition)
|
||||
if high_level_task is not None:
|
||||
# Tokenize the high-level task
|
||||
tokenized_high_level_prompt = self._tokenize_text(high_level_task)
|
||||
# Also tokenize prompt (input for generation) if available
|
||||
prompt = self.get_prompt(self.transition)
|
||||
if prompt is not None:
|
||||
tokenized_prompt_input = self._tokenize_text(prompt)
|
||||
|
||||
# Move to the same device
|
||||
if target_device is not None:
|
||||
tokenized_high_level_prompt = {
|
||||
tokenized_prompt_input = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_high_level_prompt.items()
|
||||
for k, v in tokenized_prompt_input.items()
|
||||
}
|
||||
|
||||
# Add high-level tokenized data to the observation
|
||||
new_observation[OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS] = tokenized_high_level_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK] = tokenized_high_level_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
# Add prompt tokenized data to the observation
|
||||
new_observation[OBS_LANGUAGE_PROMPT_TOKENS] = tokenized_prompt_input["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_PROMPT_ATTENTION_MASK] = tokenized_prompt_input["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
# Also tokenize subtask if available
|
||||
subtask = self.get_subtask(self.transition)
|
||||
if subtask is not None:
|
||||
# Tokenize the subtask
|
||||
tokenized_subtask_prompt = self._tokenize_text(subtask)
|
||||
# Also tokenize target (text to predict) if available
|
||||
target = self.get_target(self.transition)
|
||||
if target is not None:
|
||||
tokenized_target = self._tokenize_text(target)
|
||||
|
||||
# Move to the same device
|
||||
if target_device is not None:
|
||||
tokenized_subtask_prompt = {
|
||||
tokenized_target = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_subtask_prompt.items()
|
||||
for k, v in tokenized_target.items()
|
||||
}
|
||||
|
||||
# Add subtask tokenized data to the observation
|
||||
new_observation[OBS_LANGUAGE_SUBTASK_ONLY_TOKENS] = tokenized_subtask_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK] = tokenized_subtask_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
# Add target tokenized data to the observation
|
||||
new_observation[OBS_LANGUAGE_TARGET_TOKENS] = tokenized_target["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_TARGET_ATTENTION_MASK] = tokenized_target["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
return new_observation
|
||||
|
||||
@@ -308,17 +300,15 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
|
||||
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
A wrapper around the tokenizer call that appends an EOS token to each sequence.
|
||||
A wrapper around the tokenizer call.
|
||||
|
||||
Args:
|
||||
text: A string or list of strings to tokenize.
|
||||
|
||||
Returns:
|
||||
A dictionary containing tokenized 'input_ids' and 'attention_mask' as PyTorch tensors,
|
||||
with EOS token appended at the end of each sequence.
|
||||
A dictionary containing tokenized 'input_ids' and 'attention_mask' as PyTorch tensors.
|
||||
"""
|
||||
# Tokenize normally
|
||||
tokenized = self.input_tokenizer(
|
||||
return self.input_tokenizer(
|
||||
text,
|
||||
max_length=self.max_length,
|
||||
truncation=self.truncation,
|
||||
@@ -326,34 +316,6 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
padding_side=self.padding_side,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Get EOS token ID
|
||||
eos_token_id = self.input_tokenizer.eos_token_id
|
||||
if eos_token_id is None:
|
||||
# Some tokenizers don't have an EOS token, skip modification
|
||||
return tokenized
|
||||
|
||||
# Append EOS token to each sequence (before padding)
|
||||
input_ids = tokenized["input_ids"]
|
||||
attention_mask = tokenized["attention_mask"]
|
||||
|
||||
for i in range(input_ids.shape[0]):
|
||||
# Find the position of the last non-padding token
|
||||
non_pad_positions = (attention_mask[i] == 1).nonzero(as_tuple=True)[0]
|
||||
|
||||
if len(non_pad_positions) > 0:
|
||||
last_token_pos = non_pad_positions[-1].item()
|
||||
|
||||
# Check if there's room to add EOS token
|
||||
if last_token_pos + 1 < self.max_length:
|
||||
# Insert EOS token after the last real token
|
||||
input_ids[i, last_token_pos + 1] = eos_token_id
|
||||
attention_mask[i, last_token_pos + 1] = 1
|
||||
else:
|
||||
# If at max length, replace the last token with EOS
|
||||
input_ids[i, last_token_pos] = eos_token_id
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
@@ -368,7 +330,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
config = {
|
||||
"max_length": self.max_length,
|
||||
"task_key": self.task_key,
|
||||
"high_level_task_key": self.high_level_task_key,
|
||||
"prompt_key": self.prompt_key,
|
||||
"target_key": self.target_key,
|
||||
"padding_side": self.padding_side,
|
||||
"padding": self.padding,
|
||||
"truncation": self.truncation,
|
||||
@@ -407,255 +370,26 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
# Add features for high-level task tokens and attention mask if they don't already exist
|
||||
if OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS] = PolicyFeature(
|
||||
# Add features for prompt tokens and attention mask if they don't already exist
|
||||
if OBS_LANGUAGE_PROMPT_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_PROMPT_TOKENS] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
if OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK] = PolicyFeature(
|
||||
if OBS_LANGUAGE_PROMPT_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_PROMPT_ATTENTION_MASK] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
if OBS_LANGUAGE_SUBTASK_ONLY_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_ONLY_TOKENS] = PolicyFeature(
|
||||
# Add features for target tokens and attention mask if they don't already exist
|
||||
if OBS_LANGUAGE_TARGET_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_TARGET_TOKENS] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
if OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK] = PolicyFeature(
|
||||
if OBS_LANGUAGE_TARGET_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_TARGET_ATTENTION_MASK] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="action_tokenizer_processor")
|
||||
class ActionTokenizerProcessorStep(ActionProcessorStep):
|
||||
"""
|
||||
Processor step to tokenize action data using a fast action tokenizer.
|
||||
|
||||
This step takes action tensors from an `EnvTransition`, tokenizes them using
|
||||
a Hugging Face `transformers` AutoProcessor (such as the Physical Intelligence "fast" tokenizer),
|
||||
and returns the tokenized action.
|
||||
|
||||
Requires the `transformers` library to be installed.
|
||||
|
||||
Attributes:
|
||||
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
|
||||
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
|
||||
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
|
||||
"""
|
||||
|
||||
tokenizer_name: str | None = None
|
||||
tokenizer: Any | None = None
|
||||
trust_remote_code: bool = True
|
||||
max_action_tokens: int = 32
|
||||
# Internal tokenizer instance (not part of the config)
|
||||
action_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Initializes the action tokenizer after the dataclass is created.
|
||||
|
||||
It checks for the availability of the `transformers` library and loads the tokenizer
|
||||
either from a provided object or by name from the Hugging Face Hub.
|
||||
|
||||
Raises:
|
||||
ImportError: If the `transformers` library is not installed.
|
||||
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
|
||||
"""
|
||||
if not _transformers_available:
|
||||
raise ImportError(
|
||||
"The 'transformers' library is not installed. "
|
||||
"Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionTokenizerProcessorStep."
|
||||
)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
# Use provided tokenizer object directly
|
||||
self.action_tokenizer = self.tokenizer
|
||||
elif self.tokenizer_name is not None:
|
||||
if AutoProcessor is None:
|
||||
raise ImportError("AutoProcessor is not available")
|
||||
self.action_tokenizer = AutoProcessor.from_pretrained(
|
||||
self.tokenizer_name, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
|
||||
"Pass a tokenizer object directly or a tokenizer name to auto-load."
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
Applies action tokenization to the transition.
|
||||
|
||||
This overrides the base class to handle both tokens and mask.
|
||||
|
||||
Args:
|
||||
transition: The input transition with action data.
|
||||
|
||||
Returns:
|
||||
The processed transition with tokenized actions and mask in complementary data.
|
||||
"""
|
||||
self._current_transition = transition.copy()
|
||||
new_transition = self._current_transition
|
||||
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
raise ValueError("ActionTokenizerProcessorStep requires an action in the transition.")
|
||||
|
||||
# Tokenize and get both tokens and mask
|
||||
tokens, mask = self._tokenize_action(action)
|
||||
|
||||
# Store mask in complementary data
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if complementary_data is None:
|
||||
complementary_data = {}
|
||||
complementary_data[ACTION_TOKEN_MASK] = mask
|
||||
complementary_data[ACTION_TOKENS] = tokens
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
return new_transition
|
||||
|
||||
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Tokenizes the action tensor and creates a mask.
|
||||
|
||||
Args:
|
||||
action: The input action tensor to tokenize. Shape: (B, action_dim) or (action_dim,)
|
||||
|
||||
Returns:
|
||||
A tuple of (tokens, mask) where:
|
||||
- tokens: Tensor of token IDs with shape (B, max_action_tokens)
|
||||
- mask: Boolean mask with shape (B, max_action_tokens), True for real tokens, False for padding
|
||||
"""
|
||||
if action is None:
|
||||
raise ValueError("Action cannot be None")
|
||||
|
||||
# Get the device and dtype of the input action
|
||||
device = action.device if isinstance(action, torch.Tensor) else None
|
||||
|
||||
# Handle single sample (add batch dimension)
|
||||
single_sample = action.dim() == 1
|
||||
if single_sample:
|
||||
action = action.unsqueeze(0)
|
||||
|
||||
batch_size = action.shape[0]
|
||||
|
||||
# Tokenize the action batch
|
||||
# The fast tokenizer expects action data and returns token IDs
|
||||
tokens_list = []
|
||||
masks_list = []
|
||||
|
||||
for i in range(batch_size):
|
||||
# Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy)
|
||||
action_cpu = action[i:i+1].cpu()
|
||||
tokens = self.action_tokenizer(action_cpu)
|
||||
|
||||
# Convert to numpy array if it's a list
|
||||
if isinstance(tokens, list):
|
||||
tokens = torch.tensor(tokens, dtype=torch.long, device=action.device)
|
||||
elif not isinstance(tokens, torch.Tensor):
|
||||
tokens = torch.tensor(tokens, dtype=torch.long, device=action.device)
|
||||
else:
|
||||
# Move tokens back to the same device as input action
|
||||
tokens = tokens.to(device=action.device)
|
||||
|
||||
# Flatten to 1D if needed
|
||||
if tokens.dim() > 1:
|
||||
tokens = tokens.flatten()
|
||||
|
||||
# Truncate or pad to max_action_tokens
|
||||
if len(tokens) > self.max_action_tokens:
|
||||
tokens = tokens[:self.max_action_tokens]
|
||||
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
|
||||
else:
|
||||
mask = torch.cat([
|
||||
torch.ones(len(tokens), dtype=torch.bool, device=action.device),
|
||||
torch.zeros(self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device)
|
||||
])
|
||||
# Pad tokens with zeros
|
||||
tokens = torch.nn.functional.pad(
|
||||
tokens,
|
||||
(0, self.max_action_tokens - len(tokens)),
|
||||
value=0
|
||||
)
|
||||
|
||||
tokens_list.append(tokens)
|
||||
masks_list.append(mask)
|
||||
|
||||
# Stack into batched tensors
|
||||
tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens)
|
||||
masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens)
|
||||
|
||||
# Remove batch dimension if input was single sample
|
||||
if single_sample:
|
||||
tokens_batch = tokens_batch.squeeze(0)
|
||||
masks_batch = masks_batch.squeeze(0)
|
||||
|
||||
# Move to the same device as the input
|
||||
if device is not None:
|
||||
tokens_batch = tokens_batch.to(device)
|
||||
masks_batch = masks_batch.to(device)
|
||||
|
||||
return tokens_batch, masks_batch
|
||||
|
||||
def action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This method is not used since we override __call__.
|
||||
Required by ActionProcessorStep ABC.
|
||||
"""
|
||||
tokens, _ = self._tokenize_action(action)
|
||||
return tokens
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Returns the serializable configuration of the processor.
|
||||
|
||||
Note: The tokenizer object itself is not serialized. If the processor was initialized
|
||||
with a tokenizer name, that name will be included in the config.
|
||||
|
||||
Returns:
|
||||
A dictionary with the processor's configuration parameters.
|
||||
"""
|
||||
config = {
|
||||
"trust_remote_code": self.trust_remote_code,
|
||||
"max_action_tokens": self.max_action_tokens,
|
||||
}
|
||||
|
||||
# Only save tokenizer_name if it was used to create the tokenizer
|
||||
if self.tokenizer_name is not None and self.tokenizer is None:
|
||||
config["tokenizer_name"] = self.tokenizer_name
|
||||
|
||||
return config
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Updates feature definitions to reflect tokenized actions.
|
||||
|
||||
This updates the policy features dictionary to indicate that the action
|
||||
has been tokenized into a sequence of token IDs with shape (max_action_tokens,).
|
||||
|
||||
Args:
|
||||
features: The dictionary of existing policy features.
|
||||
|
||||
Returns:
|
||||
The updated dictionary of policy features.
|
||||
"""
|
||||
# Update the action feature to reflect the tokenized shape
|
||||
# The action is now a sequence of token IDs
|
||||
if PipelineFeatureType.ACTION in features:
|
||||
# Replace the action feature with the tokenized version
|
||||
features[PipelineFeatureType.ACTION] = {
|
||||
ACTION_TOKENS: PolicyFeature(
|
||||
type=FeatureType.SEQUENCE, # Token sequence
|
||||
shape=(self.max_action_tokens,)
|
||||
)
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
@@ -26,15 +26,13 @@ OBS_IMAGES = OBS_IMAGE + "s"
|
||||
OBS_LANGUAGE = OBS_STR + ".language"
|
||||
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||
OBS_LANGUAGE_HIGH_LEVEL_TASK = OBS_STR + ".user_prompt"
|
||||
OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS = OBS_LANGUAGE_HIGH_LEVEL_TASK + ".tokens"
|
||||
OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK = OBS_LANGUAGE_HIGH_LEVEL_TASK + ".attention_mask"
|
||||
OBS_LANGUAGE_SUBTASK_ONLY = OBS_STR + ".subtask"
|
||||
OBS_LANGUAGE_SUBTASK_ONLY_TOKENS = OBS_LANGUAGE_SUBTASK_ONLY + ".tokens"
|
||||
OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK_ONLY + ".attention_mask"
|
||||
OBS_LANGUAGE_PROMPT = OBS_STR + ".prompt"
|
||||
OBS_LANGUAGE_PROMPT_TOKENS = OBS_LANGUAGE_PROMPT + ".tokens"
|
||||
OBS_LANGUAGE_PROMPT_ATTENTION_MASK = OBS_LANGUAGE_PROMPT + ".attention_mask"
|
||||
OBS_LANGUAGE_TARGET = OBS_STR + ".target"
|
||||
OBS_LANGUAGE_TARGET_TOKENS = OBS_LANGUAGE_TARGET + ".tokens"
|
||||
OBS_LANGUAGE_TARGET_ATTENTION_MASK = OBS_LANGUAGE_TARGET + ".attention_mask"
|
||||
ACTION = "action"
|
||||
ACTION_TOKENS = ACTION + ".tokens"
|
||||
ACTION_TOKEN_MASK = ACTION + ".token_mask"
|
||||
REWARD = "next.reward"
|
||||
TRUNCATED = "next.truncated"
|
||||
DONE = "next.done"
|
||||
|
||||
Reference in New Issue
Block a user