Merge branch 'feat/add-pi05' of github.com:huggingface/lerobot into feat/add-pi05

This commit is contained in:
Jade Choghari
2026-02-09 08:34:01 +01:00
4 changed files with 156 additions and 39 deletions

View File

@@ -5,7 +5,9 @@ import lerobot
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.policies.factory import make_pre_post_processors
from lerobot.configs.policies import PreTrainedConfig
dataset = LeRobotDataset(repo_id="local", root="/fsx/jade_choghari/outputs/libero-10-annotate")
# /fsx/jade_choghari/data/libero_10_subtasks_kw_converted
dataset = LeRobotDataset(repo_id="lerobot/libero_10_image_subtask")
dataloader = torch.utils.data.DataLoader(
dataset,
@@ -24,6 +26,7 @@ pre_processor, post_processor = make_pre_post_processors(
pretrained_path="/fsx/jade_choghari/models/pi05-base",
)
batch = next(iter(dataloader))
breakpoint()
batch1 = pre_processor(batch)
breakpoint()
print(batch.keys())

View File

@@ -88,6 +88,12 @@ class PI05FullConfig(PreTrainedConfig):
# Finetuning settings
freeze_vision_encoder: bool = False # Freeze only the vision encoder
train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections
knowledge_insulation: bool = True # Enable knowledge insulation in attention (blocks gradients from action to VLM K/V)
# Loss weights (used when knowledge_insulation is enabled)
loss_weight_flow: float = 1.0 # Weight for flow matching MSE loss (continuous actions)
loss_weight_action_ce: float = 1.0 # Weight for FAST action token cross-entropy loss
loss_weight_subtask_ce: float = 1.0 # Weight for subtask token cross-entropy loss
# Optimizer settings: see openpi `AdamW`
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`

View File

@@ -222,9 +222,84 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
return padded_images
# Define the complete layer computation function for gradient checkpointing
# Define the complete layer computation function for gradient checkpointing (without knowledge insulation)
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
"""Compute a single transformer layer with fused attention across VLM and action expert (no knowledge insulation)."""
models = [paligemma.language_model, gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# Concatenate and process attention
query_states = torch.cat(query_states, dim=2)
key_states = torch.cat(key_states, dim=2)
value_states = torch.cat(value_states, dim=2)
dummy_tensor = torch.zeros(
query_states.shape[0],
query_states.shape[2],
query_states.shape[-1],
device=query_states.device,
dtype=query_states.dtype,
)
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
attention_mask,
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
start_pos = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
end_pos = start_pos + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
after_first_residual = out_emb.clone()
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
# second residual
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
outputs_embeds.append(out_emb)
start_pos = end_pos
return outputs_embeds
# Define the complete layer computation function with knowledge insulation for gradient checkpointing
def compute_layer_complete_knowledge_insulation(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
"""
Compute a single transformer layer with fused attention across VLM and action expert.
@@ -439,12 +514,14 @@ class PaliGemmaWithExpertModel(
image_size: int = DEFAULT_IMAGE_SIZE,
freeze_vision_encoder: bool = False,
train_expert_only: bool = False,
knowledge_insulation: bool = True,
):
if use_adarms is None:
use_adarms = [False, False]
super().__init__()
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
self.knowledge_insulation = knowledge_insulation
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
@@ -578,11 +655,16 @@ class PaliGemmaWithExpertModel(
and self.training
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
# Select the appropriate layer computation function based on knowledge_insulation
layer_compute_fn = (
compute_layer_complete_knowledge_insulation if self.knowledge_insulation else compute_layer_complete
)
# Process all layers with gradient checkpointing if enabled
for layer_idx in range(num_layers):
if use_gradient_checkpointing:
inputs_embeds = torch.utils.checkpoint.checkpoint(
compute_layer_complete,
layer_compute_fn,
layer_idx,
inputs_embeds,
attention_mask,
@@ -594,7 +676,7 @@ class PaliGemmaWithExpertModel(
gemma_expert=self.gemma_expert,
)
else:
inputs_embeds = compute_layer_complete(
inputs_embeds = layer_compute_fn(
layer_idx,
inputs_embeds,
attention_mask,
@@ -655,6 +737,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
image_size=config.image_resolution[0],
freeze_vision_encoder=config.freeze_vision_encoder,
train_expert_only=config.train_expert_only,
knowledge_insulation=config.knowledge_insulation,
)
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
@@ -1113,11 +1196,22 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
flow_loss = F.mse_loss(u_t, v_t, reduction="none")
# Compute weighted total loss
flow_loss_mean = flow_loss.mean()
action_ce_loss_mean = fast_loss.mean()
subtask_ce_loss_mean = subtask_loss.mean()
total_loss = (
self.config.loss_weight_flow * flow_loss_mean
+ self.config.loss_weight_action_ce * action_ce_loss_mean
+ self.config.loss_weight_subtask_ce * subtask_ce_loss_mean
)
return {
"flow_mse_loss": flow_loss.mean(),
"action_ce_loss": fast_loss.mean(),
"subtask_ce_loss": subtask_loss,
"loss": flow_loss.mean() + 0.1 * subtask_loss.mean() + 0.05 * fast_loss.mean(), # TODO: jadechoghari: check weights
"flow_mse_loss": flow_loss_mean,
"action_ce_loss": action_ce_loss_mean,
"subtask_ce_loss": subtask_ce_loss_mean,
"loss": total_loss,
}
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
@@ -1715,9 +1809,34 @@ class PI05FullPolicy(PreTrainedPolicy):
self.eval()
# generate subtask tokens with time-based caching (independent of action queue)
# only regenerate if: no cache, or interval elapsed, or interval is 0 (always regenerate)
current_time = time.time()
interval = self.config.subtask_regeneration_interval
should_regenerate = (
self._cached_subtask_tokens is None
or self._last_subtask_time is None
or interval <= 0 # 0 means regenerate every call
or (current_time - self._last_subtask_time) >= interval
)
if should_regenerate:
images, img_masks = self._preprocess_images(batch)
high_level_task_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
high_level_task_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
subtask_tokens, subtask_masks = self.model.generate_subtask_tokens(
images, img_masks, high_level_task_tokens, high_level_task_masks,
max_decoding_steps=self.config.tokenizer_max_length
)
self._cached_subtask_tokens = subtask_tokens
self._cached_subtask_masks = subtask_masks
self._last_subtask_time = current_time
# log and decode the generate subtask tokens
print(f"Generated subtask tokens: {self.model._paligemma_tokenizer.decode(subtask_tokens[0].tolist(), skip_special_tokens=True)}")
# REMOVE
# Action queue logic for n_action_steps > 1
if len(self._action_queue) == 0:
# TODO: jadechoghari, generate subtask tokens here - ideally every 1 second
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
# Transpose to get shape (n_action_steps, batch_size, action_dim)
self._action_queue.extend(actions.transpose(0, 1))
@@ -1731,28 +1850,18 @@ class PI05FullPolicy(PreTrainedPolicy):
# Prepare inputs
images, img_masks = self._preprocess_images(batch)
# tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_USER_PROMPT_TOKENS}"], batch[f"{OBS_LANGUAGE_USER_PROMPT_ATTENTION_MASK}"]
high_level_task_tokens, high_level_task_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
# Generate subtask tokens with time-based caching
# Only regenerate if: no cache, or interval elapsed, or interval is 0 (always regenerate)
current_time = time.time()
interval = self.config.subtask_regeneration_interval
should_regenerate = (
self._cached_subtask_tokens is None
or self._last_subtask_time is None
or interval <= 0 # 0 means regenerate every call
or (current_time - self._last_subtask_time) >= interval
)
if should_regenerate:
# Use cached subtask tokens (generated in select_action based on time interval)
# If called directly without select_action, generate subtask tokens
if self._cached_subtask_tokens is None:
subtask_tokens, subtask_masks = self.model.generate_subtask_tokens(
images, img_masks, high_level_task_tokens, high_level_task_masks,
max_decoding_steps=self.config.tokenizer_max_length
)
self._cached_subtask_tokens = subtask_tokens
self._cached_subtask_masks = subtask_masks
self._last_subtask_time = current_time
self._last_subtask_time = time.time()
else:
subtask_tokens = self._cached_subtask_tokens
subtask_masks = self._cached_subtask_masks

View File

@@ -54,8 +54,8 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep):
"""
max_state_dim: int = 32
user_prompt_key: str = "task"
command_key: str = "subtask"
task_key: str = "task"
subtask_key: str = "subtask"
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = transition.copy()
@@ -63,12 +63,10 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep):
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
if state is None:
raise ValueError("State is required for PI05")
user_prompts = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.user_prompt_key)
user_prompts = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
if user_prompts is None:
raise ValueError("No user prompts found in complementary data")
commands = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.command_key)
if commands is None:
raise ValueError("No commands found in complementary data")
commands = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.subtask_key)
# TODO: check if this necessary
state = deepcopy(state)
@@ -89,17 +87,18 @@ class Pi05FullPrepareStateTokenizerProcessorStep(ProcessorStep):
full_prompt = f"Task: {cleaned_text}, State: {state_str};\n"
full_prompts.append(full_prompt)
transition[TransitionKey.COMPLEMENTARY_DATA][self.user_prompt_key] = full_prompts
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
# process commands
full_commands = []
for i, command in enumerate(commands):
cleaned_text = command.strip().replace("_", " ").replace("\n", " ")
cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari)
full_command = f"Subtask: {cleaned_text};\n"
full_commands.append(full_command)
# process commands (optional)
if commands is not None:
full_commands = []
for i, command in enumerate(commands):
cleaned_text = command.strip().replace("_", " ").replace("\n", " ")
cleaned_text = cleaned_text.lower() # all lowercase # NOTE: added by (jadechoghari)
full_command = f"Subtask: {cleaned_text};\n"
full_commands.append(full_command)
transition[TransitionKey.COMPLEMENTARY_DATA][self.command_key] = full_commands
transition[TransitionKey.COMPLEMENTARY_DATA][self.subtask_key] = full_commands
# note: action tokens will be processed in the ActionTokenizerProcessorStep
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)