mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
fix(pi0, pi05): stabilize torch.compile and expand test coverage (#3610)
* chore(gr00t): sync with #3606 for fixing gr00t config crash * fix(pi0&pi05): fix graph break caused by deepcopy of past_key_values in sample_actions * fix(pi0&pi05): fix frequent recompile caused by compute_layer_complete * feat(test): add compile test and benchamrk for pi0 and pi05 * feat(test): add comprehensive testing for pi0 and pi05. Including processor, forward, sample action, etc.
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -26,9 +26,14 @@ from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from huggingface_hub.dataclasses import strict
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
|
||||
def strict(cls):
|
||||
return cls
|
||||
|
||||
AutoConfig = None
|
||||
AutoModel = None
|
||||
PretrainedConfig = object
|
||||
@@ -173,19 +178,20 @@ N_COLOR_CHANNELS = 3
|
||||
|
||||
|
||||
# config
|
||||
@strict
|
||||
class GR00TN15Config(PretrainedConfig):
|
||||
model_type = "gr00t_n1_5"
|
||||
|
||||
backbone_cfg: dict
|
||||
action_head_cfg: dict
|
||||
action_horizon: int
|
||||
action_dim: int
|
||||
backbone_cfg: dict[str, Any] | None = None
|
||||
action_head_cfg: dict[str, Any] | None = None
|
||||
action_horizon: int = 0
|
||||
action_dim: int = 0
|
||||
compute_dtype: str = "float32"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
def __post_init__(self, **kwargs):
|
||||
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
|
||||
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
|
||||
# real model
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
DynamicCache = None
|
||||
modeling_gemma = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
@@ -141,6 +142,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
def clone_past_key_values(past_key_values):
|
||||
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
|
||||
return DynamicCache(
|
||||
tuple(
|
||||
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
@@ -227,16 +237,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
@@ -258,15 +265,16 @@ def compute_layer_complete(
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
cos, sin = rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
paligemma_layer = layers[0]
|
||||
scaling = paligemma_layer.self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
paligemma_layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -274,13 +282,13 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma_layer.self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
@@ -488,8 +496,9 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
paligemma_layers = self.paligemma.model.language_model.layers
|
||||
gemma_expert_layers = self.gemma_expert.model.layers
|
||||
rotary_emb = self.paligemma.model.language_model.rotary_emb
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
@@ -499,36 +508,39 @@ class PaliGemmaWithExpertModel(
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
# final norm
|
||||
final_norms = (
|
||||
self.paligemma.model.language_model.norm,
|
||||
self.gemma_expert.model.norm,
|
||||
)
|
||||
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -907,7 +919,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
past_key_values = clone_past_key_values(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
DynamicCache = None
|
||||
modeling_gemma = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
@@ -138,6 +139,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
def clone_past_key_values(past_key_values):
|
||||
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
|
||||
return DynamicCache(
|
||||
tuple(
|
||||
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
@@ -224,16 +234,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
@@ -255,15 +262,16 @@ def compute_layer_complete(
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
cos, sin = rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
paligemma_layer = layers[0]
|
||||
scaling = paligemma_layer.self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
paligemma_layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -271,13 +279,13 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma_layer.self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
@@ -485,8 +493,9 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
paligemma_layers = self.paligemma.model.language_model.layers
|
||||
gemma_expert_layers = self.gemma_expert.model.layers
|
||||
rotary_emb = self.paligemma.model.language_model.rotary_emb
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
@@ -496,36 +505,39 @@ class PaliGemmaWithExpertModel(
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
# final norm
|
||||
final_norms = (
|
||||
self.paligemma.model.language_model.norm,
|
||||
self.gemma_expert.model.norm,
|
||||
)
|
||||
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -880,7 +892,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
past_key_values = clone_past_key_values(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
|
||||
Reference in New Issue
Block a user