Compare commits

...

4 Commits

Author SHA1 Message Date
Jade Choghari
18ddc67714 add more changes 2025-12-17 18:23:23 +00:00
Pepijn
b229e7df28 Add voice example 2025-12-17 16:31:25 +01:00
Jade Choghari
8e05dc9a7a add fast tokenizer support 2025-12-16 11:28:27 +00:00
Jade Choghari
fddd044306 add eos token in tokenizer, working 2025-12-14 14:54:07 +00:00
30 changed files with 6916 additions and 157 deletions

View File

@@ -0,0 +1,138 @@
#!/usr/bin/env python
"""
Example demonstrating how to use the ActionTokenizerProcessorStep to tokenize actions.
This example shows how to:
1. Load a dataset with action data
2. Apply the action tokenizer processor to tokenize actions with proper padding/truncation
3. Access both the tokenized actions and the attention mask
4. Decode tokenized actions back to their original form
"""
import torch
from transformers import AutoProcessor
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.processor.tokenizer_processor import ActionTokenizerProcessorStep
from lerobot.utils.constants import ACTION_TOKEN_MASK
# Define delta timestamps for the dataset
delta_timestamps = {
'action': [
0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333,
0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3,
0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335,
0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6,
0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333,
0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9,
0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334,
1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2,
1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333,
1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5,
1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333
]
}
# Load the dataset
print("Loading dataset...")
dataset = LeRobotDataset(
repo_id="local",
root="/fsx/jade_choghari/outputs/pgen_annotations1",
delta_timestamps=delta_timestamps
)
# Create a dataloader
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=4,
shuffle=True,
)
# Get a batch of data
batch = next(iter(dataloader))
action_data = batch["action"] # Shape: (batch_size, action_horizon, action_dim)
print(f"\nOriginal action shape: {action_data.shape}")
print(f"Original action data (first sample, first timestep):\n{action_data[0, 0]}")
# Method 1: Using the tokenizer directly (as in fast_tokenize.py)
print("\n" + "="*80)
print("Method 1: Direct tokenizer usage")
print("="*80)
tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
# Tokenize directly
tokens = tokenizer(action_data)
print(f"\nDirect tokenization result type: {type(tokens)}")
print(f"Tokens shape/length: {tokens.shape if isinstance(tokens, torch.Tensor) else len(tokens)}")
# Decode
decoded_actions = tokenizer.decode(tokens)
print(f"Decoded actions shape: {decoded_actions.shape}")
reconstruction_error = torch.abs(action_data - decoded_actions).mean()
print(f"Mean absolute reconstruction error: {reconstruction_error.item():.6f}")
# Method 2: Using the ActionTokenizerProcessorStep with proper padding/truncation
print("\n" + "="*80)
print("Method 2: Using ActionTokenizerProcessorStep (with padding & mask)")
print("="*80)
# Create the action tokenizer processor step
action_tokenizer_processor = ActionTokenizerProcessorStep(
tokenizer_name="physical-intelligence/fast",
trust_remote_code=True,
max_action_tokens=32, # Maximum number of tokens per action
)
# Create a transition with the action data
transition = {
TransitionKey.ACTION: action_data,
TransitionKey.OBSERVATION: {}, # Empty for this example
}
# Apply the processor
processed_transition = action_tokenizer_processor(transition)
# Extract tokenized actions and mask
tokenized_actions = processed_transition[TransitionKey.ACTION]
complementary_data = processed_transition[TransitionKey.COMPLEMENTARY_DATA]
action_mask = complementary_data[ACTION_TOKEN_MASK]
print(f"\nTokenized actions shape: {tokenized_actions.shape}") # (batch_size, max_action_tokens)
print(f"Action mask shape: {action_mask.shape}") # (batch_size, max_action_tokens)
print(f"Tokenized actions dtype: {tokenized_actions.dtype}")
print(f"Action mask dtype: {action_mask.dtype}")
# Show token statistics
print(f"\nFirst sample tokens: {tokenized_actions[0]}")
print(f"First sample mask: {action_mask[0]}")
num_real_tokens = action_mask[0].sum().item()
print(f"Number of real tokens (non-padding): {num_real_tokens}")
print(f"Number of padding tokens: {action_mask.shape[1] - num_real_tokens}")
# Decode using the mask
print("\nDecoding tokenized actions...")
decoded_with_processor = tokenizer.decode(tokenized_actions)
print(f"Decoded actions shape: {decoded_with_processor.shape}")
# Calculate reconstruction error
reconstruction_error_processor = torch.abs(action_data - decoded_with_processor).mean()
print(f"Mean absolute reconstruction error: {reconstruction_error_processor.item():.6f}")
# Show that masking works correctly
print("\n" + "="*80)
print("Mask demonstration")
print("="*80)
for i in range(min(4, tokenized_actions.shape[0])):
mask_i = action_mask[i]
num_real = mask_i.sum().item()
print(f"Sample {i}: {num_real} real tokens, {len(mask_i) - num_real} padding tokens")
print("\n" + "="*80)
print("Action tokenization example completed successfully!")
print("="*80)

View File

@@ -1402,6 +1402,13 @@ def main():
action="store_true",
help="Push modified dataset to HuggingFace Hub",
)
# add image key
parser.add_argument(
"--image-key",
type=str,
default=None,
help="Image observation key to use for image mode (default: None)",
)
args = parser.parse_args()
console = Console()
@@ -1443,7 +1450,10 @@ def main():
)
# Get image keys (for image mode)
image_keys = dataset.meta.camera_keys[:args.num_image_views_per_sample]
if args.image_key:
image_keys = [args.image_key]
else:
image_keys = dataset.meta.camera_keys[:args.num_image_views_per_sample]
if not args.video_mode:
console.print(f"[cyan]Using image keys: {image_keys}[/cyan]")

View File

@@ -0,0 +1,25 @@
import numpy as np
from transformers import AutoProcessor
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1", delta_timestamps=delta_timestamps)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=4,
shuffle=True,
)
batch = next(iter(dataloader))
# Load the tokenizer from the Hugging Face hub
tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
# Tokenize & decode action chunks (we use dummy data here)
action_data = batch["action"] # one batch of action chunks
tokens = tokenizer(action_data) # tokens = list[int]
decoded_actions = tokenizer.decode(tokens)
print("tokenized actions: ", tokens)

View File

@@ -10,17 +10,19 @@ from lerobot.policies.factory import make_policy, make_policy_config
from lerobot.configs.policies import PreTrainedConfig
cfg = PreTrainedConfig.from_pretrained(
pretrained_name_or_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model",
pretrained_name_or_path="/fsx/jade_choghari/outputs/pi0_training/checkpoints/last/pretrained_model",
)
cfg.dtype = "bfloat16"
pre_processor, post_processor = make_pre_post_processors(
policy_cfg=cfg,
pretrained_path="/fsx/jade_choghari/outputs/pi0_training_new/checkpoints/last/pretrained_model",
pretrained_path="/fsx/jade_choghari/outputs/pi0_training/checkpoints/last/pretrained_model",
)
delta_timestamps = {'action': [0.0, 0.03333333333333333, 0.06666666666666667, 0.1, 0.13333333333333333, 0.16666666666666666, 0.2, 0.23333333333333334, 0.26666666666666666, 0.3, 0.3333333333333333, 0.36666666666666664, 0.4, 0.43333333333333335, 0.4666666666666667, 0.5, 0.5333333333333333, 0.5666666666666667, 0.6, 0.6333333333333333, 0.6666666666666666, 0.7, 0.7333333333333333, 0.7666666666666667, 0.8, 0.8333333333333334, 0.8666666666666667, 0.9, 0.9333333333333333, 0.9666666666666667, 1.0, 1.0333333333333334, 1.0666666666666667, 1.1, 1.1333333333333333, 1.1666666666666667, 1.2, 1.2333333333333334, 1.2666666666666666, 1.3, 1.3333333333333333, 1.3666666666666667, 1.4, 1.4333333333333333, 1.4666666666666666, 1.5, 1.5333333333333334, 1.5666666666666667, 1.6, 1.6333333333333333]}
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1", delta_timestamps=delta_timestamps)
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/pgen_annotations1")
# rename map --rename_map='{
# "observation.images.side": "observation.images.base_0_rgb",
# "observation.images.up": "observation.images.left_wrist_0_rgb"
@@ -43,12 +45,12 @@ dataloader = torch.utils.data.DataLoader(
)
batch = next(iter(dataloader))
batch = pre_processor(batch)
policy.train()
# run inference
# action = policy.select_action(batch)
loss, loss_dict = policy.forward(batch)
breakpoint()
# import requests
# from PIL import Image
# from transformers import AutoProcessor

159
examples/dataset/mask.md Normal file
View File

@@ -0,0 +1,159 @@
## One-sentence answer
> `make_att_2d_masks(prefix_pad_masks, prefix_att_masks)` builds the **actual 2D attention mask** `[B, L, L]` that tells the transformer **which token positions may attend to which others**, combining **padding** and **causality**.
Everything else youve seen so far was just metadata.
---
## What goes in
### Inputs
```python
prefix_pad_masks # shape [B, L]
prefix_att_masks # shape [B, L]
```
Where:
* `prefix_pad_masks[b, i] = True`
→ token `i` exists (not padding)
* `prefix_att_masks[b, i] = False`
→ token `i` is **bidirectional**
* `prefix_att_masks[b, i] = True`
→ token `i` is **causal (autoregressive)**
---
## What comes out
```python
att_2d_prefix # shape [B, L, L]
```
Each entry:
```text
att_2d_prefix[b, i, j] = True
```
means:
> “In batch `b`, **token i (query)** is allowed to attend to **token j (key)**.”
---
## How it is constructed (conceptually)
For **each batch b**, **each query position i**, **each key position j**:
```python
if not prefix_pad_masks[b, j]:
att[b, i, j] = False # cannot attend to padding
else if not prefix_att_masks[b, i]:
att[b, i, j] = True # bidirectional token → can see all real tokens
else:
att[b, i, j] = (j <= i) # causal token → can see only past + itself
```
Thats it.
---
## Tiny concrete example (exactly matching your code)
Suppose:
```python
prefix_pad_masks[0] = [T, T, T, T, T, F]
prefix_att_masks[0] = [F, F, F, T, T, T]
```
Tokens:
```
0: IMG
1: IMG
2: LANG
3: SUB0
4: SUB1
5: PAD
```
---
### Resulting `att_2d_prefix[0]`
`✓ = True, ✗ = False`
| Q \ K | 0 | 1 | 2 | 3 | 4 | 5 |
| ---------- | - | - | - | - | - | - |
| 0 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 1 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 2 (bi) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 3 (causal) | ✓ | ✓ | ✓ | ✓ | ✗ | ✗ |
| 4 (causal) | ✓ | ✓ | ✓ | ✓ | ✓ | ✗ |
| 5 (pad) | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ |
---
## Why this matters for your training code
This line:
```python
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix)
```
Converts `[B, L, L] → [B, 1, L, L]` and possibly flips True/False to `0/-inf`.
This is **exactly what Paligemma uses inside self-attention**.
---
## Key implications (VERY important)
### 1⃣ This mask does **not isolate token groups**
* Bidirectional tokens can attend to **everything**
* Causal tokens only restrict *their own row*
So **flow/action tokens must be blocked separately**.
---
### 2⃣ This is why your AR subtask prediction works
* Subtask tokens are causal
* Output at position `i` predicts token `i+1`
* Padding is fully ignored
---
### 3⃣ Inference behavior
When `subtask_tokens = None`:
* `prefix_att_masks` contains only `False`
* `att_2d_prefix` becomes **fully bidirectional**
* No AR behavior remains
Exactly what you want.
---
## One-sentence takeaway (commit this)
> `make_att_2d_masks` fuses **padding** and **causality** into a concrete `[B, L, L]` attention matrix that the transformer actually uses.
If you want next, I can:
* inspect `make_att_2d_masks()` source with you
* show how to block **flow → subtask** attention
* explain how this changes when suffix tokens are added
* help you refactor this into a cleaner “grouped attention” API
Youre now at the point where the models behavior should feel *predictable*, not magical.

View File

@@ -1,10 +1,11 @@
python examples/dataset/annotate.py \
--repo-id jadechoghari/collect-data \
--video-key observation.images.base \
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
--episodes 16 22
# python examples/dataset/annotate.py \
# --repo-id lerobot/svla_so101_pickplace \
# --video-key observation.images.side \
# --model Qwen/Qwen3-VL-30B-A3B-Instruct \
python examples/dataset/annotate.py \
--repo-id lerobot/svla_so101_pickplace \
--video-key observation.images.side \
--model Qwen/Qwen3-VL-30B-A3B-Instruct \
--episodes 3 5 7 44
# --episodes 5

View File

@@ -4,12 +4,12 @@
# This generates user prompts and robot utterances for hierarchical policy training
# Configuration
REPO_ID="lerobot/svla_so101_pickplace"
REPO_ID="jadechoghari/collect-data"
MODEL="Qwen/Qwen3-VL-30B-A3B-Instruct"
# Alternative: MODEL="Qwen/Qwen2-VL-7B-Instruct"
OUTPUT_DIR="/fsx/jade_choghari/outputs/pgen_annotations1"
OUTPUT_DIR="/fsx/jade_choghari/outputs/collect-data-pgen"
BATCH_SIZE=32
TEMPERATURE=0.9
SAMPLE_INTERVAL=5.0 # Generate dialogue every 1 second (all episodes processed)
@@ -22,6 +22,7 @@ python examples/dataset/annotate_pgen.py \
--temperature "$TEMPERATURE" \
--batch-size "$BATCH_SIZE" \
--sample-interval "$SAMPLE_INTERVAL" \
--image-key observation.images.base \
--num-image-views-per-sample 1
# For faster testing, increase sample interval:

View File

@@ -0,0 +1 @@
srun --time 12:00:00 --qos=high --gres=gpu:1 --mem=24G --partition=hopper-prod --container-image /fsx/michel_aractingi/docker_images/huggingface+lerobot-gpu+dev.sqsh --container-mounts /fsx/jade_choghari

View File

@@ -0,0 +1,47 @@
# Voice Assistant Examples
Voice-enabled robot assistant examples using speech-to-text (STT), and text-to-speech (TTS).
## Overview
These examples demonstrate how to build a voice interface for robot control:
1. **Hold SPACE** → Push-to-talk recording starts
2. **Release SPACE** → Recording stops
3. **STT (Whisper)** → Converts speech to text (high-level task prompt)
4. **Pi0.5** → Generates robot response/utterance
5. **TTS (Kokoro)** → Speaks the response back
## Requirements
```bash
pip install torch transformers sounddevice numpy pynput kokoro>=0.9.2
```
## Usage
### With Pi0.5 Model
```bash
python examples/voice_assistant/voice_assistant_pi05.py \
--pretrained_path path/to/pi05/checkpoint
```
## How It Works
### Pi0.5 Voice Integration
Pi0.5 can generate robot utterances as part of its subtask prediction. The flow:
1. **High-level prompt**: User voice command is transcribed and formatted as a task prompt
2. **Subtask generation**: Pi0.5 autoregressively generates a response
3. **Utterance extraction**: If the response contains `<utterance>...</utterance>` tags, the content is extracted
4. **TTS output**: The response is spoken back to the user
## Configuration Options
| Option | Default | Description |
|--------|---------|-------------|
| `--pretrained_path` | None | Path to Pi0.5 checkpoint |
| `--record_seconds` | 5.0 | Audio recording duration |
| `--max_response_tokens` | 100 | Max tokens in generated response |

View File

@@ -0,0 +1,336 @@
#!/usr/bin/env python
"""
Voice Assistant with Pi0.5: Microphone → STT → Pi0.5 → TTS → Speaker
This example demonstrates how to use Pi0.5 as a conversational robot assistant:
1. Hold SPACE to record your voice command
2. Speech-to-text (Whisper) converts speech to text
3. Text is fed as a high-level prompt to Pi0.5
4. Pi0.5 generates a response (robot utterance)
5. Text-to-speech (Kokoro) speaks the response back
Requirements:
pip install torch transformers sounddevice numpy pynput kokoro>=0.9.2
Usage:
python examples/voice_assistant/voice_assistant_pi05.py \
--pretrained_path lerobot/pi0.5-base
"""
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import argparse
import re
import subprocess
import threading
import time
import numpy as np
import sounddevice as sd
import torch
from pynput import keyboard
from transformers import AutoTokenizer, WhisperForConditionalGeneration, WhisperProcessor
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pi05.modeling_pi05 import PI05Pytorch
SAMPLE_RATE = 16000
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
class Pi05VoiceAssistant:
"""Voice assistant using Pi0.5 for generating robot utterances."""
def __init__(
self,
pretrained_path: str | None = None,
max_response_tokens: int = 100,
max_record_seconds: float = 30.0,
):
self.device = get_device()
self.dtype = torch.float32 if self.device.type == "mps" else torch.bfloat16
self.max_response_tokens = max_response_tokens
self.max_record_seconds = max_record_seconds
# Push-to-talk state
self._recording = False
self._audio_chunks: list[np.ndarray] = []
self._stream: sd.InputStream | None = None
print(f"Using device: {self.device}")
self._load_models(pretrained_path)
def _load_models(self, pretrained_path: str | None):
print("Loading STT (Whisper tiny)...")
self.stt_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
self.stt_model = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-tiny.en", torch_dtype=self.dtype
).to(self.device)
print("Loading Pi0.5 model...")
self._load_pi05(pretrained_path)
print("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
self._load_tts()
print("Ready!\n")
def _load_pi05(self, pretrained_path: str | None):
"""Load Pi0.5 model for utterance generation."""
config = PI05Config()
config.dtype = "float32" if self.device.type == "mps" else "bfloat16"
self.pi05_model = PI05Pytorch(config)
if pretrained_path:
try:
from safetensors.torch import load_file
state_dict = load_file(f"{pretrained_path}/model.safetensors")
self.pi05_model.load_state_dict(state_dict, strict=False)
print(f"✓ Loaded Pi0.5 weights from {pretrained_path}")
except Exception as e:
print(f"Warning: Could not load pretrained weights: {e}")
print("Using randomly initialized model for demo purposes")
self.pi05_model = self.pi05_model.to(self.device)
self.pi05_model.eval()
def _load_tts(self):
try:
print("Loading TTS (Kokoro 82M)...")
from kokoro import KPipeline
self.tts_pipeline = KPipeline(lang_code="a") # American English
self.tts_voice = "af_heart"
self.tts_type = "kokoro"
print("Kokoro loaded!")
except Exception as e:
print(f"Kokoro not available ({e})")
print("Using macOS `say` for TTS")
self.tts_pipeline = None
self.tts_type = "system"
def _audio_callback(self, indata, frames, time_info, status):
"""Callback for audio stream - collects chunks while recording."""
if self._recording:
self._audio_chunks.append(indata.copy())
def _start_recording(self):
"""Start recording audio."""
if self._recording:
return
self._recording = True
self._audio_chunks = []
print("🎤 Recording... (release SPACE to stop)")
def _stop_recording(self) -> np.ndarray | None:
"""Stop recording and return the audio."""
if not self._recording:
return None
self._recording = False
if not self._audio_chunks:
return None
audio = np.concatenate(self._audio_chunks, axis=0).flatten()
duration = len(audio) / SAMPLE_RATE
volume = np.abs(audio).max()
print(f"Recorded {duration:.1f}s, volume: {volume:.4f}")
if volume < 0.001:
print("⚠️ Very low audio - check microphone permissions!")
return None
return audio
def wait_for_spacebar(self) -> np.ndarray | None:
"""Wait for spacebar press, record while held, return audio on release."""
audio_result = None
recording_done = threading.Event()
def on_press(key):
if key == keyboard.Key.space:
self._start_recording()
def on_release(key):
nonlocal audio_result
if key == keyboard.Key.space and self._recording:
audio_result = self._stop_recording()
recording_done.set()
return False # Stop listener
# Start audio stream
self._stream = sd.InputStream(
samplerate=SAMPLE_RATE,
channels=1,
dtype="float32",
callback=self._audio_callback,
blocksize=int(SAMPLE_RATE * 0.1), # 100ms blocks
)
with self._stream:
print("\n⏳ Press and hold SPACE to speak...")
with keyboard.Listener(on_press=on_press, on_release=on_release) as listener:
# Wait for recording to complete or timeout
recording_done.wait(timeout=self.max_record_seconds)
if self._recording:
audio_result = self._stop_recording()
return audio_result
def transcribe(self, audio: np.ndarray) -> str:
start = time.perf_counter()
inputs = self.stt_processor(audio, sampling_rate=SAMPLE_RATE, return_tensors="pt")
input_features = inputs.input_features.to(self.device, dtype=self.dtype)
tokens = self.stt_model.generate(input_features)
text = self.stt_processor.batch_decode(tokens, skip_special_tokens=True)[0]
print(f"STT: {time.perf_counter() - start:.2f}s")
return text.strip()
def _create_dummy_images(self, batch_size: int = 1) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""Create placeholder images for Pi0.5 when no camera is available."""
image_shape = (batch_size, 3, 224, 224)
dummy_image = torch.zeros(image_shape, dtype=torch.float32, device=self.device)
dummy_mask = torch.ones(batch_size, dtype=torch.bool, device=self.device)
return [dummy_image], [dummy_mask]
def _tokenize_prompt(self, text: str) -> tuple[torch.Tensor, torch.Tensor]:
"""Tokenize the user prompt for Pi0.5."""
prompt = f"User request: {text}\nRobot response:"
tokenized = self.tokenizer(
[prompt],
max_length=200,
truncation=True,
padding="max_length",
return_tensors="pt",
)
tokens = tokenized["input_ids"].to(self.device)
masks = tokenized["attention_mask"].to(self.device, dtype=torch.bool)
return tokens, masks
def generate_response(self, user_text: str) -> str:
"""Generate robot utterance using Pi0.5's language generation."""
start = time.perf_counter()
images, img_masks = self._create_dummy_images()
tokens, masks = self._tokenize_prompt(user_text)
with torch.no_grad():
generated_tokens = self.pi05_model._generate_subtask_tokens(
images=images,
img_masks=img_masks,
tokens=tokens,
masks=masks,
tokenizer=self.tokenizer,
max_length=self.max_response_tokens,
device=self.device,
)
# Decode generated tokens
valid_tokens = generated_tokens[0][generated_tokens[0] != 0]
response = self.tokenizer.decode(valid_tokens, skip_special_tokens=True)
# Extract utterance if marked with special tokens
response = self._extract_utterance(response)
print(f"Pi0.5: {time.perf_counter() - start:.2f}s")
return response.strip()
def _extract_utterance(self, text: str) -> str:
"""Extract utterance from between <utterance> tokens if present."""
pattern = r"<utterance>(.*?)</utterance>"
match = re.search(pattern, text, re.DOTALL)
if match:
return match.group(1).strip()
return text
def speak(self, text: str):
start = time.perf_counter()
if self.tts_type == "kokoro":
generator = self.tts_pipeline(text, voice=self.tts_voice)
audio_chunks = [audio for _, _, audio in generator]
if audio_chunks:
audio = np.concatenate(audio_chunks)
sd.play(audio, 24000)
sd.wait()
else:
subprocess.run(["say", text], check=True)
print(f"TTS: {time.perf_counter() - start:.2f}s")
def run(self):
print("=" * 50)
print("Pi0.5 Voice Assistant")
print("=" * 50)
print("• Hold SPACE to record your voice command")
print("• Release SPACE when done speaking")
print("• Press Ctrl+C to exit")
print("=" * 50)
while True:
try:
audio = self.wait_for_spacebar()
if audio is None:
print("(no audio captured)\n")
continue
user_text = self.transcribe(audio)
if not user_text:
print("(no speech detected)\n")
continue
print(f"You: {user_text}")
response = self.generate_response(user_text)
print(f"Robot: {response}\n")
self.speak(response)
except KeyboardInterrupt:
print("\nGoodbye!")
break
def main():
parser = argparse.ArgumentParser(description="Pi0.5 Voice Assistant")
parser.add_argument(
"--pretrained_path",
type=str,
default=None,
help="Path to pretrained Pi0.5 model (optional)",
)
parser.add_argument(
"--max_response_tokens",
type=int,
default=100,
help="Maximum tokens in generated response",
)
parser.add_argument(
"--max_record_seconds",
type=float,
default=30.0,
help="Maximum recording duration in seconds",
)
args = parser.parse_args()
assistant = Pi05VoiceAssistant(
pretrained_path=args.pretrained_path,
max_response_tokens=args.max_response_tokens,
max_record_seconds=args.max_record_seconds,
)
assistant.run()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,26 @@
{
"repo_id": "local",
"vocab_size": 1024,
"scale": 10.0,
"encoded_dims": "0:15",
"encoded_dim_ranges": [
[
0,
15
]
],
"total_encoded_dims": 15,
"delta_dims": null,
"delta_dim_list": null,
"use_delta_transform": false,
"state_key": "observation.state",
"action_horizon": 50,
"num_training_chunks": 4900,
"compression_stats": {
"compression_ratio": 15.85791309863622,
"mean_token_length": 47.295,
"p99_token_length": 90.0,
"min_token_length": 9.0,
"max_token_length": 109.0
}
}

View File

@@ -0,0 +1,158 @@
import logging
from typing import ClassVar
import numpy as np
from scipy.fft import dct
from scipy.fft import idct
from tokenizers import ByteLevelBPETokenizer
from tokenizers.trainers import BpeTrainer
from transformers import PreTrainedTokenizerFast
from transformers.processing_utils import ProcessorMixin
class UniversalActionProcessor(ProcessorMixin):
attributes: ClassVar[list[str]] = ["bpe_tokenizer"]
bpe_tokenizer_class: str = "AutoTokenizer"
def __init__(
self,
bpe_tokenizer: PreTrainedTokenizerFast,
scale: float = 10,
vocab_size: int = 1024,
min_token: int = 0,
*,
action_dim: int | None = None,
time_horizon: int | None = None,
):
self.scale = scale
self.vocab_size = vocab_size
self.min_token = min_token
# Action horizon and dimension needed during decoding. These can be specified
# in three ways (in order of priority):
# 1. passed in as kwargs to decode()
# 2. in the constructor
# 3. cached from the last time decode() was called
self.time_horizon = time_horizon
self.action_dim = action_dim
self.called_time_horizon = time_horizon
self.called_action_dim = action_dim
super().__init__(bpe_tokenizer)
def __call__(self, action_chunk: np.array) -> np.array:
assert action_chunk.ndim <= 3, "Only 3 dimensions supported: [batch, timesteps, action_dim]"
if action_chunk.ndim == 2:
action_chunk = action_chunk[None, ...]
# Cache the time horizon and action dimension for decoding
self.called_time_horizon = action_chunk.shape[-2]
self.called_action_dim = action_chunk.shape[-1]
dct_coeff = dct(action_chunk, axis=1, norm="ortho")
dct_coeff = np.around(dct_coeff * self.scale)
tokens = []
for elem in dct_coeff:
token_str = "".join(map(chr, np.maximum(elem.flatten() - self.min_token, 0).astype(int)))
tokens.append(self.bpe_tokenizer(token_str)["input_ids"])
return tokens
def decode(
self,
tokens: list[list[int]],
*,
time_horizon: int | None = None,
action_dim: int | None = None,
) -> np.array:
self.time_horizon = time_horizon or self.time_horizon or self.called_time_horizon
self.action_dim = action_dim or self.action_dim or self.called_action_dim
# Cache the time horizon and action dimension for the next call
self.called_time_horizon = self.time_horizon
self.called_action_dim = self.action_dim
assert (
self.time_horizon is not None and self.action_dim is not None
), "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim."
decoded_actions = []
for token in tokens:
try:
decoded_tokens = self.bpe_tokenizer.decode(token)
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.min_token
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim)
assert (
decoded_dct_coeff.shape
== (
self.time_horizon,
self.action_dim,
)
), f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})"
except Exception as e:
print(f"Error decoding tokens: {e}")
print(f"Tokens: {token}")
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim))
decoded_actions.append(idct(decoded_dct_coeff / self.scale, axis=0, norm="ortho"))
return np.stack(decoded_actions)
@classmethod
def fit(
cls,
action_data: list[np.array],
scale: float = 10,
vocab_size: int = 1024,
*,
time_horizon: int | None = None,
action_dim: int | None = None,
) -> "UniversalActionProcessor":
# Run DCT over all inputs
dct_tokens = [dct(a, axis=0, norm="ortho").flatten() for a in action_data]
# Quantize and find min token
max_token = int(np.around(np.concatenate(dct_tokens) * scale).max())
min_token = int(np.around(np.concatenate(dct_tokens) * scale).min())
min_vocab_size = max_token - min_token
assert (
min_vocab_size <= vocab_size
), f"Vocab size {vocab_size} is too small for the range of tokens {min_vocab_size}"
if min_vocab_size + 100 > vocab_size:
logging.warning(
f"Initial alphabet size {min_vocab_size} is almost as large as the vocab"
f"size {vocab_size}, consider increasing vocab size"
)
# Make token iterator for BPE training
def _token_iter():
for tokens in dct_tokens:
rounded_tokens = np.around(tokens * scale) - min_token
rounded_tokens = rounded_tokens.astype(int)
string = "".join(map(chr, rounded_tokens))
yield string
# Train BPE tokenizer
bpe = ByteLevelBPETokenizer()
# Set up the entire range of possible tokens as the initial alphabet
alphabet = [chr(i) for i in range(max_token - min_token + 1)]
trainer = BpeTrainer(
vocab_size=vocab_size,
min_frequency=2,
show_progress=True,
special_tokens=[],
initial_alphabet=alphabet,
max_token_length=10000,
)
# Train the inner tokenizer (don't use ByteLevelBPETokenizer.train_from_iterator()
# because it doesn't support custom alphabets)
bpe._tokenizer.train_from_iterator(_token_iter(), trainer=trainer)
return cls(
PreTrainedTokenizerFast(tokenizer_object=bpe, clean_up_tokenization_spaces=False),
scale=scale,
vocab_size=vocab_size,
min_token=min_token,
time_horizon=time_horizon,
action_dim=action_dim,
)

View File

@@ -0,0 +1,11 @@
{
"action_dim": 15,
"auto_map": {
"AutoProcessor": "processing_action_tokenizer.UniversalActionProcessor"
},
"min_token": -71,
"processor_class": "UniversalActionProcessor",
"scale": 10.0,
"time_horizon": 50,
"vocab_size": 1024
}

View File

@@ -0,0 +1 @@
{}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,11 @@
{
"added_tokens_decoder": {},
"auto_map": {
"AutoProcessor": "processing_action_tokenizer.UniversalActionProcessor"
},
"clean_up_tokenization_spaces": false,
"extra_special_tokens": {},
"model_max_length": 1000000000000000019884624838656,
"processor_class": "UniversalActionProcessor",
"tokenizer_class": "PreTrainedTokenizerFast"
}

View File

@@ -0,0 +1,196 @@
# FAST Tokenizer Training for LeRobotDataset
This directory contains tools for training a FAST (Factorized Action Sequence Tokenizer) on LeRobot datasets.
## Files
- **`train_fast_tokenizer.py`**: Main training script (refactored for LeRobotDataset)
- **`train_fast_tokenizer_example.md`**: Usage examples and parameter documentation
- **`MIGRATION_NOTES.md`**: Migration guide from B1K to LeRobotDataset
## Quick Start
```bash
# Basic usage
python train_fast_tokenizer.py \
--repo_id "lerobot/aloha_sim_insertion_human" \
--action_horizon 10 \
--encoded_dims "0:14"
# With delta transform
python train_fast_tokenizer.py \
--repo_id "lerobot/aloha_sim_insertion_human" \
--action_horizon 10 \
--encoded_dims "0:14" \
--delta_dims "0,1,2,3,4,5,6,7,8,9,10,11,12,13" \
--state_key "observation.state" \
--vocab_size 1024
```
## What is FAST?
FAST is a tokenizer for robotic action sequences that:
1. Applies DCT (Discrete Cosine Transform) to action chunks
2. Quantizes DCT coefficients
3. Uses BPE (Byte-Pair Encoding) to compress the quantized sequence
4. Achieves high compression ratios (e.g., 10-20x) while maintaining accuracy
This enables efficient storage and processing of long action sequences in vision-language-action models.
## Requirements
- Python 3.10+
- LeRobot dataset (either local or from HuggingFace Hub)
- transformers (for AutoProcessor)
- numpy
- torch
- tyro
## Workflow
```
LeRobotDataset → Extract Episodes → Apply Delta Transform
Select Dimensions → Normalize (q01, q99) → Create Chunks
Train FAST Tokenizer → Compute Stats → Save
```
## Parameters Guide
### Essential Parameters
- **`repo_id`**: HuggingFace dataset repository ID
- Example: `"lerobot/aloha_sim_insertion_human"`
- **`action_horizon`**: Length of action sequences to tokenize
- Typical: 10-16 steps
- **`encoded_dims`**: Which action dimensions to encode
- Format: `"start:end,start:end"`
- Example: `"0:7"` = dimensions 0-6
- Example: `"0:3,7:10"` = dimensions 0-2 and 7-9
### Optional Parameters
- **`delta_dims`**: Apply delta transform (action - state) to these dimensions
- Format: `"0,1,2,3,4,5"`
- Use for position-based actions
- **`state_key`**: Dataset key containing state observations
- Default: `"observation.state"`
- **`vocab_size`**: BPE vocabulary size
- Default: 1024
- Larger = better compression but more memory
- **`scale`**: DCT quantization scale
- Default: 10.0
- Smaller = finer quantization, larger = coarser
- **`sample_fraction`**: Fraction of action chunks to use per episode
- Default: 0.1 (10%)
- Increase for small datasets, decrease for large datasets
## Output
The script creates a directory (default: `./fast_tokenizer_{repo_id}`) containing:
1. **Tokenizer files**: Can be loaded with `AutoProcessor.from_pretrained()`
2. **`metadata.json`**: Contains:
- Training configuration
- Compression statistics
- Dataset information
## Example Output
```
Loading dataset: lerobot/aloha_sim_insertion_human
Dataset loaded: 50 episodes, 5000 frames
Encoding 14 dimensions: 0:14
Delta dimensions: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
Action horizon: 10
Processing 50 episodes...
Collected 4500 action chunks
Extracted 14 encoded dimensions
Before normalization - overall stats:
Min: -2.3451, Max: 3.1234, Mean: 0.0234, Std: 0.8765
Applied quantile normalization [q01, q99] → [-1, 1]
After normalization - overall stats:
Min: -1.0000, Max: 1.0000, Mean: 0.0156, Std: 0.4321
Training FAST tokenizer on 4500 action chunks...
Action chunk shape: (4500, 10, 14)
Vocab size: 1024
DCT scale: 10.0
✓ Tokenizer training complete!
Compression Statistics:
Average compression ratio: 14.23x
Mean token length: 9.8
P99 token length: 15
Min token length: 6
Max token length: 18
✅ Saved FAST tokenizer to ./fast_tokenizer_lerobot_aloha_sim_insertion_human
```
## Using the Trained Tokenizer
```python
from transformers import AutoProcessor
# Load tokenizer
tokenizer = AutoProcessor.from_pretrained(
"./fast_tokenizer_lerobot_aloha_sim_insertion_human",
trust_remote_code=True
)
# Encode action chunk [horizon, action_dim]
action_chunk = np.random.randn(10, 14) # Example
tokens = tokenizer(action_chunk[None])[0] # Returns token IDs
# Decode tokens back to actions
reconstructed = tokenizer.decode(tokens)
```
## Tips
1. **Start Small**: Use `--max_episodes 10` for initial testing
2. **Check Dimensions**: Verify encoded dimensions match your robot's action space
3. **Delta Transform**: Use for position-based actions, not velocity-based
4. **Normalization**: Ensure dataset has proper statistics computed
5. **Compression Ratio**: Aim for 10-20x for good balance of compression and accuracy
## Troubleshooting
**Issue**: "No normalization stats found"
- **Solution**: Compute dataset statistics first, or use raw actions
**Issue**: "Episode too short for action horizon"
- **Solution**: Reduce `--action_horizon` or filter short episodes
**Issue**: "State key not found"
- **Solution**: Check dataset features and use correct `--state_key`
**Issue**: Memory error with large datasets
- **Solution**: Reduce `--sample_fraction` or `--max_episodes`
## Citation
If you use FAST in your research, please cite:
```bibtex
@article{black2023fast,
title={FAST: Factorized Action Sequence Tokenizer for Vision-Language-Action Models},
author={Black, Kevin and others},
journal={arXiv preprint},
year={2023}
}
```

View File

@@ -37,6 +37,9 @@ class PI05Config(PreTrainedConfig):
# Shorter state and action vectors will be padded to these dimensions
max_state_dim: int = 32
max_action_dim: int = 32
max_action_tokens: int = 32
fast_vocab_size: int = 2048
# Flow matching parameters: see openpi `PI0Pytorch`
num_inference_steps: int = 10

View File

@@ -0,0 +1,21 @@
lerobot-train \
--dataset.repo_id=lerobot \
--dataset.root=/fsx/jade_choghari/outputs/collect-data-pgen \
--output_dir=/fsx/jade_choghari/outputs/pi0test1 \
--job_name=pi0_training \
--policy.repo_id=jade_choghari/pi0-base \
--policy.path=/fsx/jade_choghari/outputs/pi0_fast_fruit1/checkpoints/last/pretrained_model \
--policy.dtype=bfloat16 \
--steps=3000 \
--save_freq=1000 \
--rename_map='{
"observation.images.base": "observation.images.base_0_rgb",
"observation.images.left_wrist": "observation.images.left_wrist_0_rgb",
"observation.images.right_wrist": "observation.images.right_wrist_0_rgb",
}' \
--batch_size=4 \
--policy.device=cuda \
# --wandb.enable=true \
# --wandb.disable_artifact=true \
# --wandb.project=pi05hi-training \

View File

@@ -537,6 +537,18 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
# FAST action token embedding and prediction head
self.fast_action_embedding = nn.Embedding(config.fast_vocab_size, paligemma_config.width)
self.fast_action_lm_head = nn.Linear(paligemma_config.width, config.fast_vocab_size)
# Apply dtype conversion to FAST layers to match model precision
if config.dtype == "bfloat16":
self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16)
self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.bfloat16)
elif config.dtype == "float32":
self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.float32)
self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.float32)
# Initialize gradient checkpointing flag
self.gradient_checkpointing_enabled = False
@@ -592,6 +604,194 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
result = result.to(dtype=dtype)
return result
def _create_custom_attention_mask(self, att_mask_segments, pad_masks, bsize):
"""Create custom 2D attention mask for the new attention pattern.
Attention rules:
- Images + Language: bidirectional among themselves, don't attend to subtask or FAST
- Subtask: attend to images + language, causal among themselves, don't attend to FAST
- FAST: attend to images + language + subtask, causal among themselves
Args:
att_mask_segments: List of (type, length) tuples
pad_masks: Padding masks [B, total_seq_len]
bsize: Batch size
Returns:
att_2d_masks: 2D attention mask [B, total_seq_len, total_seq_len]
"""
total_len = sum(length for _, length in att_mask_segments)
device = pad_masks.device
# Initialize attention mask as False (cannot attend)
att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device)
# Track positions for each segment
positions = []
current_pos = 0
for seg_type, seg_len in att_mask_segments:
positions.append((seg_type, current_pos, current_pos + seg_len))
current_pos += seg_len
# Apply attention rules
for i, (query_type, query_start, query_end) in enumerate(positions):
for j, (key_type, key_start, key_end) in enumerate(positions):
# Images and Language can attend to each other bidirectionally
if query_type in ['image', 'language'] and key_type in ['image', 'language']:
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
# Subtask tokens attend to images + language
elif query_type == 'subtask' and key_type in ['image', 'language']:
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
# Subtask tokens attend causally to themselves
elif query_type == 'subtask' and key_type == 'subtask':
# Create causal mask for subtask tokens
subtask_len = query_end - query_start
causal_mask = torch.tril(torch.ones(subtask_len, subtask_len, dtype=torch.bool, device=device))
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
# FAST tokens attend to images + language + subtask
elif query_type == 'fast' and key_type in ['image', 'language', 'subtask']:
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
# FAST tokens attend causally to themselves
elif query_type == 'fast' and key_type == 'fast':
fast_len = query_end - query_start
causal_mask = torch.tril(torch.ones(fast_len, fast_len, dtype=torch.bool, device=device))
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
# Apply padding masks
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
att_2d_masks = att_2d_masks & pad_2d_masks
return att_2d_masks
def visualize_attention_mask(
self,
att_mask_segments,
att_2d_masks,
save_path,
batch_idx=0,
dpi=150,
max_display_tokens=None
):
"""Visualize the attention mask with labeled segments.
Args:
att_mask_segments: List of (type, length) tuples defining the segments
att_2d_masks: 2D attention mask tensor [B, total_seq_len, total_seq_len]
save_path: Path where to save the visualization image
batch_idx: Which batch item to visualize (default: 0)
dpi: DPI for the saved image (default: 150)
max_display_tokens: Maximum number of tokens to display (for very long sequences)
"""
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import LinearSegmentedColormap
except ImportError:
logging.warning("matplotlib not available, skipping attention mask visualization")
return
# Extract the mask for the specified batch
mask = att_2d_masks[batch_idx].cpu().float().numpy()
# If sequence is too long, downsample for visualization
if max_display_tokens is not None and mask.shape[0] > max_display_tokens:
# Simple downsampling by taking every Nth token
step = mask.shape[0] // max_display_tokens
mask = mask[::step, ::step]
# Adjust segments accordingly
att_mask_segments = [(seg_type, max(1, seg_len // step)) for seg_type, seg_len in att_mask_segments]
# Calculate positions for each segment
positions = []
current_pos = 0
for seg_type, seg_len in att_mask_segments:
positions.append((seg_type, current_pos, current_pos + seg_len))
current_pos += seg_len
# Create figure
fig, ax = plt.subplots(figsize=(12, 10))
# Create custom colormap: white for False (no attention), blue for True (attention)
colors = ['white', '#2E86AB']
n_bins = 2
cmap = LinearSegmentedColormap.from_list('attention', colors, N=n_bins)
# Display the mask
im = ax.imshow(mask, cmap=cmap, aspect='auto', interpolation='nearest', vmin=0, vmax=1)
# Add colorbar
cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label('Attention Enabled', rotation=270, labelpad=20)
cbar.set_ticks([0.25, 0.75])
cbar.set_ticklabels(['No', 'Yes'])
# Define colors for each segment type
segment_colors = {
'image': '#A23B72',
'language': '#F18F01',
'subtask': '#C73E1D',
'fast': '#6A994E'
}
# Draw segment boundaries and labels
for seg_type, start, end in positions:
color = segment_colors.get(seg_type, '#666666')
# Draw vertical lines for columns (keys)
ax.axvline(x=start - 0.5, color=color, linewidth=2, alpha=0.7)
ax.axvline(x=end - 0.5, color=color, linewidth=2, alpha=0.7)
# Draw horizontal lines for rows (queries)
ax.axhline(y=start - 0.5, color=color, linewidth=2, alpha=0.7)
ax.axhline(y=end - 0.5, color=color, linewidth=2, alpha=0.7)
# Add labels at the top
mid_pos = (start + end) / 2
ax.text(mid_pos, -mask.shape[0] * 0.02, f"{seg_type.upper()}\n({end - start})",
ha='center', va='top', fontsize=10, fontweight='bold', color=color)
# Add labels on the left
ax.text(-mask.shape[1] * 0.02, mid_pos, f"{seg_type.upper()}\n({end - start})",
ha='right', va='center', fontsize=10, fontweight='bold', color=color, rotation=0)
# Set axis labels
ax.set_xlabel('Key Position (tokens being attended to)', fontsize=12, fontweight='bold')
ax.set_ylabel('Query Position (tokens attending)', fontsize=12, fontweight='bold')
ax.set_title('Attention Mask Pattern\n(White = No Attention, Blue = Attention Allowed)',
fontsize=14, fontweight='bold', pad=20)
# Create legend for segment types
legend_patches = []
attention_rules = {
'image': 'Bidirectional with lang',
'language': 'Bidirectional with images',
'subtask': 'Attends to img+lang, causal self',
'fast': 'Attends to all, causal self'
}
for seg_type, color in segment_colors.items():
if any(seg[0] == seg_type for seg in att_mask_segments):
rule = attention_rules.get(seg_type, '')
legend_patches.append(mpatches.Patch(color=color, label=f'{seg_type.upper()}: {rule}'))
ax.legend(handles=legend_patches, loc='upper right', bbox_to_anchor=(1.15, 1.0),
framealpha=0.9, fontsize=9)
# Adjust layout and save
plt.tight_layout()
# Ensure the directory exists
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
plt.close()
logging.info(f"Attention mask visualization saved to: {save_path}")
def sample_noise(self, shape, device):
return torch.normal(
mean=0.0,
@@ -607,10 +807,18 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
)
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
return time.to(dtype=torch.float32, device=device)
def embed_prefix(
self, images, img_masks, tokens, subtask_tokens, masks, subtask_masks
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
self,
images,
img_masks,
tokens,
subtask_tokens,
masks,
subtask_masks,
fast_action_tokens=None,
fast_action_masks=None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
"""Embed images with SigLIP, tokens, and optionally subtask tokens with embedding layer.
Args:
@@ -619,17 +827,23 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
tokens: Language instruction tokens
subtask_tokens: Subtask tokens to predict (can be None for inference)
masks: Attention masks for tokens
fast_action_tokens: FAST action tokens for auxiliary prediction (can be None) - discrete token IDs
fast_action_masks: Padding masks for FAST action tokens (can be None)
Returns:
embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided)]
embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided), (fast_action_tokens if provided)]
pad_masks: Padding masks
att_masks: Attention masks (with causal masking for subtask prediction if subtask_tokens provided)
att_masks: Custom 2D attention mask implementing the required pattern
total_T_images: Total number of image tokens
num_subtask_embs: Number of subtask token embeddings
num_fast_embs: Number of FAST action token embeddings
"""
embs = []
pad_masks = []
att_masks = []
att_mask_segments = [] # Store info about each segment for custom mask creation
total_T_images = 0
num_subtask_embs = 0
num_fast_embs = 0
# Process images
for img, img_mask in zip(images, img_masks, strict=True):
@@ -642,7 +856,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
embs.append(img_emb)
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
att_masks += [0] * num_img_embs # Images can attend to all previous tokens
att_mask_segments.append(('image', num_img_embs))
total_T_images += num_img_embs
# Process language instruction tokens
@@ -656,7 +870,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
pad_masks.append(masks)
num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs # Language tokens can attend to all previous tokens (images + tokens)
att_mask_segments.append(('language', num_lang_embs))
# Process subtask tokens if provided (these are predicted, so use causal masking)
if subtask_tokens is not None:
@@ -672,18 +886,49 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
pad_masks.append(subtask_masks)
num_subtask_embs = subtask_emb.shape[1]
# Causal masking for subtask tokens: each subtask token can attend to images, all instruction tokens,
# and previous subtask tokens
att_masks += [1] * num_subtask_embs # Use 1 for causal attention on subtask tokens
att_mask_segments.append(('subtask', num_subtask_embs))
# Process FAST action tokens if provided (these are discrete token IDs)
if fast_action_tokens is not None:
def fast_action_embed_func(fast_action_tokens):
fast_emb = self.fast_action_embedding(fast_action_tokens)
fast_emb_dim = fast_emb.shape[-1]
return fast_emb * math.sqrt(fast_emb_dim)
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
embs.append(fast_action_emb)
# Use provided mask or create default (all valid)
if fast_action_masks is not None:
fast_pad_mask = fast_action_masks
else:
bsize = fast_action_tokens.shape[0]
num_fast_embs = fast_action_tokens.shape[1]
fast_pad_mask = torch.ones(bsize, num_fast_embs, dtype=torch.bool, device=fast_action_tokens.device)
num_fast_embs = fast_action_tokens.shape[1]
pad_masks.append(fast_pad_mask)
att_mask_segments.append(('fast', num_fast_embs))
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
# Create custom 2D attention mask
# Attention rules:
# - Images + Language: bidirectional among themselves, don't attend to subtask or FAST
# - Subtask: attend to images + language, causal among themselves, don't attend to FAST
# - FAST: attend to images + language + subtask, causal among themselves
att_masks = self._create_custom_attention_mask(att_mask_segments, pad_masks, bsize)
bsize = pad_masks.shape[0]
att_masks = att_masks[None, :].expand(bsize, att_masks.shape[0])
# # Optionally visualize the attention mask
# self.visualize_attention_mask(
# att_mask_segments=att_mask_segments,
# att_2d_masks=att_masks,
# save_path="/admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05/attention_mask_visualization.png",
# batch_idx=0,
# max_display_tokens=512 # Limit display for very long sequences
# )
return embs, pad_masks, att_masks, total_T_images
return embs, pad_masks, att_masks, total_T_images, num_subtask_embs, num_fast_embs
def embed_suffix(self, noisy_actions, timestep):
"""Embed noisy_actions, timestep to prepare for Expert Gemma processing."""
@@ -732,8 +977,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
return embs, pad_masks, att_masks, adarms_cond
# loss_dict = self.model.forward(images, img_masks, high_level_task, tokens, masks, subtask_tokens, subtask_masks, actions)
def forward(self, images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions, noise=None, time=None) -> Tensor:
# loss_dict = self.model.forward(images, img_masks, high_level_task, tokens, masks, subtask_tokens, subtask_masks, actions, fast_action_tokens, fast_action_masks)
def forward(self, images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions, fast_action_tokens=None, fast_action_masks=None, noise=None, time=None) -> Tensor:
"""Do a full training forward pass and compute the loss.
Args:
@@ -743,7 +988,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
high_level_task_masks: Attention masks for high_level_task
subtask_tokens: Subtask tokens to predict (e.g., tokens for "pick up the cup")
subtask_masks: Attention masks for subtask_tokens
actions: Ground truth actions
actions: Ground truth actions [B, chunk_size, action_dim]
fast_action_tokens: Discrete action token IDs [B, max_action_tokens]
fast_action_masks: Padding masks for fast action tokens [B, max_action_tokens]
noise: Optional noise for flow matching
time: Optional time for flow matching
"""
@@ -756,77 +1003,184 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
# Embed prefix (images + high_level_task + subtask_tokens)
# Use high_level_task (prompt WITHOUT subtask) + subtask_tokens to predict
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix(
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks
# Initialize FAST loss to 0 (will be computed only if FAST tokens are provided)
fast_loss = torch.tensor(0.0, device=actions.device, dtype=actions.dtype)
# ========== PASS 1: Prefix with FAST tokens for subtask + FAST prediction ==========
# Only run this pass if FAST action tokens are provided
if fast_action_tokens is not None and fast_action_masks is not None:
# Embed prefix (images + high_level_task + subtask_tokens + FAST tokens)
# FAST tokens are provided as discrete token IDs
prefix_with_fast_embs, prefix_with_fast_pad_masks, prefix_with_fast_att_masks, total_T_images, num_subtask_embs, num_fast_embs = self.embed_prefix(
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks,
fast_action_tokens=fast_action_tokens, fast_action_masks=fast_action_masks
)
# Convert embeddings to bfloat16 if needed for the model
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
prefix_with_fast_embs = prefix_with_fast_embs.to(dtype=torch.bfloat16)
# Prepare attention masks for prefix pass with FAST tokens
position_ids_prefix_with_fast = torch.cumsum(prefix_with_fast_pad_masks, dim=1) - 1
att_2d_prefix_with_fast_4d = self._prepare_attention_masks_4d(prefix_with_fast_att_masks, dtype=prefix_with_fast_embs.dtype)
# Forward pass through paligemma for subtask + FAST prediction
(prefix_with_fast_out, _), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_prefix_with_fast_4d,
position_ids=position_ids_prefix_with_fast,
past_key_values=None,
inputs_embeds=[prefix_with_fast_embs, None], # SUFFIX = None
use_cache=False,
adarms_cond=[None, None],
)
# LM HEAD → SUBTASK LOGITS
lm_head = self.paligemma_with_expert.paligemma.lm_head
logits = lm_head(prefix_with_fast_out) # (B, T_prefix_with_fast, vocab)
# Extract logits for subtask token prediction
T_high_level_task = high_level_task.size(1)
T_subtask = subtask_tokens.size(1)
start_index = total_T_images + T_high_level_task
end_index = start_index + T_subtask
logits_subtask = logits[:, start_index-1:end_index-1, :] # (B, T_subtask, vocab)
targets = subtask_tokens # (B, T_subtask)
# Compute cross-entropy loss for subtask
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1))
targets_flat = targets.reshape(-1)
loss_per_token = loss_fct(logits_flat, targets_flat)
loss_per_token = loss_per_token.reshape(targets.shape)
masked_loss = loss_per_token * subtask_masks.float()
subtask_loss = masked_loss.sum() / subtask_masks.sum().clamp(min=1)
# Extract outputs for FAST action token prediction and compute auxiliary loss
# FAST outputs start after subtask tokens
# Similar to subtask, we use autoregressive prediction where position i predicts token i+1
fast_start_index = end_index
fast_end_index = fast_start_index + num_fast_embs
# Get logits for FAST action tokens using the FAST LM head
fast_logits = self.fast_action_lm_head(prefix_with_fast_out) # (B, T_prefix_with_fast, fast_vocab_size)
# Extract logits for FAST token prediction (autoregressive: position i predicts token i+1)
# - Position (fast_start_index-1) predicts fast_action_tokens[0]
# - Position (fast_start_index) predicts fast_action_tokens[1], etc.
fast_logits_for_pred = fast_logits[:, fast_start_index-1:fast_end_index-1, :] # (B, max_action_tokens, fast_vocab_size)
# Compute cross-entropy loss for FAST action tokens
fast_targets = fast_action_tokens # (B, max_action_tokens)
loss_fct_fast = torch.nn.CrossEntropyLoss(reduction='none')
fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1)) # (B*max_action_tokens, fast_vocab_size)
fast_targets_flat = fast_targets.reshape(-1) # (B*max_action_tokens)
fast_loss_per_token = loss_fct_fast(fast_logits_flat, fast_targets_flat) # (B*max_action_tokens)
fast_loss_per_token = fast_loss_per_token.reshape(fast_targets.shape) # (B, max_action_tokens)
# Apply mask and compute mean loss over valid tokens
masked_fast_loss = fast_loss_per_token * fast_action_masks.float()
fast_loss = masked_fast_loss.sum() / fast_action_masks.sum().clamp(min=1)
else:
# If no FAST tokens provided, compute subtask loss without FAST tokens
# This is the fallback for backward compatibility
prefix_embs_for_subtask, prefix_pad_masks_for_subtask, prefix_att_masks_for_subtask, total_T_images, _, _ = self.embed_prefix(
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks,
fast_action_tokens=None, fast_action_masks=None
)
# Convert embeddings to bfloat16 if needed for the model
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
prefix_embs_for_subtask = prefix_embs_for_subtask.to(dtype=torch.bfloat16)
position_ids_prefix = torch.cumsum(prefix_pad_masks_for_subtask, dim=1) - 1
att_2d_prefix_4d = self._prepare_attention_masks_4d(prefix_att_masks_for_subtask, dtype=prefix_embs_for_subtask.dtype)
(prefix_out, _), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_prefix_4d,
position_ids=position_ids_prefix,
past_key_values=None,
inputs_embeds=[prefix_embs_for_subtask, None],
use_cache=False,
adarms_cond=[None, None],
)
lm_head = self.paligemma_with_expert.paligemma.lm_head
logits = lm_head(prefix_out)
T_high_level_task = high_level_task.size(1)
T_subtask = subtask_tokens.size(1)
start_index = total_T_images + T_high_level_task
end_index = start_index + T_subtask
logits_subtask = logits[:, start_index-1:end_index-1, :]
targets = subtask_tokens
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1))
targets_flat = targets.reshape(-1)
loss_per_token = loss_fct(logits_flat, targets_flat)
loss_per_token = loss_per_token.reshape(targets.shape)
masked_loss = loss_per_token * subtask_masks.float()
subtask_loss = masked_loss.sum() / subtask_masks.sum().clamp(min=1)
# ========== PASS 2: Full forward WITHOUT FAST tokens for flow matching ==========
# Embed prefix WITHOUT FAST tokens (images + high_level_task + subtask_tokens)
prefix_embs_no_fast, prefix_pad_masks_no_fast, prefix_att_masks_no_fast, _, _, _ = self.embed_prefix(
images, img_masks, high_level_task, subtask_tokens, high_level_task_masks, subtask_masks,
fast_action_tokens=None, fast_action_masks=None
)
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
# Prepare attention masks for prefix-only pass (for subtask token prediction)
att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype)
# prefix-only transformer run for subtask token prediction
(prefix_out, _), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_prefix_4d,
position_ids=position_ids_prefix,
past_key_values=None,
inputs_embeds=[prefix_embs, None], # SUFFIX = None
use_cache=False,
adarms_cond=[None, None],
)
# LM HEAD → SUBTASK LOGITS
# prefix_out: (B, T_prefix, H) where T_prefix = total_T_images + T_high_level_task + T_subtask
lm_head = self.paligemma_with_expert.paligemma.lm_head
logits = lm_head(prefix_out) # (B, T_prefix, vocab)
# Extract logits for subtask token prediction
# In autoregressive modeling, output at position i predicts token at position i+1
# So we take logits from one position earlier:
# - Position (start_index-1) (last high_level_task token) predicts subtask_tokens[0]
# - Position (start_index) (first subtask token) predicts subtask_tokens[1], etc.
T_high_level_task = high_level_task.size(1)
T_subtask = subtask_tokens.size(1)
start_index = total_T_images + T_high_level_task
end_index = start_index + T_subtask
logits_subtask = logits[:, start_index-1:end_index-1, :] # (B, T_subtask, vocab)
targets = subtask_tokens # (B, T_subtask)
# Compute cross-entropy loss
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
# Reshape for loss computation
logits_flat = logits_subtask.reshape(-1, logits_subtask.size(-1)) # (B*T_subtask, vocab)
targets_flat = targets.reshape(-1) # (B*T_subtask)
loss_per_token = loss_fct(logits_flat, targets_flat) # (B*T_subtask)
loss_per_token = loss_per_token.reshape(targets.shape) # (B, T_subtask)
# Apply mask and compute mean loss over valid tokens
masked_loss = loss_per_token * subtask_masks.float()
subtask_loss = masked_loss.sum() / subtask_masks.sum().clamp(min=1)
breakpoint()
# Convert embeddings to bfloat16 if needed for the model
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
prefix_embs_no_fast = prefix_embs_no_fast.to(dtype=torch.bfloat16)
# Concatenate prefix (images + tokens + subtask_tokens) and suffix (actions) masks
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
# Prepare attention masks for full forward pass (prefix + suffix)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
# For the flow matching pass, we need custom attention where:
# - prefix follows the custom pattern (images+lang bidirectional, subtask causal, no cross-attention)
# - suffix attends to all prefix + causal to itself
# We'll construct this by extending prefix_att_masks_no_fast to include suffix
# prefix_att_masks_no_fast is already a 2D boolean mask [B, prefix_len, prefix_len]
# We need to extend it to [B, prefix_len + suffix_len, prefix_len + suffix_len]
bsize = prefix_pad_masks_no_fast.shape[0]
prefix_len = prefix_pad_masks_no_fast.shape[1]
suffix_len = suffix_pad_masks.shape[1]
total_len = prefix_len + suffix_len
device = prefix_pad_masks_no_fast.device
# Create full attention mask
full_att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device)
# Copy prefix attention pattern
full_att_2d_masks[:, :prefix_len, :prefix_len] = prefix_att_masks_no_fast
# Suffix attends to all prefix
full_att_2d_masks[:, prefix_len:, :prefix_len] = True
# Suffix has causal attention among itself
suffix_causal_mask = torch.tril(torch.ones(suffix_len, suffix_len, dtype=torch.bool, device=device))
full_att_2d_masks[:, prefix_len:, prefix_len:] = suffix_causal_mask[None, :, :]
# Apply padding masks
pad_masks = torch.cat([prefix_pad_masks_no_fast, suffix_pad_masks], dim=1)
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
full_att_2d_masks = full_att_2d_masks & pad_2d_masks
position_ids = torch.cumsum(pad_masks, dim=1) - 1
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks, dtype=prefix_embs.dtype)
att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks, dtype=prefix_embs_no_fast.dtype)
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
(_, suffix_out), _ = self.paligemma_with_expert.forward(
@@ -837,11 +1191,10 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
use_cache=False,
adarms_cond=[None, adarms_cond],
)
# prefix_out to be used for the language head
return suffix_out
suffix_out = self._apply_checkpoint(
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
forward_func, prefix_embs_no_fast, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
)
suffix_out = suffix_out[:, -self.config.chunk_size :]
@@ -857,80 +1210,81 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
return {
"flow_loss": fm_loss,
"subtask_loss": subtask_loss,
"loss": 10 * fm_loss.mean() + subtask_loss,
"fast_loss": fast_loss,
"loss": fm_loss.mean() + 0.1 * subtask_loss + 0.05 * fast_loss, # ref: b1k winner
}
@torch.no_grad()
def _generate_subtask_tokens(
self, images, img_masks, tokens, masks, tokenizer, max_length, device
):
"""Generate subtask tokens autoregressively using next token prediction."""
bsize = tokens.shape[0]
# Get lm_head for token generation
lm_head = self.paligemma_with_expert.paligemma.lm_head
# Embed prefix without subtask tokens first
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images = self.embed_prefix(
images, img_masks, tokens, subtask_tokens=None, masks=masks, subtask_masks=None
prefix_embs, prefix_pad_masks, prefix_att_masks, total_T_images, _, _ = self.embed_prefix(
images, img_masks, tokens, subtask_tokens=None, masks=masks, subtask_masks=None,
fast_action_tokens=None, fast_action_masks=None
)
# Initialize generated tokens list - start with BOS token or first token after instruction
# For PaliGemma, we'll start generation and accumulate tokens
generated_tokens = torch.zeros((bsize, max_length), dtype=torch.long, device=device)
# tracking mask: False = still generating, True = finished
finished = torch.zeros(bsize, dtype=torch.bool, device=device)
for t in range(max_length):
# Prepare attention masks for current prefix
att_2d_prefix = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
position_ids_prefix = torch.cumsum(prefix_pad_masks, dim=1) - 1
att_2d_prefix_4d = self._prepare_attention_masks_4d(att_2d_prefix, dtype=prefix_embs.dtype)
# Forward pass through model to get logits
att_2d_prefix_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
(prefix_out, _), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_prefix_4d,
attention_mask=att_2d_prefix_4d,
position_ids=position_ids_prefix,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=False,
adarms_cond=[None, None],
# ...
)
# Get logits from the last position
logits = lm_head(prefix_out) # (B, T_prefix, vocab)
next_token_logits = logits[:, -1, :] # (B, vocab)
logits = lm_head(prefix_out)
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1) # (B,)
# Greedy decoding - take the most likely token
next_token = torch.argmax(next_token_logits, dim=-1) # (B,)
# 1. if a row was already finished, force the next token to be PAD (0)
next_token = torch.where(finished, torch.tensor(0, device=device), next_token)
# Store generated token
# 2. store the token
generated_tokens[:, t] = next_token
# Check for EOS token - if all batches have generated EOS, stop
# 3. update the finished mask
if tokenizer.eos_token_id is not None:
if (next_token == tokenizer.eos_token_id).all():
break
finished |= (next_token == tokenizer.eos_token_id)
# Embed the generated token and append to prefix
next_token_unsqueezed = next_token.unsqueeze(1) # (B, 1)
breakpoint()
# 4. break only if everyone is finished
if finished.all():
break
next_token_unsqueezed = next_token.unsqueeze(1)
def next_token_embed_func(next_token_unsqueezed):
next_emb = self.paligemma_with_expert.embed_language_tokens(next_token_unsqueezed)
next_emb_dim = next_emb.shape[-1]
return next_emb * math.sqrt(next_emb_dim)
return next_emb * math.sqrt(next_emb.shape[-1])
next_emb = self._apply_checkpoint(next_token_embed_func, next_token_unsqueezed)
# Append to prefix embeddings
# update embeddings
prefix_embs = torch.cat([prefix_embs, next_emb], dim=1)
# Update masks - new token is valid and uses causal attention
# update padding masks
prefix_pad_masks = torch.cat([
prefix_pad_masks,
torch.ones((bsize, 1), dtype=torch.bool, device=device)
], dim=1)
prefix_att_masks = torch.cat([prefix_att_masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1)
# update attention masks
old_seq_len = prefix_att_masks.shape[1]
new_seq_len = old_seq_len + 1
new_att_masks = torch.zeros((bsize, new_seq_len, new_seq_len), dtype=torch.bool, device=device)
new_att_masks[:, :old_seq_len, :old_seq_len] = prefix_att_masks
new_att_masks[:, -1, :] = prefix_pad_masks
prefix_att_masks = new_att_masks
return generated_tokens
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
@@ -980,13 +1334,23 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
subtask_masks = torch.ones_like(generated_subtask_tokens, dtype=torch.bool)
# During inference, we don't have subtask_tokens yet, so pass None
prefix_embs, prefix_pad_masks, prefix_att_masks, _ = self.embed_prefix(
images, img_masks, tokens, subtask_tokens=generated_subtask_tokens, masks=masks, subtask_masks=subtask_masks
# Also no FAST tokens during inference
prefix_embs, prefix_pad_masks, prefix_att_masks, _, _, _ = self.embed_prefix(
images, img_masks, tokens, subtask_tokens=generated_subtask_tokens, masks=masks, subtask_masks=subtask_masks,
fast_action_tokens=None, fast_action_masks=None
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
# Convert embeddings to bfloat16 if needed for the model
if (
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
# prefix_att_masks is already a 2D attention mask from embed_prefix
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks, dtype=prefix_embs.dtype)
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype)
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
@@ -1211,7 +1575,7 @@ class PI05Policy(PreTrainedPolicy):
print(f"Remapped {remap_count} state dict keys")
# Load the remapped state dict into the model
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=False)
if missing_keys:
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
@@ -1419,7 +1783,7 @@ class PI05Policy(PreTrainedPolicy):
# Use high_level_task tokens (WITHOUT subtask) for inference - we'll generate the subtask
high_level_task = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS}"]
high_level_task_masks = batch[f"{OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK}"]
breakpoint()
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
actions = self.model.sample_actions(
images, img_masks, high_level_task, high_level_task_masks,
@@ -1444,29 +1808,36 @@ class PI05Policy(PreTrainedPolicy):
actions = self.prepare_action(batch)
# Decode and print ground truth subtask tokens during training
if self.tokenizer is not None and self.training:
bsize = subtask_tokens.shape[0]
for i in range(bsize):
# Remove padding tokens (0) and special tokens
valid_tokens = subtask_tokens[i][subtask_masks[i].bool()]
if len(valid_tokens) > 0:
decoded_text = self.tokenizer.decode(valid_tokens, skip_special_tokens=True)
print(f"[Training] Ground truth subtask {i}: {decoded_text}")
# if self.tokenizer is not None and self.training:
# bsize = subtask_tokens.shape[0]
# for i in range(bsize):
# # Remove padding tokens (0) and special tokens
# valid_tokens = subtask_tokens[i][subtask_masks[i].bool()]
# # if len(valid_tokens) > 0:
# # decoded_text = self.tokenizer.decode(valid_tokens, skip_special_tokens=True)
# # print(f"[Training] Ground truth subtask {i}: {decoded_text}")
# Get FAST action tokens from batch
fast_action_tokens = batch.get("action.tokens", None) # (B, max_action_tokens)
fast_action_masks = batch.get("action.token_mask", None) # (B, max_action_tokens)
# Compute loss (no separate state needed for PI05)
# high_level_task = instruction tokens WITHOUT subtask (e.g., "High level task: X; State: Y; Subtask:")
# subtask_tokens = subtask tokens to predict (e.g., "pick up the cup")
loss_dict = self.model.forward(images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions)
# fast_action_tokens = discrete action token IDs to predict
loss_dict = self.model.forward(
images, img_masks, high_level_task, high_level_task_masks, subtask_tokens, subtask_masks, actions,
fast_action_tokens=fast_action_tokens, fast_action_masks=fast_action_masks
)
# Extract the total loss
loss = loss_dict["loss"]
breakpoint()
# Prepare detailed loss dictionary for logging
detailed_loss_dict = {
"loss": loss.item(),
"flow_loss": loss_dict["flow_loss"].mean().item(),
"subtask_loss": loss_dict["subtask_loss"].item(),
"fast_loss": loss_dict["fast_loss"].item(),
}
return loss, detailed_loss_dict

View File

@@ -33,6 +33,7 @@ from lerobot.processor import (
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
ActionTokenizerProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
@@ -158,7 +159,6 @@ def make_pi05_pre_post_processors(
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
@@ -177,6 +177,9 @@ def make_pi05_pre_post_processors(
padding_side="right",
padding="max_length",
),
ActionTokenizerProcessorStep(
tokenizer_name="/fsx/jade_choghari/outputs/fast_tokenizer", # TODO: jade put the PI
),
DeviceProcessorStep(device=config.device),
]
@@ -186,7 +189,7 @@ def make_pi05_pre_post_processors(
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,

View File

@@ -0,0 +1,22 @@
export CUDA_LAUNCH_BLOCKING=1
lerobot-train \
--dataset.repo_id=local \
--dataset.root=/fsx/jade_choghari/outputs/collect-data-pgen \
--output_dir=/fsx/jade_choghari/outputs/pi0_fast_fruit1 \
--job_name=pi0_training \
--policy.repo_id=jade_choghari/pi0-base1 \
--policy.path=lerobot/pi05_base \
--policy.dtype=bfloat16 \
--steps=200000 \
--save_freq=5000 \
--rename_map='{
"observation.images.base": "observation.images.base_0_rgb",
"observation.images.left_wrist": "observation.images.left_wrist_0_rgb",
"observation.images.right_wrist": "observation.images.right_wrist_0_rgb",
}' \
--batch_size=4 \
--policy.device=cuda \
--wandb.enable=true \
--wandb.disable_artifact=true \
--wandb.project=pi05hi-training \
# /fsx/jade_choghari/.cache/huggingface/lerobot/jadechoghari/collect-data

View File

@@ -0,0 +1,18 @@
rm -rf /fsx/jade_choghari/outputs/pi0_multi_training
lerobot-train \
--dataset.repo_id=local\
--dataset.root=/fsx/jade_choghari/outputs/collect-data-pgen \
--output_dir=/fsx/jade_choghari/outputs/pi0_multi_training \
--job_name=pi0_multi_training \
--policy.repo_id=jadechoghari/pi0-base1 \
--policy.path=lerobot/pi05_base \
--policy.dtype=bfloat16 \
--steps=50000 \
--save_freq=5000 \
--rename_map='{
"observation.images.base": "observation.images.base_0_rgb",
"observation.images.left_wrist": "observation.images.left_wrist_0_rgb",
"observation.images.right_wrist": "observation.images.right_wrist_0_rgb",
}' \
--batch_size=32 \
--policy.device=cuda \

View File

@@ -0,0 +1,9 @@
python src/lerobot/policies/pi05/train_fast_tokenizer.py \
--repo_id "local" \
--root "/fsx/jade_choghari/outputs/collect-data-pgen" \
--action_horizon 16 \
--encoded_dims "0:15" \
--action_horizon 50 \
--vocab_size 1024 \
--scale 10.0 \
--output_dir "/fsx/jade_choghari/outputs/fast_tokenizer"

View File

@@ -0,0 +1,410 @@
"""Train FAST tokenizer for action encoding.
This script:
1. Loads action chunks from LeRobotDataset (with sampling)
2. Applies delta transforms and per-timestamp normalization
3. Trains FAST tokenizer on specified action dimensions
4. Saves tokenizer to assets directory
5. Reports compression statistics
"""
import json
import numpy as np
import tyro
from pathlib import Path
from transformers import AutoProcessor
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def apply_delta_transform(state: np.ndarray, actions: np.ndarray, delta_dims: list[int] | None) -> np.ndarray:
"""Apply delta transform to specified dimensions.
Args:
state: Current state [D]
actions: Future actions [D]
delta_dims: List of dimension indices to apply delta transform to
Returns:
Transformed actions [D]
"""
if delta_dims is None or len(delta_dims) == 0:
return actions
delta_actions = actions.copy()
for dim in delta_dims:
delta_actions[dim] = actions[dim] - state[dim]
return delta_actions
def process_episode(args):
"""Process single episode and return action chunks."""
dataset, ep_idx, action_horizon, delta_dims, sample_fraction, state_key, use_delta_transform = args
try:
# Get episode info
ep_info = dataset.meta.episodes[ep_idx]
from_idx = ep_info["dataset_from_index"]
to_idx = ep_info["dataset_to_index"]
ep_length = to_idx - from_idx
if ep_length < action_horizon:
return None
# Load all frames in episode
# If dataset has episode filtering, we need to use the mapping
states = []
actions = []
for abs_idx in range(from_idx, to_idx):
# Map absolute index to relative index if needed
if dataset._absolute_to_relative_idx is not None:
if abs_idx not in dataset._absolute_to_relative_idx:
# This episode's frames aren't in the filtered dataset
return None
rel_idx = dataset._absolute_to_relative_idx[abs_idx]
else:
rel_idx = abs_idx
frame = dataset.hf_dataset[rel_idx]
# Get state (could be from observation.state or other state key)
if state_key in frame:
state = frame[state_key].numpy() if torch.is_tensor(frame[state_key]) else np.array(frame[state_key])
else:
# If no state key, use zeros (no delta transform)
state = np.zeros_like(frame["action"].numpy() if torch.is_tensor(frame["action"]) else np.array(frame["action"]))
action = frame["action"].numpy() if torch.is_tensor(frame["action"]) else np.array(frame["action"])
states.append(state)
actions.append(action)
states = np.array(states)
actions = np.array(actions)
# Create action chunks (sliding window)
# All actions in a chunk are relative to the FIRST state in that chunk
action_chunks = []
for i in range(len(states) - action_horizon + 1):
current_state = states[i] # First state in chunk
future_absolute_actions = actions[i:i + action_horizon]
if use_delta_transform:
# Relative actions
delta_chunk = np.zeros_like(future_absolute_actions)
for t in range(action_horizon):
delta_chunk[t] = apply_delta_transform(
current_state,
future_absolute_actions[t],
delta_dims,
)
action_chunks.append(delta_chunk)
else:
# Absolute actions (NO delta)
action_chunks.append(future_absolute_actions)
if len(action_chunks) == 0:
return None
action_chunks = np.array(action_chunks)
# Sample chunks
if sample_fraction < 1.0:
n_chunks = len(action_chunks)
n_samples = max(1, int(n_chunks * sample_fraction))
episode_seed = hash(ep_idx) % (2**31)
rng = np.random.RandomState(episode_seed)
indices = rng.choice(n_chunks, size=n_samples, replace=False)
action_chunks = action_chunks[indices]
return action_chunks
except Exception as e:
print(f"Error processing episode {ep_idx}: {e}")
import traceback
traceback.print_exc()
return None
def train_fast_tokenizer(
action_chunks: np.ndarray,
vocab_size: int = 1024,
scale: float = 10.0,
) -> AutoProcessor:
"""
Train FAST tokenizer (BPE on DCT coefficients) on action chunks.
Uses the .fit() method to train a new tokenizer on the provided data.
Args:
action_chunks: Array of action chunks [N, H, D] where N=num_chunks, H=horizon, D=action_dim
vocab_size: BPE vocabulary size
scale: DCT scaling factor for quantization
Returns:
Trained FAST tokenizer
"""
print(f"Training FAST tokenizer on {len(action_chunks)} action chunks...")
print(f"Action chunk shape: {action_chunks.shape}")
print(f"Vocab size: {vocab_size}")
print(f"DCT scale: {scale}")
# Download the tokenizer source code (not pretrained weights)
# We'll train a new tokenizer on our own data
base_tokenizer = AutoProcessor.from_pretrained(
"physical-intelligence/fast",
trust_remote_code=True
)
# Convert action_chunks array to list of arrays (expected by .fit())
action_data_list = [action_chunks[i] for i in range(len(action_chunks))]
# Train the new tokenizer on our action data using .fit()
# This trains the BPE tokenizer on DCT coefficients
print("Training new tokenizer (this may take a few minutes)...")
tokenizer = base_tokenizer.fit(
action_data_list,
scale=scale,
vocab_size=vocab_size,
time_horizon=action_chunks.shape[1], # action_horizon
action_dim=action_chunks.shape[2], # encoded dimensions
)
print("✓ Tokenizer training complete!")
# Validate it works
sample_chunk = action_chunks[0]
encoded = tokenizer(sample_chunk[None])[0]
if isinstance(encoded, list):
encoded = np.array(encoded)
print(f"Sample encoding: {len(encoded)} tokens for chunk shape {sample_chunk.shape}")
return tokenizer
def compute_compression_stats(tokenizer, action_chunks: np.ndarray):
"""Compute compression statistics."""
print("\nComputing compression statistics...")
# Sample for stats (use max 1000 chunks for speed)
sample_size = min(1000, len(action_chunks))
sample_indices = np.random.RandomState(42).choice(len(action_chunks), size=sample_size, replace=False)
sample_chunks = action_chunks[sample_indices]
token_lengths = []
for chunk in sample_chunks:
encoded = tokenizer(chunk[None])[0]
if isinstance(encoded, list):
token_lengths.append(len(encoded))
else:
token_lengths.append(encoded.shape[0] if hasattr(encoded, 'shape') else len(encoded))
token_lengths = np.array(token_lengths)
# Compression ratio: (H * D) / avg_tokens
input_size = action_chunks.shape[1] * action_chunks.shape[2]
avg_tokens = np.mean(token_lengths)
compression_ratio = input_size / avg_tokens
stats = {
'compression_ratio': float(compression_ratio),
'mean_token_length': float(np.mean(token_lengths)),
'p99_token_length': float(np.percentile(token_lengths, 99)),
'min_token_length': float(np.min(token_lengths)),
'max_token_length': float(np.max(token_lengths)),
}
print(f"Compression Statistics:")
print(f" Average compression ratio: {stats['compression_ratio']:.2f}x")
print(f" Mean token length: {stats['mean_token_length']:.1f}")
print(f" P99 token length: {stats['p99_token_length']:.0f}")
print(f" Min token length: {stats['min_token_length']:.0f}")
print(f" Max token length: {stats['max_token_length']:.0f}")
return stats
def main(
repo_id: str,
root: str | None = None,
action_horizon: int = 10,
max_episodes: int | None = None,
sample_fraction: float = 0.1,
encoded_dims: str = "0:6,7:23",
delta_dims: str | None = None,
use_delta_transform: bool = False,
state_key: str = "observation.state",
vocab_size: int = 1024,
scale: float = 10.0,
output_dir: str | None = None,
):
"""
Train FAST tokenizer for action encoding.
Args:
repo_id: LeRobot dataset repository ID
root: Root directory for dataset (default: ~/.cache/huggingface/lerobot)
action_horizon: Number of future actions in each chunk
max_episodes: Max episodes to use (None = all episodes in dataset)
sample_fraction: Fraction of chunks to sample per episode
encoded_dims: Comma-separated dimension ranges to encode (e.g., "0:6,7:23")
delta_dims: Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5")
use_delta_transform: Whether to apply delta transform (relative actions vs absolute actions)
state_key: Dataset key for state observations (default: "observation.state")
vocab_size: FAST vocabulary size (BPE vocab size)
scale: DCT scaling factor (default: 10.0)
output_dir: Directory to save tokenizer (default: ./fast_tokenizer_{repo_id})
"""
# Load dataset
print(f"Loading dataset: {repo_id}")
dataset = LeRobotDataset(repo_id=repo_id, root=root)
print(f"Dataset loaded: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
# Parse encoded dimensions
encoded_dim_ranges = []
for range_str in encoded_dims.split(','):
start, end = map(int, range_str.strip().split(':'))
encoded_dim_ranges.append((start, end))
total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges)
print(f"Encoding {total_encoded_dims} dimensions: {encoded_dims}")
# Parse delta dimensions
delta_dim_list = None
if delta_dims is not None and delta_dims.strip():
delta_dim_list = [int(d.strip()) for d in delta_dims.split(',')]
print(f"Delta dimensions: {delta_dim_list}")
else:
print("No delta dimensions specified")
print(f"Use delta transform: {use_delta_transform}")
if use_delta_transform and (delta_dim_list is None or len(delta_dim_list) == 0):
print("Warning: use_delta_transform=True but no delta_dims specified. No delta will be applied.")
print(f"Action horizon: {action_horizon}")
print(f"State key: {state_key}")
# Determine episodes to process
num_episodes = dataset.num_episodes
if max_episodes is not None:
num_episodes = min(max_episodes, num_episodes)
print(f"Processing {num_episodes} episodes...")
# Process episodes sequentially (to avoid pickling issues with dataset)
all_chunks = []
for ep_idx in range(num_episodes):
if ep_idx % 10 == 0:
print(f" Processing episode {ep_idx}/{num_episodes}...")
chunks = process_episode(
(dataset, ep_idx, action_horizon, delta_dim_list, sample_fraction, state_key, use_delta_transform)
)
if chunks is not None:
all_chunks.append(chunks)
# Concatenate all chunks
all_chunks = np.concatenate(all_chunks, axis=0)
print(f"Collected {len(all_chunks)} action chunks")
# Extract only encoded dimensions FIRST (before normalization)
encoded_chunks = []
for start, end in encoded_dim_ranges:
encoded_chunks.append(all_chunks[:, :, start:end])
encoded_chunks = np.concatenate(encoded_chunks, axis=-1) # [N, H, D_encoded]
print(f"Extracted {encoded_chunks.shape[-1]} encoded dimensions")
# Apply normalization to encoded dimensions only
# NOTE: For FAST, we ALWAYS use QUANTILE normalization (no per-timestamp)
# This clips outliers and provides consistent [-1, 1] range for DCT compression
print(f"\nBefore normalization - overall stats:")
print(f" Min: {np.min(encoded_chunks):.4f}, Max: {np.max(encoded_chunks):.4f}")
print(f" Mean: {np.mean(encoded_chunks):.4f}, Std: {np.std(encoded_chunks):.4f}")
norm_stats = dataset.meta.stats
if norm_stats is not None and "action" in norm_stats:
action_stats = norm_stats["action"]
# Build encoded dimension indices
encoded_dim_indices = []
for start, end in encoded_dim_ranges:
encoded_dim_indices.extend(range(start, end))
encoded_dim_indices = np.array(encoded_dim_indices)
# Use QUANTILE normalization: clip to [q01, q99] and map to [-1, 1]
if "q01" in action_stats and "q99" in action_stats:
q01 = np.array(action_stats["q01"])[encoded_dim_indices] # [D_encoded]
q99 = np.array(action_stats["q99"])[encoded_dim_indices] # [D_encoded]
print(f"\nNormalization stats (q01, q99) for encoded dimensions:")
for i, dim_idx in enumerate(encoded_dim_indices):
print(f" Orig dim {dim_idx}: q01={q01[i]:7.4f}, q99={q99[i]:7.4f}, range={q99[i]-q01[i]:7.4f}")
# Clip to quantile range and normalize to [-1, 1]
encoded_chunks = np.clip(encoded_chunks, q01, q99)
encoded_chunks = 2.0 * (encoded_chunks - q01) / np.maximum(q99 - q01, 1e-6) - 1.0
print(f"\nApplied quantile normalization [q01, q99] → [-1, 1]")
print(f"\nAfter normalization - overall stats:")
print(f" Min: {np.min(encoded_chunks):.4f}, Max: {np.max(encoded_chunks):.4f}")
print(f" Mean: {np.mean(encoded_chunks):.4f}, Std: {np.std(encoded_chunks):.4f}")
print(f"\nPer-dimension stats (after normalization):")
for d in range(encoded_chunks.shape[-1]):
dim_data = encoded_chunks[:, :, d]
print(f" Dim {d}: min={np.min(dim_data):7.4f}, max={np.max(dim_data):7.4f}, "
f"mean={np.mean(dim_data):7.4f}, std={np.std(dim_data):7.4f}")
else:
print("Warning: q01/q99 stats not found, using raw actions")
else:
print("Warning: No normalization stats found, using raw actions")
print(f"Encoded chunks shape: {encoded_chunks.shape}")
# Train FAST tokenizer
tokenizer = train_fast_tokenizer(
encoded_chunks,
vocab_size=vocab_size,
scale=scale,
)
# Compute compression statistics
compression_stats = compute_compression_stats(tokenizer, encoded_chunks)
# Save tokenizer
if output_dir is None:
output_dir = f"fast_tokenizer_{repo_id.replace('/', '_')}"
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
tokenizer.save_pretrained(output_path)
# Save metadata
metadata = {
'repo_id': repo_id,
'vocab_size': vocab_size,
'scale': scale,
'encoded_dims': encoded_dims,
'encoded_dim_ranges': encoded_dim_ranges,
'total_encoded_dims': total_encoded_dims,
'delta_dims': delta_dims,
'delta_dim_list': delta_dim_list,
'use_delta_transform': use_delta_transform,
'state_key': state_key,
'action_horizon': action_horizon,
'num_training_chunks': len(encoded_chunks),
'compression_stats': compression_stats,
}
with open(output_path / "metadata.json", 'w') as f:
json.dump(metadata, f, indent=2)
print(f"\n✅ Saved FAST tokenizer to {output_path}")
print(f"Metadata: {json.dumps(metadata, indent=2)}")
if __name__ == "__main__":
tyro.cli(main)

View File

@@ -0,0 +1,101 @@
# Train FAST Tokenizer - Usage Examples
This script trains a FAST (Factorized Action Sequence Tokenizer) on LeRobotDataset action data.
## Basic Usage
```bash
python src/lerobot/policies/pi05/train_fast_tokenizer.py \
--repo_id "lerobot/aloha_sim_insertion_human" \
--action_horizon 10 \
--encoded_dims "0:7" \
--vocab_size 1024 \
--scale 10.0
```
## Parameters
### Required
- `--repo_id`: LeRobot dataset repository ID (e.g., "lerobot/aloha_sim_insertion_human")
### Optional
- `--root`: Root directory for dataset (default: ~/.cache/huggingface/lerobot)
- `--action_horizon`: Number of future actions in each chunk (default: 10)
- `--max_episodes`: Maximum number of episodes to use (default: None = all)
- `--sample_fraction`: Fraction of chunks to sample per episode (default: 0.1)
- `--encoded_dims`: Comma-separated dimension ranges to encode (default: "0:6,7:23")
- Example: "0:7" encodes dimensions 0-6
- Example: "0:3,6:9" encodes dimensions 0-2 and 6-8
- `--delta_dims`: Comma-separated dimension indices for delta transform (default: None)
- Example: "0,1,2,3,4,5" applies delta transform to first 6 dimensions
- Delta transform: action[i] - state[i] for specified dimensions
- `--state_key`: Dataset key for state observations (default: "observation.state")
- `--vocab_size`: FAST vocabulary size / BPE vocab size (default: 1024)
- `--scale`: DCT scaling factor (default: 10.0)
- `--output_dir`: Directory to save tokenizer (default: ./fast_tokenizer_{repo_id})
## Examples
### Example 1: Train on full action space
```bash
python src/lerobot/policies/pi05/train_fast_tokenizer.py \
--repo_id "lerobot/pusht" \
--action_horizon 16 \
--encoded_dims "0:2" \
--vocab_size 512 \
--max_episodes 100
```
### Example 2: Train with delta transform
```bash
python src/lerobot/policies/pi05/train_fast_tokenizer.py \
--repo_id "lerobot/aloha_sim_insertion_human" \
--action_horizon 10 \
--encoded_dims "0:14" \
--delta_dims "0,1,2,3,4,5,6,7,8,9,10,11,12,13" \
--state_key "observation.state" \
--vocab_size 1024 \
--scale 10.0 \
--sample_fraction 0.2
```
### Example 3: Train on subset of dimensions
```bash
python src/lerobot/policies/pi05/train_fast_tokenizer.py \
--repo_id "lerobot/aloha_sim_insertion_human" \
--action_horizon 10 \
--encoded_dims "0:7" \
--vocab_size 1024 \
--output_dir "./my_tokenizer"
```
## Output
The script saves:
1. **Tokenizer files**: Trained FAST tokenizer (can be loaded with `AutoProcessor.from_pretrained()`)
2. **metadata.json**: Contains:
- Configuration parameters
- Compression statistics (compression ratio, token lengths)
- Training dataset information
## Understanding the Process
1. **Load Dataset**: Loads the LeRobotDataset from HuggingFace
2. **Extract Action Chunks**: Creates sliding windows of actions with specified horizon
3. **Apply Delta Transform**: (Optional) Computes action deltas relative to current state
4. **Select Encoded Dimensions**: Extracts only the dimensions to be encoded
5. **Normalize**: Applies quantile normalization ([q01, q99] → [-1, 1])
6. **Train Tokenizer**: Trains BPE tokenizer on DCT coefficients
7. **Compute Stats**: Reports compression ratio and token length statistics
8. **Save**: Saves tokenizer and metadata
## Notes
- **Normalization**: The script uses quantile normalization (q01, q99) from the dataset's statistics
- **Sampling**: To speed up training, you can sample a fraction of chunks per episode
- **Delta Transform**: Applied per-dimension to make actions relative to current state
- **Compression**: FAST uses DCT + BPE to compress action sequences efficiently

View File

@@ -0,0 +1,23 @@
rm -rf /fsx/jade_choghari/outputs/pi0_multi_training
accelerate launch --multi_gpu --num_processes=2 \
$(which lerobot-train) \
--dataset.repo_id=local \
--dataset.root=/fsx/jade_choghari/outputs/collect-data-pgen \
--output_dir=/fsx/jade_choghari/outputs/pi0_multi_training \
--job_name=pi0_multi_training \
--policy.repo_id=jadechoghari/pi0-base1 \
--policy.path=lerobot/pi05_base \
--policy.dtype=bfloat16 \
--steps=50000 \
--save_freq=5000 \
--rename_map='{
"observation.images.base": "observation.images.base_0_rgb",
"observation.images.left_wrist": "observation.images.left_wrist_0_rgb",
"observation.images.right_wrist": "observation.images.right_wrist_0_rgb",
}' \
--policy.gradient_checkpointing=true \
--batch_size=1 \
--policy.device=cpu
# --wandb.enable=true \
# --wandb.disable_artifact=true \
# --wandb.project=pi05hi-training \

View File

@@ -75,7 +75,7 @@ from .policy_robot_bridge import (
RobotActionToPolicyActionProcessorStep,
)
from .rename_processor import RenameObservationsProcessorStep
from .tokenizer_processor import TokenizerProcessorStep
from .tokenizer_processor import TokenizerProcessorStep, ActionTokenizerProcessorStep
__all__ = [
"ActionProcessorStep",

View File

@@ -15,10 +15,13 @@
# limitations under the License.
"""
This script defines a processor for tokenizing natural language instructions from an environment transition.
This script defines processors for tokenizing data from an environment transition.
It uses a tokenizer from the Hugging Face `transformers` library to convert task descriptions (text) into
token IDs and attention masks, which are then added to the observation dictionary.
It includes:
- TokenizerProcessorStep: Uses a tokenizer from the Hugging Face `transformers` library to convert
task descriptions (text) into token IDs and attention masks, which are then added to the observation dictionary.
- ActionTokenizerProcessorStep: Uses a processor/tokenizer (e.g., the Physical Intelligence "fast" tokenizer)
to tokenize action tensors into discrete token IDs for action modeling.
"""
from __future__ import annotations
@@ -30,6 +33,8 @@ import torch
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import (
ACTION_TOKEN_MASK,
ACTION_TOKENS,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK,
OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS,
@@ -40,12 +45,13 @@ from lerobot.utils.constants import (
from lerobot.utils.import_utils import _transformers_available
from .core import EnvTransition, TransitionKey
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers import AutoTokenizer
from transformers import AutoProcessor, AutoTokenizer
else:
AutoProcessor = None
AutoTokenizer = None
@@ -302,15 +308,17 @@ class TokenizerProcessorStep(ObservationProcessorStep):
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
"""
A wrapper around the tokenizer call.
A wrapper around the tokenizer call that appends an EOS token to each sequence.
Args:
text: A string or list of strings to tokenize.
Returns:
A dictionary containing tokenized 'input_ids' and 'attention_mask' as PyTorch tensors.
A dictionary containing tokenized 'input_ids' and 'attention_mask' as PyTorch tensors,
with EOS token appended at the end of each sequence.
"""
return self.input_tokenizer(
# Tokenize normally
tokenized = self.input_tokenizer(
text,
max_length=self.max_length,
truncation=self.truncation,
@@ -318,6 +326,34 @@ class TokenizerProcessorStep(ObservationProcessorStep):
padding_side=self.padding_side,
return_tensors="pt",
)
# Get EOS token ID
eos_token_id = self.input_tokenizer.eos_token_id
if eos_token_id is None:
# Some tokenizers don't have an EOS token, skip modification
return tokenized
# Append EOS token to each sequence (before padding)
input_ids = tokenized["input_ids"]
attention_mask = tokenized["attention_mask"]
for i in range(input_ids.shape[0]):
# Find the position of the last non-padding token
non_pad_positions = (attention_mask[i] == 1).nonzero(as_tuple=True)[0]
if len(non_pad_positions) > 0:
last_token_pos = non_pad_positions[-1].item()
# Check if there's room to add EOS token
if last_token_pos + 1 < self.max_length:
# Insert EOS token after the last real token
input_ids[i, last_token_pos + 1] = eos_token_id
attention_mask[i, last_token_pos + 1] = 1
else:
# If at max length, replace the last token with EOS
input_ids[i, last_token_pos] = eos_token_id
return {"input_ids": input_ids, "attention_mask": attention_mask}
def get_config(self) -> dict[str, Any]:
"""
@@ -393,3 +429,233 @@ class TokenizerProcessorStep(ObservationProcessorStep):
)
return features
@dataclass
@ProcessorStepRegistry.register(name="action_tokenizer_processor")
class ActionTokenizerProcessorStep(ActionProcessorStep):
"""
Processor step to tokenize action data using a fast action tokenizer.
This step takes action tensors from an `EnvTransition`, tokenizes them using
a Hugging Face `transformers` AutoProcessor (such as the Physical Intelligence "fast" tokenizer),
and returns the tokenized action.
Requires the `transformers` library to be installed.
Attributes:
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
"""
tokenizer_name: str | None = None
tokenizer: Any | None = None
trust_remote_code: bool = True
max_action_tokens: int = 32
# Internal tokenizer instance (not part of the config)
action_tokenizer: Any = field(default=None, init=False, repr=False)
def __post_init__(self):
"""
Initializes the action tokenizer after the dataclass is created.
It checks for the availability of the `transformers` library and loads the tokenizer
either from a provided object or by name from the Hugging Face Hub.
Raises:
ImportError: If the `transformers` library is not installed.
ValueError: If neither `tokenizer` nor `tokenizer_name` is provided.
"""
if not _transformers_available:
raise ImportError(
"The 'transformers' library is not installed. "
"Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionTokenizerProcessorStep."
)
if self.tokenizer is not None:
# Use provided tokenizer object directly
self.action_tokenizer = self.tokenizer
elif self.tokenizer_name is not None:
if AutoProcessor is None:
raise ImportError("AutoProcessor is not available")
self.action_tokenizer = AutoProcessor.from_pretrained(
self.tokenizer_name, trust_remote_code=self.trust_remote_code
)
else:
raise ValueError(
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
"Pass a tokenizer object directly or a tokenizer name to auto-load."
)
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
Applies action tokenization to the transition.
This overrides the base class to handle both tokens and mask.
Args:
transition: The input transition with action data.
Returns:
The processed transition with tokenized actions and mask in complementary data.
"""
self._current_transition = transition.copy()
new_transition = self._current_transition
action = new_transition.get(TransitionKey.ACTION)
if action is None:
raise ValueError("ActionTokenizerProcessorStep requires an action in the transition.")
# Tokenize and get both tokens and mask
tokens, mask = self._tokenize_action(action)
# Store mask in complementary data
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
if complementary_data is None:
complementary_data = {}
complementary_data[ACTION_TOKEN_MASK] = mask
complementary_data[ACTION_TOKENS] = tokens
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Tokenizes the action tensor and creates a mask.
Args:
action: The input action tensor to tokenize. Shape: (B, action_dim) or (action_dim,)
Returns:
A tuple of (tokens, mask) where:
- tokens: Tensor of token IDs with shape (B, max_action_tokens)
- mask: Boolean mask with shape (B, max_action_tokens), True for real tokens, False for padding
"""
if action is None:
raise ValueError("Action cannot be None")
# Get the device and dtype of the input action
device = action.device if isinstance(action, torch.Tensor) else None
# Handle single sample (add batch dimension)
single_sample = action.dim() == 1
if single_sample:
action = action.unsqueeze(0)
batch_size = action.shape[0]
# Tokenize the action batch
# The fast tokenizer expects action data and returns token IDs
tokens_list = []
masks_list = []
for i in range(batch_size):
# Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy)
action_cpu = action[i:i+1].cpu()
tokens = self.action_tokenizer(action_cpu)
# Convert to numpy array if it's a list
if isinstance(tokens, list):
tokens = torch.tensor(tokens, dtype=torch.long, device=action.device)
elif not isinstance(tokens, torch.Tensor):
tokens = torch.tensor(tokens, dtype=torch.long, device=action.device)
else:
# Move tokens back to the same device as input action
tokens = tokens.to(device=action.device)
# Flatten to 1D if needed
if tokens.dim() > 1:
tokens = tokens.flatten()
# Truncate or pad to max_action_tokens
if len(tokens) > self.max_action_tokens:
tokens = tokens[:self.max_action_tokens]
mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device)
else:
mask = torch.cat([
torch.ones(len(tokens), dtype=torch.bool, device=action.device),
torch.zeros(self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device)
])
# Pad tokens with zeros
tokens = torch.nn.functional.pad(
tokens,
(0, self.max_action_tokens - len(tokens)),
value=0
)
tokens_list.append(tokens)
masks_list.append(mask)
# Stack into batched tensors
tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens)
masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens)
# Remove batch dimension if input was single sample
if single_sample:
tokens_batch = tokens_batch.squeeze(0)
masks_batch = masks_batch.squeeze(0)
# Move to the same device as the input
if device is not None:
tokens_batch = tokens_batch.to(device)
masks_batch = masks_batch.to(device)
return tokens_batch, masks_batch
def action(self, action: torch.Tensor) -> torch.Tensor:
"""
This method is not used since we override __call__.
Required by ActionProcessorStep ABC.
"""
tokens, _ = self._tokenize_action(action)
return tokens
def get_config(self) -> dict[str, Any]:
"""
Returns the serializable configuration of the processor.
Note: The tokenizer object itself is not serialized. If the processor was initialized
with a tokenizer name, that name will be included in the config.
Returns:
A dictionary with the processor's configuration parameters.
"""
config = {
"trust_remote_code": self.trust_remote_code,
"max_action_tokens": self.max_action_tokens,
}
# Only save tokenizer_name if it was used to create the tokenizer
if self.tokenizer_name is not None and self.tokenizer is None:
config["tokenizer_name"] = self.tokenizer_name
return config
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Updates feature definitions to reflect tokenized actions.
This updates the policy features dictionary to indicate that the action
has been tokenized into a sequence of token IDs with shape (max_action_tokens,).
Args:
features: The dictionary of existing policy features.
Returns:
The updated dictionary of policy features.
"""
# Update the action feature to reflect the tokenized shape
# The action is now a sequence of token IDs
if PipelineFeatureType.ACTION in features:
# Replace the action feature with the tokenized version
features[PipelineFeatureType.ACTION] = {
ACTION_TOKENS: PolicyFeature(
type=FeatureType.SEQUENCE, # Token sequence
shape=(self.max_action_tokens,)
)
}
return features

View File

@@ -33,6 +33,8 @@ OBS_LANGUAGE_SUBTASK_ONLY = OBS_STR + ".subtask"
OBS_LANGUAGE_SUBTASK_ONLY_TOKENS = OBS_LANGUAGE_SUBTASK_ONLY + ".tokens"
OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK_ONLY + ".attention_mask"
ACTION = "action"
ACTION_TOKENS = ACTION + ".tokens"
ACTION_TOKEN_MASK = ACTION + ".token_mask"
REWARD = "next.reward"
TRUNCATED = "next.truncated"
DONE = "next.done"