Merge branch 'feat/add_pi' into feat/validate_pi_libero

This commit is contained in:
Pepijn
2025-09-13 11:13:13 +02:00
12 changed files with 519 additions and 83 deletions

View File

@@ -29,7 +29,7 @@ version = "0.3.4"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
readme = "README.md"
license = { text = "Apache-2.0" }
requires-python = ">=3.10"
requires-python = ">=3.11"
authors = [
{ name = "Rémi Cadène", email = "re.cadene@gmail.com" },
{ name = "Simon Alibert", email = "alibert.sim@gmail.com" },
@@ -50,7 +50,7 @@ classifiers = [
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Topic :: Software Development :: Build Tools",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
@@ -280,7 +280,7 @@ default.extend-ignore-identifiers-re = [
# paths = ["src/lerobot"]
# [tool.mypy]
# python_version = "3.10"
# python_version = "3.11"
# warn_return_any = true
# warn_unused_configs = true
# ignore_missing_imports = false

View File

@@ -517,6 +517,19 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
torch.set_float32_matmul_precision("high")
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
msg = """transformers_replace is not installed correctly.
Please install it with `pip install transformers==4.53.2`
and `cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \
$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")`"""
try:
from transformers.models.siglip import check
if not check.check_whether_transformers_replace_is_installed_correctly():
raise ValueError(msg)
except ImportError:
raise ValueError(msg) from None
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
self.gradient_checkpointing_enabled = True
@@ -851,41 +864,35 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
@classmethod
def from_pretrained(
cls, *args, **kwargs
cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs
): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
print(
"⚠️ DISCLAIMER: The PI05OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
"⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
" This implementation follows the original OpenPI structure for compatibility. \n"
" Original implementation: https://github.com/Physical-Intelligence/openpi"
)
if pretrained_name_or_path is None:
raise ValueError("pretrained_name_or_path is required")
# Store original strict mode
original_strict = kwargs.get("strict", True)
# Temporarily set strict=False to avoid loading issues, we'll handle it manually
kwargs["strict"] = False
# Create default config
config = cls.config_class()
# Call parent from_pretrained with strict=False
model = super().from_pretrained(*args, **kwargs)
# Extract the pretrained_model_name_or_path from args or kwargs for remapping
if len(args) > 0:
pretrained_model_name_or_path = args[0]
elif "pretrained_model_name_or_path" in kwargs:
pretrained_model_name_or_path = kwargs["pretrained_model_name_or_path"]
else:
return model
# Initialize model without loading weights
# Check if dataset_stats were provided in kwargs
dataset_stats = kwargs.get("dataset_stats")
model = cls(config=config, dataset_stats=dataset_stats)
# Now manually load and remap the state dict
try:
# Try to load the pytorch_model.bin or model.safetensors file
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try to load the pytorch_model.bin or model.safetensors file
print(f"Loading model from: {pretrained_model_name_or_path}")
try:
# Try safetensors first
resolved_file = cached_file(
pretrained_model_name_or_path,
pretrained_name_or_path,
"model.safetensors",
cache_dir=kwargs.get("cache_dir"),
force_download=kwargs.get("force_download", False),
@@ -901,9 +908,10 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
print("✓ Loaded state dict from model.safetensors")
except Exception as e:
print(f"Could not load state dict from remote files: {e}")
print("Returning model without loading pretrained weights")
return model
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
# First, fix any pi key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
# Then add "model." prefix for all keys that don't already have it
@@ -926,10 +934,10 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
print(f"Total keys remapped: {remap_count}")
# Load the remapped state dict into the model
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=original_strict)
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
if missing_keys:
print(f"⚠️ Missing keys when loading state dict: {len(missing_keys)} keys")
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
if len(missing_keys) <= 5:
for key in missing_keys:
print(f" - {key}")
@@ -939,7 +947,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
print(f" ... and {len(missing_keys) - 5} more")
if unexpected_keys:
print(f"⚠️ Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
if len(unexpected_keys) <= 5:
for key in unexpected_keys:
print(f" - {key}")
@@ -949,11 +957,10 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
print(f" ... and {len(unexpected_keys) - 5} more")
if not missing_keys and not unexpected_keys:
print("All keys loaded successfully!")
print("All keys loaded successfully!")
except Exception as e:
print(f"⚠️ Warning: Could not remap state dict keys: {e}")
print("Using default loading behavior")
print(f"Warning: Could not remap state dict keys: {e}")
return model
@@ -974,13 +981,21 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight",
key,
):
# This key structure suggests old model without adaRMS - keep as is or skip
logging.warning(f"Skipping old layer norm key (no adaRMS support): {key}")
# Check if the model actually has adaRMS enabled for the expert
expert_uses_adarms = getattr(
self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
)
if expert_uses_adarms:
logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}")
continue
if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key):
# Skip old norm structure
logging.warning(f"Skipping old norm key (no adaRMS support): {key}")
# Check if the model actually has adaRMS enabled for the expert
expert_uses_adarms = getattr(
self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
)
if expert_uses_adarms:
logging.warning(f"Skipping norm key (adaRMS mismatch): {key}")
continue
# Handle MLP naming changes for pi05
@@ -1232,7 +1247,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
# Action queue logic for n_action_steps > 1
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch, train=False)[:, : self.config.n_action_steps]
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
# Transpose to get shape (n_action_steps, batch_size, action_dim)
self._action_queue.extend(actions.transpose(0, 1))

View File

@@ -374,9 +374,8 @@ class GemmaDecoderLayer(GradientCheckpointingLayer):
output_attentions: bool | None = False,
use_cache: bool | None = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: (
None | tuple[torch.Tensor, torch.Tensor]
) = None, # necessary, but kept here for BC
position_embeddings: None
| (tuple[torch.Tensor, torch.Tensor]) = None, # necessary, but kept here for BC
adarms_cond: torch.Tensor | None = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
@@ -540,7 +539,7 @@ class GemmaModel(GemmaPreTrainedModel):
# normalized
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) # noqa: F841
_normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
# hidden_states = hidden_states * normalizer
# decoder layers

View File

@@ -336,7 +336,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
is_training = token_type_ids is not None and labels is not None
# Replace image id worth PAD if the image token if OOV, to avoid index-errors
# Replace image id with PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id
llm_input_ids = input_ids.clone()
@@ -450,7 +450,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
def get_image_features(self, pixel_values):
return self.model.get_image_features(pixel_values)
# Make modules available conditional class for BC
# Make modules available through conditional class for BC
@property
def language_model(self):
return self.model.language_model

View File

@@ -0,0 +1,5 @@
import transformers
def check_whether_transformers_replace_is_installed_correctly():
return transformers.__version__ == "4.53.2"

View File

@@ -78,7 +78,7 @@ def _trunc_normal_(tensor, mean, std, a, b):
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
tensor.uniform_(2 * l - 1, 2 * u - 1) # noqa: E741
# Use inverse cdf transform for normal distribution to get truncated
# standard normal

View File

@@ -517,6 +517,21 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
if config.compile_model:
torch.set_float32_matmul_precision("high")
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
# Also compile the main forward pass used during training
self.forward = torch.compile(self.forward, mode=config.compile_mode)
msg = """transformers_replace is not installed correctly.
Please install it with `pip install transformers==4.53.2`
and `cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \
$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")`"""
try:
from transformers.models.siglip import check
if not check.check_whether_transformers_replace_is_installed_correctly():
raise ValueError(msg)
except ImportError:
raise ValueError(msg) from None
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
@@ -868,7 +883,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
@classmethod
def from_pretrained(
cls, *args, **kwargs
cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs
): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
print(
@@ -876,33 +891,27 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
" This implementation follows the original OpenPI structure for compatibility. \n"
" Original implementation: https://github.com/Physical-Intelligence/openpi"
)
if pretrained_name_or_path is None:
raise ValueError("pretrained_name_or_path is required")
# Store original strict mode
original_strict = kwargs.get("strict", True)
# Temporarily set strict=False to avoid loading issues, we'll handle it manually
kwargs["strict"] = False
# Create default config
config = cls.config_class()
# Call parent from_pretrained with strict=False
model = super().from_pretrained(*args, **kwargs)
# Extract the pretrained_model_name_or_path from args or kwargs for remapping
if len(args) > 0:
pretrained_model_name_or_path = args[0]
elif "pretrained_model_name_or_path" in kwargs:
pretrained_model_name_or_path = kwargs["pretrained_model_name_or_path"]
else:
return model
# Initialize model without loading weights
# Check if dataset_stats were provided in kwargs
dataset_stats = kwargs.get("dataset_stats")
model = cls(config=config, dataset_stats=dataset_stats)
# Now manually load and remap the state dict
try:
# Try to load the pytorch_model.bin or model.safetensors file
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try to load the pytorch_model.bin or model.safetensors file
print(f"Loading model from: {pretrained_model_name_or_path}")
try:
# Try safetensors first
resolved_file = cached_file(
pretrained_model_name_or_path,
pretrained_name_or_path,
"model.safetensors",
cache_dir=kwargs.get("cache_dir"),
force_download=kwargs.get("force_download", False),
@@ -918,6 +927,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
print("✓ Loaded state dict from model.safetensors")
except Exception as e:
print(f"Could not load state dict from remote files: {e}")
print("Returning model without loading pretrained weights")
return model
# First, fix any pi key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
@@ -943,10 +953,10 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
print(f"Total keys remapped: {remap_count}")
# Load the remapped state dict into the model
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=original_strict)
missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict)
if missing_keys:
print(f"⚠️ Missing keys when loading state dict: {len(missing_keys)} keys")
print(f"Missing keys when loading state dict: {len(missing_keys)} keys")
if len(missing_keys) <= 5:
for key in missing_keys:
print(f" - {key}")
@@ -956,7 +966,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
print(f" ... and {len(missing_keys) - 5} more")
if unexpected_keys:
print(f"⚠️ Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys")
if len(unexpected_keys) <= 5:
for key in unexpected_keys:
print(f" - {key}")
@@ -966,11 +976,10 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
print(f" ... and {len(unexpected_keys) - 5} more")
if not missing_keys and not unexpected_keys:
print("All keys loaded successfully!")
print("All keys loaded successfully!")
except Exception as e:
print(f"⚠️ Warning: Could not remap state dict keys: {e}")
print("Using default loading behavior")
print(f"Warning: Could not remap state dict keys: {e}")
return model
@@ -991,13 +1000,21 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight",
key,
):
# This key structure suggests old model without adaRMS - keep as is or skip
logging.warning(f"Skipping old layer norm key (no adaRMS support): {key}")
# Check if the model actually has adaRMS enabled for the expert
expert_uses_adarms = getattr(
self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
)
if expert_uses_adarms:
logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}")
continue
if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key):
# Skip old norm structure
logging.warning(f"Skipping old norm key (no adaRMS support): {key}")
# Check if the model actually has adaRMS enabled for the expert
expert_uses_adarms = getattr(
self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False
)
if expert_uses_adarms:
logging.warning(f"Skipping norm key (adaRMS mismatch): {key}")
continue
# Handle MLP naming changes for pi0
@@ -1245,7 +1262,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
# Action queue logic for n_action_steps > 1
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch, train=False)[:, : self.config.n_action_steps]
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
# Transpose to get shape (n_action_steps, batch_size, action_dim)
self._action_queue.extend(actions.transpose(0, 1))

View File

@@ -374,9 +374,8 @@ class GemmaDecoderLayer(GradientCheckpointingLayer):
output_attentions: bool | None = False,
use_cache: bool | None = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: (
None | tuple[torch.Tensor, torch.Tensor]
) = None, # necessary, but kept here for BC
position_embeddings: None
| (tuple[torch.Tensor, torch.Tensor]) = None, # necessary, but kept here for BC
adarms_cond: torch.Tensor | None = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
@@ -540,7 +539,7 @@ class GemmaModel(GemmaPreTrainedModel):
# normalized
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) # noqa: F841
_normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
# hidden_states = hidden_states * normalizer
# decoder layers

View File

@@ -336,7 +336,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
is_training = token_type_ids is not None and labels is not None
# Replace image id worth PAD if the image token if OOV, to avoid index-errors
# Replace image id with PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id
llm_input_ids = input_ids.clone()
@@ -450,7 +450,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
def get_image_features(self, pixel_values):
return self.model.get_image_features(pixel_values)
# Make modules available conditional class for BC
# Make modules available through conditional class for BC
@property
def language_model(self):
return self.model.language_model

View File

@@ -0,0 +1,5 @@
import transformers
def check_whether_transformers_replace_is_installed_correctly():
return transformers.__version__ == "4.53.2"

View File

@@ -78,7 +78,7 @@ def _trunc_normal_(tensor, mean, std, a, b):
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
tensor.uniform_(2 * l - 1, 2 * u - 1) # noqa: E741
# Use inverse cdf transform for normal distribution to get truncated
# standard normal

View File

@@ -0,0 +1,396 @@
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation."""
import os
import torch
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch
from transformers import AutoTokenizer
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy
DUMMY_ACTION_DIM = 32
DUMMY_STATE_DIM = 32
DUMMY_ACTION_HORIZON = 50
DUMMY_MAX_TOKEN_LEN = 48 # Default for PI0 (non-pi05)
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
DUMMY_DATASET_STATS = {
"observation.state": {
"mean": torch.zeros(DUMMY_STATE_DIM),
"std": torch.ones(DUMMY_STATE_DIM),
},
"action": {
"mean": torch.zeros(DUMMY_ACTION_DIM),
"std": torch.ones(DUMMY_ACTION_DIM),
},
"images": {
"base_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
},
"left_wrist_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
},
"right_wrist_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
},
},
}
class PI0BaseOriginalConfig:
action_dim: int = DUMMY_ACTION_DIM
action_horizon: int = DUMMY_ACTION_HORIZON
paligemma_variant: str = "gemma_2b"
action_expert_variant: str = "gemma_300m"
precision: str = "float32"
pi05: bool = False
dtype: str = "float32"
def instantiate_lerobot_pi0(from_pretrained: bool = False):
if from_pretrained:
# Load the policy first
policy = PI0OpenPIPolicy.from_pretrained(
pretrained_name_or_path="pepijn223/pi0_base_fp32", strict=True
)
# Then reinitialize the normalization with proper stats
from lerobot.policies.normalize import Normalize, Unnormalize
policy.normalize_inputs = Normalize(
policy.config.input_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
)
policy.normalize_targets = Normalize(
policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
)
policy.unnormalize_outputs = Unnormalize(
policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
)
else:
config = PI0OpenPIConfig(action_dim=DUMMY_ACTION_DIM, state_dim=DUMMY_STATE_DIM, dtype="float32")
policy = PI0OpenPIPolicy(config, DUMMY_DATASET_STATS)
policy.to(DEVICE)
return policy
def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None):
config = PI0BaseOriginalConfig()
policy = PI0Pytorch(config)
if from_pretrained:
try:
print("Loading converted PyTorch weights from HuggingFace Hub (pepijn223/pi0_base_fp32)...")
# Download the model from HuggingFace Hub
import safetensors.torch
from huggingface_hub import snapshot_download
# Download the entire repository
if model_path and os.path.exists(model_path):
cache_dir = model_path
print(f"Using cached model from: {cache_dir}")
else:
cache_dir = snapshot_download(repo_id="pepijn223/pi0_base_fp32", repo_type="model")
print(f"Downloaded model to: {cache_dir}")
# Try to load safetensors format first
model_file = os.path.join(cache_dir, "model.safetensors")
if os.path.exists(model_file):
state_dict = safetensors.torch.load_file(model_file)
print(f"Loaded {len(state_dict)} parameters from safetensors")
else:
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
# Load the state dict into the model
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
if missing_keys:
print(f"Missing keys: {len(missing_keys)}")
if len(missing_keys) <= 5:
for key in missing_keys:
print(f" - {key}")
else:
for key in missing_keys[:5]:
print(f" - {key}")
print(f" ... and {len(missing_keys) - 5} more")
if unexpected_keys:
print(f"Unexpected keys: {len(unexpected_keys)}")
if len(unexpected_keys) <= 5:
for key in unexpected_keys:
print(f" - {key}")
else:
for key in unexpected_keys[:5]:
print(f" - {key}")
print(f" ... and {len(unexpected_keys) - 5} more")
if not missing_keys and not unexpected_keys:
print("All pretrained weights loaded successfully!")
else:
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
except Exception as e:
print(f"Failed to load pretrained weights: {e}")
print(" Using randomly initialized weights...")
import traceback
traceback.print_exc()
policy.to(DEVICE)
return policy
def create_dummy_data():
batch_size = 2 # Reduce batch size for testing
device = DEVICE
# Use the exact same prompt for both implementations
prompt = "Pick up the red block and place it in the bin"
batch = {
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
"action": torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
),
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
"observation.images.base_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
"observation.images.left_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
"observation.images.right_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
"task": [prompt],
}
return batch
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
"""Extract the exact same processed inputs that LeRobot uses internally."""
# Get the tokenized language from LeRobot's internal method
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
# Get the preprocessed images from LeRobot's internal method
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
# Create dummy token_ar_mask and token_loss_mask for original implementation
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
class PI0Observation:
"""Observation class that matches the original OpenPI format."""
def __init__(
self,
state,
images,
image_masks,
tokenized_prompt,
tokenized_prompt_mask,
token_ar_mask,
token_loss_mask,
):
self.state = state
self.images = images
self.image_masks = image_masks
self.tokenized_prompt = tokenized_prompt
self.tokenized_prompt_mask = tokenized_prompt_mask
self.token_ar_mask = token_ar_mask
self.token_loss_mask = token_loss_mask
def create_original_observation_with_openpi_preprocessing(batch):
"""Create observation object for OpenPI using OpenPI's own preprocessing."""
batch_size = batch["observation.state"].shape[0]
device = batch["observation.state"].device
# Create tokenizer for OpenPI (same as LeRobot uses)
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
# Get task description
if "task" in batch:
tasks = batch["task"]
if isinstance(tasks, str):
tasks = [tasks]
elif isinstance(tasks, list) and len(tasks) == 1:
# Expand to batch size
tasks = tasks * batch_size
else:
# Default task if not provided
tasks = ["Pick up the object"] * batch_size
# Tokenize with max_length padding to match OpenPI's expected format
tokenized = tokenizer(
tasks,
padding="max_length",
padding_side="right",
truncation=True,
max_length=DUMMY_MAX_TOKEN_LEN,
return_tensors="pt",
)
lang_tokens = tokenized["input_ids"].to(device)
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
# Create dummy token_ar_mask and token_loss_mask for OpenPI
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
image_dict = {
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
}
# Create image masks (all ones for real images)
image_masks_dict = {}
for key in image_dict:
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
# Create raw observation object (before preprocessing)
raw_observation = PI0Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
)
# Now use OpenPI's preprocessing
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
return processed_obs
def create_original_observation_from_lerobot(lerobot_pi0, batch):
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
_batch_size = batch["observation.state"].shape[0]
_device = batch["observation.state"].device
# Extract the exact same processed inputs that LeRobot uses
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
extract_lerobot_processed_inputs(lerobot_pi0, batch)
)
# Convert images list to dict with original OpenPI keys
image_dict = {
"base_0_rgb": images[0],
"left_wrist_0_rgb": images[1],
"right_wrist_0_rgb": images[2],
}
# Convert image masks list to dict with original OpenPI keys
image_masks_dict = {
"base_0_rgb": img_masks[0],
"left_wrist_0_rgb": img_masks[1],
"right_wrist_0_rgb": img_masks[2],
}
return PI0Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
)
def main():
print("Initializing models...")
lerobot_pi0 = instantiate_lerobot_pi0(from_pretrained=True) # Load pretrained LeRobot model
original_pi0 = instantiate_original_pi0(
from_pretrained=True
) # Load pretrained OpenPI model from HuggingFace Hub
print("Creating dummy data...")
batch = create_dummy_data()
# Test 1: Each model with its own preprocessing (more realistic end-to-end test)
print("\n=== TEST 1: Each model with its own preprocessing ===")
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
print(f"Task prompt: '{batch['task'][0]}'")
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
print("Testing OpenPI with own preprocessing...")
original_pi0.eval()
torch.manual_seed(42) # Set seed for reproducibility
batch_size = batch["observation.state"].shape[0]
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
with torch.no_grad():
openpi_actions = original_pi0.sample_actions(
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
)
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
print("Testing LeRobot with own preprocessing...")
lerobot_pi0.eval()
torch.manual_seed(42) # Set the same seed
with torch.no_grad():
lerobot_actions_own = lerobot_pi0.predict_action_chunk(batch)
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
print("\nComparing end-to-end implementations:")
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
# Test 2: Both models with LeRobot preprocessing (isolates model differences)
print("\n=== TEST 2: Both models with LeRobot preprocessing (model comparison) ===")
print("Creating observation for OpenPI using LeRobot's preprocessing...")
pi0_obs_lerobot = create_original_observation_from_lerobot(lerobot_pi0, batch)
print("Testing OpenPI with LeRobot preprocessing...")
torch.manual_seed(42) # Set seed for reproducibility
with torch.no_grad():
openpi_actions_lerobot_preproc = original_pi0.sample_actions(
device=DEVICE, observation=pi0_obs_lerobot, noise=fixed_noise, num_steps=10
)
print(f"OpenPI (LeRobot preprocessing) Actions shape: {openpi_actions_lerobot_preproc.shape}")
print(f"OpenPI (LeRobot preprocessing) Actions mean: {openpi_actions_lerobot_preproc.mean().item():.6f}")
print(f"OpenPI (LeRobot preprocessing) Actions std: {openpi_actions_lerobot_preproc.std().item():.6f}")
print("\nComparing models with same preprocessing:")
print(
f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-4)}"
)
print(
f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-2)}"
)
print(
f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions_lerobot_preproc).max().item():.6f}"
)
print("\n=== SUMMARY ===")
print("Test 1 compares end-to-end pipelines (each model with its own preprocessing)")
print("Test 2 isolates model differences (both models with LeRobot preprocessing)")
print("Both tests completed successfully!")
if __name__ == "__main__":
main()