From 6380c0d0dd8a0c292cc4ae296d309cd3a830c9ae Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Thu, 29 Jan 2026 11:21:03 +0000 Subject: [PATCH 1/4] example change --- src/lerobot/policies/pi05_full/annotate/load_lerobot_high.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/lerobot/policies/pi05_full/annotate/load_lerobot_high.py b/src/lerobot/policies/pi05_full/annotate/load_lerobot_high.py index 7a48d6903..1b5ab30f5 100644 --- a/src/lerobot/policies/pi05_full/annotate/load_lerobot_high.py +++ b/src/lerobot/policies/pi05_full/annotate/load_lerobot_high.py @@ -5,7 +5,9 @@ import lerobot from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.policies.factory import make_pre_post_processors from lerobot.configs.policies import PreTrainedConfig -dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/libero-10-annotate") + +# /fsx/jade_choghari/data/libero_10_subtasks_kw_converted +dataset = LeRobotDataset(repo_id="lerobot/libero_10_image_subtask") dataloader = torch.utils.data.DataLoader( dataset, @@ -24,6 +26,7 @@ pre_processor, post_processor = make_pre_post_processors( pretrained_path="/fsx/jade_choghari/models/pi05-base", ) batch = next(iter(dataloader)) +breakpoint() batch1 = pre_processor(batch) breakpoint() print(batch.keys()) From 092f4617ca0315dd06915f4012be15e28d3ed773 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Mon, 2 Feb 2026 09:04:55 +0000 Subject: [PATCH 2/4] more changes --- src/lerobot/policies/pi05_full/modeling_pi05.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/policies/pi05_full/modeling_pi05.py b/src/lerobot/policies/pi05_full/modeling_pi05.py index b1181765e..a13767794 100644 --- a/src/lerobot/policies/pi05_full/modeling_pi05.py +++ b/src/lerobot/policies/pi05_full/modeling_pi05.py @@ -1117,7 +1117,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` "flow_mse_loss": flow_loss.mean(), "action_ce_loss": fast_loss.mean(), "subtask_ce_loss": subtask_loss, - "loss": flow_loss.mean() + 0.1 * subtask_loss.mean() + 0.05 * fast_loss.mean(), # TODO: jadechoghari: check weights + "loss": flow_loss.mean() + subtask_loss.mean() + fast_loss.mean(), # TODO: jadechoghari: check weights } @torch.no_grad() # see openpi `sample_actions` (slightly adapted) From 6c94fcd1b17c9c37556fcd4e387bbd8f6db5ba0e Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Mon, 2 Feb 2026 15:58:47 +0000 Subject: [PATCH 3/4] add KI optional --- .../policies/pi05_full/configuration_pi05.py | 6 + .../policies/pi05_full/modeling_pi05.py | 108 ++++++++++++++++-- 2 files changed, 107 insertions(+), 7 deletions(-) diff --git a/src/lerobot/policies/pi05_full/configuration_pi05.py b/src/lerobot/policies/pi05_full/configuration_pi05.py index a95645220..744854521 100644 --- a/src/lerobot/policies/pi05_full/configuration_pi05.py +++ b/src/lerobot/policies/pi05_full/configuration_pi05.py @@ -88,6 +88,12 @@ class PI05FullConfig(PreTrainedConfig): # Finetuning settings freeze_vision_encoder: bool = False # Freeze only the vision encoder train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections + knowledge_insulation: bool = True # Enable knowledge insulation in attention (blocks gradients from action to VLM K/V) + + # Loss weights (used when knowledge_insulation is enabled) + loss_weight_flow: float = 1.0 # Weight for flow matching MSE loss (continuous actions) + loss_weight_action_ce: float = 1.0 # Weight for FAST action token cross-entropy loss + loss_weight_subtask_ce: float = 1.0 # Weight for subtask token cross-entropy loss # Optimizer settings: see openpi `AdamW` optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr` diff --git a/src/lerobot/policies/pi05_full/modeling_pi05.py b/src/lerobot/policies/pi05_full/modeling_pi05.py index a13767794..ca66becb7 100644 --- a/src/lerobot/policies/pi05_full/modeling_pi05.py +++ b/src/lerobot/policies/pi05_full/modeling_pi05.py @@ -222,9 +222,84 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) return padded_images -# Define the complete layer computation function for gradient checkpointing +# Define the complete layer computation function for gradient checkpointing (without knowledge insulation) def compute_layer_complete( layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert +): + """Compute a single transformer layer with fused attention across VLM and action expert (no knowledge insulation).""" + models = [paligemma.language_model, gemma_expert.model] + query_states = [] + key_states = [] + value_states = [] + gates = [] + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 + gates.append(gate) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + # Concatenate and process attention + query_states = torch.cat(query_states, dim=2) + key_states = torch.cat(key_states, dim=2) + value_states = torch.cat(value_states, dim=2) + dummy_tensor = torch.zeros( + query_states.shape[0], + query_states.shape[2], + query_states.shape[-1], + device=query_states.device, + dtype=query_states.dtype, + ) + cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + query_states, key_states = modeling_gemma.apply_rotary_pos_emb( + query_states, key_states, cos, sin, unsqueeze_dim=1 + ) + batch_size = query_states.shape[0] + scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling + # Attention computation + att_output, _ = modeling_gemma.eager_attention_forward( + paligemma.language_model.layers[layer_idx].self_attn, + query_states, + key_states, + value_states, + attention_mask, + scaling, + ) + # Get head_dim from the current layer, not from the model + head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim + att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) + # Process layer outputs + outputs_embeds = [] + start_pos = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + end_pos = start_pos + hidden_states.shape[1] + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) + # first residual + out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 + after_first_residual = out_emb.clone() + out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) + # Convert to bfloat16 if the next layer (mlp) uses bfloat16 + if layer.mlp.up_proj.weight.dtype == torch.bfloat16: + out_emb = out_emb.to(dtype=torch.bfloat16) + out_emb = layer.mlp(out_emb) + # second residual + out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 + outputs_embeds.append(out_emb) + start_pos = end_pos + return outputs_embeds + + +# Define the complete layer computation function with knowledge insulation for gradient checkpointing +def compute_layer_complete_knowledge_insulation( + layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert ): """ Compute a single transformer layer with fused attention across VLM and action expert. @@ -439,12 +514,14 @@ class PaliGemmaWithExpertModel( image_size: int = DEFAULT_IMAGE_SIZE, freeze_vision_encoder: bool = False, train_expert_only: bool = False, + knowledge_insulation: bool = True, ): if use_adarms is None: use_adarms = [False, False] super().__init__() self.freeze_vision_encoder = freeze_vision_encoder self.train_expert_only = train_expert_only + self.knowledge_insulation = knowledge_insulation vlm_config_hf = CONFIG_MAPPING["paligemma"]() vlm_config_hf._vocab_size = 257152 # noqa: SLF001 @@ -578,11 +655,16 @@ class PaliGemmaWithExpertModel( and self.training ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training) + # Select the appropriate layer computation function based on knowledge_insulation + layer_compute_fn = ( + compute_layer_complete_knowledge_insulation if self.knowledge_insulation else compute_layer_complete + ) + # Process all layers with gradient checkpointing if enabled for layer_idx in range(num_layers): if use_gradient_checkpointing: inputs_embeds = torch.utils.checkpoint.checkpoint( - compute_layer_complete, + layer_compute_fn, layer_idx, inputs_embeds, attention_mask, @@ -594,7 +676,7 @@ class PaliGemmaWithExpertModel( gemma_expert=self.gemma_expert, ) else: - inputs_embeds = compute_layer_complete( + inputs_embeds = layer_compute_fn( layer_idx, inputs_embeds, attention_mask, @@ -655,6 +737,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` image_size=config.image_resolution[0], freeze_vision_encoder=config.freeze_vision_encoder, train_expert_only=config.train_expert_only, + knowledge_insulation=config.knowledge_insulation, ) self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width) @@ -1113,11 +1196,22 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) flow_loss = F.mse_loss(u_t, v_t, reduction="none") + # Compute weighted total loss + flow_loss_mean = flow_loss.mean() + action_ce_loss_mean = fast_loss.mean() + subtask_ce_loss_mean = subtask_loss.mean() + + total_loss = ( + self.config.loss_weight_flow * flow_loss_mean + + self.config.loss_weight_action_ce * action_ce_loss_mean + + self.config.loss_weight_subtask_ce * subtask_ce_loss_mean + ) + return { - "flow_mse_loss": flow_loss.mean(), - "action_ce_loss": fast_loss.mean(), - "subtask_ce_loss": subtask_loss, - "loss": flow_loss.mean() + subtask_loss.mean() + fast_loss.mean(), # TODO: jadechoghari: check weights + "flow_mse_loss": flow_loss_mean, + "action_ce_loss": action_ce_loss_mean, + "subtask_ce_loss": subtask_ce_loss_mean, + "loss": total_loss, } @torch.no_grad() # see openpi `sample_actions` (slightly adapted) From 0059ca7924518d5c36dc52bf721acf636e3866c0 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Mon, 9 Feb 2026 07:33:12 +0000 Subject: [PATCH 4/4] add cached subtask inference --- .../policies/pi05_full/modeling_pi05.py | 47 ++++++++++++------- .../policies/pi05_full/processor_pi05.py | 29 ++++++------ 2 files changed, 45 insertions(+), 31 deletions(-) diff --git a/src/lerobot/policies/pi05_full/modeling_pi05.py b/src/lerobot/policies/pi05_full/modeling_pi05.py index ca66becb7..7d0f8a0a1 100644 --- a/src/lerobot/policies/pi05_full/modeling_pi05.py +++ b/src/lerobot/policies/pi05_full/modeling_pi05.py @@ -1809,9 +1809,34 @@ class PI05FullPolicy(PreTrainedPolicy): self.eval() + # generate subtask tokens with time-based caching (independent of action queue) + # only regenerate if: no cache, or interval elapsed, or interval is 0 (always regenerate) + current_time = time.time() + interval = self.config.subtask_regeneration_interval + should_regenerate = ( + self._cached_subtask_tokens is None + or self._last_subtask_time is None + or interval <= 0 # 0 means regenerate every call + or (current_time - self._last_subtask_time) >= interval + ) + + if should_regenerate: + images, img_masks = self._preprocess_images(batch) + high_level_task_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] + high_level_task_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + subtask_tokens, subtask_masks = self.model.generate_subtask_tokens( + images, img_masks, high_level_task_tokens, high_level_task_masks, + max_decoding_steps=self.config.tokenizer_max_length + ) + self._cached_subtask_tokens = subtask_tokens + self._cached_subtask_masks = subtask_masks + self._last_subtask_time = current_time + # log and decode the generate subtask tokens + print(f"Generated subtask tokens: {self.model._paligemma_tokenizer.decode(subtask_tokens[0].tolist(), skip_special_tokens=True)}") + # REMOVE + # Action queue logic for n_action_steps > 1 if len(self._action_queue) == 0: - # TODO: jadechoghari, generate subtask tokens here - ideally every 1 second actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] # Transpose to get shape (n_action_steps, batch_size, action_dim) self._action_queue.extend(actions.transpose(0, 1)) @@ -1825,28 +1850,18 @@ class PI05FullPolicy(PreTrainedPolicy): # Prepare inputs images, img_masks = self._preprocess_images(batch) - # tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] - high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_USER_PROMPT_TOKENS}"], batch[f"{OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK}"] + high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] - # Generate subtask tokens with time-based caching - # Only regenerate if: no cache, or interval elapsed, or interval is 0 (always regenerate) - current_time = time.time() - interval = self.config.subtask_regeneration_interval - should_regenerate = ( - self._cached_subtask_tokens is None - or self._last_subtask_time is None - or interval <= 0 # 0 means regenerate every call - or (current_time - self._last_subtask_time) >= interval - ) - - if should_regenerate: + # Use cached subtask tokens (generated in select_action based on time interval) + # If called directly without select_action, generate subtask tokens + if self._cached_subtask_tokens is None: subtask_tokens, subtask_masks = self.model.generate_subtask_tokens( images, img_masks, high_level_task_tokens, high_level_task_masks, max_decoding_steps=self.config.tokenizer_max_length ) self._cached_subtask_tokens = subtask_tokens self._cached_subtask_masks = subtask_masks - self._last_subtask_time = current_time + self._last_subtask_time = time.time() else: subtask_tokens = self._cached_subtask_tokens subtask_masks = self._cached_subtask_masks diff --git a/src/lerobot/policies/pi05_full/processor_pi05.py b/src/lerobot/policies/pi05_full/processor_pi05.py index 80059e9c9..43b643f0b 100644 --- a/src/lerobot/policies/pi05_full/processor_pi05.py +++ b/src/lerobot/policies/pi05_full/processor_pi05.py @@ -54,8 +54,8 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep): """ max_state_dim: int = 32 - user_prompt_key: str = "task" - command_key: str = "subtask" + task_key: str = "task" + subtask_key: str = "subtask" def __call__(self, transition: EnvTransition) -> EnvTransition: transition = transition.copy() @@ -63,12 +63,10 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep): state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE) if state is None: raise ValueError("State is required for PI05") - user_prompts = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.user_prompt_key) + user_prompts = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key) if user_prompts is None: raise ValueError("No user prompts found in complementary data") - commands = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.command_key) - if commands is None: - raise ValueError("No commands found in complementary data") + commands = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.subtask_key) # TODO: check if this necessary state = deepcopy(state) @@ -89,17 +87,18 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep): full_prompt = f"Task: {cleaned_text}, State: {state_str};\n" full_prompts.append(full_prompt) - transition[TransitionKey.COMPLEMENTARY_DATA][self.user_prompt_key] = full_prompts + transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts - # process commands - full_commands = [] - for i, command in enumerate(commands): - cleaned_text = command.strip().replace("_", " ").replace("\n", " ") - cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari) - full_command = f"Subtask: {cleaned_text};\n" - full_commands.append(full_command) + # process commands (optional) + if commands is not None: + full_commands = [] + for i, command in enumerate(commands): + cleaned_text = command.strip().replace("_", " ").replace("\n", " ") + cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari) + full_command = f"Subtask: {cleaned_text};\n" + full_commands.append(full_command) - transition[TransitionKey.COMPLEMENTARY_DATA][self.command_key] = full_commands + transition[TransitionKey.COMPLEMENTARY_DATA][self.subtask_key] = full_commands # note: action tokens will be processed in the ActionTokenizerProcessorStep # Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)