diff --git a/pyproject.toml b/pyproject.toml index bb5aadc8e..5ba99b946 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ version = "0.3.4" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" readme = "README.md" license = { text = "Apache-2.0" } -requires-python = ">=3.10" +requires-python = ">=3.11" authors = [ { name = "Rémi Cadène", email = "re.cadene@gmail.com" }, { name = "Simon Alibert", email = "alibert.sim@gmail.com" }, @@ -50,7 +50,7 @@ classifiers = [ "Intended Audience :: Education", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Software Development :: Build Tools", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] @@ -280,7 +280,7 @@ default.extend-ignore-identifiers-re = [ # paths = ["src/lerobot"] # [tool.mypy] -# python_version = "3.10" +# python_version = "3.11" # warn_return_any = true # warn_unused_configs = true # ignore_missing_imports = false diff --git a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py index 39281204e..9ff71152a 100644 --- a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py +++ b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py @@ -517,6 +517,19 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` torch.set_float32_matmul_precision("high") self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode) + msg = """transformers_replace is not installed correctly. +Please install it with `pip install transformers==4.53.2` +and `cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \ +$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")`""" + + try: + from transformers.models.siglip import check + + if not check.check_whether_transformers_replace_is_installed_correctly(): + raise ValueError(msg) + except ImportError: + raise ValueError(msg) from None + def gradient_checkpointing_enable(self): """Enable gradient checkpointing for memory optimization.""" self.gradient_checkpointing_enabled = True @@ -851,41 +864,35 @@ class PI05OpenPIPolicy(PreTrainedPolicy): @classmethod def from_pretrained( - cls, *args, **kwargs + cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs ): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict """Override the from_pretrained method to handle key remapping and display important disclaimer.""" print( - "⚠️ DISCLAIMER: The PI05OpenPI model is a direct PyTorch port of the OpenPI implementation. \n" + "⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n" " This implementation follows the original OpenPI structure for compatibility. \n" " Original implementation: https://github.com/Physical-Intelligence/openpi" ) + if pretrained_name_or_path is None: + raise ValueError("pretrained_name_or_path is required") - # Store original strict mode - original_strict = kwargs.get("strict", True) - # Temporarily set strict=False to avoid loading issues, we'll handle it manually - kwargs["strict"] = False + # Create default config + config = cls.config_class() - # Call parent from_pretrained with strict=False - model = super().from_pretrained(*args, **kwargs) - - # Extract the pretrained_model_name_or_path from args or kwargs for remapping - if len(args) > 0: - pretrained_model_name_or_path = args[0] - elif "pretrained_model_name_or_path" in kwargs: - pretrained_model_name_or_path = kwargs["pretrained_model_name_or_path"] - else: - return model + # Initialize model without loading weights + # Check if dataset_stats were provided in kwargs + dataset_stats = kwargs.get("dataset_stats") + model = cls(config=config, dataset_stats=dataset_stats) # Now manually load and remap the state dict try: - from transformers.utils import cached_file - # Try to load the pytorch_model.bin or model.safetensors file - print(f"Loading model from: {pretrained_model_name_or_path}") + print(f"Loading model from: {pretrained_name_or_path}") try: + from transformers.utils import cached_file + # Try safetensors first resolved_file = cached_file( - pretrained_model_name_or_path, + pretrained_name_or_path, "model.safetensors", cache_dir=kwargs.get("cache_dir"), force_download=kwargs.get("force_download", False), @@ -901,9 +908,10 @@ class PI05OpenPIPolicy(PreTrainedPolicy): print("✓ Loaded state dict from model.safetensors") except Exception as e: print(f"Could not load state dict from remote files: {e}") + print("Returning model without loading pretrained weights") return model - # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` + # First, fix any pi key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) # Then add "model." prefix for all keys that don't already have it @@ -926,10 +934,10 @@ class PI05OpenPIPolicy(PreTrainedPolicy): print(f"Total keys remapped: {remap_count}") # Load the remapped state dict into the model - missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=original_strict) + missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) if missing_keys: - print(f"⚠️ Missing keys when loading state dict: {len(missing_keys)} keys") + print(f"Missing keys when loading state dict: {len(missing_keys)} keys") if len(missing_keys) <= 5: for key in missing_keys: print(f" - {key}") @@ -939,7 +947,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy): print(f" ... and {len(missing_keys) - 5} more") if unexpected_keys: - print(f"⚠️ Unexpected keys when loading state dict: {len(unexpected_keys)} keys") + print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys") if len(unexpected_keys) <= 5: for key in unexpected_keys: print(f" - {key}") @@ -949,11 +957,10 @@ class PI05OpenPIPolicy(PreTrainedPolicy): print(f" ... and {len(unexpected_keys) - 5} more") if not missing_keys and not unexpected_keys: - print("✅ All keys loaded successfully!") + print("All keys loaded successfully!") except Exception as e: - print(f"⚠️ Warning: Could not remap state dict keys: {e}") - print("Using default loading behavior") + print(f"Warning: Could not remap state dict keys: {e}") return model @@ -974,14 +981,22 @@ class PI05OpenPIPolicy(PreTrainedPolicy): r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight", key, ): - # This key structure suggests old model without adaRMS - keep as is or skip - logging.warning(f"Skipping old layer norm key (no adaRMS support): {key}") - continue + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}") + continue if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key): - # Skip old norm structure - logging.warning(f"Skipping old norm key (no adaRMS support): {key}") - continue + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping norm key (adaRMS mismatch): {key}") + continue # Handle MLP naming changes for pi05 # pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_* @@ -1232,7 +1247,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy): # Action queue logic for n_action_steps > 1 if len(self._action_queue) == 0: - actions = self.predict_action_chunk(batch, train=False)[:, : self.config.n_action_steps] + actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] # Transpose to get shape (n_action_steps, batch_size, action_dim) self._action_queue.extend(actions.transpose(0, 1)) diff --git a/src/lerobot/policies/pi05_openpi/transformers_replace/models/gemma/modeling_gemma.py b/src/lerobot/policies/pi05_openpi/transformers_replace/models/gemma/modeling_gemma.py index a3f6b5325..05066afc5 100644 --- a/src/lerobot/policies/pi05_openpi/transformers_replace/models/gemma/modeling_gemma.py +++ b/src/lerobot/policies/pi05_openpi/transformers_replace/models/gemma/modeling_gemma.py @@ -374,9 +374,8 @@ class GemmaDecoderLayer(GradientCheckpointingLayer): output_attentions: bool | None = False, use_cache: bool | None = False, cache_position: torch.LongTensor | None = None, - position_embeddings: ( - None | tuple[torch.Tensor, torch.Tensor] - ) = None, # necessary, but kept here for BC + position_embeddings: None + | (tuple[torch.Tensor, torch.Tensor]) = None, # necessary, but kept here for BC adarms_cond: torch.Tensor | None = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: @@ -540,7 +539,7 @@ class GemmaModel(GemmaPreTrainedModel): # normalized # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 - normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) # noqa: F841 + _normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) # hidden_states = hidden_states * normalizer # decoder layers diff --git a/src/lerobot/policies/pi05_openpi/transformers_replace/models/paligemma/modeling_paligemma.py b/src/lerobot/policies/pi05_openpi/transformers_replace/models/paligemma/modeling_paligemma.py index feb709b5b..b2a36b5ca 100644 --- a/src/lerobot/policies/pi05_openpi/transformers_replace/models/paligemma/modeling_paligemma.py +++ b/src/lerobot/policies/pi05_openpi/transformers_replace/models/paligemma/modeling_paligemma.py @@ -336,7 +336,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel): is_training = token_type_ids is not None and labels is not None - # Replace image id worth PAD if the image token if OOV, to avoid index-errors + # Replace image id with PAD if the image token if OOV, to avoid index-errors if input_ids is not None and self.config.image_token_id >= self.vocab_size: special_image_mask = input_ids == self.config.image_token_id llm_input_ids = input_ids.clone() @@ -450,7 +450,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi def get_image_features(self, pixel_values): return self.model.get_image_features(pixel_values) - # Make modules available conditional class for BC + # Make modules available through conditional class for BC @property def language_model(self): return self.model.language_model diff --git a/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/check.py b/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/check.py new file mode 100644 index 000000000..d899dc1b9 --- /dev/null +++ b/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/check.py @@ -0,0 +1,5 @@ +import transformers + + +def check_whether_transformers_replace_is_installed_correctly(): + return transformers.__version__ == "4.53.2" diff --git a/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/modeling_siglip.py b/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/modeling_siglip.py index d0e634472..0fc0bba0f 100644 --- a/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/modeling_siglip.py +++ b/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/modeling_siglip.py @@ -78,7 +78,7 @@ def _trunc_normal_(tensor, mean, std, a, b): # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) + tensor.uniform_(2 * l - 1, 2 * u - 1) # noqa: E741 # Use inverse cdf transform for normal distribution to get truncated # standard normal diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py index 35c6f7c9a..549dc0a9b 100644 --- a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py @@ -517,6 +517,21 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` if config.compile_model: torch.set_float32_matmul_precision("high") self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode) + # Also compile the main forward pass used during training + self.forward = torch.compile(self.forward, mode=config.compile_mode) + + msg = """transformers_replace is not installed correctly. +Please install it with `pip install transformers==4.53.2` +and `cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \ +$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")`""" + + try: + from transformers.models.siglip import check + + if not check.check_whether_transformers_replace_is_installed_correctly(): + raise ValueError(msg) + except ImportError: + raise ValueError(msg) from None def gradient_checkpointing_enable(self): """Enable gradient checkpointing for memory optimization.""" @@ -868,7 +883,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy): @classmethod def from_pretrained( - cls, *args, **kwargs + cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs ): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict """Override the from_pretrained method to handle key remapping and display important disclaimer.""" print( @@ -876,33 +891,27 @@ class PI0OpenPIPolicy(PreTrainedPolicy): " This implementation follows the original OpenPI structure for compatibility. \n" " Original implementation: https://github.com/Physical-Intelligence/openpi" ) + if pretrained_name_or_path is None: + raise ValueError("pretrained_name_or_path is required") - # Store original strict mode - original_strict = kwargs.get("strict", True) - # Temporarily set strict=False to avoid loading issues, we'll handle it manually - kwargs["strict"] = False + # Create default config + config = cls.config_class() - # Call parent from_pretrained with strict=False - model = super().from_pretrained(*args, **kwargs) - - # Extract the pretrained_model_name_or_path from args or kwargs for remapping - if len(args) > 0: - pretrained_model_name_or_path = args[0] - elif "pretrained_model_name_or_path" in kwargs: - pretrained_model_name_or_path = kwargs["pretrained_model_name_or_path"] - else: - return model + # Initialize model without loading weights + # Check if dataset_stats were provided in kwargs + dataset_stats = kwargs.get("dataset_stats") + model = cls(config=config, dataset_stats=dataset_stats) # Now manually load and remap the state dict try: - from transformers.utils import cached_file - # Try to load the pytorch_model.bin or model.safetensors file - print(f"Loading model from: {pretrained_model_name_or_path}") + print(f"Loading model from: {pretrained_name_or_path}") try: + from transformers.utils import cached_file + # Try safetensors first resolved_file = cached_file( - pretrained_model_name_or_path, + pretrained_name_or_path, "model.safetensors", cache_dir=kwargs.get("cache_dir"), force_download=kwargs.get("force_download", False), @@ -918,6 +927,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy): print("✓ Loaded state dict from model.safetensors") except Exception as e: print(f"Could not load state dict from remote files: {e}") + print("Returning model without loading pretrained weights") return model # First, fix any pi key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` @@ -943,10 +953,10 @@ class PI0OpenPIPolicy(PreTrainedPolicy): print(f"Total keys remapped: {remap_count}") # Load the remapped state dict into the model - missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=original_strict) + missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) if missing_keys: - print(f"⚠️ Missing keys when loading state dict: {len(missing_keys)} keys") + print(f"Missing keys when loading state dict: {len(missing_keys)} keys") if len(missing_keys) <= 5: for key in missing_keys: print(f" - {key}") @@ -956,7 +966,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy): print(f" ... and {len(missing_keys) - 5} more") if unexpected_keys: - print(f"⚠️ Unexpected keys when loading state dict: {len(unexpected_keys)} keys") + print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys") if len(unexpected_keys) <= 5: for key in unexpected_keys: print(f" - {key}") @@ -966,11 +976,10 @@ class PI0OpenPIPolicy(PreTrainedPolicy): print(f" ... and {len(unexpected_keys) - 5} more") if not missing_keys and not unexpected_keys: - print("✅ All keys loaded successfully!") + print("All keys loaded successfully!") except Exception as e: - print(f"⚠️ Warning: Could not remap state dict keys: {e}") - print("Using default loading behavior") + print(f"Warning: Could not remap state dict keys: {e}") return model @@ -991,14 +1000,22 @@ class PI0OpenPIPolicy(PreTrainedPolicy): r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight", key, ): - # This key structure suggests old model without adaRMS - keep as is or skip - logging.warning(f"Skipping old layer norm key (no adaRMS support): {key}") - continue + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}") + continue if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key): - # Skip old norm structure - logging.warning(f"Skipping old norm key (no adaRMS support): {key}") - continue + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping norm key (adaRMS mismatch): {key}") + continue # Handle MLP naming changes for pi0 # non-pi05 model expects action_time_mlp_*, but checkpoint might have time_mlp_* @@ -1245,7 +1262,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy): # Action queue logic for n_action_steps > 1 if len(self._action_queue) == 0: - actions = self.predict_action_chunk(batch, train=False)[:, : self.config.n_action_steps] + actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] # Transpose to get shape (n_action_steps, batch_size, action_dim) self._action_queue.extend(actions.transpose(0, 1)) diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py b/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py index a3f6b5325..05066afc5 100644 --- a/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py +++ b/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py @@ -374,9 +374,8 @@ class GemmaDecoderLayer(GradientCheckpointingLayer): output_attentions: bool | None = False, use_cache: bool | None = False, cache_position: torch.LongTensor | None = None, - position_embeddings: ( - None | tuple[torch.Tensor, torch.Tensor] - ) = None, # necessary, but kept here for BC + position_embeddings: None + | (tuple[torch.Tensor, torch.Tensor]) = None, # necessary, but kept here for BC adarms_cond: torch.Tensor | None = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: @@ -540,7 +539,7 @@ class GemmaModel(GemmaPreTrainedModel): # normalized # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 - normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) # noqa: F841 + _normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) # hidden_states = hidden_states * normalizer # decoder layers diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/paligemma/modeling_paligemma.py b/src/lerobot/policies/pi0_openpi/transformers_replace/models/paligemma/modeling_paligemma.py index feb709b5b..b2a36b5ca 100644 --- a/src/lerobot/policies/pi0_openpi/transformers_replace/models/paligemma/modeling_paligemma.py +++ b/src/lerobot/policies/pi0_openpi/transformers_replace/models/paligemma/modeling_paligemma.py @@ -336,7 +336,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel): is_training = token_type_ids is not None and labels is not None - # Replace image id worth PAD if the image token if OOV, to avoid index-errors + # Replace image id with PAD if the image token if OOV, to avoid index-errors if input_ids is not None and self.config.image_token_id >= self.vocab_size: special_image_mask = input_ids == self.config.image_token_id llm_input_ids = input_ids.clone() @@ -450,7 +450,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi def get_image_features(self, pixel_values): return self.model.get_image_features(pixel_values) - # Make modules available conditional class for BC + # Make modules available through conditional class for BC @property def language_model(self): return self.model.language_model diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/check.py b/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/check.py new file mode 100644 index 000000000..d899dc1b9 --- /dev/null +++ b/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/check.py @@ -0,0 +1,5 @@ +import transformers + + +def check_whether_transformers_replace_is_installed_correctly(): + return transformers.__version__ == "4.53.2" diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/modeling_siglip.py b/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/modeling_siglip.py index d0e634472..0fc0bba0f 100644 --- a/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/modeling_siglip.py +++ b/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/modeling_siglip.py @@ -78,7 +78,7 @@ def _trunc_normal_(tensor, mean, std, a, b): # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) + tensor.uniform_(2 * l - 1, 2 * u - 1) # noqa: E741 # Use inverse cdf transform for normal distribution to get truncated # standard normal diff --git a/test_pi0_original_vs_lerobot.py b/test_pi0_original_vs_lerobot.py new file mode 100644 index 000000000..68b62d110 --- /dev/null +++ b/test_pi0_original_vs_lerobot.py @@ -0,0 +1,396 @@ +"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation.""" + +import os + +import torch +from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing + +# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions. +from openpi.models_pytorch.pi0_pytorch import PI0Pytorch +from transformers import AutoTokenizer + +from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy + +DUMMY_ACTION_DIM = 32 +DUMMY_STATE_DIM = 32 +DUMMY_ACTION_HORIZON = 50 +DUMMY_MAX_TOKEN_LEN = 48 # Default for PI0 (non-pi05) +DEVICE = "cpu" # Use CPU to avoid memory issues for testing + +DUMMY_DATASET_STATS = { + "observation.state": { + "mean": torch.zeros(DUMMY_STATE_DIM), + "std": torch.ones(DUMMY_STATE_DIM), + }, + "action": { + "mean": torch.zeros(DUMMY_ACTION_DIM), + "std": torch.ones(DUMMY_ACTION_DIM), + }, + "images": { + "base_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + }, + "left_wrist_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + }, + "right_wrist_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + }, + }, +} + + +class PI0BaseOriginalConfig: + action_dim: int = DUMMY_ACTION_DIM + action_horizon: int = DUMMY_ACTION_HORIZON + paligemma_variant: str = "gemma_2b" + action_expert_variant: str = "gemma_300m" + precision: str = "float32" + pi05: bool = False + dtype: str = "float32" + + +def instantiate_lerobot_pi0(from_pretrained: bool = False): + if from_pretrained: + # Load the policy first + policy = PI0OpenPIPolicy.from_pretrained( + pretrained_name_or_path="pepijn223/pi0_base_fp32", strict=True + ) + # Then reinitialize the normalization with proper stats + from lerobot.policies.normalize import Normalize, Unnormalize + + policy.normalize_inputs = Normalize( + policy.config.input_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS + ) + policy.normalize_targets = Normalize( + policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS + ) + policy.unnormalize_outputs = Unnormalize( + policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS + ) + else: + config = PI0OpenPIConfig(action_dim=DUMMY_ACTION_DIM, state_dim=DUMMY_STATE_DIM, dtype="float32") + policy = PI0OpenPIPolicy(config, DUMMY_DATASET_STATS) + policy.to(DEVICE) + return policy + + +def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None): + config = PI0BaseOriginalConfig() + policy = PI0Pytorch(config) + + if from_pretrained: + try: + print("Loading converted PyTorch weights from HuggingFace Hub (pepijn223/pi0_base_fp32)...") + + # Download the model from HuggingFace Hub + import safetensors.torch + from huggingface_hub import snapshot_download + + # Download the entire repository + if model_path and os.path.exists(model_path): + cache_dir = model_path + print(f"Using cached model from: {cache_dir}") + else: + cache_dir = snapshot_download(repo_id="pepijn223/pi0_base_fp32", repo_type="model") + print(f"Downloaded model to: {cache_dir}") + + # Try to load safetensors format first + model_file = os.path.join(cache_dir, "model.safetensors") + if os.path.exists(model_file): + state_dict = safetensors.torch.load_file(model_file) + print(f"Loaded {len(state_dict)} parameters from safetensors") + else: + raise FileNotFoundError(f"No safetensors file found in {cache_dir}") + + # Load the state dict into the model + missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False) + + if missing_keys: + print(f"Missing keys: {len(missing_keys)}") + if len(missing_keys) <= 5: + for key in missing_keys: + print(f" - {key}") + else: + for key in missing_keys[:5]: + print(f" - {key}") + print(f" ... and {len(missing_keys) - 5} more") + + if unexpected_keys: + print(f"Unexpected keys: {len(unexpected_keys)}") + if len(unexpected_keys) <= 5: + for key in unexpected_keys: + print(f" - {key}") + else: + for key in unexpected_keys[:5]: + print(f" - {key}") + print(f" ... and {len(unexpected_keys) - 5} more") + + if not missing_keys and not unexpected_keys: + print("All pretrained weights loaded successfully!") + else: + print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)") + + except Exception as e: + print(f"Failed to load pretrained weights: {e}") + print(" Using randomly initialized weights...") + import traceback + + traceback.print_exc() + + policy.to(DEVICE) + return policy + + +def create_dummy_data(): + batch_size = 2 # Reduce batch size for testing + device = DEVICE + + # Use the exact same prompt for both implementations + prompt = "Pick up the red block and place it in the bin" + + batch = { + "observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device), + "action": torch.randn( + batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device + ), + # Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally) + "observation.images.base_0_rgb": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ), + "observation.images.left_wrist_0_rgb": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ), + "observation.images.right_wrist_0_rgb": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ), + # Add the task prompt for LeRobot - provide as list with single element to trigger expansion + "task": [prompt], + } + return batch + + +def extract_lerobot_processed_inputs(lerobot_pi0, batch): + """Extract the exact same processed inputs that LeRobot uses internally.""" + # Get the tokenized language from LeRobot's internal method + lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch) + + # Get the preprocessed images from LeRobot's internal method + images, img_masks = lerobot_pi0._preprocess_images(batch, train=False) + + # Create dummy token_ar_mask and token_loss_mask for original implementation + token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32) + token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool) + + return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask + + +class PI0Observation: + """Observation class that matches the original OpenPI format.""" + + def __init__( + self, + state, + images, + image_masks, + tokenized_prompt, + tokenized_prompt_mask, + token_ar_mask, + token_loss_mask, + ): + self.state = state + self.images = images + self.image_masks = image_masks + self.tokenized_prompt = tokenized_prompt + self.tokenized_prompt_mask = tokenized_prompt_mask + self.token_ar_mask = token_ar_mask + self.token_loss_mask = token_loss_mask + + +def create_original_observation_with_openpi_preprocessing(batch): + """Create observation object for OpenPI using OpenPI's own preprocessing.""" + batch_size = batch["observation.state"].shape[0] + device = batch["observation.state"].device + + # Create tokenizer for OpenPI (same as LeRobot uses) + tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") + + # Get task description + if "task" in batch: + tasks = batch["task"] + if isinstance(tasks, str): + tasks = [tasks] + elif isinstance(tasks, list) and len(tasks) == 1: + # Expand to batch size + tasks = tasks * batch_size + else: + # Default task if not provided + tasks = ["Pick up the object"] * batch_size + + # Tokenize with max_length padding to match OpenPI's expected format + tokenized = tokenizer( + tasks, + padding="max_length", + padding_side="right", + truncation=True, + max_length=DUMMY_MAX_TOKEN_LEN, + return_tensors="pt", + ) + + lang_tokens = tokenized["input_ids"].to(device) + lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool) + + # Create dummy token_ar_mask and token_loss_mask for OpenPI + token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32) + token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool) + + # Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range) + image_dict = { + "base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0, + "left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0, + "right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0, + } + + # Create image masks (all ones for real images) + image_masks_dict = {} + for key in image_dict: + image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device) + + # Create raw observation object (before preprocessing) + raw_observation = PI0Observation( + state=batch["observation.state"], + images=image_dict, + image_masks=image_masks_dict, + tokenized_prompt=lang_tokens, + tokenized_prompt_mask=lang_masks, + token_ar_mask=token_ar_mask, + token_loss_mask=token_loss_mask, + ) + + # Now use OpenPI's preprocessing + processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False) + + return processed_obs + + +def create_original_observation_from_lerobot(lerobot_pi0, batch): + """Create observation object compatible with original OpenPI using the exact same inputs as LeRobot.""" + _batch_size = batch["observation.state"].shape[0] + _device = batch["observation.state"].device + + # Extract the exact same processed inputs that LeRobot uses + images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = ( + extract_lerobot_processed_inputs(lerobot_pi0, batch) + ) + + # Convert images list to dict with original OpenPI keys + image_dict = { + "base_0_rgb": images[0], + "left_wrist_0_rgb": images[1], + "right_wrist_0_rgb": images[2], + } + + # Convert image masks list to dict with original OpenPI keys + image_masks_dict = { + "base_0_rgb": img_masks[0], + "left_wrist_0_rgb": img_masks[1], + "right_wrist_0_rgb": img_masks[2], + } + + return PI0Observation( + state=batch["observation.state"], + images=image_dict, + image_masks=image_masks_dict, + tokenized_prompt=lang_tokens, + tokenized_prompt_mask=lang_masks, + token_ar_mask=token_ar_mask, + token_loss_mask=token_loss_mask, + ) + + +def main(): + print("Initializing models...") + lerobot_pi0 = instantiate_lerobot_pi0(from_pretrained=True) # Load pretrained LeRobot model + original_pi0 = instantiate_original_pi0( + from_pretrained=True + ) # Load pretrained OpenPI model from HuggingFace Hub + + print("Creating dummy data...") + batch = create_dummy_data() + + # Test 1: Each model with its own preprocessing (more realistic end-to-end test) + print("\n=== TEST 1: Each model with its own preprocessing ===") + print("Creating observation for OpenPI using OpenPI's own preprocessing...") + pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch) + + print(f"Task prompt: '{batch['task'][0]}'") + print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}") + print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}") + print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}") + + print("Testing OpenPI with own preprocessing...") + original_pi0.eval() + torch.manual_seed(42) # Set seed for reproducibility + batch_size = batch["observation.state"].shape[0] + noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM) + fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE) + + with torch.no_grad(): + openpi_actions = original_pi0.sample_actions( + device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10 + ) + print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}") + print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}") + print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}") + + print("Testing LeRobot with own preprocessing...") + lerobot_pi0.eval() + torch.manual_seed(42) # Set the same seed + with torch.no_grad(): + lerobot_actions_own = lerobot_pi0.predict_action_chunk(batch) + print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}") + print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}") + print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}") + + print("\nComparing end-to-end implementations:") + print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}") + print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}") + print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}") + + # Test 2: Both models with LeRobot preprocessing (isolates model differences) + print("\n=== TEST 2: Both models with LeRobot preprocessing (model comparison) ===") + print("Creating observation for OpenPI using LeRobot's preprocessing...") + pi0_obs_lerobot = create_original_observation_from_lerobot(lerobot_pi0, batch) + + print("Testing OpenPI with LeRobot preprocessing...") + torch.manual_seed(42) # Set seed for reproducibility + with torch.no_grad(): + openpi_actions_lerobot_preproc = original_pi0.sample_actions( + device=DEVICE, observation=pi0_obs_lerobot, noise=fixed_noise, num_steps=10 + ) + print(f"OpenPI (LeRobot preprocessing) Actions shape: {openpi_actions_lerobot_preproc.shape}") + print(f"OpenPI (LeRobot preprocessing) Actions mean: {openpi_actions_lerobot_preproc.mean().item():.6f}") + print(f"OpenPI (LeRobot preprocessing) Actions std: {openpi_actions_lerobot_preproc.std().item():.6f}") + + print("\nComparing models with same preprocessing:") + print( + f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-4)}" + ) + print( + f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-2)}" + ) + print( + f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions_lerobot_preproc).max().item():.6f}" + ) + + print("\n=== SUMMARY ===") + print("Test 1 compares end-to-end pipelines (each model with its own preprocessing)") + print("Test 2 isolates model differences (both models with LeRobot preprocessing)") + print("Both tests completed successfully!") + + +if __name__ == "__main__": + main()