mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 04:11:24 +00:00
fix(pi0, pi05): stabilize torch.compile and expand test coverage (#3610)
* chore(gr00t): sync with #3606 for fixing gr00t config crash * fix(pi0&pi05): fix graph break caused by deepcopy of past_key_values in sample_actions * fix(pi0&pi05): fix frequent recompile caused by compute_layer_complete * feat(test): add compile test and benchamrk for pi0 and pi05 * feat(test): add comprehensive testing for pi0 and pi05. Including processor, forward, sample action, etc.
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -26,9 +26,14 @@ from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from huggingface_hub.dataclasses import strict
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
|
||||
def strict(cls):
|
||||
return cls
|
||||
|
||||
AutoConfig = None
|
||||
AutoModel = None
|
||||
PretrainedConfig = object
|
||||
@@ -173,19 +178,20 @@ N_COLOR_CHANNELS = 3
|
||||
|
||||
|
||||
# config
|
||||
@strict
|
||||
class GR00TN15Config(PretrainedConfig):
|
||||
model_type = "gr00t_n1_5"
|
||||
|
||||
backbone_cfg: dict
|
||||
action_head_cfg: dict
|
||||
action_horizon: int
|
||||
action_dim: int
|
||||
backbone_cfg: dict[str, Any] | None = None
|
||||
action_head_cfg: dict[str, Any] | None = None
|
||||
action_horizon: int = 0
|
||||
action_dim: int = 0
|
||||
compute_dtype: str = "float32"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
def __post_init__(self, **kwargs):
|
||||
self.backbone_cfg = {} if self.backbone_cfg is None else self.backbone_cfg
|
||||
self.action_head_cfg = {} if self.action_head_cfg is None else self.action_head_cfg
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
|
||||
# real model
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
DynamicCache = None
|
||||
modeling_gemma = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
@@ -141,6 +142,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
def clone_past_key_values(past_key_values):
|
||||
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
|
||||
return DynamicCache(
|
||||
tuple(
|
||||
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
@@ -227,16 +237,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
@@ -258,15 +265,16 @@ def compute_layer_complete(
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
cos, sin = rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
paligemma_layer = layers[0]
|
||||
scaling = paligemma_layer.self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
paligemma_layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -274,13 +282,13 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma_layer.self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
@@ -488,8 +496,9 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
paligemma_layers = self.paligemma.model.language_model.layers
|
||||
gemma_expert_layers = self.gemma_expert.model.layers
|
||||
rotary_emb = self.paligemma.model.language_model.rotary_emb
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
@@ -499,36 +508,39 @@ class PaliGemmaWithExpertModel(
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
# final norm
|
||||
final_norms = (
|
||||
self.paligemma.model.language_model.norm,
|
||||
self.gemma_expert.model.norm,
|
||||
)
|
||||
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -907,7 +919,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
past_key_values = clone_past_key_values(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
@@ -30,6 +29,7 @@ from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
@@ -41,6 +41,7 @@ if TYPE_CHECKING or _transformers_available:
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
DynamicCache = None
|
||||
modeling_gemma = None
|
||||
PiGemmaForCausalLM = None
|
||||
_gated_residual = None
|
||||
@@ -138,6 +139,15 @@ def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
def clone_past_key_values(past_key_values):
|
||||
"""Clone the DynamicCache returned by prefix prefill for compiled denoising."""
|
||||
return DynamicCache(
|
||||
tuple(
|
||||
(keys.clone(), values.clone(), sliding_window) for keys, values, sliding_window in past_key_values
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
@@ -224,16 +234,13 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
|
||||
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
|
||||
):
|
||||
models = [paligemma.model.language_model, gemma_expert.model]
|
||||
def compute_layer_complete(inputs_embeds, attention_mask, position_ids, adarms_cond, layers, rotary_emb):
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
@@ -255,15 +262,16 @@ def compute_layer_complete(
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
cos, sin = rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
paligemma_layer = layers[0]
|
||||
scaling = paligemma_layer.self_attn.scaling
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
paligemma_layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -271,13 +279,13 @@ def compute_layer_complete(
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
head_dim = paligemma_layer.self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
layer = layers[i]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
@@ -485,8 +493,9 @@ class PaliGemmaWithExpertModel(
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
paligemma_layers = self.paligemma.model.language_model.layers
|
||||
gemma_expert_layers = self.gemma_expert.model.layers
|
||||
rotary_emb = self.paligemma.model.language_model.rotary_emb
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
@@ -496,36 +505,39 @@ class PaliGemmaWithExpertModel(
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
for layers in zip(paligemma_layers, gemma_expert_layers, strict=True):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
paligemma=self.paligemma,
|
||||
gemma_expert=self.gemma_expert,
|
||||
layers=layers,
|
||||
rotary_emb=rotary_emb,
|
||||
)
|
||||
|
||||
# final norm
|
||||
final_norms = (
|
||||
self.paligemma.model.language_model.norm,
|
||||
self.gemma_expert.model.norm,
|
||||
)
|
||||
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
out_emb, _ = layernorm_forward(final_norms[i], hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
@@ -880,7 +892,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
past_key_values = clone_past_key_values(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
|
||||
1
tests/policies/pi0_pi05/openpi_pytorch/__init__.py
Normal file
1
tests/policies/pi0_pi05/openpi_pytorch/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Lightweight vendored OpenPI PyTorch modules for PI0/PI05 parity tests."""
|
||||
22
tests/policies/pi0_pi05/openpi_pytorch/gemma.py
Normal file
22
tests/policies/pi0_pi05/openpi_pytorch/gemma.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
width: int
|
||||
depth: int
|
||||
mlp_dim: int
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
head_dim: int
|
||||
|
||||
|
||||
def get_config(variant: str) -> Config:
|
||||
"""Return the Gemma shape config needed by the OpenPI PyTorch model."""
|
||||
if variant == "dummy":
|
||||
return Config(width=64, depth=4, mlp_dim=128, num_heads=8, num_kv_heads=1, head_dim=16)
|
||||
if variant == "gemma_300m":
|
||||
return Config(width=1024, depth=18, mlp_dim=4096, num_heads=8, num_kv_heads=1, head_dim=256)
|
||||
if variant == "gemma_2b":
|
||||
return Config(width=2048, depth=18, mlp_dim=16_384, num_heads=8, num_kv_heads=1, head_dim=256)
|
||||
raise ValueError(f"Unknown variant: {variant}")
|
||||
300
tests/policies/pi0_pi05/openpi_pytorch/gemma_pytorch.py
Normal file
300
tests/policies/pi0_pi05/openpi_pytorch/gemma_pytorch.py
Normal file
@@ -0,0 +1,300 @@
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaForCausalLM,
|
||||
_gated_residual,
|
||||
layernorm_forward,
|
||||
)
|
||||
|
||||
|
||||
class PaliGemmaWithExpertModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vlm_config,
|
||||
action_expert_config,
|
||||
use_adarms=None,
|
||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
||||
):
|
||||
if use_adarms is None:
|
||||
use_adarms = [False, False]
|
||||
super().__init__()
|
||||
|
||||
vlm_config_hf = CONFIG_MAPPING["paligemma"]()
|
||||
vlm_config_hf._vocab_size = 257152 # noqa: SLF001
|
||||
vlm_config_hf.image_token_index = 257152
|
||||
vlm_config_hf.text_config.hidden_size = vlm_config.width
|
||||
vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim
|
||||
vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads
|
||||
vlm_config_hf.text_config.head_dim = vlm_config.head_dim
|
||||
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
|
||||
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
|
||||
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
|
||||
vlm_config_hf.text_config.dtype = "float32"
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
vlm_config_hf.vision_config.dtype = "float32"
|
||||
|
||||
action_expert_config_hf = CONFIG_MAPPING["gemma"](
|
||||
head_dim=action_expert_config.head_dim,
|
||||
hidden_size=action_expert_config.width,
|
||||
intermediate_size=action_expert_config.mlp_dim,
|
||||
num_attention_heads=action_expert_config.num_heads,
|
||||
num_hidden_layers=action_expert_config.depth,
|
||||
num_key_value_heads=action_expert_config.num_kv_heads,
|
||||
vocab_size=257152,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
dtype="float32",
|
||||
use_adarms=use_adarms[1],
|
||||
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
|
||||
)
|
||||
|
||||
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
|
||||
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
|
||||
self.gemma_expert.model.embed_tokens = None
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
|
||||
def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
|
||||
if precision == "bfloat16":
|
||||
self.to(dtype=torch.bfloat16)
|
||||
elif precision == "float32":
|
||||
self.to(dtype=torch.float32)
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
params_to_keep_float32 = [
|
||||
"vision_tower",
|
||||
"multi_modal_projector",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
"model.norm",
|
||||
]
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if any(selector in name for selector in params_to_keep_float32):
|
||||
param.data = param.data.to(dtype=torch.float32)
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
# Transformers 5.4 no longer divides PaliGemma image features by sqrt(hidden_size),
|
||||
# so the upstream helper now matches OpenPI's patched PaliGemma image-scale semantics.
|
||||
# See https://github.com/huggingface/transformers/pull/44432/changes#diff-c916907e7e52ac85ee1a1527560eae4656cd6c76141ceb1fe3da61bd5f697d2a
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | None = None,
|
||||
inputs_embeds: list[torch.FloatTensor] | None = None,
|
||||
use_cache: bool | None = None,
|
||||
adarms_cond: list[torch.Tensor] | None = None,
|
||||
):
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
if inputs_embeds[1] is None:
|
||||
prefix_output = self.paligemma.model.language_model.forward(
|
||||
inputs_embeds=inputs_embeds[0],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
adarms_cond=adarms_cond[0] if adarms_cond is not None else None,
|
||||
)
|
||||
prefix_past_key_values = prefix_output.past_key_values
|
||||
prefix_output = prefix_output.last_hidden_state
|
||||
suffix_output = None
|
||||
elif inputs_embeds[0] is None:
|
||||
suffix_output = self.gemma_expert.model.forward(
|
||||
inputs_embeds=inputs_embeds[1],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
|
||||
)
|
||||
suffix_output = suffix_output.last_hidden_state
|
||||
prefix_output = None
|
||||
prefix_past_key_values = None
|
||||
else:
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
|
||||
# Check if gradient checkpointing is enabled for any of the models
|
||||
use_gradient_checkpointing = (
|
||||
hasattr(self.gemma_expert.model, "gradient_checkpointing")
|
||||
and self.gemma_expert.model.gradient_checkpointing
|
||||
and self.training
|
||||
) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training)
|
||||
|
||||
# Force enable gradient checkpointing if we're in training mode and the model supports it
|
||||
if self.training and hasattr(self.gemma_expert.model, "gradient_checkpointing"):
|
||||
if not self.gemma_expert.model.gradient_checkpointing:
|
||||
print("Forcing gradient checkpointing to be enabled for Gemma expert model")
|
||||
self.gemma_expert.model.gradient_checkpointing = True
|
||||
use_gradient_checkpointing = True
|
||||
|
||||
# Debug gradient checkpointing status
|
||||
if hasattr(self, "_debug_gc_printed") and not self._debug_gc_printed:
|
||||
print(f"Gemma expert model gradient checkpointing: {use_gradient_checkpointing}")
|
||||
print(f"Model training mode: {self.training}")
|
||||
print(
|
||||
f"Gemma expert model has gradient_checkpointing attr: {hasattr(self.gemma_expert.model, 'gradient_checkpointing')}"
|
||||
)
|
||||
if hasattr(self.gemma_expert.model, "gradient_checkpointing"):
|
||||
print(
|
||||
f"Gemma expert model gradient_checkpointing value: {self.gemma_expert.model.gradient_checkpointing}"
|
||||
)
|
||||
self._debug_gc_printed = True
|
||||
|
||||
# Define the complete layer computation function for gradient checkpointing
|
||||
def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
gates = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layernorm_forward(
|
||||
layer.input_layernorm, hidden_states, adarms_cond[i]
|
||||
)
|
||||
gates.append(gate)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
|
||||
# Concatenate and process attention
|
||||
query_states = torch.cat(query_states, dim=2)
|
||||
key_states = torch.cat(key_states, dim=2)
|
||||
value_states = torch.cat(value_states, dim=2)
|
||||
|
||||
dummy_tensor = torch.zeros(
|
||||
query_states.shape[0],
|
||||
query_states.shape[2],
|
||||
query_states.shape[-1],
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = self.paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = self.paligemma.model.language_model.layers[layer_idx].self_attn.scaling
|
||||
|
||||
# Attention computation
|
||||
att_output, _ = modeling_gemma.eager_attention_forward(
|
||||
self.paligemma.model.language_model.layers[layer_idx].self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
scaling,
|
||||
)
|
||||
# Get head_dim from the current layer, not from the model
|
||||
head_dim = self.paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
|
||||
# Process layer outputs
|
||||
outputs_embeds = []
|
||||
start_pos = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
end_pos = start_pos + hidden_states.shape[1]
|
||||
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
|
||||
|
||||
# first residual
|
||||
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
|
||||
after_first_residual = out_emb.clone()
|
||||
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
|
||||
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
|
||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
out_emb = layer.mlp(out_emb)
|
||||
# second residual
|
||||
out_emb = _gated_residual(after_first_residual, out_emb, gate)
|
||||
outputs_embeds.append(out_emb)
|
||||
start_pos = end_pos
|
||||
|
||||
return outputs_embeds
|
||||
|
||||
# Process all layers with gradient checkpointing if enabled
|
||||
for layer_idx in range(num_layers):
|
||||
if use_gradient_checkpointing:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_layer_complete,
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = compute_layer_complete(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond
|
||||
)
|
||||
|
||||
# Old code removed - now using compute_layer_complete function above
|
||||
|
||||
# final norm
|
||||
# Define final norm computation function for gradient checkpointing
|
||||
def compute_final_norms(inputs_embeds, adarms_cond):
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return outputs_embeds
|
||||
|
||||
# Apply gradient checkpointing to final norm if enabled
|
||||
if use_gradient_checkpointing:
|
||||
outputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
compute_final_norms,
|
||||
inputs_embeds,
|
||||
adarms_cond,
|
||||
use_reentrant=False,
|
||||
preserve_rng_state=False,
|
||||
)
|
||||
else:
|
||||
outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond)
|
||||
|
||||
prefix_output = outputs_embeds[0]
|
||||
suffix_output = outputs_embeds[1]
|
||||
prefix_past_key_values = None
|
||||
|
||||
return [prefix_output, suffix_output], prefix_past_key_values
|
||||
79
tests/policies/pi0_pi05/openpi_pytorch/image_tools.py
Normal file
79
tests/policies/pi0_pi05/openpi_pytorch/image_tools.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
|
||||
|
||||
def resize_with_pad_torch(
|
||||
images: torch.Tensor,
|
||||
height: int,
|
||||
width: int,
|
||||
mode: str = "bilinear",
|
||||
) -> torch.Tensor:
|
||||
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion
|
||||
by padding with black. If the image is float32, it must be in the range [-1, 1].
|
||||
|
||||
Args:
|
||||
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w]
|
||||
height: Target height
|
||||
width: Target width
|
||||
mode: Interpolation mode ('bilinear', 'nearest', etc.)
|
||||
|
||||
Returns:
|
||||
Resized and padded tensor with same shape format as input
|
||||
"""
|
||||
# Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w]
|
||||
if images.shape[-1] <= 4: # Assume channels-last format
|
||||
channels_last = True
|
||||
# Convert to channels-first for torch operations
|
||||
if images.dim() == 3:
|
||||
images = images.unsqueeze(0) # Add batch dimension
|
||||
images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w]
|
||||
else:
|
||||
channels_last = False
|
||||
if images.dim() == 3:
|
||||
images = images.unsqueeze(0) # Add batch dimension
|
||||
|
||||
batch_size, channels, cur_height, cur_width = images.shape
|
||||
|
||||
# Calculate resize ratio
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
|
||||
# Resize
|
||||
resized_images = F.interpolate(
|
||||
images,
|
||||
size=(resized_height, resized_width),
|
||||
mode=mode,
|
||||
align_corners=False if mode == "bilinear" else None,
|
||||
)
|
||||
|
||||
# Handle dtype-specific clipping
|
||||
if images.dtype == torch.uint8:
|
||||
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
|
||||
elif images.dtype == torch.float32:
|
||||
resized_images = resized_images.clamp(-1.0, 1.0)
|
||||
else:
|
||||
raise ValueError(f"Unsupported image dtype: {images.dtype}")
|
||||
|
||||
# Calculate padding
|
||||
pad_h0, remainder_h = divmod(height - resized_height, 2)
|
||||
pad_h1 = pad_h0 + remainder_h
|
||||
pad_w0, remainder_w = divmod(width - resized_width, 2)
|
||||
pad_w1 = pad_w0 + remainder_w
|
||||
|
||||
# Pad
|
||||
constant_value = 0 if images.dtype == torch.uint8 else -1.0
|
||||
padded_images = F.pad(
|
||||
resized_images,
|
||||
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
|
||||
mode="constant",
|
||||
value=constant_value,
|
||||
)
|
||||
|
||||
# Convert back to original format if needed
|
||||
if channels_last:
|
||||
padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
if batch_size == 1 and images.shape[0] == 1:
|
||||
padded_images = padded_images.squeeze(0) # Remove batch dimension if it was added
|
||||
|
||||
return padded_images
|
||||
471
tests/policies/pi0_pi05/openpi_pytorch/pi0_pytorch.py
Normal file
471
tests/policies/pi0_pi05/openpi_pytorch/pi0_pytorch.py
Normal file
@@ -0,0 +1,471 @@
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
import tests.policies.pi0_pi05.openpi_pytorch.gemma as _gemma
|
||||
from tests.policies.pi0_pi05.openpi_pytorch import preprocessing_pytorch as _preprocessing
|
||||
from tests.policies.pi0_pi05.openpi_pytorch.gemma_pytorch import PaliGemmaWithExpertModel
|
||||
|
||||
|
||||
def get_safe_dtype(target_dtype, device_type):
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "cpu":
|
||||
# CPU doesn't support bfloat16, use float32 instead
|
||||
if target_dtype == torch.bfloat16:
|
||||
return torch.float32
|
||||
if target_dtype == torch.float64:
|
||||
return torch.float64
|
||||
return target_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device):
|
||||
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
|
||||
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
|
||||
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||
return dist.sample((bsize,))
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks):
|
||||
"""Copied from big_vision.
|
||||
|
||||
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
||||
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
||||
setup several types of attention, for example:
|
||||
|
||||
[[1 1 1 1 1 1]]: pure causal attention.
|
||||
|
||||
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
||||
themselves and the last 3 tokens have a causal attention. The first
|
||||
entry could also be a 1 without changing behaviour.
|
||||
|
||||
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
||||
block can attend all previous blocks and all tokens on the same block.
|
||||
|
||||
Args:
|
||||
input_mask: bool[B, N] true if its part of the input, false if padding.
|
||||
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
||||
it and 0 where it shares the same attention mask as the previous token.
|
||||
"""
|
||||
if att_masks.ndim != 2:
|
||||
raise ValueError(att_masks.ndim)
|
||||
if pad_masks.ndim != 2:
|
||||
raise ValueError(pad_masks.ndim)
|
||||
|
||||
cumsum = torch.cumsum(att_masks, dim=1)
|
||||
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||
return att_2d_masks & pad_2d_masks
|
||||
|
||||
|
||||
class PI0Pytorch(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.pi05 = config.pi05
|
||||
|
||||
paligemma_config = _gemma.get_config(config.paligemma_variant)
|
||||
action_expert_config = _gemma.get_config(config.action_expert_variant)
|
||||
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
||||
paligemma_config,
|
||||
action_expert_config,
|
||||
use_adarms=[False, True] if self.pi05 else [False, False],
|
||||
precision=config.dtype,
|
||||
)
|
||||
|
||||
self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width)
|
||||
self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim)
|
||||
|
||||
if self.pi05:
|
||||
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
else:
|
||||
self.state_proj = nn.Linear(config.action_dim, action_expert_config.width)
|
||||
self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
|
||||
self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
if config.pytorch_compile_mode is not None:
|
||||
self.sample_actions = torch.compile(self.sample_actions, mode=config.pytorch_compile_mode)
|
||||
|
||||
# Initialize gradient checkpointing flag
|
||||
self.gradient_checkpointing_enabled = False
|
||||
|
||||
# The upstream OpenPI module verifies a site-package Transformers patch here.
|
||||
# This vendored test copy instead routes through LeRobot's local PiGemma compatibility layer.
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
|
||||
|
||||
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""Disable gradient checkpointing."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
|
||||
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
def is_gradient_checkpointing_enabled(self):
|
||||
"""Check if gradient checkpointing is enabled."""
|
||||
return self.gradient_checkpointing_enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def _prepare_attention_masks_4d(self, att_2d_masks):
|
||||
"""Helper method to prepare 4D attention masks for transformer."""
|
||||
att_2d_masks_4d = att_2d_masks[:, None, :, :]
|
||||
return torch.where(att_2d_masks_4d, 0.0, -2.3819763e38)
|
||||
|
||||
def _preprocess_observation(self, observation, *, train=True):
|
||||
"""Helper method to preprocess observation."""
|
||||
observation = _preprocessing.preprocess_observation_pytorch(observation, train=train)
|
||||
return (
|
||||
list(observation.images.values()),
|
||||
list(observation.image_masks.values()),
|
||||
observation.tokenized_prompt,
|
||||
observation.tokenized_prompt_mask,
|
||||
observation.state,
|
||||
)
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
return torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
||||
time = time_beta * 0.999 + 0.001
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def embed_prefix(
|
||||
self, images, img_masks, lang_tokens, lang_masks
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
||||
for PaliGemma transformer processing.
|
||||
"""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
# Process images
|
||||
for img, img_mask in zip(images, img_masks, strict=True):
|
||||
|
||||
def image_embed_func(img):
|
||||
return self.paligemma_with_expert.embed_image(img)
|
||||
|
||||
img_emb = self._apply_checkpoint(image_embed_func, img)
|
||||
|
||||
bsize, num_img_embs = img_emb.shape[:2]
|
||||
|
||||
embs.append(img_emb)
|
||||
pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs))
|
||||
|
||||
# Create attention masks so that image tokens attend to each other
|
||||
att_masks += [0] * num_img_embs
|
||||
|
||||
# Process language tokens
|
||||
def lang_embed_func(lang_tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
||||
# Transformers > 5.4 scales Gemma token embeddings inside embed_tokens, matching
|
||||
# OpenPI's former explicit sqrt(hidden_size) multiply without applying it twice.
|
||||
# See https://github.com/huggingface/transformers/pull/44432/changes#diff-5f76eac6f18f4b491521314c318a9692318feb4d19228e9576cce7bde4240834
|
||||
return lang_emb
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
||||
|
||||
embs.append(lang_emb)
|
||||
pad_masks.append(lang_masks)
|
||||
|
||||
# full attention between image and language inputs
|
||||
num_lang_embs = lang_emb.shape[1]
|
||||
att_masks += [0] * num_lang_embs
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||
|
||||
# Get batch size from the first dimension of the concatenated tensors
|
||||
bsize = pad_masks.shape[0]
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def embed_suffix(self, state, noisy_actions, timestep):
|
||||
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
if not self.pi05:
|
||||
if self.state_proj.weight.dtype == torch.float32:
|
||||
state = state.to(torch.float32)
|
||||
|
||||
# Embed state
|
||||
def state_proj_func(state):
|
||||
return self.state_proj(state)
|
||||
|
||||
state_emb = self._apply_checkpoint(state_proj_func, state)
|
||||
|
||||
embs.append(state_emb[:, None, :])
|
||||
bsize = state_emb.shape[0]
|
||||
device = state_emb.device
|
||||
|
||||
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
|
||||
pad_masks.append(state_mask)
|
||||
|
||||
# Set attention masks so that image and language inputs do not attend to state or actions
|
||||
att_masks += [1]
|
||||
|
||||
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = create_sinusoidal_pos_embedding(
|
||||
timestep,
|
||||
self.action_in_proj.out_features,
|
||||
min_period=4e-3,
|
||||
max_period=4.0,
|
||||
device=timestep.device,
|
||||
)
|
||||
time_emb = time_emb.type(dtype=timestep.dtype)
|
||||
|
||||
# Fuse timestep + action information using an MLP
|
||||
def action_proj_func(noisy_actions):
|
||||
return self.action_in_proj(noisy_actions)
|
||||
|
||||
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
|
||||
|
||||
if not self.pi05:
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
# Apply MLP layers
|
||||
def mlp_func(action_time_emb):
|
||||
x = self.action_time_mlp_in(action_time_emb)
|
||||
x = F.silu(x) # swish == silu
|
||||
return self.action_time_mlp_out(x)
|
||||
|
||||
action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb)
|
||||
adarms_cond = None
|
||||
else:
|
||||
# time MLP (for adaRMS)
|
||||
def time_mlp_func(time_emb):
|
||||
x = self.time_mlp_in(time_emb)
|
||||
x = F.silu(x) # swish == silu
|
||||
x = self.time_mlp_out(x)
|
||||
return F.silu(x)
|
||||
|
||||
time_emb = self._apply_checkpoint(time_mlp_func, time_emb)
|
||||
action_time_emb = action_emb
|
||||
adarms_cond = time_emb
|
||||
|
||||
# Add to input tokens
|
||||
embs.append(action_time_emb)
|
||||
|
||||
bsize, action_time_dim = action_time_emb.shape[:2]
|
||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
|
||||
pad_masks.append(action_time_mask)
|
||||
|
||||
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
||||
att_masks += [1] + ([0] * (self.config.action_horizon - 1))
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
|
||||
return embs, pad_masks, att_masks, adarms_cond
|
||||
|
||||
def forward(self, observation, actions, noise=None, time=None) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(
|
||||
observation, train=True
|
||||
)
|
||||
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
|
||||
if (
|
||||
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
|
||||
== torch.bfloat16
|
||||
):
|
||||
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
|
||||
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
|
||||
# Prepare attention masks
|
||||
att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks)
|
||||
|
||||
# Apply gradient checkpointing if enabled
|
||||
def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond):
|
||||
(_, suffix_out), _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, suffix_embs],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
return suffix_out
|
||||
|
||||
suffix_out = self._apply_checkpoint(
|
||||
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
||||
)
|
||||
|
||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
|
||||
# Apply gradient checkpointing to final action projection if enabled
|
||||
def action_out_proj_func(suffix_out):
|
||||
return self.action_out_proj(suffix_out)
|
||||
|
||||
v_t = self._apply_checkpoint(action_out_proj_func, suffix_out)
|
||||
|
||||
return F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_actions(self, device, observation, noise=None, num_steps=10) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = observation.state.shape[0]
|
||||
if noise is None:
|
||||
actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(
|
||||
observation, train=False
|
||||
)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
|
||||
# Compute image and language key value cache
|
||||
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
|
||||
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
_, past_key_values = self.paligemma_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks_4d,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
|
||||
# Euler step - use new tensor assignment instead of in-place operation
|
||||
x_t = x_t + dt * v_t
|
||||
time += dt
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
self,
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
timestep,
|
||||
):
|
||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep)
|
||||
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
batch_size = prefix_pad_masks.shape[0]
|
||||
prefix_len = prefix_pad_masks.shape[1]
|
||||
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
||||
|
||||
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
|
||||
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
||||
|
||||
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
||||
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
# Prepare attention masks
|
||||
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
|
||||
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
|
||||
|
||||
past_key_values = copy.deepcopy(past_key_values)
|
||||
outputs_embeds, _ = self.paligemma_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks_4d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=[None, suffix_embs],
|
||||
use_cache=False,
|
||||
adarms_cond=[None, adarms_cond],
|
||||
)
|
||||
|
||||
suffix_out = outputs_embeds[1]
|
||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
return self.action_out_proj(suffix_out)
|
||||
179
tests/policies/pi0_pi05/openpi_pytorch/preprocessing_pytorch.py
Normal file
179
tests/policies/pi0_pi05/openpi_pytorch/preprocessing_pytorch.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from tests.policies.pi0_pi05.openpi_pytorch import image_tools
|
||||
|
||||
logger = logging.getLogger("openpi")
|
||||
|
||||
# Constants moved from model.py
|
||||
IMAGE_KEYS = (
|
||||
"base_0_rgb",
|
||||
"left_wrist_0_rgb",
|
||||
"right_wrist_0_rgb",
|
||||
)
|
||||
|
||||
IMAGE_RESOLUTION = (224, 224)
|
||||
|
||||
|
||||
def preprocess_observation_pytorch(
|
||||
observation,
|
||||
*,
|
||||
train: bool = False,
|
||||
image_keys: Sequence[str] = IMAGE_KEYS,
|
||||
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
|
||||
):
|
||||
"""Torch.compile-compatible version of preprocess_observation_pytorch with simplified type annotations.
|
||||
|
||||
This function avoids complex type annotations that can cause torch.compile issues.
|
||||
"""
|
||||
if not set(image_keys).issubset(observation.images):
|
||||
raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
|
||||
|
||||
batch_shape = observation.state.shape[:-1]
|
||||
|
||||
out_images = {}
|
||||
for key in image_keys:
|
||||
image = observation.images[key]
|
||||
|
||||
# TODO: This is a hack to handle both [B, C, H, W] and [B, H, W, C] formats
|
||||
# Handle both [B, C, H, W] and [B, H, W, C] formats
|
||||
is_channels_first = image.shape[1] == 3 # Check if channels are in dimension 1
|
||||
|
||||
if is_channels_first:
|
||||
# Convert [B, C, H, W] to [B, H, W, C] for processing
|
||||
image = image.permute(0, 2, 3, 1)
|
||||
|
||||
if image.shape[1:3] != image_resolution:
|
||||
logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
|
||||
image = image_tools.resize_with_pad_torch(image, *image_resolution)
|
||||
|
||||
if train:
|
||||
# Convert from [-1, 1] to [0, 1] for PyTorch augmentations
|
||||
image = image / 2.0 + 0.5
|
||||
|
||||
# Apply PyTorch-based augmentations
|
||||
if "wrist" not in key:
|
||||
# Geometric augmentations for non-wrist cameras
|
||||
height, width = image.shape[1:3]
|
||||
|
||||
# Random crop and resize
|
||||
crop_height = int(height * 0.95)
|
||||
crop_width = int(width * 0.95)
|
||||
|
||||
# Random crop
|
||||
max_h = height - crop_height
|
||||
max_w = width - crop_width
|
||||
if max_h > 0 and max_w > 0:
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
start_h = torch.randint(0, max_h + 1, (1,), device=image.device)
|
||||
start_w = torch.randint(0, max_w + 1, (1,), device=image.device)
|
||||
image = image[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
|
||||
|
||||
# Resize back to original size
|
||||
image = torch.nn.functional.interpolate(
|
||||
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
||||
size=(height, width),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
|
||||
# Random rotation (small angles)
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
angle = torch.rand(1, device=image.device) * 10 - 5 # Random angle between -5 and 5 degrees
|
||||
if torch.abs(angle) > 0.1: # Only rotate if angle is significant
|
||||
# Convert to radians
|
||||
angle_rad = angle * torch.pi / 180.0
|
||||
|
||||
# Create rotation matrix
|
||||
cos_a = torch.cos(angle_rad)
|
||||
sin_a = torch.sin(angle_rad)
|
||||
|
||||
# Apply rotation using grid_sample
|
||||
grid_x = torch.linspace(-1, 1, width, device=image.device)
|
||||
grid_y = torch.linspace(-1, 1, height, device=image.device)
|
||||
|
||||
# Create meshgrid
|
||||
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
|
||||
|
||||
# Expand to batch dimension
|
||||
grid_x = grid_x.unsqueeze(0).expand(image.shape[0], -1, -1)
|
||||
grid_y = grid_y.unsqueeze(0).expand(image.shape[0], -1, -1)
|
||||
|
||||
# Apply rotation transformation
|
||||
grid_x_rot = grid_x * cos_a - grid_y * sin_a
|
||||
grid_y_rot = grid_x * sin_a + grid_y * cos_a
|
||||
|
||||
# Stack and reshape for grid_sample
|
||||
grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)
|
||||
|
||||
image = torch.nn.functional.grid_sample(
|
||||
image.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
||||
grid,
|
||||
mode="bilinear",
|
||||
padding_mode="zeros",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
|
||||
# Color augmentations for all cameras
|
||||
# Random brightness
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
brightness_factor = (
|
||||
0.7 + torch.rand(1, device=image.device) * 0.6
|
||||
) # Random factor between 0.7 and 1.3
|
||||
image = image * brightness_factor
|
||||
|
||||
# Random contrast
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
contrast_factor = (
|
||||
0.6 + torch.rand(1, device=image.device) * 0.8
|
||||
) # Random factor between 0.6 and 1.4
|
||||
mean = image.mean(dim=[1, 2, 3], keepdim=True)
|
||||
image = (image - mean) * contrast_factor + mean
|
||||
|
||||
# Random saturation (convert to HSV, modify S, convert back)
|
||||
# For simplicity, we'll just apply a random scaling to the color channels
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
saturation_factor = (
|
||||
0.5 + torch.rand(1, device=image.device) * 1.0
|
||||
) # Random factor between 0.5 and 1.5
|
||||
gray = image.mean(dim=-1, keepdim=True)
|
||||
image = gray + (image - gray) * saturation_factor
|
||||
|
||||
# Clamp values to [0, 1]
|
||||
image = torch.clamp(image, 0, 1)
|
||||
|
||||
# Back to [-1, 1]
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
# Convert back to [B, C, H, W] format if it was originally channels-first
|
||||
if is_channels_first:
|
||||
image = image.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
|
||||
|
||||
out_images[key] = image
|
||||
|
||||
# obtain mask
|
||||
out_masks = {}
|
||||
for key in out_images:
|
||||
if key not in observation.image_masks:
|
||||
# do not mask by default
|
||||
out_masks[key] = torch.ones(batch_shape, dtype=torch.bool, device=observation.state.device)
|
||||
else:
|
||||
out_masks[key] = observation.image_masks[key]
|
||||
|
||||
# Create a simple object with the required attributes instead of using the complex Observation class
|
||||
class SimpleProcessedObservation:
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
return SimpleProcessedObservation(
|
||||
images=out_images,
|
||||
image_masks=out_masks,
|
||||
state=observation.state,
|
||||
tokenized_prompt=observation.tokenized_prompt,
|
||||
tokenized_prompt_mask=observation.tokenized_prompt_mask,
|
||||
token_ar_mask=observation.token_ar_mask,
|
||||
token_loss_mask=observation.token_loss_mask,
|
||||
)
|
||||
101
tests/policies/pi0_pi05/test_pi05_compile.py
Normal file
101
tests/policies/pi0_pi05/test_pi05_compile.py
Normal file
@@ -0,0 +1,101 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.pi05 import PI05Config # noqa: E402
|
||||
from lerobot.policies.pi05.modeling_pi05 import PI05Pytorch # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402
|
||||
assert_cache_stability,
|
||||
assert_compiled_output_matches_eager,
|
||||
assert_explain_has_no_graph_breaks,
|
||||
benchmark_runtime,
|
||||
make_compile_config,
|
||||
reset_compile_state,
|
||||
)
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes",
|
||||
)
|
||||
|
||||
|
||||
def _make_model(*, compile_model):
|
||||
return PI05Pytorch(make_compile_config(PI05Config, compile_model=compile_model)).cuda().eval()
|
||||
|
||||
|
||||
def _make_dummy_inputs(config):
|
||||
device = torch.device("cuda")
|
||||
common = {
|
||||
"images": [torch.randn(1, 3, *config.image_resolution, device=device)],
|
||||
"img_masks": [torch.ones(1, dtype=torch.bool, device=device)],
|
||||
"tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device),
|
||||
"masks": torch.ones(1, 5, dtype=torch.bool, device=device),
|
||||
}
|
||||
forward_kwargs = {
|
||||
**common,
|
||||
"actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"time": torch.rand(1, device=device),
|
||||
}
|
||||
sample_kwargs = {
|
||||
**common,
|
||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"num_steps": config.num_inference_steps,
|
||||
}
|
||||
return forward_kwargs, sample_kwargs
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi05_torch_compile_forward_and_sample_actions():
|
||||
if not hasattr(torch, "compile"):
|
||||
pytest.skip("torch.compile is not available")
|
||||
if not torch._dynamo.is_dynamo_supported():
|
||||
pytest.skip("torch._dynamo is not supported on this platform")
|
||||
|
||||
torch.manual_seed(0)
|
||||
eager_model = _make_model(compile_model=False)
|
||||
torch.manual_seed(0)
|
||||
compiled_model = _make_model(compile_model=True)
|
||||
forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config)
|
||||
|
||||
try:
|
||||
assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs)
|
||||
|
||||
assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi05.forward")
|
||||
assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi05.sample_actions")
|
||||
|
||||
assert_cache_stability(compiled_model.forward, forward_kwargs, "pi05.forward")
|
||||
assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi05.sample_actions")
|
||||
|
||||
benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi05.forward")
|
||||
benchmark_runtime(
|
||||
eager_model.sample_actions,
|
||||
compiled_model.sample_actions,
|
||||
sample_kwargs,
|
||||
"pi05.sample_actions",
|
||||
)
|
||||
finally:
|
||||
reset_compile_state()
|
||||
del eager_model
|
||||
del compiled_model
|
||||
torch.cuda.empty_cache()
|
||||
@@ -14,52 +14,56 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation"""
|
||||
"""Compare LeRobot PI0.5 against the vendored OpenPI PyTorch reference."""
|
||||
|
||||
import gc
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip if openpi or transformers is not available
|
||||
pytest.importorskip("openpi")
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||
from lerobot.configs import PreTrainedConfig # noqa: E402
|
||||
from lerobot.policies.pi05 import PI05Policy # noqa: E402
|
||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
||||
from tests.policies.pi0_pi05.openpi_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
||||
assert_processor_inputs_match_lerobot,
|
||||
clone_batch,
|
||||
deterministic_openpi_forward_preprocess,
|
||||
fix_reference_state_dict,
|
||||
fixed_flow_sampling,
|
||||
load_openpi_reference_state_dict,
|
||||
make_openpi_observation_from_raw,
|
||||
openpi_model_actions_from_raw,
|
||||
)
|
||||
|
||||
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="OpenPI parity and torch.compile checks are too slow for CI; run manually on GPU nodes",
|
||||
)
|
||||
|
||||
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
|
||||
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from transformers import AutoTokenizer # noqa: E402
|
||||
|
||||
from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402
|
||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
||||
from lerobot.processor import PolicyProcessorPipeline # noqa: E402
|
||||
from lerobot.types import PolicyAction # noqa: E402
|
||||
|
||||
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 200
|
||||
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
COMPILE_MODE = "default"
|
||||
FORWARD_RTOL = 1e-4
|
||||
FORWARD_ATOL = 1e-4
|
||||
SAMPLE_RTOL = 1e-2
|
||||
SAMPLE_ATOL = 5e-3
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
"action": {
|
||||
ACTION: {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
@@ -88,6 +92,15 @@ DUMMY_DATASET_STATS = {
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_cuda_after_test():
|
||||
yield
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
class PI05BaseOriginalConfig:
|
||||
action_dim: int = DUMMY_ACTION_DIM
|
||||
action_horizon: int = DUMMY_ACTION_HORIZON
|
||||
@@ -96,341 +109,163 @@ class PI05BaseOriginalConfig:
|
||||
precision: str = "float32"
|
||||
pi05: bool = True
|
||||
dtype: str = "float32"
|
||||
pytorch_compile_mode: str | None = None
|
||||
|
||||
|
||||
def instantiate_lerobot_pi05(
|
||||
from_pretrained: bool = False,
|
||||
) -> tuple[
|
||||
PI05Policy,
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
if from_pretrained:
|
||||
# Load the policy first
|
||||
policy = PI05Policy.from_pretrained(pretrained_name_or_path="lerobot/pi05_base", strict=True)
|
||||
else:
|
||||
config = PI05Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
|
||||
policy = PI05Policy(config)
|
||||
def instantiate_lerobot_pi05(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
||||
config = PreTrainedConfig.from_pretrained("lerobot/pi05_base")
|
||||
config.device = str(DEVICE)
|
||||
config.dtype = "float32"
|
||||
config.compile_model = compile_model
|
||||
config.compile_mode = COMPILE_MODE
|
||||
config.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
policy = PI05Policy.from_pretrained("lerobot/pi05_base", config=config, strict=True)
|
||||
policy.to(DEVICE)
|
||||
policy.config.device = DEVICE
|
||||
preprocessor, postprocessor = make_pi05_pre_post_processors(
|
||||
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
|
||||
)
|
||||
return (policy, preprocessor, postprocessor)
|
||||
policy.config.device = str(DEVICE)
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config=policy.config, dataset_stats=DUMMY_DATASET_STATS)
|
||||
return policy, preprocessor
|
||||
|
||||
|
||||
def instantiate_original_pi05(from_pretrained: bool = False, model_path: str | None = None):
|
||||
config = PI05BaseOriginalConfig()
|
||||
policy = PI0Pytorch(config)
|
||||
def instantiate_original_pi05():
|
||||
policy = PI0Pytorch(PI05BaseOriginalConfig()).to(DEVICE)
|
||||
|
||||
if from_pretrained:
|
||||
try:
|
||||
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi05_base)...")
|
||||
|
||||
# Download the model from HuggingFace Hub
|
||||
import safetensors.torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download the entire repository
|
||||
if model_path and os.path.exists(model_path):
|
||||
cache_dir = model_path
|
||||
print(f"Using cached model from: {cache_dir}")
|
||||
else:
|
||||
cache_dir = snapshot_download(repo_id="lerobot/pi05_base", repo_type="model")
|
||||
print(f"Downloaded model to: {cache_dir}")
|
||||
|
||||
# Try to load safetensors format first
|
||||
model_file = os.path.join(cache_dir, "model.safetensors")
|
||||
if os.path.exists(model_file):
|
||||
state_dict = safetensors.torch.load_file(model_file)
|
||||
print(f"Loaded {len(state_dict)} parameters from safetensors")
|
||||
else:
|
||||
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
|
||||
|
||||
# Load the state dict into the model
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if missing_keys:
|
||||
print(f"Missing keys: {len(missing_keys)}")
|
||||
if len(missing_keys) <= 5:
|
||||
for key in missing_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in missing_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(missing_keys) - 5} more")
|
||||
|
||||
if unexpected_keys:
|
||||
print(f"Unexpected keys: {len(unexpected_keys)}")
|
||||
if len(unexpected_keys) <= 5:
|
||||
for key in unexpected_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in unexpected_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(unexpected_keys) - 5} more")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("All pretrained weights loaded successfully!")
|
||||
else:
|
||||
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load pretrained weights: {e}")
|
||||
print(" Using randomly initialized weights...")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
policy.to(DEVICE)
|
||||
# NOTE: `lerobot/pi05_base` 的 LeRobot loader 和 PI0 一样会在 strict load 前做 key
|
||||
# 兼容转换,因此预期没有 missing_keys 或 unexpected_keys。vendored reference 则是裸
|
||||
# `nn.Module`,需要在测试侧补齐 checkpoint 与模块命名之间的最小差异。
|
||||
# NOTE: `lm_head.weight` 是 PaliGemma tied embedding 的保存名;LeRobot 的
|
||||
# from_pretrained 会把它映射到内部 `embed_tokens.weight`,而 reference 模型没有这层
|
||||
# loader,所以这里手动复用同一份 tensor,避免把权重别名差异误判成模型差异。
|
||||
state_dict = fix_reference_state_dict(load_openpi_reference_state_dict("lerobot/pi05_base"))
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
assert missing_keys == []
|
||||
assert unexpected_keys == []
|
||||
return policy
|
||||
|
||||
|
||||
def create_dummy_data():
|
||||
batch_size = 2 # Reduce batch size for testing
|
||||
device = DEVICE
|
||||
|
||||
# Use the exact same prompt for both implementations
|
||||
batch_size = 2
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
||||
ACTION: torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
"observation.images.left_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
"observation.images.right_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
|
||||
"""Extract the exact same processed inputs that LeRobot uses internally."""
|
||||
# Get the tokenized language from LeRobot's internal method
|
||||
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
|
||||
|
||||
# Get the preprocessed images from LeRobot's internal method
|
||||
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for original implementation
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
|
||||
def prepare_parity_inputs(lerobot_pi05, lerobot_preprocessor):
|
||||
torch.manual_seed(0)
|
||||
raw_batch = create_dummy_data()
|
||||
lerobot_batch = lerobot_preprocessor(clone_batch(raw_batch))
|
||||
openpi_observation = make_openpi_observation_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=True,
|
||||
)
|
||||
openpi_actions = openpi_model_actions_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=True,
|
||||
)
|
||||
assert_processor_inputs_match_lerobot(
|
||||
lerobot_pi05,
|
||||
lerobot_batch,
|
||||
openpi_observation,
|
||||
compare_state=False,
|
||||
)
|
||||
batch_size = raw_batch[OBS_STATE].shape[0]
|
||||
noise = torch.randn(
|
||||
batch_size,
|
||||
DUMMY_ACTION_HORIZON,
|
||||
DUMMY_ACTION_DIM,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE,
|
||||
)
|
||||
time = torch.linspace(0.2, 0.8, batch_size, dtype=torch.float32, device=DEVICE)
|
||||
return lerobot_batch, openpi_observation, openpi_actions, noise, time
|
||||
|
||||
|
||||
class PI05Observation:
|
||||
"""Observation class that matches the original OpenPI format."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state,
|
||||
images,
|
||||
image_masks,
|
||||
tokenized_prompt,
|
||||
tokenized_prompt_mask,
|
||||
token_ar_mask,
|
||||
token_loss_mask,
|
||||
):
|
||||
self.state = state
|
||||
self.images = images
|
||||
self.image_masks = image_masks
|
||||
self.tokenized_prompt = tokenized_prompt
|
||||
self.tokenized_prompt_mask = tokenized_prompt_mask
|
||||
self.token_ar_mask = token_ar_mask
|
||||
self.token_loss_mask = token_loss_mask
|
||||
|
||||
|
||||
def create_original_observation_with_openpi_preprocessing(batch):
|
||||
"""Create observation object for OpenPI using OpenPI's own preprocessing with pi05 state tokenizer."""
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
device = batch["observation.state"].device
|
||||
|
||||
# Create tokenizer for OpenPI (same as LeRobot uses)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
|
||||
# Get task description (pi05 processor handles all text formatting)
|
||||
tasks = batch.get("task", ["Pick up the object"] * batch_size)
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks] * batch_size
|
||||
elif len(tasks) == 1:
|
||||
tasks = tasks * batch_size
|
||||
|
||||
# Use pi05 state and input tokenizer logic (same as Pi05PrepareStateTokenizerProcessorStep)
|
||||
state = batch["observation.state"]
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
|
||||
state = pad_vector(state, DUMMY_STATE_DIM)
|
||||
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
# Create pi05-formatted prompts that include state information
|
||||
full_prompts = []
|
||||
for i, task in enumerate(tasks):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
# Tokenize with max_length padding to match OpenPI's expected format
|
||||
tokenized = tokenizer(
|
||||
full_prompts,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
truncation=True,
|
||||
max_length=DUMMY_MAX_TOKEN_LEN,
|
||||
return_tensors="pt",
|
||||
def assert_forward_matches(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
||||
lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05(
|
||||
compile_model=compile_model,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
original_pi05 = instantiate_original_pi05()
|
||||
lerobot_batch, openpi_observation, openpi_actions, noise, time = prepare_parity_inputs(
|
||||
lerobot_pi05,
|
||||
lerobot_preprocessor,
|
||||
)
|
||||
|
||||
lang_tokens = tokenized["input_ids"].to(device)
|
||||
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||
if gradient_checkpointing:
|
||||
lerobot_pi05.train()
|
||||
else:
|
||||
lerobot_pi05.eval()
|
||||
original_pi05.eval()
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for OpenPI
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
with fixed_flow_sampling(lerobot_pi05.model, noise=noise, time=time):
|
||||
lerobot_loss, _ = lerobot_pi05(lerobot_batch, reduction="none")
|
||||
with deterministic_openpi_forward_preprocess(original_pi05):
|
||||
openpi_losses = original_pi05(openpi_observation, openpi_actions, noise=noise, time=time)
|
||||
openpi_loss = openpi_losses.mean(dim=(1, 2))
|
||||
|
||||
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
|
||||
image_dict = {
|
||||
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
|
||||
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
}
|
||||
torch.testing.assert_close(lerobot_loss, openpi_loss, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
|
||||
# Create image masks (all ones for real images)
|
||||
image_masks_dict = {}
|
||||
for key in image_dict:
|
||||
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
# Create raw observation object (before preprocessing)
|
||||
raw_observation = PI05Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
def assert_sample_actions_match_openpi(*, compile_model: bool = False):
|
||||
lerobot_pi05, lerobot_preprocessor = instantiate_lerobot_pi05(compile_model=compile_model)
|
||||
original_pi05 = instantiate_original_pi05()
|
||||
lerobot_batch, openpi_observation, _openpi_actions, noise, _time = prepare_parity_inputs(
|
||||
lerobot_pi05,
|
||||
lerobot_preprocessor,
|
||||
)
|
||||
|
||||
# Now use OpenPI's preprocessing
|
||||
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
|
||||
|
||||
return processed_obs
|
||||
|
||||
|
||||
def create_original_observation_from_lerobot(lerobot_pi0, batch):
|
||||
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
|
||||
_batch_size = batch["observation.state"].shape[0]
|
||||
_device = batch["observation.state"].device
|
||||
|
||||
# Extract the exact same processed inputs that LeRobot uses
|
||||
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
|
||||
extract_lerobot_processed_inputs(lerobot_pi0, batch)
|
||||
)
|
||||
|
||||
# Convert images list to dict with original OpenPI keys
|
||||
image_dict = {
|
||||
"base_0_rgb": images[0],
|
||||
"left_wrist_0_rgb": images[1],
|
||||
"right_wrist_0_rgb": images[2],
|
||||
}
|
||||
|
||||
# Convert image masks list to dict with original OpenPI keys
|
||||
image_masks_dict = {
|
||||
"base_0_rgb": img_masks[0],
|
||||
"left_wrist_0_rgb": img_masks[1],
|
||||
"right_wrist_0_rgb": img_masks[2],
|
||||
}
|
||||
|
||||
return PI05Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
|
||||
def test_pi05_original_vs_lerobot():
|
||||
"""Test PI05 original implementation vs LeRobot implementation."""
|
||||
print("Initializing models...")
|
||||
lerobot_pi05, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi05(
|
||||
from_pretrained=True
|
||||
) # Load pretrained LeRobot model
|
||||
original_pi0 = instantiate_original_pi05(
|
||||
from_pretrained=True
|
||||
) # Load pretrained OpenPI model from HuggingFace Hub
|
||||
|
||||
print("Creating dummy data...")
|
||||
batch = create_dummy_data()
|
||||
batch_lerobot = deepcopy(batch)
|
||||
|
||||
# Test each model with its own preprocessing (more realistic end-to-end test)
|
||||
print("\nTest each model with its own preprocessing")
|
||||
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
|
||||
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
|
||||
|
||||
print(f"Task prompt: '{batch['task'][0]}'")
|
||||
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
|
||||
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
|
||||
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
|
||||
|
||||
print("Testing OpenPI with own preprocessing...")
|
||||
original_pi0.eval()
|
||||
torch.manual_seed(42) # Set seed for reproducibility
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
|
||||
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
openpi_actions = original_pi0.sample_actions(
|
||||
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
|
||||
)
|
||||
openpi_actions_unit = openpi_actions[:, 0, :]
|
||||
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
|
||||
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
|
||||
|
||||
print("Testing LeRobot with own preprocessing...")
|
||||
lerobot_pi05.eval()
|
||||
torch.manual_seed(42) # Set the same seed
|
||||
|
||||
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||
original_pi05.eval()
|
||||
with torch.no_grad():
|
||||
lerobot_actions_own = lerobot_pi05.predict_action_chunk(
|
||||
batch_lerobot_processed
|
||||
) # batch_size, n_action_steps, action_dim
|
||||
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
|
||||
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
|
||||
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
|
||||
lerobot_actions = lerobot_pi05.predict_action_chunk(lerobot_batch, noise=noise, num_steps=10)
|
||||
openpi_actions = original_pi05.sample_actions(
|
||||
device=DEVICE,
|
||||
observation=openpi_observation,
|
||||
noise=noise,
|
||||
num_steps=10,
|
||||
)
|
||||
|
||||
print("\nComparing end-to-end implementations:")
|
||||
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
|
||||
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
|
||||
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
|
||||
torch.testing.assert_close(lerobot_actions, openpi_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
|
||||
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
|
||||
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4
|
||||
|
||||
def test_pi05_forward_matches_openpi():
|
||||
assert_forward_matches()
|
||||
|
||||
|
||||
def test_pi05_sample_actions_match_openpi():
|
||||
assert_sample_actions_match_openpi()
|
||||
|
||||
|
||||
def test_pi05_gradient_checkpointing_forward_matches_openpi():
|
||||
assert_forward_matches(gradient_checkpointing=True)
|
||||
|
||||
|
||||
def test_pi05_compile_forward_matches_openpi():
|
||||
assert_forward_matches(compile_model=True)
|
||||
|
||||
|
||||
def test_pi05_compile_sample_actions_match_openpi():
|
||||
assert_sample_actions_match_openpi(compile_model=True)
|
||||
|
||||
|
||||
def test_pi05_compile_gradient_checkpointing_forward_matches_openpi():
|
||||
assert_forward_matches(compile_model=True, gradient_checkpointing=True)
|
||||
|
||||
99
tests/policies/pi0_pi05/test_pi0_compile.py
Normal file
99
tests/policies/pi0_pi05/test_pi0_compile.py
Normal file
@@ -0,0 +1,99 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.pi0 import PI0Config # noqa: E402
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Pytorch # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402
|
||||
assert_cache_stability,
|
||||
assert_compiled_output_matches_eager,
|
||||
assert_explain_has_no_graph_breaks,
|
||||
benchmark_runtime,
|
||||
make_compile_config,
|
||||
reset_compile_state,
|
||||
)
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes",
|
||||
)
|
||||
|
||||
|
||||
def _make_model(*, compile_model):
|
||||
return PI0Pytorch(make_compile_config(PI0Config, compile_model=compile_model)).cuda().eval()
|
||||
|
||||
|
||||
def _make_dummy_inputs(config):
|
||||
device = torch.device("cuda")
|
||||
common = {
|
||||
"images": [torch.randn(1, 3, *config.image_resolution, device=device)],
|
||||
"img_masks": [torch.ones(1, dtype=torch.bool, device=device)],
|
||||
"lang_tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device),
|
||||
"lang_masks": torch.ones(1, 5, dtype=torch.bool, device=device),
|
||||
"state": torch.randn(1, config.max_state_dim, device=device),
|
||||
}
|
||||
forward_kwargs = {
|
||||
**common,
|
||||
"actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"time": torch.rand(1, device=device),
|
||||
}
|
||||
sample_kwargs = {
|
||||
**common,
|
||||
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
||||
"num_steps": config.num_inference_steps,
|
||||
}
|
||||
return forward_kwargs, sample_kwargs
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi0_torch_compile_forward_and_sample_actions():
|
||||
if not hasattr(torch, "compile"):
|
||||
pytest.skip("torch.compile is not available")
|
||||
if not torch._dynamo.is_dynamo_supported():
|
||||
pytest.skip("torch._dynamo is not supported on this platform")
|
||||
|
||||
torch.manual_seed(0)
|
||||
eager_model = _make_model(compile_model=False)
|
||||
torch.manual_seed(0)
|
||||
compiled_model = _make_model(compile_model=True)
|
||||
forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config)
|
||||
|
||||
try:
|
||||
assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs)
|
||||
|
||||
assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi0.forward")
|
||||
assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi0.sample_actions")
|
||||
|
||||
assert_cache_stability(compiled_model.forward, forward_kwargs, "pi0.forward")
|
||||
assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi0.sample_actions")
|
||||
|
||||
benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi0.forward")
|
||||
benchmark_runtime(
|
||||
eager_model.sample_actions, compiled_model.sample_actions, sample_kwargs, "pi0.sample_actions"
|
||||
)
|
||||
finally:
|
||||
reset_compile_state()
|
||||
del eager_model
|
||||
del compiled_model
|
||||
torch.cuda.empty_cache()
|
||||
@@ -14,51 +14,56 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify PI0 policy integration with LeRobot vs the original implementation"""
|
||||
"""Compare LeRobot PI0 against the vendored OpenPI PyTorch reference."""
|
||||
|
||||
import gc
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip if openpi or transformers is not available
|
||||
pytest.importorskip("openpi")
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||
from lerobot.configs import PreTrainedConfig # noqa: E402
|
||||
from lerobot.policies.pi0 import PI0Policy # noqa: E402
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
||||
from tests.policies.pi0_pi05.openpi_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
||||
assert_processor_inputs_match_lerobot,
|
||||
clone_batch,
|
||||
deterministic_openpi_forward_preprocess,
|
||||
fix_reference_state_dict,
|
||||
fixed_flow_sampling,
|
||||
load_openpi_reference_state_dict,
|
||||
make_openpi_observation_from_raw,
|
||||
openpi_model_actions_from_raw,
|
||||
)
|
||||
|
||||
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="OpenPI parity and torch.compile checks are too slow for CI; run manually on GPU nodes",
|
||||
)
|
||||
|
||||
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
|
||||
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
|
||||
from transformers import AutoTokenizer # noqa: E402
|
||||
|
||||
from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
|
||||
from lerobot.processor import PolicyProcessorPipeline # noqa: E402
|
||||
from lerobot.types import PolicyAction # noqa: E402
|
||||
|
||||
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 48 # Default for PI0 (non-pi05)
|
||||
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
|
||||
DUMMY_MAX_TOKEN_LEN = 48
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
COMPILE_MODE = "default"
|
||||
FORWARD_RTOL = 1e-4
|
||||
FORWARD_ATOL = 1e-4
|
||||
SAMPLE_RTOL = 1e-2
|
||||
SAMPLE_ATOL = 5e-3
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
"observation.state": {
|
||||
OBS_STATE: {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
"action": {
|
||||
ACTION: {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
@@ -87,6 +92,15 @@ DUMMY_DATASET_STATS = {
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_cuda_after_test():
|
||||
yield
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
class PI0BaseOriginalConfig:
|
||||
action_dim: int = DUMMY_ACTION_DIM
|
||||
action_horizon: int = DUMMY_ACTION_HORIZON
|
||||
@@ -95,333 +109,156 @@ class PI0BaseOriginalConfig:
|
||||
precision: str = "float32"
|
||||
pi05: bool = False
|
||||
dtype: str = "float32"
|
||||
pytorch_compile_mode: str | None = None
|
||||
|
||||
|
||||
def instantiate_lerobot_pi0(
|
||||
from_pretrained: bool = False,
|
||||
) -> tuple[
|
||||
PI0Policy,
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
if from_pretrained:
|
||||
# Load the policy first
|
||||
policy = PI0Policy.from_pretrained(pretrained_name_or_path="lerobot/pi0_base", strict=True)
|
||||
else:
|
||||
config = PI0Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
|
||||
policy = PI0Policy(config)
|
||||
def instantiate_lerobot_pi0(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
||||
config = PreTrainedConfig.from_pretrained("lerobot/pi0_base")
|
||||
config.device = str(DEVICE)
|
||||
config.dtype = "float32"
|
||||
config.compile_model = compile_model
|
||||
config.compile_mode = COMPILE_MODE
|
||||
config.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
policy = PI0Policy.from_pretrained("lerobot/pi0_base", config=config, strict=True)
|
||||
policy.to(DEVICE)
|
||||
policy.config.device = DEVICE
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
|
||||
)
|
||||
return (policy, preprocessor, postprocessor)
|
||||
policy.config.device = str(DEVICE)
|
||||
preprocessor, _ = make_pi0_pre_post_processors(config=policy.config, dataset_stats=DUMMY_DATASET_STATS)
|
||||
return policy, preprocessor
|
||||
|
||||
|
||||
def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None):
|
||||
config = PI0BaseOriginalConfig()
|
||||
policy = PI0Pytorch(config)
|
||||
|
||||
if from_pretrained:
|
||||
try:
|
||||
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi0_base)...")
|
||||
|
||||
# Download the model from HuggingFace Hub
|
||||
import safetensors.torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download the entire repository
|
||||
if model_path and os.path.exists(model_path):
|
||||
cache_dir = model_path
|
||||
print(f"Using cached model from: {cache_dir}")
|
||||
else:
|
||||
cache_dir = snapshot_download(repo_id="lerobot/pi0_base", repo_type="model")
|
||||
print(f"Downloaded model to: {cache_dir}")
|
||||
|
||||
# Try to load safetensors format first
|
||||
model_file = os.path.join(cache_dir, "model.safetensors")
|
||||
if os.path.exists(model_file):
|
||||
state_dict = safetensors.torch.load_file(model_file)
|
||||
print(f"Loaded {len(state_dict)} parameters from safetensors")
|
||||
else:
|
||||
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
|
||||
|
||||
# Load the state dict into the model
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if missing_keys:
|
||||
print(f"Missing keys: {len(missing_keys)}")
|
||||
if len(missing_keys) <= 5:
|
||||
for key in missing_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in missing_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(missing_keys) - 5} more")
|
||||
|
||||
if unexpected_keys:
|
||||
print(f"Unexpected keys: {len(unexpected_keys)}")
|
||||
if len(unexpected_keys) <= 5:
|
||||
for key in unexpected_keys:
|
||||
print(f" - {key}")
|
||||
else:
|
||||
for key in unexpected_keys[:5]:
|
||||
print(f" - {key}")
|
||||
print(f" ... and {len(unexpected_keys) - 5} more")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("All pretrained weights loaded successfully!")
|
||||
else:
|
||||
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load pretrained weights: {e}")
|
||||
print(" Using randomly initialized weights...")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
policy.to(DEVICE)
|
||||
def instantiate_original_pi0():
|
||||
policy = PI0Pytorch(PI0BaseOriginalConfig()).to(DEVICE)
|
||||
state_dict = fix_reference_state_dict(load_openpi_reference_state_dict("lerobot/pi0_base"))
|
||||
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
|
||||
assert missing_keys == []
|
||||
assert unexpected_keys == []
|
||||
return policy
|
||||
|
||||
|
||||
def create_dummy_data():
|
||||
batch_size = 2 # Reduce batch size for testing
|
||||
device = DEVICE
|
||||
|
||||
# Use the exact same prompt for both implementations
|
||||
batch_size = 2
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
||||
ACTION: torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
"observation.images.left_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
"observation.images.right_wrist_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
|
||||
"""Extract the exact same processed inputs that LeRobot uses internally."""
|
||||
# Get the tokenized language from LeRobot's internal method
|
||||
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
|
||||
|
||||
# Get the preprocessed images from LeRobot's internal method
|
||||
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for original implementation
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
|
||||
def prepare_parity_inputs(lerobot_pi0, lerobot_preprocessor):
|
||||
torch.manual_seed(0)
|
||||
raw_batch = create_dummy_data()
|
||||
lerobot_batch = lerobot_preprocessor(clone_batch(raw_batch))
|
||||
openpi_observation = make_openpi_observation_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=False,
|
||||
)
|
||||
openpi_actions = openpi_model_actions_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=False,
|
||||
)
|
||||
assert_processor_inputs_match_lerobot(
|
||||
lerobot_pi0,
|
||||
lerobot_batch,
|
||||
openpi_observation,
|
||||
compare_state=True,
|
||||
)
|
||||
batch_size = raw_batch[OBS_STATE].shape[0]
|
||||
noise = torch.randn(
|
||||
batch_size,
|
||||
DUMMY_ACTION_HORIZON,
|
||||
DUMMY_ACTION_DIM,
|
||||
dtype=torch.float32,
|
||||
device=DEVICE,
|
||||
)
|
||||
time = torch.linspace(0.2, 0.8, batch_size, dtype=torch.float32, device=DEVICE)
|
||||
return lerobot_batch, openpi_observation, openpi_actions, noise, time
|
||||
|
||||
|
||||
class PI0Observation:
|
||||
"""Observation class that matches the original OpenPI format."""
|
||||
def assert_forward_matches(*, compile_model: bool = False, gradient_checkpointing: bool = False):
|
||||
lerobot_pi0, lerobot_preprocessor = instantiate_lerobot_pi0(
|
||||
compile_model=compile_model,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
original_pi0 = instantiate_original_pi0()
|
||||
lerobot_batch, openpi_observation, openpi_actions, noise, time = prepare_parity_inputs(
|
||||
lerobot_pi0,
|
||||
lerobot_preprocessor,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state,
|
||||
images,
|
||||
image_masks,
|
||||
tokenized_prompt,
|
||||
tokenized_prompt_mask,
|
||||
token_ar_mask,
|
||||
token_loss_mask,
|
||||
):
|
||||
self.state = state
|
||||
self.images = images
|
||||
self.image_masks = image_masks
|
||||
self.tokenized_prompt = tokenized_prompt
|
||||
self.tokenized_prompt_mask = tokenized_prompt_mask
|
||||
self.token_ar_mask = token_ar_mask
|
||||
self.token_loss_mask = token_loss_mask
|
||||
|
||||
|
||||
def create_original_observation_with_openpi_preprocessing(batch):
|
||||
"""Create observation object for OpenPI using OpenPI's own preprocessing."""
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
device = batch["observation.state"].device
|
||||
|
||||
# Create tokenizer for OpenPI (same as LeRobot uses)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
|
||||
# Get task description
|
||||
if "task" in batch:
|
||||
tasks = batch["task"]
|
||||
if isinstance(tasks, str):
|
||||
# Single string: add newline if not present, then convert to list
|
||||
if not tasks.endswith("\n"):
|
||||
tasks = f"{tasks}\n"
|
||||
tasks = [tasks]
|
||||
elif isinstance(tasks, list) and all(isinstance(t, str) for t in tasks):
|
||||
# List of strings: add newline to each if not present
|
||||
tasks = [t if t.endswith("\n") else f"{t}\n" for t in tasks]
|
||||
if len(tasks) == 1:
|
||||
# Expand to batch size
|
||||
tasks = tasks * batch_size
|
||||
if len(tasks) != batch_size:
|
||||
raise ValueError(f"Expected batch size {batch_size}, got {len(tasks)}")
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
if gradient_checkpointing:
|
||||
lerobot_pi0.train()
|
||||
else:
|
||||
# Default task if not provided
|
||||
tasks = ["Pick up the object\n"] * batch_size
|
||||
|
||||
# Tokenize with max_length padding to match OpenPI's expected format
|
||||
tokenized = tokenizer(
|
||||
tasks,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
truncation=True,
|
||||
max_length=DUMMY_MAX_TOKEN_LEN,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
lang_tokens = tokenized["input_ids"].to(device)
|
||||
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
|
||||
|
||||
# Create dummy token_ar_mask and token_loss_mask for OpenPI
|
||||
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
|
||||
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
|
||||
|
||||
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
|
||||
image_dict = {
|
||||
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
|
||||
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
|
||||
}
|
||||
|
||||
# Create image masks (all ones for real images)
|
||||
image_masks_dict = {}
|
||||
for key in image_dict:
|
||||
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
# Create raw observation object (before preprocessing)
|
||||
raw_observation = PI0Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
# Now use OpenPI's preprocessing
|
||||
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
|
||||
|
||||
return processed_obs
|
||||
|
||||
|
||||
def create_original_observation_from_lerobot(lerobot_pi0, batch):
|
||||
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
|
||||
_batch_size = batch["observation.state"].shape[0]
|
||||
_device = batch["observation.state"].device
|
||||
|
||||
# Extract the exact same processed inputs that LeRobot uses
|
||||
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
|
||||
extract_lerobot_processed_inputs(lerobot_pi0, batch)
|
||||
)
|
||||
|
||||
# Convert images list to dict with original OpenPI keys
|
||||
image_dict = {
|
||||
"base_0_rgb": images[0],
|
||||
"left_wrist_0_rgb": images[1],
|
||||
"right_wrist_0_rgb": images[2],
|
||||
}
|
||||
|
||||
# Convert image masks list to dict with original OpenPI keys
|
||||
image_masks_dict = {
|
||||
"base_0_rgb": img_masks[0],
|
||||
"left_wrist_0_rgb": img_masks[1],
|
||||
"right_wrist_0_rgb": img_masks[2],
|
||||
}
|
||||
|
||||
return PI0Observation(
|
||||
state=batch["observation.state"],
|
||||
images=image_dict,
|
||||
image_masks=image_masks_dict,
|
||||
tokenized_prompt=lang_tokens,
|
||||
tokenized_prompt_mask=lang_masks,
|
||||
token_ar_mask=token_ar_mask,
|
||||
token_loss_mask=token_loss_mask,
|
||||
)
|
||||
|
||||
|
||||
def test_pi0_original_vs_lerobot():
|
||||
"""Test PI0 original implementation vs LeRobot implementation."""
|
||||
print("Initializing models...")
|
||||
lerobot_pi0, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi0(
|
||||
from_pretrained=True
|
||||
) # Load pretrained LeRobot model
|
||||
original_pi0 = instantiate_original_pi0(
|
||||
from_pretrained=True
|
||||
) # Load pretrained OpenPI model from HuggingFace Hub
|
||||
|
||||
print("Creating dummy data...")
|
||||
batch = create_dummy_data()
|
||||
batch_lerobot = deepcopy(batch)
|
||||
|
||||
# Test each model with its own preprocessing (more realistic end-to-end test)
|
||||
print("\nTest each model with its own preprocessing")
|
||||
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
|
||||
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
|
||||
|
||||
print(f"Task prompt: '{batch['task'][0]}'")
|
||||
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
|
||||
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
|
||||
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
|
||||
|
||||
print("Testing OpenPI with own preprocessing...")
|
||||
lerobot_pi0.eval()
|
||||
original_pi0.eval()
|
||||
torch.manual_seed(42) # Set seed for reproducibility
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
|
||||
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
openpi_actions = original_pi0.sample_actions(
|
||||
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
|
||||
)
|
||||
openpi_actions_unit = openpi_actions[:, 0, :]
|
||||
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
|
||||
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
|
||||
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
|
||||
with fixed_flow_sampling(lerobot_pi0.model, noise=noise, time=time):
|
||||
lerobot_loss, _ = lerobot_pi0(lerobot_batch, reduction="none")
|
||||
with deterministic_openpi_forward_preprocess(original_pi0):
|
||||
openpi_losses = original_pi0(openpi_observation, openpi_actions, noise=noise, time=time)
|
||||
openpi_loss = openpi_losses.mean(dim=(1, 2))
|
||||
|
||||
torch.testing.assert_close(lerobot_loss, openpi_loss, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
|
||||
|
||||
def assert_sample_actions_match_openpi(*, compile_model: bool = False):
|
||||
lerobot_pi0, lerobot_preprocessor = instantiate_lerobot_pi0(compile_model=compile_model)
|
||||
original_pi0 = instantiate_original_pi0()
|
||||
lerobot_batch, openpi_observation, _openpi_actions, noise, _time = prepare_parity_inputs(
|
||||
lerobot_pi0,
|
||||
lerobot_preprocessor,
|
||||
)
|
||||
|
||||
print("Testing LeRobot with own preprocessing...")
|
||||
lerobot_pi0.eval()
|
||||
torch.manual_seed(42) # Set the same seed
|
||||
|
||||
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
|
||||
original_pi0.eval()
|
||||
with torch.no_grad():
|
||||
lerobot_actions_own = lerobot_pi0.predict_action_chunk(
|
||||
batch_lerobot_processed
|
||||
) # batch_size, n_action_steps, action_dim
|
||||
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
|
||||
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
|
||||
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
|
||||
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
|
||||
lerobot_actions = lerobot_pi0.predict_action_chunk(lerobot_batch, noise=noise, num_steps=10)
|
||||
openpi_actions = original_pi0.sample_actions(
|
||||
device=DEVICE,
|
||||
observation=openpi_observation,
|
||||
noise=noise,
|
||||
num_steps=10,
|
||||
)
|
||||
|
||||
print("\nComparing end-to-end implementations:")
|
||||
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
|
||||
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
|
||||
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
|
||||
torch.testing.assert_close(lerobot_actions, openpi_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
|
||||
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
|
||||
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
|
||||
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4
|
||||
|
||||
def test_pi0_forward_matches_openpi():
|
||||
assert_forward_matches()
|
||||
|
||||
|
||||
def test_pi0_sample_actions_match_openpi():
|
||||
assert_sample_actions_match_openpi()
|
||||
|
||||
|
||||
def test_pi0_gradient_checkpointing_forward_matches_openpi():
|
||||
assert_forward_matches(gradient_checkpointing=True)
|
||||
|
||||
|
||||
def test_pi0_compile_forward_matches_openpi():
|
||||
assert_forward_matches(compile_model=True)
|
||||
|
||||
|
||||
def test_pi0_compile_sample_actions_match_openpi():
|
||||
assert_sample_actions_match_openpi(compile_model=True)
|
||||
|
||||
|
||||
def test_pi0_compile_gradient_checkpointing_forward_matches_openpi():
|
||||
assert_forward_matches(compile_model=True, gradient_checkpointing=True)
|
||||
|
||||
1
tests/policies/pi0_pi05/utils/__init__.py
Normal file
1
tests/policies/pi0_pi05/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Utilities shared by PI0/PI05 policy tests."""
|
||||
291
tests/policies/pi0_pi05/utils/openpi_parity.py
Normal file
291
tests/policies/pi0_pi05/utils/openpi_parity.py
Normal file
@@ -0,0 +1,291 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
)
|
||||
from tests.policies.pi0_pi05.openpi_pytorch import preprocessing_pytorch as openpi_preprocessing
|
||||
|
||||
IMAGE_KEYS = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
|
||||
TOKENIZER_NAME = "google/paligemma-3b-pt-224"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenPIObservation:
|
||||
state: torch.Tensor
|
||||
images: dict[str, torch.Tensor]
|
||||
image_masks: dict[str, torch.Tensor]
|
||||
tokenized_prompt: torch.Tensor
|
||||
tokenized_prompt_mask: torch.Tensor
|
||||
token_ar_mask: torch.Tensor
|
||||
token_loss_mask: torch.Tensor
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def paligemma_tokenizer():
|
||||
return AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
||||
|
||||
|
||||
def clone_batch(batch: dict) -> dict:
|
||||
return {
|
||||
key: value.clone() if isinstance(value, torch.Tensor) else list(value) for key, value in batch.items()
|
||||
}
|
||||
|
||||
|
||||
def pad_last_dim(tensor: torch.Tensor, target_dim: int) -> torch.Tensor:
|
||||
if tensor.shape[-1] > target_dim:
|
||||
raise ValueError(f"Cannot pad last dimension {tensor.shape[-1]} down to {target_dim}")
|
||||
return F.pad(tensor, (0, target_dim - tensor.shape[-1]))
|
||||
|
||||
|
||||
def mean_std_normalize(tensor: torch.Tensor, stats: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
mean = stats["mean"].to(device=tensor.device, dtype=tensor.dtype)
|
||||
std = stats["std"].to(device=tensor.device, dtype=tensor.dtype)
|
||||
return (tensor - mean) / (std + 1e-8)
|
||||
|
||||
|
||||
def quantile_normalize(tensor: torch.Tensor, stats: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
q01 = stats["q01"].to(device=tensor.device, dtype=tensor.dtype)
|
||||
q99 = stats["q99"].to(device=tensor.device, dtype=tensor.dtype)
|
||||
denom = torch.where(q99 == q01, torch.full_like(q99, 1e-8), q99 - q01)
|
||||
return 2.0 * (tensor - q01) / denom - 1.0
|
||||
|
||||
|
||||
def openpi_model_state_from_raw(
|
||||
batch: dict[str, torch.Tensor],
|
||||
*,
|
||||
action_dim: int,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]],
|
||||
pi05: bool,
|
||||
) -> torch.Tensor:
|
||||
state = batch[OBS_STATE].to(dtype=torch.float32)
|
||||
if pi05:
|
||||
state = quantile_normalize(state, dataset_stats[OBS_STATE])
|
||||
else:
|
||||
state = mean_std_normalize(state, dataset_stats[OBS_STATE])
|
||||
return pad_last_dim(state, action_dim)
|
||||
|
||||
|
||||
def openpi_model_actions_from_raw(
|
||||
batch: dict[str, torch.Tensor],
|
||||
*,
|
||||
action_dim: int,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]],
|
||||
pi05: bool,
|
||||
) -> torch.Tensor:
|
||||
actions = batch[ACTION].to(dtype=torch.float32)
|
||||
if pi05:
|
||||
actions = quantile_normalize(actions, dataset_stats[ACTION])
|
||||
else:
|
||||
actions = mean_std_normalize(actions, dataset_stats[ACTION])
|
||||
return pad_last_dim(actions, action_dim)
|
||||
|
||||
|
||||
def _tasks_from_raw(batch: dict, batch_size: int) -> list[str]:
|
||||
tasks = batch.get("task")
|
||||
if tasks is None:
|
||||
raise ValueError("The parity batch must include a task prompt.")
|
||||
if isinstance(tasks, str):
|
||||
return [tasks] * batch_size
|
||||
if len(tasks) == 1:
|
||||
return [tasks[0]] * batch_size
|
||||
if len(tasks) != batch_size:
|
||||
raise ValueError(f"Expected {batch_size} task prompts, got {len(tasks)}")
|
||||
return list(tasks)
|
||||
|
||||
|
||||
def _format_pi0_prompts(tasks: list[str]) -> list[str]:
|
||||
return [f"{task.strip().replace('_', ' ').replace(chr(10), ' ')}\n" for task in tasks]
|
||||
|
||||
|
||||
def _format_pi05_prompts(tasks: list[str], normalized_state: torch.Tensor) -> list[str]:
|
||||
state_np = normalized_state.detach().cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
prompts = []
|
||||
for task, state in zip(tasks, discretized_states, strict=True):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, state))
|
||||
prompts.append(f"Task: {cleaned_text}, State: {state_str};\nAction: ")
|
||||
return prompts
|
||||
|
||||
|
||||
def _tokenize_prompts(prompts: list[str], *, max_token_len: int, device: torch.device | str):
|
||||
tokenized = paligemma_tokenizer()(
|
||||
prompts,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
truncation=True,
|
||||
max_length=max_token_len,
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = tokenized["input_ids"].to(device)
|
||||
masks = tokenized["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
return tokens, masks
|
||||
|
||||
|
||||
def make_openpi_observation_from_raw(
|
||||
batch: dict[str, torch.Tensor],
|
||||
*,
|
||||
action_dim: int,
|
||||
max_token_len: int,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]],
|
||||
pi05: bool,
|
||||
) -> OpenPIObservation:
|
||||
batch_size = batch[OBS_STATE].shape[0]
|
||||
device = batch[OBS_STATE].device
|
||||
state = openpi_model_state_from_raw(
|
||||
batch,
|
||||
action_dim=action_dim,
|
||||
dataset_stats=dataset_stats,
|
||||
pi05=pi05,
|
||||
)
|
||||
|
||||
tasks = _tasks_from_raw(batch, batch_size)
|
||||
prompts = _format_pi05_prompts(tasks, state) if pi05 else _format_pi0_prompts(tasks)
|
||||
tokens, masks = _tokenize_prompts(prompts, max_token_len=max_token_len, device=device)
|
||||
|
||||
images = {
|
||||
key: batch[f"observation.images.{key}"].to(device=device, dtype=torch.float32) * 2.0 - 1.0
|
||||
for key in IMAGE_KEYS
|
||||
}
|
||||
image_masks = {key: torch.ones(batch_size, dtype=torch.bool, device=device) for key in IMAGE_KEYS}
|
||||
|
||||
return OpenPIObservation(
|
||||
state=state,
|
||||
images=images,
|
||||
image_masks=image_masks,
|
||||
tokenized_prompt=tokens,
|
||||
tokenized_prompt_mask=masks,
|
||||
token_ar_mask=torch.zeros_like(tokens, dtype=torch.int32),
|
||||
token_loss_mask=torch.ones_like(masks, dtype=torch.bool),
|
||||
)
|
||||
|
||||
|
||||
def assert_processor_inputs_match_lerobot(
|
||||
lerobot_policy,
|
||||
lerobot_batch: dict[str, torch.Tensor],
|
||||
openpi_observation: OpenPIObservation,
|
||||
*,
|
||||
compare_state: bool,
|
||||
):
|
||||
openpi_processed = openpi_preprocessing.preprocess_observation_pytorch(openpi_observation, train=False)
|
||||
lerobot_images, lerobot_image_masks = lerobot_policy._preprocess_images(lerobot_batch)
|
||||
|
||||
# Token IDs, token masks, images, image masks, and PI0 state are intentionally built from the same
|
||||
# raw batch through independent LeRobot/OpenPI-style processor logic. They must be bitwise equal.
|
||||
torch.testing.assert_close(
|
||||
openpi_observation.tokenized_prompt, lerobot_batch[OBS_LANGUAGE_TOKENS], rtol=0, atol=0
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
openpi_observation.tokenized_prompt_mask,
|
||||
lerobot_batch[OBS_LANGUAGE_ATTENTION_MASK],
|
||||
rtol=0,
|
||||
atol=0,
|
||||
)
|
||||
|
||||
for openpi_image, lerobot_image in zip(openpi_processed.images.values(), lerobot_images, strict=True):
|
||||
torch.testing.assert_close(openpi_image, lerobot_image, rtol=0, atol=0)
|
||||
|
||||
for openpi_mask, lerobot_mask in zip(
|
||||
openpi_processed.image_masks.values(), lerobot_image_masks, strict=True
|
||||
):
|
||||
torch.testing.assert_close(openpi_mask, lerobot_mask, rtol=0, atol=0)
|
||||
|
||||
if compare_state:
|
||||
torch.testing.assert_close(
|
||||
openpi_processed.state, lerobot_policy.prepare_state(lerobot_batch), rtol=0, atol=0
|
||||
)
|
||||
|
||||
|
||||
def load_openpi_reference_state_dict(repo_id: str) -> dict[str, torch.Tensor]:
|
||||
cache_dir = Path(snapshot_download(repo_id=repo_id, repo_type="model"))
|
||||
return safetensors.torch.load_file(cache_dir / "model.safetensors")
|
||||
|
||||
|
||||
def fix_reference_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
fixed_state_dict = dict(state_dict)
|
||||
lm_head_key = "paligemma_with_expert.paligemma.lm_head.weight"
|
||||
embed_tokens_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
if lm_head_key in fixed_state_dict and embed_tokens_key not in fixed_state_dict:
|
||||
fixed_state_dict[embed_tokens_key] = fixed_state_dict[lm_head_key].clone()
|
||||
return fixed_state_dict
|
||||
|
||||
|
||||
@contextmanager
|
||||
def fixed_flow_sampling(model, *, noise: torch.Tensor, time: torch.Tensor) -> Iterator[None]:
|
||||
original_sample_noise = model.sample_noise
|
||||
original_sample_time = model.sample_time
|
||||
|
||||
def sample_noise(shape, device):
|
||||
if tuple(shape) != tuple(noise.shape):
|
||||
raise ValueError(f"Expected noise shape {tuple(noise.shape)}, got {tuple(shape)}")
|
||||
return noise.to(device=device)
|
||||
|
||||
def sample_time(batch_size, device):
|
||||
if batch_size != time.shape[0]:
|
||||
raise ValueError(f"Expected time batch size {time.shape[0]}, got {batch_size}")
|
||||
return time.to(device=device)
|
||||
|
||||
model.sample_noise = sample_noise
|
||||
model.sample_time = sample_time
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
model.sample_noise = original_sample_noise
|
||||
model.sample_time = original_sample_time
|
||||
|
||||
|
||||
@contextmanager
|
||||
def deterministic_openpi_forward_preprocess(openpi_policy) -> Iterator[None]:
|
||||
"""Disable OpenPI's training-time image augmentation only inside a parity forward block.
|
||||
|
||||
OpenPI's `forward()` calls `_preprocess_observation(..., train=True)`, which can apply stochastic
|
||||
image augmentation. LeRobot's policy forward path does not apply that augmentation, so parity would
|
||||
otherwise compare two different image tensors rather than two model implementations. The context manager
|
||||
keeps the public `openpi_policy.forward(observation, ...)` call while making preprocessing deterministic.
|
||||
|
||||
`yield` marks the body of the caller's `with` block. The `try/finally` restores the original method even
|
||||
if the assertion inside the block fails, so the temporary monkeypatch cannot leak into later tests.
|
||||
"""
|
||||
|
||||
original_preprocess_observation = openpi_policy._preprocess_observation
|
||||
|
||||
def preprocess_observation(observation, *, train=True):
|
||||
return original_preprocess_observation(observation, train=False)
|
||||
|
||||
openpi_policy._preprocess_observation = preprocess_observation
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
openpi_policy._preprocess_observation = original_preprocess_observation
|
||||
207
tests/policies/pi0_pi05/utils/torch_compile.py
Normal file
207
tests/policies/pi0_pi05/utils/torch_compile.py
Normal file
@@ -0,0 +1,207 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters, guard_failures
|
||||
from torch.profiler import ProfilerActivity
|
||||
|
||||
FORWARD_RTOL = 1e-5
|
||||
FORWARD_ATOL = 5e-2
|
||||
SAMPLE_RTOL = 1e-5
|
||||
SAMPLE_ATOL = 1e-2
|
||||
COMPILE_MODE = "max-autotune"
|
||||
STEADY_STATE_WARMUPS = 3
|
||||
STEADY_STATE_REPEATS = 3
|
||||
|
||||
|
||||
def make_compile_config(config_cls, *, compile_model):
|
||||
return config_cls(device="cuda", compile_model=compile_model, compile_mode=COMPILE_MODE)
|
||||
|
||||
|
||||
def counter_total(name):
|
||||
return sum(counters.get(name, {}).values())
|
||||
|
||||
|
||||
def compile_snapshot():
|
||||
return {
|
||||
"graph_breaks": counter_total("graph_break"),
|
||||
"recompiles": counter_total("recompiles"),
|
||||
"recompile_limits": counter_total("recompile_limit"),
|
||||
"unique_graphs": counters["stats"].get("unique_graphs", 0),
|
||||
}
|
||||
|
||||
|
||||
def reset_compile_state():
|
||||
torch._dynamo.reset()
|
||||
counters.clear()
|
||||
guard_failures.clear()
|
||||
|
||||
|
||||
def clone_cuda_graph_output(output):
|
||||
if torch.is_tensor(output):
|
||||
return output.clone()
|
||||
if isinstance(output, tuple):
|
||||
return tuple(clone_cuda_graph_output(item) for item in output)
|
||||
if isinstance(output, list):
|
||||
return [clone_cuda_graph_output(item) for item in output]
|
||||
if isinstance(output, dict):
|
||||
return {key: clone_cuda_graph_output(value) for key, value in output.items()}
|
||||
return output
|
||||
|
||||
|
||||
def run_model_step(fn: Callable, kwargs: dict):
|
||||
if hasattr(torch.compiler, "cudagraph_mark_step_begin"):
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
return fn(**kwargs)
|
||||
|
||||
|
||||
def assert_explain_has_no_graph_breaks(fn: Callable, kwargs: dict, label: str):
|
||||
reset_compile_state()
|
||||
explanation = torch._dynamo.explain(fn)(**kwargs)
|
||||
|
||||
assert explanation.graph_count > 0, f"{label} was not captured by Dynamo"
|
||||
assert explanation.graph_break_count == 0, (
|
||||
f"{label} has {explanation.graph_break_count} graph break(s): {explanation.break_reasons}"
|
||||
)
|
||||
assert not explanation.break_reasons, f"{label} graph break reasons: {explanation.break_reasons}"
|
||||
|
||||
print(
|
||||
f"{label} capture: graphs={explanation.graph_count}, "
|
||||
f"graph_breaks={explanation.graph_break_count}, ops={explanation.op_count}, "
|
||||
f"guards={len(explanation.out_guards or [])}"
|
||||
)
|
||||
return explanation
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs):
|
||||
eager_forward = eager_model.forward(**forward_kwargs)
|
||||
compiled_forward = compiled_model.forward(**forward_kwargs)
|
||||
torch.testing.assert_close(compiled_forward, eager_forward, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
|
||||
eager_actions = eager_model.sample_actions(**sample_kwargs)
|
||||
compiled_actions = compiled_model.sample_actions(**sample_kwargs)
|
||||
torch.testing.assert_close(compiled_actions, eager_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def assert_cache_stability(fn: Callable, kwargs: dict, label: str):
|
||||
reset_compile_state()
|
||||
|
||||
first_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
|
||||
first_snapshot = compile_snapshot()
|
||||
second_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
|
||||
second_snapshot = compile_snapshot()
|
||||
third_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
|
||||
third_snapshot = compile_snapshot()
|
||||
|
||||
torch.testing.assert_close(second_output, first_output, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
torch.testing.assert_close(third_output, first_output, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
|
||||
assert first_snapshot["unique_graphs"] > 0, f"{label} did not compile any graph"
|
||||
assert third_snapshot["graph_breaks"] == 0, f"{label} graph breaks: {third_snapshot}"
|
||||
assert third_snapshot["recompiles"] == 0, f"{label} recompiled: {third_snapshot}"
|
||||
assert third_snapshot["recompile_limits"] == 0, f"{label} hit recompile limit: {third_snapshot}"
|
||||
assert second_snapshot["unique_graphs"] == first_snapshot["unique_graphs"], (
|
||||
f"{label} compiled new graph on second call: first={first_snapshot}, second={second_snapshot}"
|
||||
)
|
||||
assert third_snapshot["unique_graphs"] == first_snapshot["unique_graphs"], (
|
||||
f"{label} compiled new graph on third call: first={first_snapshot}, third={third_snapshot}"
|
||||
)
|
||||
assert not guard_failures, f"{label} guard failures: {dict(guard_failures)}"
|
||||
|
||||
print(f"{label} cache: first={first_snapshot}, third={third_snapshot}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark_runtime(eager_fn: Callable, compiled_fn: Callable, kwargs: dict, label: str):
|
||||
run_warmups(eager_fn, kwargs)
|
||||
run_warmups(compiled_fn, kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
eager_metrics = profile_callable(eager_fn, kwargs)
|
||||
compiled_metrics = profile_callable(compiled_fn, kwargs)
|
||||
speedup = eager_metrics["cuda_event_ms"] / compiled_metrics["cuda_event_ms"]
|
||||
|
||||
print(
|
||||
f"{label} runtime: eager_cuda={eager_metrics['cuda_event_ms']:.3f} ms, "
|
||||
f"compiled_cuda={compiled_metrics['cuda_event_ms']:.3f} ms, speedup={speedup:.3f}x, "
|
||||
f"host_wall_ms eager/compiled={eager_metrics['host_wall_ms']:.3f}/"
|
||||
f"{compiled_metrics['host_wall_ms']:.3f}, "
|
||||
f"cpu_self_time_ms eager/compiled={eager_metrics['cpu_self_time_ms']:.3f}/"
|
||||
f"{compiled_metrics['cpu_self_time_ms']:.3f}, "
|
||||
f"cuda_launches eager/compiled={eager_metrics['cuda_launch_count']}/"
|
||||
f"{compiled_metrics['cuda_launch_count']}, "
|
||||
f"profiler_events eager/compiled={eager_metrics['profiler_event_count']}/"
|
||||
f"{compiled_metrics['profiler_event_count']}, "
|
||||
f"peak_mem_mib eager/compiled={eager_metrics['peak_mem_mib']:.1f}/"
|
||||
f"{compiled_metrics['peak_mem_mib']:.1f}"
|
||||
)
|
||||
|
||||
assert eager_metrics["cuda_event_ms"] > 0
|
||||
assert compiled_metrics["cuda_event_ms"] > 0
|
||||
assert eager_metrics["profiler_event_count"] > 0
|
||||
assert compiled_metrics["profiler_event_count"] > 0
|
||||
return eager_metrics, compiled_metrics
|
||||
|
||||
|
||||
def run_warmups(fn: Callable, kwargs: dict):
|
||||
for _ in range(STEADY_STATE_WARMUPS):
|
||||
run_model_step(fn, kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def profile_callable(fn: Callable, kwargs: dict):
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
host_start = time.perf_counter()
|
||||
start_event.record()
|
||||
for _ in range(STEADY_STATE_REPEATS):
|
||||
run_model_step(fn, kwargs)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
cuda_event_ms = start_event.elapsed_time(end_event) / STEADY_STATE_REPEATS
|
||||
host_wall_ms = (time.perf_counter() - host_start) * 1000 / STEADY_STATE_REPEATS
|
||||
peak_mem_mib = torch.cuda.max_memory_allocated() / 1024**2
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[ProfilerActivity.CPU],
|
||||
) as profiler:
|
||||
run_model_step(fn, kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
key_averages = profiler.key_averages()
|
||||
cpu_self_time_ms = sum(event.self_cpu_time_total for event in key_averages) / 1000
|
||||
cuda_launch_count = sum(
|
||||
event.count
|
||||
for event in key_averages
|
||||
if event.key in {"cudaLaunchKernel", "cudaGraphLaunch", "cudaLaunchKernelExC"}
|
||||
)
|
||||
profiler_event_count = sum(event.count for event in key_averages)
|
||||
|
||||
return {
|
||||
"cuda_event_ms": cuda_event_ms,
|
||||
"host_wall_ms": host_wall_ms,
|
||||
"cpu_self_time_ms": cpu_self_time_ms,
|
||||
"cuda_launch_count": cuda_launch_count,
|
||||
"profiler_event_count": profiler_event_count,
|
||||
"peak_mem_mib": peak_mem_mib,
|
||||
}
|
||||
155
tests/processor/test_pi05_processor.py
Normal file
155
tests/processor/test_pi05_processor.py
Normal file
@@ -0,0 +1,155 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Compare the PI0.5 processor pipeline against the vendored OpenPI reference processors."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature # noqa: E402
|
||||
from lerobot.policies.pi05 import PI05Policy # noqa: E402
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config # noqa: E402
|
||||
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
||||
IMAGE_KEYS,
|
||||
assert_processor_inputs_match_lerobot,
|
||||
clone_batch,
|
||||
make_openpi_observation_from_raw,
|
||||
openpi_model_actions_from_raw,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="OpenPI processor parity uses the PaliGemma tokenizer; run manually outside CI.",
|
||||
)
|
||||
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 200
|
||||
DEVICE = torch.device("cpu")
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
OBS_STATE: {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
ACTION: {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"q99": torch.ones(DUMMY_ACTION_DIM),
|
||||
},
|
||||
"images": {
|
||||
key: {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
}
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class PI05PolicyInputAdapter(torch.nn.Module):
|
||||
"""Minimal adapter exposing PI0.5 policy image preparation without loading model weights."""
|
||||
|
||||
_preprocess_images = PI05Policy._preprocess_images
|
||||
|
||||
def __init__(self, config: PI05Config) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._device_anchor = torch.nn.Parameter(torch.empty((), device=config.device), requires_grad=False)
|
||||
|
||||
|
||||
def create_pi05_config() -> PI05Config:
|
||||
config = PI05Config(device=str(DEVICE))
|
||||
config.max_state_dim = DUMMY_STATE_DIM
|
||||
config.max_action_dim = DUMMY_ACTION_DIM
|
||||
config.chunk_size = DUMMY_ACTION_HORIZON
|
||||
config.n_action_steps = DUMMY_ACTION_HORIZON
|
||||
config.tokenizer_max_length = DUMMY_MAX_TOKEN_LEN
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(DUMMY_STATE_DIM,)),
|
||||
**{
|
||||
f"observation.images.{key}": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224))
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(DUMMY_ACTION_DIM,)),
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def create_dummy_data() -> dict:
|
||||
batch_size = 2
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
||||
ACTION: torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
**{
|
||||
f"observation.images.{key}": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
)
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
|
||||
|
||||
def test_pi05_processor_inputs_match_openpi_reference():
|
||||
torch.manual_seed(0)
|
||||
config = create_pi05_config()
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=DUMMY_DATASET_STATS)
|
||||
|
||||
raw_batch = create_dummy_data()
|
||||
lerobot_batch = preprocessor(clone_batch(raw_batch))
|
||||
openpi_observation = make_openpi_observation_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=True,
|
||||
)
|
||||
|
||||
assert_processor_inputs_match_lerobot(
|
||||
PI05PolicyInputAdapter(config),
|
||||
lerobot_batch,
|
||||
openpi_observation,
|
||||
compare_state=False,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
lerobot_batch[ACTION],
|
||||
openpi_model_actions_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=True,
|
||||
),
|
||||
rtol=0,
|
||||
atol=0,
|
||||
)
|
||||
156
tests/processor/test_pi0_processor.py
Normal file
156
tests/processor/test_pi0_processor.py
Normal file
@@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Compare the PI0 processor pipeline against the vendored OpenPI reference processors."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature # noqa: E402
|
||||
from lerobot.policies.pi0 import PI0Policy # noqa: E402
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config # noqa: E402
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
||||
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
||||
IMAGE_KEYS,
|
||||
assert_processor_inputs_match_lerobot,
|
||||
clone_batch,
|
||||
make_openpi_observation_from_raw,
|
||||
openpi_model_actions_from_raw,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="OpenPI processor parity uses the PaliGemma tokenizer; run manually outside CI.",
|
||||
)
|
||||
|
||||
DUMMY_ACTION_DIM = 32
|
||||
DUMMY_STATE_DIM = 32
|
||||
DUMMY_ACTION_HORIZON = 50
|
||||
DUMMY_MAX_TOKEN_LEN = 48
|
||||
DEVICE = torch.device("cpu")
|
||||
|
||||
DUMMY_DATASET_STATS = {
|
||||
OBS_STATE: {
|
||||
"mean": torch.zeros(DUMMY_STATE_DIM),
|
||||
"std": torch.ones(DUMMY_STATE_DIM),
|
||||
"q01": torch.zeros(DUMMY_STATE_DIM),
|
||||
"q99": torch.ones(DUMMY_STATE_DIM),
|
||||
},
|
||||
ACTION: {
|
||||
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"std": torch.ones(DUMMY_ACTION_DIM),
|
||||
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
||||
"q99": torch.ones(DUMMY_ACTION_DIM),
|
||||
},
|
||||
"images": {
|
||||
key: {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
"q01": torch.zeros(3, 224, 224),
|
||||
"q99": torch.ones(3, 224, 224),
|
||||
}
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class PI0PolicyInputAdapter(torch.nn.Module):
|
||||
"""Minimal adapter exposing PI0 policy input-preparation helpers without loading model weights."""
|
||||
|
||||
_preprocess_images = PI0Policy._preprocess_images
|
||||
prepare_state = PI0Policy.prepare_state
|
||||
|
||||
def __init__(self, config: PI0Config) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._device_anchor = torch.nn.Parameter(torch.empty((), device=config.device), requires_grad=False)
|
||||
|
||||
|
||||
def create_pi0_config() -> PI0Config:
|
||||
config = PI0Config(device=str(DEVICE))
|
||||
config.max_state_dim = DUMMY_STATE_DIM
|
||||
config.max_action_dim = DUMMY_ACTION_DIM
|
||||
config.chunk_size = DUMMY_ACTION_HORIZON
|
||||
config.n_action_steps = DUMMY_ACTION_HORIZON
|
||||
config.tokenizer_max_length = DUMMY_MAX_TOKEN_LEN
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(DUMMY_STATE_DIM,)),
|
||||
**{
|
||||
f"observation.images.{key}": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224))
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(DUMMY_ACTION_DIM,)),
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def create_dummy_data() -> dict:
|
||||
batch_size = 2
|
||||
prompt = "Pick up the red block and place it in the bin"
|
||||
return {
|
||||
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
||||
ACTION: torch.randn(
|
||||
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
||||
),
|
||||
**{
|
||||
f"observation.images.{key}": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
||||
)
|
||||
for key in IMAGE_KEYS
|
||||
},
|
||||
"task": [prompt for _ in range(batch_size)],
|
||||
}
|
||||
|
||||
|
||||
def test_pi0_processor_inputs_match_openpi_reference():
|
||||
torch.manual_seed(0)
|
||||
config = create_pi0_config()
|
||||
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=DUMMY_DATASET_STATS)
|
||||
|
||||
raw_batch = create_dummy_data()
|
||||
lerobot_batch = preprocessor(clone_batch(raw_batch))
|
||||
openpi_observation = make_openpi_observation_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=False,
|
||||
)
|
||||
|
||||
assert_processor_inputs_match_lerobot(
|
||||
PI0PolicyInputAdapter(config),
|
||||
lerobot_batch,
|
||||
openpi_observation,
|
||||
compare_state=True,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
lerobot_batch[ACTION],
|
||||
openpi_model_actions_from_raw(
|
||||
raw_batch,
|
||||
action_dim=DUMMY_ACTION_DIM,
|
||||
dataset_stats=DUMMY_DATASET_STATS,
|
||||
pi05=False,
|
||||
),
|
||||
rtol=0,
|
||||
atol=0,
|
||||
)
|
||||
Reference in New Issue
Block a user