refactor the policy

This commit is contained in:
Jade Choghari
2026-01-04 13:12:29 +01:00
parent f0d0faa8a1
commit 911f7c4cc2
5 changed files with 80 additions and 538 deletions

View File

@@ -14,6 +14,7 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
# from .groot.configuration_groot import GrootConfig as GrootConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig

View File

@@ -32,24 +32,13 @@ class PI0FastConfig(PreTrainedConfig):
action_expert_variant: str = "gemma_300m"
dtype: str = "float32" # Options: "bfloat16", "float32"
n_obs_steps: int = 1
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
n_action_steps: int = 50 # Number of action steps to execute
# Shorter state and action vectors will be padded to these dimensions
max_state_dim: int = 32
max_action_dim: int = 32
max_action_tokens: int = 32
fast_vocab_size: int = 2048
# Flow matching parameters: see openpi `PI0Pytorch`
num_inference_steps: int = 10
time_sampling_beta_alpha: float = 1.5
time_sampling_beta_beta: float = 1.0
time_sampling_scale: float = 0.999
time_sampling_offset: float = 0.001
min_period: float = 4e-3
max_period: float = 4.0
max_action_tokens: int = 256
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
@@ -63,6 +52,11 @@ class PI0FastConfig(PreTrainedConfig):
empty_cameras: int = 0
tokenizer_max_length: int = 200 # see openpi `__post_init__`
text_tokenizer_name: str = "google/paligemma-3b-pt-224"
action_tokenizer_name: str = "physical-intelligence/fast"
temperature: float = 0.0
max_decoding_steps: int = 256
fast_skip_tokens: int = 128
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
@@ -92,8 +86,6 @@ class PI0FastConfig(PreTrainedConfig):
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
tokenizer_max_length: int = 200 # see openpi `__post_init__`
def __post_init__(self):
super().__post_init__()

View File

@@ -32,6 +32,7 @@ from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers import AutoTokenizer
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
@@ -41,6 +42,7 @@ else:
modeling_gemma = None
GemmaForCausalLM = None
PaliGemmaForConditionalGeneration = None
AutoTokenizer = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
@@ -48,6 +50,8 @@ from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import (
ACTION,
ACTION_TOKEN_MASK,
ACTION_TOKENS,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OPENPI_ATTENTION_MASK_VALUE,
@@ -55,81 +59,7 @@ from lerobot.utils.constants import (
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type."""
if device_type == "mps" and target_dtype == torch.float64:
return torch.float32
if device_type == "cpu":
# CPU doesn't support bfloat16, use float32 instead
if target_dtype == torch.bfloat16:
return torch.float32
if target_dtype == torch.float64:
return torch.float64
return target_dtype
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
period = min_period * (max_period / min_period) ** fraction
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
dist = torch.distributions.Beta(alpha_t, beta_t)
return dist.sample((bsize,))
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
"""Copied from big_vision.
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
setup several types of attention, for example:
[[1 1 1 1 1 1]]: pure causal attention.
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
themselves and the last 3 tokens have a causal attention. The first
entry could also be a 1 without changing behaviour.
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
block can attend all previous blocks and all tokens on the same block.
Args:
input_mask: bool[B, N] true if its part of the input, false if padding.
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
it and 0 where it shares the same attention mask as the previous token.
"""
if att_masks.ndim != 2:
raise ValueError(att_masks.ndim)
if pad_masks.ndim != 2:
raise ValueError(pad_masks.ndim)
cumsum = torch.cumsum(att_masks, dim=1)
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
return att_2d_masks & pad_2d_masks
temperature: float | None
def pad_vector(vector, new_dim):
@@ -514,10 +444,16 @@ class PaliGemmaWithExpertModel(
class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
"""Core PI0Fast PyTorch model."""
def __init__(self, config: PI0FastConfig, rtc_processor: RTCProcessor | None = None):
def __init__(
self,
config: PI0FastConfig,
rtc_processor: RTCProcessor | None = None,
paligemma_tokenizer: AutoTokenizer | None = None,
):
super().__init__()
self.config = config
self.rtc_processor = rtc_processor
self._paligemma_tokenizer = paligemma_tokenizer
paligemma_config = get_gemma_config(config.paligemma_variant)
action_expert_config = get_gemma_config(config.action_expert_variant)
@@ -528,19 +464,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
use_adarms=[False, True],
precision=config.dtype,
)
from transformers import AutoTokenizer
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
"google/paligemma-3b-pt-224",
trust_remote_code=True,
)
# # Apply dtype conversion to FAST layers to match model precision
# if config.dtype == "bfloat16":
# self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16)
# self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.bfloat16)
# elif config.dtype == "float32":
# self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.float32)
# self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.float32)
# Initialize gradient checkpointing flag
self.gradient_checkpointing_enabled = False
@@ -548,8 +471,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
# Compile model if requested
if config.compile_model:
torch.set_float32_matmul_precision("high")
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
# Also compile the main forward pass used during training
self.sample_actions_fast = torch.compile(self.sample_actions_fast, mode=config.compile_mode)
self.forward = torch.compile(self.forward, mode=config.compile_mode)
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
@@ -578,9 +500,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI0FastPytorch model")
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
@@ -597,352 +516,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
result = result.to(dtype=dtype)
return result
def _create_custom_attention_mask(self, att_mask_segments, pad_masks, bsize):
"""Create custom 2D attention mask for the new attention pattern.
Attention rules:
- Images + Language: bidirectional among themselves, don't attend to subtask or FAST
- Subtask: attend to images + language, causal among themselves, don't attend to FAST
- FAST: attend to images + language + subtask, causal among themselves
Args:
att_mask_segments: List of (type, length) tuples
pad_masks: Padding masks [B, total_seq_len]
bsize: Batch size
Returns:
att_2d_masks: 2D attention mask [B, total_seq_len, total_seq_len]
"""
total_len = sum(length for _, length in att_mask_segments)
device = pad_masks.device
# Initialize attention mask as False (cannot attend)
att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device)
# Track positions for each segment
positions = []
current_pos = 0
for seg_type, seg_len in att_mask_segments:
positions.append((seg_type, current_pos, current_pos + seg_len))
current_pos += seg_len
# Apply attention rules
for _i, (query_type, query_start, query_end) in enumerate(positions):
for _j, (key_type, key_start, key_end) in enumerate(positions):
# Images and Language can attend to each other bidirectionally
if (
query_type in ["image", "language"]
and key_type in ["image", "language"]
or query_type == "subtask"
and key_type in ["image", "language"]
):
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
# Subtask tokens attend causally to themselves
elif query_type == "subtask" and key_type == "subtask":
# Create causal mask for subtask tokens
subtask_len = query_end - query_start
causal_mask = torch.tril(
torch.ones(subtask_len, subtask_len, dtype=torch.bool, device=device)
)
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
# FAST tokens attend to images + language + subtask
elif query_type == "fast" and key_type in ["image", "language", "subtask"]:
att_2d_masks[:, query_start:query_end, key_start:key_end] = True
# FAST tokens attend causally to themselves
elif query_type == "fast" and key_type == "fast":
fast_len = query_end - query_start
causal_mask = torch.tril(torch.ones(fast_len, fast_len, dtype=torch.bool, device=device))
att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :]
# Apply padding masks
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
att_2d_masks = att_2d_masks & pad_2d_masks
return att_2d_masks
def visualize_attention_mask(
self, att_mask_segments, att_2d_masks, save_path, batch_idx=0, dpi=150, max_display_tokens=None
):
"""Visualize the attention mask with labeled segments.
Args:
att_mask_segments: List of (type, length) tuples defining the segments
att_2d_masks: 2D attention mask tensor [B, total_seq_len, total_seq_len]
save_path: Path where to save the visualization image
batch_idx: Which batch item to visualize (default: 0)
dpi: DPI for the saved image (default: 150)
max_display_tokens: Maximum number of tokens to display (for very long sequences)
"""
try:
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
except ImportError:
logging.warning("matplotlib not available, skipping attention mask visualization")
return
# Extract the mask for the specified batch
mask = att_2d_masks[batch_idx].cpu().float().numpy()
# If sequence is too long, downsample for visualization
if max_display_tokens is not None and mask.shape[0] > max_display_tokens:
# Simple downsampling by taking every Nth token
step = mask.shape[0] // max_display_tokens
mask = mask[::step, ::step]
# Adjust segments accordingly
att_mask_segments = [
(seg_type, max(1, seg_len // step)) for seg_type, seg_len in att_mask_segments
]
# Calculate positions for each segment
positions = []
current_pos = 0
for seg_type, seg_len in att_mask_segments:
positions.append((seg_type, current_pos, current_pos + seg_len))
current_pos += seg_len
# Create figure
fig, ax = plt.subplots(figsize=(12, 10))
# Create custom colormap: white for False (no attention), blue for True (attention)
colors = ["white", "#2E86AB"]
n_bins = 2
cmap = LinearSegmentedColormap.from_list("attention", colors, N=n_bins)
# Display the mask
im = ax.imshow(mask, cmap=cmap, aspect="auto", interpolation="nearest", vmin=0, vmax=1)
# Add colorbar
cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label("Attention Enabled", rotation=270, labelpad=20)
cbar.set_ticks([0.25, 0.75])
cbar.set_ticklabels(["No", "Yes"])
# Define colors for each segment type
segment_colors = {"image": "#A23B72", "language": "#F18F01", "subtask": "#C73E1D", "fast": "#6A994E"}
# Draw segment boundaries and labels
for seg_type, start, end in positions:
color = segment_colors.get(seg_type, "#666666")
# Draw vertical lines for columns (keys)
ax.axvline(x=start - 0.5, color=color, linewidth=2, alpha=0.7)
ax.axvline(x=end - 0.5, color=color, linewidth=2, alpha=0.7)
# Draw horizontal lines for rows (queries)
ax.axhline(y=start - 0.5, color=color, linewidth=2, alpha=0.7)
ax.axhline(y=end - 0.5, color=color, linewidth=2, alpha=0.7)
# Add labels at the top
mid_pos = (start + end) / 2
ax.text(
mid_pos,
-mask.shape[0] * 0.02,
f"{seg_type.upper()}\n({end - start})",
ha="center",
va="top",
fontsize=10,
fontweight="bold",
color=color,
)
# Add labels on the left
ax.text(
-mask.shape[1] * 0.02,
mid_pos,
f"{seg_type.upper()}\n({end - start})",
ha="right",
va="center",
fontsize=10,
fontweight="bold",
color=color,
rotation=0,
)
# Set axis labels
ax.set_xlabel("Key Position (tokens being attended to)", fontsize=12, fontweight="bold")
ax.set_ylabel("Query Position (tokens attending)", fontsize=12, fontweight="bold")
ax.set_title(
"Attention Mask Pattern\n(White = No Attention, Blue = Attention Allowed)",
fontsize=14,
fontweight="bold",
pad=20,
)
# Create legend for segment types
legend_patches = []
attention_rules = {
"image": "Bidirectional with lang",
"language": "Bidirectional with images",
"subtask": "Attends to img+lang, causal self",
"fast": "Attends to all, causal self",
}
for seg_type, color in segment_colors.items():
if any(seg[0] == seg_type for seg in att_mask_segments):
rule = attention_rules.get(seg_type, "")
legend_patches.append(mpatches.Patch(color=color, label=f"{seg_type.upper()}: {rule}"))
ax.legend(
handles=legend_patches, loc="upper right", bbox_to_anchor=(1.15, 1.0), framealpha=0.9, fontsize=9
)
# Adjust layout and save
plt.tight_layout()
# Ensure the directory exists
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=dpi, bbox_inches="tight")
plt.close()
logging.info(f"Attention mask visualization saved to: {save_path}")
def sample_noise(self, shape, device):
return torch.normal(
mean=0.0,
std=1.0,
size=shape,
dtype=torch.float32,
device=device,
)
def sample_time(self, bsize, device):
time_beta = sample_beta(
self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
)
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
return time.to(dtype=torch.float32, device=device)
def embed_prefix(
self,
images,
img_masks,
tokens,
subtask_tokens,
masks,
subtask_masks,
fast_action_tokens=None,
fast_action_masks=None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
"""Embed images with SigLIP, tokens, and optionally subtask tokens with embedding layer.
Args:
images: List of image tensors
img_masks: List of image masks
tokens: Language instruction tokens
subtask_tokens: Subtask tokens to predict (can be None for inference)
masks: Attention masks for tokens
fast_action_tokens: FAST action tokens for auxiliary prediction (can be None) - discrete token IDs
fast_action_masks: Padding masks for FAST action tokens (can be None)
Returns:
embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided), (fast_action_tokens if provided)]
pad_masks: Padding masks
att_masks: Custom 2D attention mask implementing the required pattern
total_T_images: Total number of image tokens
num_subtask_embs: Number of subtask token embeddings
num_fast_embs: Number of FAST action token embeddings
"""
embs = []
pad_masks = []
att_mask_segments = [] # Store info about each segment for custom mask creation
total_t_images = 0
num_subtask_embs = 0
num_fast_embs = 0
# Process images
for img, img_mask in zip(images, img_masks, strict=True):
def image_embed_func(img):
return self.paligemma_with_expert.embed_image(img)
img_emb = self._apply_checkpoint(image_embed_func, img)
bsize, num_img_embs = img_emb.shape[:2]
embs.append(img_emb)
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
att_mask_segments.append(("image", num_img_embs))
total_t_images += num_img_embs
# Process language instruction tokens
def lang_embed_func(tokens):
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
lang_emb_dim = lang_emb.shape[-1]
return lang_emb * math.sqrt(lang_emb_dim)
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
embs.append(lang_emb)
pad_masks.append(masks)
num_lang_embs = lang_emb.shape[1]
att_mask_segments.append(("language", num_lang_embs))
# Process subtask tokens if provided (these are predicted, so use causal masking)
if subtask_tokens is not None:
def subtask_embed_func(subtask_tokens):
subtask_emb = self.paligemma_with_expert.embed_language_tokens(subtask_tokens)
subtask_emb_dim = subtask_emb.shape[-1]
return subtask_emb * math.sqrt(subtask_emb_dim)
subtask_emb = self._apply_checkpoint(subtask_embed_func, subtask_tokens)
embs.append(subtask_emb)
# Create subtask pad masks (non-zero tokens are valid)
pad_masks.append(subtask_masks)
num_subtask_embs = subtask_emb.shape[1]
att_mask_segments.append(("subtask", num_subtask_embs))
# Process FAST action tokens if provided (these are discrete token IDs)
if fast_action_tokens is not None:
def fast_action_embed_func(fast_action_tokens):
fast_emb = self.fast_action_embedding(fast_action_tokens)
fast_emb_dim = fast_emb.shape[-1]
return fast_emb * math.sqrt(fast_emb_dim)
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
embs.append(fast_action_emb)
# Use provided mask or create default (all valid)
if fast_action_masks is not None:
fast_pad_mask = fast_action_masks
else:
bsize = fast_action_tokens.shape[0]
num_fast_embs = fast_action_tokens.shape[1]
fast_pad_mask = torch.ones(
bsize, num_fast_embs, dtype=torch.bool, device=fast_action_tokens.device
)
num_fast_embs = fast_action_tokens.shape[1]
pad_masks.append(fast_pad_mask)
att_mask_segments.append(("fast", num_fast_embs))
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
# Create custom 2D attention mask
# Attention rules:
# - Images + Language: bidirectional among themselves, don't attend to subtask or FAST
# - Subtask: attend to images + language, causal among themselves, don't attend to FAST
# - FAST: attend to images + language + subtask, causal among themselves
att_masks = self._create_custom_attention_mask(att_mask_segments, pad_masks, bsize)
# # Optionally visualize the attention mask
# self.visualize_attention_mask(
# att_mask_segments=att_mask_segments,
# att_2d_masks=att_masks,
# save_path="/admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05/attention_mask_visualization.png",
# batch_idx=0,
# max_display_tokens=512 # Limit display for very long sequences
# )
return embs, pad_masks, att_masks, total_t_images, num_subtask_embs, num_fast_embs
def embed_prefix_fast(
self,
images,
@@ -952,9 +525,8 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
fast_action_tokens=None,
fast_action_masks=None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
"""Embed images, language tokens, and FAST action tokens for FAST-only mode.
"""Embed images, language tokens, and FAST action tokens.
This is a simplified version of embed_prefix without subtask tokens.
Attention pattern:
- Images + Language: bidirectional among themselves
- FAST: attend to images + language, causal among themselves
@@ -1018,23 +590,14 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
embs.append(fast_action_emb)
if fast_action_masks is not None:
fast_pad_mask = fast_action_masks
else:
bsize = fast_action_tokens.shape[0]
num_fast_embs = fast_action_tokens.shape[1]
fast_pad_mask = torch.ones(
bsize, num_fast_embs, dtype=torch.bool, device=fast_action_tokens.device
)
num_fast_embs = fast_action_tokens.shape[1]
pad_masks.append(fast_pad_mask)
pad_masks.append(fast_action_masks)
att_mask_segments.append(("fast", num_fast_embs))
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
# Create custom 2D attention mask for FAST-only mode:
# Create custom 2D attention mask:
# - Images + Language: bidirectional among themselves
# - FAST: attend to images + language, causal among themselves
att_masks = self._create_custom_attention_mask_fast(att_mask_segments, pad_masks, bsize)
@@ -1042,7 +605,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
return embs, pad_masks, att_masks, total_t_images, num_fast_embs
def _create_custom_attention_mask_fast(self, att_mask_segments, pad_masks, bsize):
"""Create custom 2D attention mask for FAST-only mode.
"""Create custom 2D attention mask.
Attention rules:
- Images + Language: bidirectional among themselves
@@ -1091,7 +654,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
fast_action_tokens,
fast_action_masks,
) -> dict:
"""Forward pass for FAST-only mode (no flow matching, no subtask).
"""Forward pass for PI0Fast.
This implements the Pi0FAST training objective: predict next action token
using cross-entropy loss.
@@ -1151,11 +714,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
# only compute logits for the positions that predict FAST tokens
lm_head = self.paligemma_with_expert.paligemma.lm_head
# The FAST tokens start at position (total_T_images + num_lang_tokens)
# For next-token prediction:
# - Position (fast_start - 1) in input predicts fast_action_tokens[0]
# - Position (fast_start) in input predicts fast_action_tokens[1], etc.
# Targets are the FAST action tokens
fast_targets = fast_action_tokens # (B, num_fast_embs)
@@ -1169,24 +727,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
fast_targets = fast_targets[:, 1:] # shift targets right
fast_action_masks = fast_action_masks[:, 1:] # shift masks to match targets
# from transformers import AutoTokenizer
# self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
# "google/paligemma-3b-pt-224",
# trust_remote_code=True,
# add_eos_token=True,
# add_bos_token=False
# )
# # remove
# decoded_tokens = [
# self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
# for seq in fast_targets
# ]
# corrected_tokens = [
# self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist())
# for seq in fast_logits_for_pred.argmax(dim=-1)
# ]
# breakpoint()
# compute cross-entropy loss
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1))
@@ -1320,38 +860,38 @@ class PI0FastPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
# Initialize the core PI0Fast model
self.init_rtc_processor()
self.model = PI0FastPytorch(config, rtc_processor=self.rtc_processor)
# Enable gradient checkpointing if requested
if config.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
self.model.to(config.device)
# Load FAST tokenizer for action detokenization (only if fast_only mode)
self.action_tokenizer = None
self._paligemma_tokenizer = None
self._fast_skip_tokens = 128
# Load tokenizers first
try:
from transformers import AutoProcessor, AutoTokenizer
# Load FAST tokenizer
self.action_tokenizer = AutoProcessor.from_pretrained(
"jadechoghari/fast-libero-tokenizer-mean-std", trust_remote_code=True
config.action_tokenizer_name, trust_remote_code=True
)
# Load PaliGemma tokenizer for token conversion
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
"google/paligemma-3b-pt-224", trust_remote_code=True, add_eos_token=True, add_bos_token=False
config.text_tokenizer_name, trust_remote_code=True, add_eos_token=True, add_bos_token=False
)
logging.info("Loaded FAST tokenizer for action detokenization")
except Exception as e:
logging.warning(f"Could not load FAST tokenizer for action detokenization: {e}")
logging.warning("Action tokens will be returned without detokenization")
self._paligemma_tokenizer = None
self.action_tokenizer = None
# Initialize the core PI0Fast model
self.init_rtc_processor()
self.model = PI0FastPytorch(
config, rtc_processor=self.rtc_processor, paligemma_tokenizer=self._paligemma_tokenizer
)
# Enable gradient checkpointing if requested
if config.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
self.model.to(config.device)
self.reset()
@@ -1634,7 +1174,7 @@ class PI0FastPolicy(PreTrainedPolicy):
Returns:
Action token IDs
"""
return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens
return self._paligemma_tokenizer.vocab_size - 1 - self.config.fast_skip_tokens - tokens
def decode_actions_with_fast(
self, token_ids: list[int], time_horizon: int, action_dim: int, relaxed_decoding: bool = True
@@ -1807,9 +1347,9 @@ class PI0FastPolicy(PreTrainedPolicy):
tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
# Get optional parameters
temperature = kwargs.get("temperature", 0.0)
max_decoding_steps = 256
# Get decoding parameters
temperature = self.config.temperature
max_decoding_steps = self.config.max_decoding_steps
# Sample action tokens autoregressively
action_tokens = self.model.sample_actions_fast(
@@ -1820,9 +1360,10 @@ class PI0FastPolicy(PreTrainedPolicy):
max_decoding_steps=max_decoding_steps,
temperature=temperature,
)
# Detokenize action tokens to continuous actions
action_horizon = self.config.n_action_steps
action_dim = 7
action_dim = self.config.output_features[ACTION].shape[0]
continuous_actions = self.detokenize_actions(
action_tokens, action_horizon=action_horizon, action_dim=action_dim
@@ -1837,23 +1378,25 @@ class PI0FastPolicy(PreTrainedPolicy):
images, img_masks = self._preprocess_images(batch)
# Get FAST action tokens from batch
fast_action_tokens = batch.get("action.tokens") # (B, max_action_tokens)
fast_action_masks = batch.get("action.token_mask") # (B, max_action_tokens)
fast_action_tokens = batch.get(ACTION_TOKENS) # (B, max_action_tokens)
fast_action_masks = batch.get(ACTION_TOKEN_MASK) # (B, max_action_tokens)
# Use full language tokens (no separation into high_level_task and subtask)
tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
tokens = batch.get(OBS_LANGUAGE_TOKENS)
masks = batch.get(OBS_LANGUAGE_ATTENTION_MASK)
if fast_action_tokens is None or fast_action_masks is None:
raise ValueError("FAST-only mode requires action.tokens and action.token_mask in the batch")
raise ValueError(
f"PI0Fast requires {ACTION_TOKENS} and {ACTION_TOKEN_MASK} to be present in the batch"
)
loss_dict = self.model.forward(
images,
img_masks,
tokens,
masks,
fast_action_tokens=fast_action_tokens,
fast_action_masks=fast_action_masks,
fast_action_tokens,
fast_action_masks,
)
loss = loss_dict["loss"]

View File

@@ -142,13 +142,16 @@ def make_pi0_fast_pre_post_processors(
),
Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(max_state_dim=config.max_state_dim),
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
tokenizer_name=config.tokenizer_name,
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
ActionTokenizerProcessorStep(
tokenizer_name="/fsx/jade_choghari/outputs/fast_tokenizer", # TODO: jade put the PI
action_tokenizer_name=config.action_tokenizer_name,
max_action_tokens=config.max_action_tokens,
fast_skip_tokens=config.fast_skip_tokens,
paligemma_tokenizer_name=config.text_tokenizer_name,
),
DeviceProcessorStep(device=config.device),
]

View File

@@ -157,7 +157,7 @@ class TokenizerProcessorStep(ObservationProcessorStep):
# Tokenize the task (this will create CPU tensors)
tokenized_prompt = self._tokenize_text(task)
# Detect the device from existing tensors in the transition to ensure consistency
target_device = self._detect_device(self.transition)
@@ -295,14 +295,15 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.
"""
tokenizer_name: str | None = None
tokenizer: Any | None = None
action_tokenizer_name: str | None = None
action_tokenizer_input_object: Any | None = None
trust_remote_code: bool = True
max_action_tokens: int = 256
fast_skip_tokens: int = 128
paligemma_tokenizer_name: str = "google/paligemma-3b-pt-224"
# Internal tokenizer instance (not part of the config)
action_tokenizer: Any = field(default=None, init=False, repr=False)
_paligemma_tokenizer: Any = field(default=None, init=False, repr=False)
_fast_skip_tokens: int = field(default=128, init=False, repr=False)
def __post_init__(self):
"""
@@ -321,25 +322,27 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
"Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionTokenizerProcessorStep."
)
if self.tokenizer is not None:
# Use provided tokenizer object directly
self.action_tokenizer = self.tokenizer
elif self.tokenizer_name is not None:
if self.action_tokenizer_input_object is not None:
self.action_tokenizer = self.action_tokenizer_input_object
elif self.action_tokenizer_name is not None:
if AutoProcessor is None:
raise ImportError("AutoProcessor is not available")
self.action_tokenizer = AutoProcessor.from_pretrained(
self.tokenizer_name, trust_remote_code=self.trust_remote_code
self.action_tokenizer_name, trust_remote_code=self.trust_remote_code
)
else:
raise ValueError(
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
"Either 'action_tokenizer' or 'action_tokenizer_name' must be provided. "
"Pass a tokenizer object directly or a tokenizer name to auto-load."
)
self._paligemma_tokenizer = AutoTokenizer.from_pretrained(
"google/paligemma-3b-pt-224", trust_remote_code=True, add_eos_token=True, add_bos_token=False
self.paligemma_tokenizer_name,
trust_remote_code=self.trust_remote_code,
add_eos_token=True,
add_bos_token=False,
)
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""
@@ -358,7 +361,8 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
action = new_transition.get(TransitionKey.ACTION)
if action is None:
raise ValueError("ActionTokenizerProcessorStep requires an action in the transition.")
# During inference, no action is available, skip tokenization
return new_transition
# Tokenize and get both tokens and mask
tokens, mask = self._tokenize_action(action)
@@ -376,7 +380,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
"""
Converts action tokens to PaliGemma tokens.
"""
return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens
return self._paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens
def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
@@ -523,5 +527,4 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
Returns:
The updated dictionary of policy features.
"""
# TODO: jadechoghari, should we add the tokenized action to the features?
return features