From 911f7c4cc21e8990c0b26cd91c545cf9ee26ab86 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Sun, 4 Jan 2026 13:12:29 +0100 Subject: [PATCH] refactor the policy --- src/lerobot/policies/__init__.py | 1 + .../pi0_fast/configuration_pi0_fast.py | 20 +- .../policies/pi0_fast/modeling_pi0_fast.py | 557 ++---------------- .../policies/pi0_fast/processor_pi0_fast.py | 7 +- src/lerobot/processor/tokenizer_processor.py | 33 +- 5 files changed, 80 insertions(+), 538 deletions(-) diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index db7156210..5923fb954 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -14,6 +14,7 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig + # from .groot.configuration_groot import GrootConfig as GrootConfig from .pi0.configuration_pi0 import PI0Config as PI0Config from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig diff --git a/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py b/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py index 6bb4bff66..ac9ed0da8 100644 --- a/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py @@ -32,24 +32,13 @@ class PI0FastConfig(PreTrainedConfig): action_expert_variant: str = "gemma_300m" dtype: str = "float32" # Options: "bfloat16", "float32" - n_obs_steps: int = 1 chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon" n_action_steps: int = 50 # Number of action steps to execute # Shorter state and action vectors will be padded to these dimensions max_state_dim: int = 32 max_action_dim: int = 32 - max_action_tokens: int = 32 - fast_vocab_size: int = 2048 - - # Flow matching parameters: see openpi `PI0Pytorch` - num_inference_steps: int = 10 - time_sampling_beta_alpha: float = 1.5 - time_sampling_beta_beta: float = 1.0 - time_sampling_scale: float = 0.999 - time_sampling_offset: float = 0.001 - min_period: float = 4e-3 - max_period: float = 4.0 + max_action_tokens: int = 256 # Real-Time Chunking (RTC) configuration rtc_config: RTCConfig | None = None @@ -63,6 +52,11 @@ class PI0FastConfig(PreTrainedConfig): empty_cameras: int = 0 tokenizer_max_length: int = 200 # see openpi `__post_init__` + text_tokenizer_name: str = "google/paligemma-3b-pt-224" + action_tokenizer_name: str = "physical-intelligence/fast" + temperature: float = 0.0 + max_decoding_steps: int = 256 + fast_skip_tokens: int = 128 normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { @@ -92,8 +86,6 @@ class PI0FastConfig(PreTrainedConfig): scheduler_decay_steps: int = 30_000 scheduler_decay_lr: float = 2.5e-6 - tokenizer_max_length: int = 200 # see openpi `__post_init__` - def __post_init__(self): super().__post_init__() diff --git a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py index 81e833aae..6e9bc84aa 100644 --- a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py @@ -32,6 +32,7 @@ from lerobot.utils.import_utils import _transformers_available # Conditional import for type checking and lazy loading if TYPE_CHECKING or _transformers_available: + from transformers import AutoTokenizer from transformers.models.auto import CONFIG_MAPPING from transformers.models.gemma import modeling_gemma from transformers.models.gemma.modeling_gemma import GemmaForCausalLM @@ -41,6 +42,7 @@ else: modeling_gemma = None GemmaForCausalLM = None PaliGemmaForConditionalGeneration = None + AutoTokenizer = None from lerobot.configs.policies import PreTrainedConfig from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig @@ -48,6 +50,8 @@ from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.rtc.modeling_rtc import RTCProcessor from lerobot.utils.constants import ( ACTION, + ACTION_TOKEN_MASK, + ACTION_TOKENS, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OPENPI_ATTENTION_MASK_VALUE, @@ -55,81 +59,7 @@ from lerobot.utils.constants import ( class ActionSelectKwargs(TypedDict, total=False): - inference_delay: int | None - prev_chunk_left_over: Tensor | None - execution_horizon: int | None - - -def get_safe_dtype(target_dtype, device_type): - """Get a safe dtype for the given device type.""" - if device_type == "mps" and target_dtype == torch.float64: - return torch.float32 - if device_type == "cpu": - # CPU doesn't support bfloat16, use float32 instead - if target_dtype == torch.bfloat16: - return torch.float32 - if target_dtype == torch.float64: - return torch.float64 - return target_dtype - - -def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy) - time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu" -) -> Tensor: - """Computes sine-cosine positional embedding vectors for scalar positions.""" - if dimension % 2 != 0: - raise ValueError(f"dimension ({dimension}) must be divisible by 2") - - if time.ndim != 1: - raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") - - dtype = get_safe_dtype(torch.float64, device.type) - fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) - period = min_period * (max_period / min_period) ** fraction - - # Compute the outer product - scaling_factor = 1.0 / period * 2 * math.pi - sin_input = scaling_factor[None, :] * time[:, None] - return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) - - -def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy) - alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) - beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) - dist = torch.distributions.Beta(alpha_t, beta_t) - return dist.sample((bsize,)) - - -def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy) - """Copied from big_vision. - - Tokens can attend to valid inputs tokens which have a cumulative mask_ar - smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to - setup several types of attention, for example: - - [[1 1 1 1 1 1]]: pure causal attention. - - [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between - themselves and the last 3 tokens have a causal attention. The first - entry could also be a 1 without changing behaviour. - - [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a - block can attend all previous blocks and all tokens on the same block. - - Args: - input_mask: bool[B, N] true if its part of the input, false if padding. - mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on - it and 0 where it shares the same attention mask as the previous token. - """ - if att_masks.ndim != 2: - raise ValueError(att_masks.ndim) - if pad_masks.ndim != 2: - raise ValueError(pad_masks.ndim) - - cumsum = torch.cumsum(att_masks, dim=1) - att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] - pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] - return att_2d_masks & pad_2d_masks + temperature: float | None def pad_vector(vector, new_dim): @@ -514,10 +444,16 @@ class PaliGemmaWithExpertModel( class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` """Core PI0Fast PyTorch model.""" - def __init__(self, config: PI0FastConfig, rtc_processor: RTCProcessor | None = None): + def __init__( + self, + config: PI0FastConfig, + rtc_processor: RTCProcessor | None = None, + paligemma_tokenizer: AutoTokenizer | None = None, + ): super().__init__() self.config = config self.rtc_processor = rtc_processor + self._paligemma_tokenizer = paligemma_tokenizer paligemma_config = get_gemma_config(config.paligemma_variant) action_expert_config = get_gemma_config(config.action_expert_variant) @@ -528,19 +464,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` use_adarms=[False, True], precision=config.dtype, ) - from transformers import AutoTokenizer - - self._paligemma_tokenizer = AutoTokenizer.from_pretrained( - "google/paligemma-3b-pt-224", - trust_remote_code=True, - ) - # # Apply dtype conversion to FAST layers to match model precision - # if config.dtype == "bfloat16": - # self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.bfloat16) - # self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.bfloat16) - # elif config.dtype == "float32": - # self.fast_action_embedding = self.fast_action_embedding.to(dtype=torch.float32) - # self.fast_action_lm_head = self.fast_action_lm_head.to(dtype=torch.float32) # Initialize gradient checkpointing flag self.gradient_checkpointing_enabled = False @@ -548,8 +471,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` # Compile model if requested if config.compile_model: torch.set_float32_matmul_precision("high") - self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode) - # Also compile the main forward pass used during training + self.sample_actions_fast = torch.compile(self.sample_actions_fast, mode=config.compile_mode) self.forward = torch.compile(self.forward, mode=config.compile_mode) msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" @@ -578,9 +500,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False logging.info("Disabled gradient checkpointing for PI0FastPytorch model") - def _rtc_enabled(self): - return self.config.rtc_config is not None and self.config.rtc_config.enabled - def _apply_checkpoint(self, func, *args, **kwargs): """Helper method to apply gradient checkpointing if enabled.""" if self.gradient_checkpointing_enabled and self.training: @@ -597,352 +516,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` result = result.to(dtype=dtype) return result - def _create_custom_attention_mask(self, att_mask_segments, pad_masks, bsize): - """Create custom 2D attention mask for the new attention pattern. - - Attention rules: - - Images + Language: bidirectional among themselves, don't attend to subtask or FAST - - Subtask: attend to images + language, causal among themselves, don't attend to FAST - - FAST: attend to images + language + subtask, causal among themselves - - Args: - att_mask_segments: List of (type, length) tuples - pad_masks: Padding masks [B, total_seq_len] - bsize: Batch size - - Returns: - att_2d_masks: 2D attention mask [B, total_seq_len, total_seq_len] - """ - total_len = sum(length for _, length in att_mask_segments) - device = pad_masks.device - - # Initialize attention mask as False (cannot attend) - att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device) - - # Track positions for each segment - positions = [] - current_pos = 0 - for seg_type, seg_len in att_mask_segments: - positions.append((seg_type, current_pos, current_pos + seg_len)) - current_pos += seg_len - - # Apply attention rules - for _i, (query_type, query_start, query_end) in enumerate(positions): - for _j, (key_type, key_start, key_end) in enumerate(positions): - # Images and Language can attend to each other bidirectionally - if ( - query_type in ["image", "language"] - and key_type in ["image", "language"] - or query_type == "subtask" - and key_type in ["image", "language"] - ): - att_2d_masks[:, query_start:query_end, key_start:key_end] = True - - # Subtask tokens attend causally to themselves - elif query_type == "subtask" and key_type == "subtask": - # Create causal mask for subtask tokens - subtask_len = query_end - query_start - causal_mask = torch.tril( - torch.ones(subtask_len, subtask_len, dtype=torch.bool, device=device) - ) - att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :] - - # FAST tokens attend to images + language + subtask - elif query_type == "fast" and key_type in ["image", "language", "subtask"]: - att_2d_masks[:, query_start:query_end, key_start:key_end] = True - - # FAST tokens attend causally to themselves - elif query_type == "fast" and key_type == "fast": - fast_len = query_end - query_start - causal_mask = torch.tril(torch.ones(fast_len, fast_len, dtype=torch.bool, device=device)) - att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :] - - # Apply padding masks - pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] - att_2d_masks = att_2d_masks & pad_2d_masks - - return att_2d_masks - - def visualize_attention_mask( - self, att_mask_segments, att_2d_masks, save_path, batch_idx=0, dpi=150, max_display_tokens=None - ): - """Visualize the attention mask with labeled segments. - - Args: - att_mask_segments: List of (type, length) tuples defining the segments - att_2d_masks: 2D attention mask tensor [B, total_seq_len, total_seq_len] - save_path: Path where to save the visualization image - batch_idx: Which batch item to visualize (default: 0) - dpi: DPI for the saved image (default: 150) - max_display_tokens: Maximum number of tokens to display (for very long sequences) - """ - try: - import matplotlib.patches as mpatches - import matplotlib.pyplot as plt - from matplotlib.colors import LinearSegmentedColormap - except ImportError: - logging.warning("matplotlib not available, skipping attention mask visualization") - return - - # Extract the mask for the specified batch - mask = att_2d_masks[batch_idx].cpu().float().numpy() - - # If sequence is too long, downsample for visualization - if max_display_tokens is not None and mask.shape[0] > max_display_tokens: - # Simple downsampling by taking every Nth token - step = mask.shape[0] // max_display_tokens - mask = mask[::step, ::step] - # Adjust segments accordingly - att_mask_segments = [ - (seg_type, max(1, seg_len // step)) for seg_type, seg_len in att_mask_segments - ] - - # Calculate positions for each segment - positions = [] - current_pos = 0 - for seg_type, seg_len in att_mask_segments: - positions.append((seg_type, current_pos, current_pos + seg_len)) - current_pos += seg_len - - # Create figure - fig, ax = plt.subplots(figsize=(12, 10)) - - # Create custom colormap: white for False (no attention), blue for True (attention) - colors = ["white", "#2E86AB"] - n_bins = 2 - cmap = LinearSegmentedColormap.from_list("attention", colors, N=n_bins) - - # Display the mask - im = ax.imshow(mask, cmap=cmap, aspect="auto", interpolation="nearest", vmin=0, vmax=1) - - # Add colorbar - cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) - cbar.set_label("Attention Enabled", rotation=270, labelpad=20) - cbar.set_ticks([0.25, 0.75]) - cbar.set_ticklabels(["No", "Yes"]) - - # Define colors for each segment type - segment_colors = {"image": "#A23B72", "language": "#F18F01", "subtask": "#C73E1D", "fast": "#6A994E"} - - # Draw segment boundaries and labels - for seg_type, start, end in positions: - color = segment_colors.get(seg_type, "#666666") - - # Draw vertical lines for columns (keys) - ax.axvline(x=start - 0.5, color=color, linewidth=2, alpha=0.7) - ax.axvline(x=end - 0.5, color=color, linewidth=2, alpha=0.7) - - # Draw horizontal lines for rows (queries) - ax.axhline(y=start - 0.5, color=color, linewidth=2, alpha=0.7) - ax.axhline(y=end - 0.5, color=color, linewidth=2, alpha=0.7) - - # Add labels at the top - mid_pos = (start + end) / 2 - ax.text( - mid_pos, - -mask.shape[0] * 0.02, - f"{seg_type.upper()}\n({end - start})", - ha="center", - va="top", - fontsize=10, - fontweight="bold", - color=color, - ) - - # Add labels on the left - ax.text( - -mask.shape[1] * 0.02, - mid_pos, - f"{seg_type.upper()}\n({end - start})", - ha="right", - va="center", - fontsize=10, - fontweight="bold", - color=color, - rotation=0, - ) - - # Set axis labels - ax.set_xlabel("Key Position (tokens being attended to)", fontsize=12, fontweight="bold") - ax.set_ylabel("Query Position (tokens attending)", fontsize=12, fontweight="bold") - ax.set_title( - "Attention Mask Pattern\n(White = No Attention, Blue = Attention Allowed)", - fontsize=14, - fontweight="bold", - pad=20, - ) - - # Create legend for segment types - legend_patches = [] - attention_rules = { - "image": "Bidirectional with lang", - "language": "Bidirectional with images", - "subtask": "Attends to img+lang, causal self", - "fast": "Attends to all, causal self", - } - for seg_type, color in segment_colors.items(): - if any(seg[0] == seg_type for seg in att_mask_segments): - rule = attention_rules.get(seg_type, "") - legend_patches.append(mpatches.Patch(color=color, label=f"{seg_type.upper()}: {rule}")) - - ax.legend( - handles=legend_patches, loc="upper right", bbox_to_anchor=(1.15, 1.0), framealpha=0.9, fontsize=9 - ) - - # Adjust layout and save - plt.tight_layout() - - # Ensure the directory exists - save_path = Path(save_path) - save_path.parent.mkdir(parents=True, exist_ok=True) - - plt.savefig(save_path, dpi=dpi, bbox_inches="tight") - plt.close() - - logging.info(f"Attention mask visualization saved to: {save_path}") - - def sample_noise(self, shape, device): - return torch.normal( - mean=0.0, - std=1.0, - size=shape, - dtype=torch.float32, - device=device, - ) - - def sample_time(self, bsize, device): - time_beta = sample_beta( - self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device - ) - time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset - return time.to(dtype=torch.float32, device=device) - - def embed_prefix( - self, - images, - img_masks, - tokens, - subtask_tokens, - masks, - subtask_masks, - fast_action_tokens=None, - fast_action_masks=None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: - """Embed images with SigLIP, tokens, and optionally subtask tokens with embedding layer. - - Args: - images: List of image tensors - img_masks: List of image masks - tokens: Language instruction tokens - subtask_tokens: Subtask tokens to predict (can be None for inference) - masks: Attention masks for tokens - fast_action_tokens: FAST action tokens for auxiliary prediction (can be None) - discrete token IDs - fast_action_masks: Padding masks for FAST action tokens (can be None) - - Returns: - embs: Concatenated embeddings [images, tokens, (subtask_tokens if provided), (fast_action_tokens if provided)] - pad_masks: Padding masks - att_masks: Custom 2D attention mask implementing the required pattern - total_T_images: Total number of image tokens - num_subtask_embs: Number of subtask token embeddings - num_fast_embs: Number of FAST action token embeddings - """ - embs = [] - pad_masks = [] - att_mask_segments = [] # Store info about each segment for custom mask creation - total_t_images = 0 - num_subtask_embs = 0 - num_fast_embs = 0 - - # Process images - for img, img_mask in zip(images, img_masks, strict=True): - - def image_embed_func(img): - return self.paligemma_with_expert.embed_image(img) - - img_emb = self._apply_checkpoint(image_embed_func, img) - bsize, num_img_embs = img_emb.shape[:2] - - embs.append(img_emb) - pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) - att_mask_segments.append(("image", num_img_embs)) - total_t_images += num_img_embs - - # Process language instruction tokens - def lang_embed_func(tokens): - lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) - lang_emb_dim = lang_emb.shape[-1] - return lang_emb * math.sqrt(lang_emb_dim) - - lang_emb = self._apply_checkpoint(lang_embed_func, tokens) - embs.append(lang_emb) - pad_masks.append(masks) - - num_lang_embs = lang_emb.shape[1] - att_mask_segments.append(("language", num_lang_embs)) - - # Process subtask tokens if provided (these are predicted, so use causal masking) - if subtask_tokens is not None: - - def subtask_embed_func(subtask_tokens): - subtask_emb = self.paligemma_with_expert.embed_language_tokens(subtask_tokens) - subtask_emb_dim = subtask_emb.shape[-1] - return subtask_emb * math.sqrt(subtask_emb_dim) - - subtask_emb = self._apply_checkpoint(subtask_embed_func, subtask_tokens) - embs.append(subtask_emb) - - # Create subtask pad masks (non-zero tokens are valid) - pad_masks.append(subtask_masks) - - num_subtask_embs = subtask_emb.shape[1] - att_mask_segments.append(("subtask", num_subtask_embs)) - # Process FAST action tokens if provided (these are discrete token IDs) - if fast_action_tokens is not None: - - def fast_action_embed_func(fast_action_tokens): - fast_emb = self.fast_action_embedding(fast_action_tokens) - fast_emb_dim = fast_emb.shape[-1] - return fast_emb * math.sqrt(fast_emb_dim) - - fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens) - embs.append(fast_action_emb) - - # Use provided mask or create default (all valid) - if fast_action_masks is not None: - fast_pad_mask = fast_action_masks - else: - bsize = fast_action_tokens.shape[0] - num_fast_embs = fast_action_tokens.shape[1] - fast_pad_mask = torch.ones( - bsize, num_fast_embs, dtype=torch.bool, device=fast_action_tokens.device - ) - - num_fast_embs = fast_action_tokens.shape[1] - pad_masks.append(fast_pad_mask) - att_mask_segments.append(("fast", num_fast_embs)) - - embs = torch.cat(embs, dim=1) - pad_masks = torch.cat(pad_masks, dim=1) - - # Create custom 2D attention mask - # Attention rules: - # - Images + Language: bidirectional among themselves, don't attend to subtask or FAST - # - Subtask: attend to images + language, causal among themselves, don't attend to FAST - # - FAST: attend to images + language + subtask, causal among themselves - att_masks = self._create_custom_attention_mask(att_mask_segments, pad_masks, bsize) - - # # Optionally visualize the attention mask - # self.visualize_attention_mask( - # att_mask_segments=att_mask_segments, - # att_2d_masks=att_masks, - # save_path="/admin/home/jade_choghari/lerobot/src/lerobot/policies/pi05/attention_mask_visualization.png", - # batch_idx=0, - # max_display_tokens=512 # Limit display for very long sequences - # ) - - return embs, pad_masks, att_masks, total_t_images, num_subtask_embs, num_fast_embs - def embed_prefix_fast( self, images, @@ -952,9 +525,8 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` fast_action_tokens=None, fast_action_masks=None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]: - """Embed images, language tokens, and FAST action tokens for FAST-only mode. + """Embed images, language tokens, and FAST action tokens. - This is a simplified version of embed_prefix without subtask tokens. Attention pattern: - Images + Language: bidirectional among themselves - FAST: attend to images + language, causal among themselves @@ -1018,23 +590,14 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens) embs.append(fast_action_emb) - if fast_action_masks is not None: - fast_pad_mask = fast_action_masks - else: - bsize = fast_action_tokens.shape[0] - num_fast_embs = fast_action_tokens.shape[1] - fast_pad_mask = torch.ones( - bsize, num_fast_embs, dtype=torch.bool, device=fast_action_tokens.device - ) - num_fast_embs = fast_action_tokens.shape[1] - pad_masks.append(fast_pad_mask) + pad_masks.append(fast_action_masks) att_mask_segments.append(("fast", num_fast_embs)) embs = torch.cat(embs, dim=1) pad_masks = torch.cat(pad_masks, dim=1) - # Create custom 2D attention mask for FAST-only mode: + # Create custom 2D attention mask: # - Images + Language: bidirectional among themselves # - FAST: attend to images + language, causal among themselves att_masks = self._create_custom_attention_mask_fast(att_mask_segments, pad_masks, bsize) @@ -1042,7 +605,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` return embs, pad_masks, att_masks, total_t_images, num_fast_embs def _create_custom_attention_mask_fast(self, att_mask_segments, pad_masks, bsize): - """Create custom 2D attention mask for FAST-only mode. + """Create custom 2D attention mask. Attention rules: - Images + Language: bidirectional among themselves @@ -1091,7 +654,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` fast_action_tokens, fast_action_masks, ) -> dict: - """Forward pass for FAST-only mode (no flow matching, no subtask). + """Forward pass for PI0Fast. This implements the Pi0FAST training objective: predict next action token using cross-entropy loss. @@ -1151,11 +714,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` # only compute logits for the positions that predict FAST tokens lm_head = self.paligemma_with_expert.paligemma.lm_head - # The FAST tokens start at position (total_T_images + num_lang_tokens) - # For next-token prediction: - # - Position (fast_start - 1) in input predicts fast_action_tokens[0] - # - Position (fast_start) in input predicts fast_action_tokens[1], etc. - # Targets are the FAST action tokens fast_targets = fast_action_tokens # (B, num_fast_embs) @@ -1169,24 +727,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` fast_targets = fast_targets[:, 1:] # shift targets right fast_action_masks = fast_action_masks[:, 1:] # shift masks to match targets - # from transformers import AutoTokenizer - # self._paligemma_tokenizer = AutoTokenizer.from_pretrained( - # "google/paligemma-3b-pt-224", - # trust_remote_code=True, - # add_eos_token=True, - # add_bos_token=False - # ) - # # remove - # decoded_tokens = [ - # self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist()) - # for seq in fast_targets - # ] - # corrected_tokens = [ - # self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist()) - # for seq in fast_logits_for_pred.argmax(dim=-1) - # ] - # breakpoint() - # compute cross-entropy loss loss_fct = torch.nn.CrossEntropyLoss(reduction="none") fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1)) @@ -1320,38 +860,38 @@ class PI0FastPolicy(PreTrainedPolicy): config.validate_features() self.config = config - # Initialize the core PI0Fast model - self.init_rtc_processor() - self.model = PI0FastPytorch(config, rtc_processor=self.rtc_processor) - - # Enable gradient checkpointing if requested - if config.gradient_checkpointing: - self.model.gradient_checkpointing_enable() - - self.model.to(config.device) - - # Load FAST tokenizer for action detokenization (only if fast_only mode) - self.action_tokenizer = None - self._paligemma_tokenizer = None - self._fast_skip_tokens = 128 - + # Load tokenizers first try: from transformers import AutoProcessor, AutoTokenizer # Load FAST tokenizer self.action_tokenizer = AutoProcessor.from_pretrained( - "jadechoghari/fast-libero-tokenizer-mean-std", trust_remote_code=True + config.action_tokenizer_name, trust_remote_code=True ) # Load PaliGemma tokenizer for token conversion self._paligemma_tokenizer = AutoTokenizer.from_pretrained( - "google/paligemma-3b-pt-224", trust_remote_code=True, add_eos_token=True, add_bos_token=False + config.text_tokenizer_name, trust_remote_code=True, add_eos_token=True, add_bos_token=False ) logging.info("Loaded FAST tokenizer for action detokenization") except Exception as e: logging.warning(f"Could not load FAST tokenizer for action detokenization: {e}") logging.warning("Action tokens will be returned without detokenization") + self._paligemma_tokenizer = None + self.action_tokenizer = None + + # Initialize the core PI0Fast model + self.init_rtc_processor() + self.model = PI0FastPytorch( + config, rtc_processor=self.rtc_processor, paligemma_tokenizer=self._paligemma_tokenizer + ) + + # Enable gradient checkpointing if requested + if config.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + + self.model.to(config.device) self.reset() @@ -1634,7 +1174,7 @@ class PI0FastPolicy(PreTrainedPolicy): Returns: Action token IDs """ - return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens + return self._paligemma_tokenizer.vocab_size - 1 - self.config.fast_skip_tokens - tokens def decode_actions_with_fast( self, token_ids: list[int], time_horizon: int, action_dim: int, relaxed_decoding: bool = True @@ -1807,9 +1347,9 @@ class PI0FastPolicy(PreTrainedPolicy): tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] - # Get optional parameters - temperature = kwargs.get("temperature", 0.0) - max_decoding_steps = 256 + # Get decoding parameters + temperature = self.config.temperature + max_decoding_steps = self.config.max_decoding_steps # Sample action tokens autoregressively action_tokens = self.model.sample_actions_fast( @@ -1820,9 +1360,10 @@ class PI0FastPolicy(PreTrainedPolicy): max_decoding_steps=max_decoding_steps, temperature=temperature, ) + # Detokenize action tokens to continuous actions action_horizon = self.config.n_action_steps - action_dim = 7 + action_dim = self.config.output_features[ACTION].shape[0] continuous_actions = self.detokenize_actions( action_tokens, action_horizon=action_horizon, action_dim=action_dim @@ -1837,23 +1378,25 @@ class PI0FastPolicy(PreTrainedPolicy): images, img_masks = self._preprocess_images(batch) # Get FAST action tokens from batch - fast_action_tokens = batch.get("action.tokens") # (B, max_action_tokens) - fast_action_masks = batch.get("action.token_mask") # (B, max_action_tokens) + fast_action_tokens = batch.get(ACTION_TOKENS) # (B, max_action_tokens) + fast_action_masks = batch.get(ACTION_TOKEN_MASK) # (B, max_action_tokens) # Use full language tokens (no separation into high_level_task and subtask) - tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] - masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + tokens = batch.get(OBS_LANGUAGE_TOKENS) + masks = batch.get(OBS_LANGUAGE_ATTENTION_MASK) if fast_action_tokens is None or fast_action_masks is None: - raise ValueError("FAST-only mode requires action.tokens and action.token_mask in the batch") + raise ValueError( + f"PI0Fast requires {ACTION_TOKENS} and {ACTION_TOKEN_MASK} to be present in the batch" + ) loss_dict = self.model.forward( images, img_masks, tokens, masks, - fast_action_tokens=fast_action_tokens, - fast_action_masks=fast_action_masks, + fast_action_tokens, + fast_action_masks, ) loss = loss_dict["loss"] diff --git a/src/lerobot/policies/pi0_fast/processor_pi0_fast.py b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py index cde21980a..bff94e092 100644 --- a/src/lerobot/policies/pi0_fast/processor_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py @@ -142,13 +142,16 @@ def make_pi0_fast_pre_post_processors( ), Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(max_state_dim=config.max_state_dim), TokenizerProcessorStep( - tokenizer_name="google/paligemma-3b-pt-224", + tokenizer_name=config.tokenizer_name, max_length=config.tokenizer_max_length, padding_side="right", padding="max_length", ), ActionTokenizerProcessorStep( - tokenizer_name="/fsx/jade_choghari/outputs/fast_tokenizer", # TODO: jade put the PI + action_tokenizer_name=config.action_tokenizer_name, + max_action_tokens=config.max_action_tokens, + fast_skip_tokens=config.fast_skip_tokens, + paligemma_tokenizer_name=config.text_tokenizer_name, ), DeviceProcessorStep(device=config.device), ] diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index 1b68000f7..fc7d79e57 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -157,7 +157,7 @@ class TokenizerProcessorStep(ObservationProcessorStep): # Tokenize the task (this will create CPU tensors) tokenized_prompt = self._tokenize_text(task) - + # Detect the device from existing tensors in the transition to ensure consistency target_device = self._detect_device(self.transition) @@ -295,14 +295,15 @@ class ActionTokenizerProcessorStep(ActionProcessorStep): action_tokenizer: The internal tokenizer/processor instance, loaded during initialization. """ - tokenizer_name: str | None = None - tokenizer: Any | None = None + action_tokenizer_name: str | None = None + action_tokenizer_input_object: Any | None = None trust_remote_code: bool = True max_action_tokens: int = 256 + fast_skip_tokens: int = 128 + paligemma_tokenizer_name: str = "google/paligemma-3b-pt-224" # Internal tokenizer instance (not part of the config) action_tokenizer: Any = field(default=None, init=False, repr=False) _paligemma_tokenizer: Any = field(default=None, init=False, repr=False) - _fast_skip_tokens: int = field(default=128, init=False, repr=False) def __post_init__(self): """ @@ -321,25 +322,27 @@ class ActionTokenizerProcessorStep(ActionProcessorStep): "Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionTokenizerProcessorStep." ) - if self.tokenizer is not None: - # Use provided tokenizer object directly - self.action_tokenizer = self.tokenizer - elif self.tokenizer_name is not None: + if self.action_tokenizer_input_object is not None: + self.action_tokenizer = self.action_tokenizer_input_object + + elif self.action_tokenizer_name is not None: if AutoProcessor is None: raise ImportError("AutoProcessor is not available") self.action_tokenizer = AutoProcessor.from_pretrained( - self.tokenizer_name, trust_remote_code=self.trust_remote_code + self.action_tokenizer_name, trust_remote_code=self.trust_remote_code ) else: raise ValueError( - "Either 'tokenizer' or 'tokenizer_name' must be provided. " + "Either 'action_tokenizer' or 'action_tokenizer_name' must be provided. " "Pass a tokenizer object directly or a tokenizer name to auto-load." ) self._paligemma_tokenizer = AutoTokenizer.from_pretrained( - "google/paligemma-3b-pt-224", trust_remote_code=True, add_eos_token=True, add_bos_token=False + self.paligemma_tokenizer_name, + trust_remote_code=self.trust_remote_code, + add_eos_token=True, + add_bos_token=False, ) - self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens def __call__(self, transition: EnvTransition) -> EnvTransition: """ @@ -358,7 +361,8 @@ class ActionTokenizerProcessorStep(ActionProcessorStep): action = new_transition.get(TransitionKey.ACTION) if action is None: - raise ValueError("ActionTokenizerProcessorStep requires an action in the transition.") + # During inference, no action is available, skip tokenization + return new_transition # Tokenize and get both tokens and mask tokens, mask = self._tokenize_action(action) @@ -376,7 +380,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep): """ Converts action tokens to PaliGemma tokens. """ - return self._paligemma_tokenizer.vocab_size - 1 - self._fast_skip_tokens - tokens + return self._paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -523,5 +527,4 @@ class ActionTokenizerProcessorStep(ActionProcessorStep): Returns: The updated dictionary of policy features. """ - # TODO: jadechoghari, should we add the tokenized action to the features? return features