From bf90efa7e1495d1dd0718c00d5d59c9d7153e70e Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 12 Sep 2025 18:44:12 +0200 Subject: [PATCH 01/10] fix key match from pytorch state dict (similar keys to openpi implementation now) --- .../policies/pi0_openpi/modeling_pi0openpi.py | 20 +- test_pi0_original_vs_lerobot.py | 316 ++++++++++++++++++ 2 files changed, 330 insertions(+), 6 deletions(-) create mode 100644 test_pi0_original_vs_lerobot.py diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py index 35c6f7c9a..ce813fdb8 100644 --- a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py @@ -991,14 +991,22 @@ class PI0OpenPIPolicy(PreTrainedPolicy): r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight", key, ): - # This key structure suggests old model without adaRMS - keep as is or skip - logging.warning(f"Skipping old layer norm key (no adaRMS support): {key}") - continue + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}") + continue if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key): - # Skip old norm structure - logging.warning(f"Skipping old norm key (no adaRMS support): {key}") - continue + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping norm key (adaRMS mismatch): {key}") + continue # Handle MLP naming changes for pi0 # non-pi05 model expects action_time_mlp_*, but checkpoint might have time_mlp_* diff --git a/test_pi0_original_vs_lerobot.py b/test_pi0_original_vs_lerobot.py new file mode 100644 index 000000000..e5cdf3dd7 --- /dev/null +++ b/test_pi0_original_vs_lerobot.py @@ -0,0 +1,316 @@ +"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation.""" + +import os + +import torch + +# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions. +from openpi.models_pytorch.pi0_pytorch import PI0Pytorch + +from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy + +DUMMY_ACTION_DIM = 32 +DUMMY_STATE_DIM = 32 +DUMMY_ACTION_HORIZON = 50 +DUMMY_MAX_TOKEN_LEN = 48 # Default for PI0 (non-pi05) +DEVICE = "cpu" # Use CPU to avoid memory issues for testing + +DUMMY_DATASET_STATS = { + "observation.state": { + "mean": torch.zeros(DUMMY_STATE_DIM), + "std": torch.ones(DUMMY_STATE_DIM), + }, + "action": { + "mean": torch.zeros(DUMMY_ACTION_DIM), + "std": torch.ones(DUMMY_ACTION_DIM), + }, + "images": { + "base_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + }, + "left_wrist_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + }, + "right_wrist_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + }, + }, +} + + +class PI0BaseOriginalConfig: + action_dim: int = DUMMY_ACTION_DIM + action_horizon: int = DUMMY_ACTION_HORIZON + paligemma_variant: str = "gemma_2b" + action_expert_variant: str = "gemma_300m" + precision: str = "float32" + pi05: bool = False + dtype: str = "float32" + + +def instantiate_lerobot_pi0(from_pretrained: bool = False): + if from_pretrained: + # Load the policy first + policy = PI0OpenPIPolicy.from_pretrained("pepijn223/pi0_base_fp32") + # Then reinitialize the normalization with proper stats + from lerobot.policies.normalize import Normalize, Unnormalize + + policy.normalize_inputs = Normalize( + policy.config.input_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS + ) + policy.normalize_targets = Normalize( + policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS + ) + policy.unnormalize_outputs = Unnormalize( + policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS + ) + else: + config = PI0OpenPIConfig(action_dim=DUMMY_ACTION_DIM, state_dim=DUMMY_STATE_DIM, dtype="float32") + policy = PI0OpenPIPolicy(config, DUMMY_DATASET_STATS) + policy.to(DEVICE) + return policy + + +def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None): + config = PI0BaseOriginalConfig() + policy = PI0Pytorch(config) + + if from_pretrained: + try: + print("Loading converted PyTorch weights from HuggingFace Hub (pepijn223/pi0_base_fp32)...") + + # Download the model from HuggingFace Hub + import safetensors.torch + from huggingface_hub import snapshot_download + + # Download the entire repository + if model_path and os.path.exists(model_path): + cache_dir = model_path + print(f"Using cached model from: {cache_dir}") + else: + cache_dir = snapshot_download(repo_id="pepijn223/pi0_base_fp32", repo_type="model") + print(f"Downloaded model to: {cache_dir}") + + # Try to load safetensors format first + model_file = os.path.join(cache_dir, "model.safetensors") + if os.path.exists(model_file): + state_dict = safetensors.torch.load_file(model_file) + print(f"Loaded {len(state_dict)} parameters from safetensors") + else: + raise FileNotFoundError(f"No safetensors file found in {cache_dir}") + + # Load the state dict into the model + missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False) + + if missing_keys: + print(f"Missing keys: {len(missing_keys)}") + if len(missing_keys) <= 5: + for key in missing_keys: + print(f" - {key}") + else: + for key in missing_keys[:5]: + print(f" - {key}") + print(f" ... and {len(missing_keys) - 5} more") + + if unexpected_keys: + print(f"Unexpected keys: {len(unexpected_keys)}") + if len(unexpected_keys) <= 5: + for key in unexpected_keys: + print(f" - {key}") + else: + for key in unexpected_keys[:5]: + print(f" - {key}") + print(f" ... and {len(unexpected_keys) - 5} more") + + if not missing_keys and not unexpected_keys: + print("All pretrained weights loaded successfully!") + else: + print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)") + + except Exception as e: + print(f"Failed to load pretrained weights: {e}") + print(" Using randomly initialized weights...") + import traceback + + traceback.print_exc() + + policy.to(DEVICE) + return policy + + +def create_dummy_data(): + batch_size = 2 # Reduce batch size for testing + device = DEVICE + + # Use the exact same prompt for both implementations + prompt = "Pick up the red block and place it in the bin" + + batch = { + "observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device), + "action": torch.randn( + batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device + ), + # Create images in [-1, 1] range as expected by both implementations + "observation.images.base_0_rgb": torch.randn( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ).clamp(-1, 1), + "observation.images.left_wrist_0_rgb": torch.randn( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ).clamp(-1, 1), + "observation.images.right_wrist_0_rgb": torch.randn( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ).clamp(-1, 1), + # Add the task prompt for LeRobot - provide as list with single element to trigger expansion + "task": [prompt], + } + return batch + + +def extract_lerobot_processed_inputs(lerobot_pi0, batch): + """Extract the exact same processed inputs that LeRobot uses internally.""" + # Get the tokenized language from LeRobot's internal method + lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch) + + # Get the preprocessed images from LeRobot's internal method + images, img_masks = lerobot_pi0._preprocess_images(batch) + + # Create dummy token_ar_mask and token_loss_mask for original implementation + token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32) + token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool) + + return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask + + +class PI0Observation: + """Observation class that matches the original OpenPI format.""" + + def __init__( + self, + state, + images, + image_masks, + tokenized_prompt, + tokenized_prompt_mask, + token_ar_mask, + token_loss_mask, + ): + self.state = state + self.images = images + self.image_masks = image_masks + self.tokenized_prompt = tokenized_prompt + self.tokenized_prompt_mask = tokenized_prompt_mask + self.token_ar_mask = token_ar_mask + self.token_loss_mask = token_loss_mask + + +def create_original_observation_from_lerobot(lerobot_pi0, batch): + """Create observation object compatible with original OpenPI using the exact same inputs as LeRobot.""" + _batch_size = batch["observation.state"].shape[0] + _device = batch["observation.state"].device + + # Extract the exact same processed inputs that LeRobot uses + images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = ( + extract_lerobot_processed_inputs(lerobot_pi0, batch) + ) + + # Convert images list to dict with original OpenPI keys + image_dict = { + "base_0_rgb": images[0], + "left_wrist_0_rgb": images[1], + "right_wrist_0_rgb": images[2], + } + + # Convert image masks list to dict with original OpenPI keys + image_masks_dict = { + "base_0_rgb": img_masks[0], + "left_wrist_0_rgb": img_masks[1], + "right_wrist_0_rgb": img_masks[2], + } + + return PI0Observation( + state=batch["observation.state"], + images=image_dict, + image_masks=image_masks_dict, + tokenized_prompt=lang_tokens, + tokenized_prompt_mask=lang_masks, + token_ar_mask=token_ar_mask, + token_loss_mask=token_loss_mask, + ) + + +def main(): + print("Initializing models...") + lerobot_pi0 = instantiate_lerobot_pi0(from_pretrained=True) # Load pretrained LeRobot model + original_pi0 = instantiate_original_pi0( + from_pretrained=True + ) # Load pretrained OpenPI model from HuggingFace Hub + + print("Creating dummy data...") + batch = create_dummy_data() + + print("Creating observation for original PI0 using LeRobot's exact preprocessing...") + pi0_obs = create_original_observation_from_lerobot(lerobot_pi0, batch) + + # Verify both implementations get the same inputs + print(f"Task prompt: '{batch['task'][0]}'") + print(f"Tokenized prompt shape: {pi0_obs.tokenized_prompt.shape}") + print(f"Image shapes: {[img.shape for img in pi0_obs.images.values()]}") + print(f"State shape: {pi0_obs.state.shape}") + + print("Testing original PI0...") + + # Test training forward pass (returns loss) + print("1. Training forward pass (computing loss):") + original_pi0.train() + original_loss = original_pi0(observation=pi0_obs, actions=batch["action"]) + print(f" Loss shape: {original_loss.shape}, Mean loss: {original_loss.mean().item():.6f}") + + # Test inference (action sampling) with fixed noise for reproducibility + print("2. Inference (action sampling):") + original_pi0.eval() + + # Create the same noise for both implementations + torch.manual_seed(42) # Set seed for reproducibility + batch_size = batch["observation.state"].shape[0] + noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM) + fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE) + + with torch.no_grad(): + original_actions = original_pi0.sample_actions( + device=DEVICE, observation=pi0_obs, noise=fixed_noise, num_steps=10 + ) + print(f"Original PI0 Actions shape: {original_actions.shape}") + print(f"Original PI0 Actions mean: {original_actions.mean().item():.6f}") + print(f"Original PI0 Actions std: {original_actions.std().item():.6f}") + + # Test LeRobot implementation with the same noise + print("\nTesting LeRobot PI0...") + lerobot_pi0.eval() + + # For LeRobot, we need to modify the batch to force the same noise + # This is more complex since LeRobot generates noise internally + torch.manual_seed(42) # Set the same seed + with torch.no_grad(): + # lerobot_pi0_actions = lerobot_pi0.select_action(batch) + lerobot_pi0_actions = lerobot_pi0.predict_action_chunk(batch) + print(f"LeRobot actions shape: {lerobot_pi0_actions.shape}") + print(f"LeRobot actions mean: {lerobot_pi0_actions.mean().item():.6f}") + print(f"LeRobot actions std: {lerobot_pi0_actions.std().item():.6f}") + + print("\nComparing implementations:") + print(f"Original actions shape: {original_actions.shape}") + print(f"LeRobot actions shape: {lerobot_pi0_actions.shape}") + + # Compare the first action step (since LeRobot select_action returns a single step) + print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_pi0_actions, original_actions, atol=1e-4)}") + print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_pi0_actions, original_actions, atol=1e-2)}") + print(f"Max absolute difference: {torch.abs(lerobot_pi0_actions - original_actions).max().item():.6f}") + + print("\nOriginal PI0 test completed successfully!") + + +if __name__ == "__main__": + main() From 6ce2a00135a5345cbe5ba09948628cd098e06d37 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 12 Sep 2025 19:02:13 +0200 Subject: [PATCH 02/10] also for pi05 --- .../pi05_openpi/modeling_pi05openpi.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py index 39281204e..d9e86f830 100644 --- a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py +++ b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py @@ -974,14 +974,22 @@ class PI05OpenPIPolicy(PreTrainedPolicy): r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight", key, ): - # This key structure suggests old model without adaRMS - keep as is or skip - logging.warning(f"Skipping old layer norm key (no adaRMS support): {key}") - continue + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}") + continue if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key): - # Skip old norm structure - logging.warning(f"Skipping old norm key (no adaRMS support): {key}") - continue + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping norm key (adaRMS mismatch): {key}") + continue # Handle MLP naming changes for pi05 # pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_* From 990f8e9cc94ab8b6ab0c2535e81817fa5fbc6440 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 12 Sep 2025 19:04:42 +0200 Subject: [PATCH 03/10] update to python 3.11 --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d16d00776..d3be72b85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ version = "0.3.4" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" readme = "README.md" license = { text = "Apache-2.0" } -requires-python = ">=3.10" +requires-python = ">=3.11" authors = [ { name = "Rémi Cadène", email = "re.cadene@gmail.com" }, { name = "Simon Alibert", email = "alibert.sim@gmail.com" }, @@ -50,7 +50,7 @@ classifiers = [ "Intended Audience :: Education", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Software Development :: Build Tools", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] @@ -260,7 +260,7 @@ default.extend-ignore-identifiers-re = [ # paths = ["src/lerobot"] # [tool.mypy] -# python_version = "3.10" +# python_version = "3.11" # warn_return_any = true # warn_unused_configs = true # ignore_missing_imports = false From e94844fa59fdb327e7e3476538213c5644794d45 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 12 Sep 2025 20:00:21 +0200 Subject: [PATCH 04/10] revert to openpi transformer replace python 3.11 --- .../pi05_openpi/modeling_pi05openpi.py | 13 +++++ .../models/gemma/modeling_gemma.py | 44 +++++---------- .../models/paligemma/modeling_paligemma.py | 37 ++++--------- .../models/siglip/check.py | 5 ++ .../models/siglip/modeling_siglip.py | 53 ++++++------------- .../policies/pi0_openpi/modeling_pi0openpi.py | 13 +++++ .../models/gemma/modeling_gemma.py | 44 +++++---------- .../models/paligemma/modeling_paligemma.py | 37 ++++--------- .../models/siglip/check.py | 5 ++ .../models/siglip/modeling_siglip.py | 53 ++++++------------- 10 files changed, 112 insertions(+), 192 deletions(-) create mode 100644 src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/check.py create mode 100644 src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/check.py diff --git a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py index d9e86f830..fd040159e 100644 --- a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py +++ b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py @@ -517,6 +517,19 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` torch.set_float32_matmul_precision("high") self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode) + msg = """transformers_replace is not installed correctly. +Please install it with `pip install transformers==4.53.2` +and `cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \ +$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")""" + + try: + from transformers.models.siglip import check + + if not check.check_whether_transformers_replace_is_installed_correctly(): + raise ValueError(msg) + except ImportError: + raise ValueError(msg) from None + def gradient_checkpointing_enable(self): """Enable gradient checkpointing for memory optimization.""" self.gradient_checkpointing_enabled = True diff --git a/src/lerobot/policies/pi05_openpi/transformers_replace/models/gemma/modeling_gemma.py b/src/lerobot/policies/pi05_openpi/transformers_replace/models/gemma/modeling_gemma.py index a3f6b5325..e88051c6e 100644 --- a/src/lerobot/policies/pi05_openpi/transformers_replace/models/gemma/modeling_gemma.py +++ b/src/lerobot/policies/pi05_openpi/transformers_replace/models/gemma/modeling_gemma.py @@ -45,25 +45,6 @@ from .configuration_gemma import GemmaConfig logger = logging.get_logger(__name__) -# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring -def safe_auto_docstring(func=None, **kwargs): - """Auto docstring decorator that handles Python 3.10+ UnionType gracefully.""" - - def decorator(f): - try: - return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f) - except (AttributeError, TypeError): - # If auto_docstring fails due to UnionType, just return the function unchanged - return f - - if func is None: - # Called with arguments, return the decorator - return decorator - else: - # Called without arguments, apply directly - return decorator(func) - - class GemmaRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None): super().__init__() @@ -374,9 +355,8 @@ class GemmaDecoderLayer(GradientCheckpointingLayer): output_attentions: bool | None = False, use_cache: bool | None = False, cache_position: torch.LongTensor | None = None, - position_embeddings: ( - None | tuple[torch.Tensor, torch.Tensor] - ) = None, # necessary, but kept here for BC + position_embeddings: None + | (tuple[torch.Tensor, torch.Tensor]) = None, # necessary, but kept here for BC adarms_cond: torch.Tensor | None = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: @@ -410,7 +390,7 @@ class GemmaDecoderLayer(GradientCheckpointingLayer): return outputs -@safe_auto_docstring +@auto_docstring class GemmaPreTrainedModel(PreTrainedModel): config_class = GemmaConfig base_model_prefix = "model" @@ -441,7 +421,7 @@ class GemmaPreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) -@safe_auto_docstring +@auto_docstring class GemmaModel(GemmaPreTrainedModel): def __init__(self, config: GemmaConfig): super().__init__(config) @@ -468,7 +448,7 @@ class GemmaModel(GemmaPreTrainedModel): self.embed_tokens = value @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -540,7 +520,7 @@ class GemmaModel(GemmaPreTrainedModel): # normalized # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 - normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) # noqa: F841 + _normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) # hidden_states = hidden_states * normalizer # decoder layers @@ -586,7 +566,7 @@ class GemmaModel(GemmaPreTrainedModel): class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... -@safe_auto_docstring +@auto_docstring class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -620,7 +600,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -704,7 +684,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): ) -@safe_auto_docstring( +@auto_docstring( custom_intro=""" The Gemma Model transformer with a sequence classification head on top (linear layer). @@ -735,7 +715,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): self.model.embed_tokens = value @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -811,7 +791,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): ) -@safe_auto_docstring +@auto_docstring class GemmaForTokenClassification(GemmaPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -836,7 +816,7 @@ class GemmaForTokenClassification(GemmaPreTrainedModel): self.model.embed_tokens = value @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, diff --git a/src/lerobot/policies/pi05_openpi/transformers_replace/models/paligemma/modeling_paligemma.py b/src/lerobot/policies/pi05_openpi/transformers_replace/models/paligemma/modeling_paligemma.py index feb709b5b..0f7251881 100644 --- a/src/lerobot/policies/pi05_openpi/transformers_replace/models/paligemma/modeling_paligemma.py +++ b/src/lerobot/policies/pi05_openpi/transformers_replace/models/paligemma/modeling_paligemma.py @@ -39,27 +39,8 @@ from .configuration_paligemma import PaliGemmaConfig logger = logging.get_logger(__name__) -# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring -def safe_auto_docstring(func=None, **kwargs): - """Auto docstring decorator that handles Python 3.10+ UnionType gracefully.""" - - def decorator(f): - try: - return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f) - except (AttributeError, TypeError): - # If auto_docstring fails due to UnionType, just return the function unchanged - return f - - if func is None: - # Called with arguments, return the decorator - return decorator - else: - # Called without arguments, apply directly - return decorator(func) - - @dataclass -@safe_auto_docstring( +@auto_docstring( custom_intro=""" Base class for Paligemma outputs, with hidden states and attentions. """ @@ -81,7 +62,7 @@ class PaligemmaModelOutputWithPast(BaseModelOutputWithPast): @dataclass -@safe_auto_docstring( +@auto_docstring( custom_intro=""" Base class for PaliGemma causal language model (or autoregressive) outputs. """ @@ -124,7 +105,7 @@ class PaliGemmaMultiModalProjector(nn.Module): return hidden_states -@safe_auto_docstring +@auto_docstring class PaliGemmaPreTrainedModel(PreTrainedModel): config_class = PaliGemmaConfig base_model_prefix = "" @@ -150,7 +131,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): module.bias.data.zero_() -@safe_auto_docstring( +@auto_docstring( custom_intro=""" The Base Paligemma model which consists of a vision backbone and a language model without language modeling head., """ @@ -277,7 +258,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel): return image_features @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, input_ids: torch.LongTensor = None, @@ -336,7 +317,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel): is_training = token_type_ids is not None and labels is not None - # Replace image id worth PAD if the image token if OOV, to avoid index-errors + # Replace image id with PAD if the image token if OOV, to avoid index-errors if input_ids is not None and self.config.image_token_id >= self.vocab_size: special_image_mask = input_ids == self.config.image_token_id llm_input_ids = input_ids.clone() @@ -409,7 +390,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel): class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... -@safe_auto_docstring( +@auto_docstring( custom_intro=""" The Base Paligemma model which consists of a vision backbone and a language model without language modeling head., """ @@ -450,7 +431,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi def get_image_features(self, pixel_values): return self.model.get_image_features(pixel_values) - # Make modules available conditional class for BC + # Make modules available through conditional class for BC @property def language_model(self): return self.model.language_model @@ -464,7 +445,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi return self.model.multi_modal_projector @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, input_ids: torch.LongTensor = None, diff --git a/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/check.py b/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/check.py new file mode 100644 index 000000000..d899dc1b9 --- /dev/null +++ b/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/check.py @@ -0,0 +1,5 @@ +import transformers + + +def check_whether_transformers_replace_is_installed_correctly(): + return transformers.__version__ == "4.53.2" diff --git a/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/modeling_siglip.py b/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/modeling_siglip.py index d0e634472..b3af5cff4 100644 --- a/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/modeling_siglip.py +++ b/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/modeling_siglip.py @@ -37,25 +37,6 @@ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionCo logger = logging.get_logger(__name__) -# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring -def safe_auto_docstring(func=None, **kwargs): - """Auto docstring decorator that handles Python 3.10+ UnionType gracefully.""" - - def decorator(f): - try: - return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f) - except (AttributeError, TypeError): - # If auto_docstring fails due to UnionType, just return the function unchanged - return f - - if func is None: - # Called with arguments, return the decorator - return decorator - else: - # Called without arguments, apply directly - return decorator(func) - - def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf @@ -78,7 +59,7 @@ def _trunc_normal_(tensor, mean, std, a, b): # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) + tensor.uniform_(2 * l - 1, 2 * u - 1) # noqa: E741 # Use inverse cdf transform for normal distribution to get truncated # standard normal @@ -152,7 +133,7 @@ def default_flax_embed_init(tensor): @dataclass -@safe_auto_docstring( +@auto_docstring( custom_intro=""" Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. """ @@ -171,7 +152,7 @@ class SiglipVisionModelOutput(ModelOutput): @dataclass -@safe_auto_docstring( +@auto_docstring( custom_intro=""" Base class for text model's outputs that also contains a pooling of the last hidden states. """ @@ -190,7 +171,7 @@ class SiglipTextModelOutput(ModelOutput): @dataclass -@safe_auto_docstring +@auto_docstring # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip class SiglipOutput(ModelOutput): r""" @@ -502,7 +483,7 @@ class SiglipEncoderLayer(GradientCheckpointingLayer): return outputs -@safe_auto_docstring +@auto_docstring class SiglipPreTrainedModel(PreTrainedModel): config_class = SiglipConfig base_model_prefix = "siglip" @@ -663,7 +644,7 @@ class SiglipTextTransformer(nn.Module): self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, input_ids: torch.Tensor | None = None, @@ -715,7 +696,7 @@ class SiglipTextTransformer(nn.Module): ) -@safe_auto_docstring( +@auto_docstring( custom_intro=""" The text model from SigLIP without any head or projection on top. """ @@ -736,7 +717,7 @@ class SiglipTextModel(SiglipPreTrainedModel): self.text_model.embeddings.token_embedding = value @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, input_ids: torch.Tensor | None = None, @@ -785,7 +766,7 @@ class SiglipVisionTransformer(nn.Module): self.head = SiglipMultiheadAttentionPoolingHead(config) @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, pixel_values, @@ -853,7 +834,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): return hidden_state[:, 0] -@safe_auto_docstring( +@auto_docstring( custom_intro=""" The vision model from SigLIP without any head or projection on top. """ @@ -874,7 +855,7 @@ class SiglipVisionModel(SiglipPreTrainedModel): return self.vision_model.embeddings.patch_embedding @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, pixel_values, @@ -911,7 +892,7 @@ class SiglipVisionModel(SiglipPreTrainedModel): ) -@safe_auto_docstring +@auto_docstring class SiglipModel(SiglipPreTrainedModel): config_class = SiglipConfig @@ -947,7 +928,7 @@ class SiglipModel(SiglipPreTrainedModel): # Initialize weights and apply final processing self.post_init() - @safe_auto_docstring + @auto_docstring def get_text_features( self, input_ids: torch.Tensor | None = None, @@ -995,7 +976,7 @@ class SiglipModel(SiglipPreTrainedModel): return pooled_output - @safe_auto_docstring + @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor | None = None, @@ -1047,7 +1028,7 @@ class SiglipModel(SiglipPreTrainedModel): return pooled_output @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -1150,7 +1131,7 @@ class SiglipModel(SiglipPreTrainedModel): ) -@safe_auto_docstring( +@auto_docstring( custom_intro=""" SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of the patch tokens) e.g. for ImageNet. @@ -1180,7 +1161,7 @@ class SiglipForImageClassification(SiglipPreTrainedModel): self.post_init() @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, pixel_values: torch.Tensor | None = None, diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py index ce813fdb8..0bdaaa6f1 100644 --- a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py @@ -518,6 +518,19 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` torch.set_float32_matmul_precision("high") self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode) + msg = """transformers_replace is not installed correctly. +Please install it with `pip install transformers==4.53.2` +and `cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \ +$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")""" + + try: + from transformers.models.siglip import check + + if not check.check_whether_transformers_replace_is_installed_correctly(): + raise ValueError(msg) + except ImportError: + raise ValueError(msg) from None + def gradient_checkpointing_enable(self): """Enable gradient checkpointing for memory optimization.""" self.gradient_checkpointing_enabled = True diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py b/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py index a3f6b5325..e88051c6e 100644 --- a/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py +++ b/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py @@ -45,25 +45,6 @@ from .configuration_gemma import GemmaConfig logger = logging.get_logger(__name__) -# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring -def safe_auto_docstring(func=None, **kwargs): - """Auto docstring decorator that handles Python 3.10+ UnionType gracefully.""" - - def decorator(f): - try: - return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f) - except (AttributeError, TypeError): - # If auto_docstring fails due to UnionType, just return the function unchanged - return f - - if func is None: - # Called with arguments, return the decorator - return decorator - else: - # Called without arguments, apply directly - return decorator(func) - - class GemmaRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None): super().__init__() @@ -374,9 +355,8 @@ class GemmaDecoderLayer(GradientCheckpointingLayer): output_attentions: bool | None = False, use_cache: bool | None = False, cache_position: torch.LongTensor | None = None, - position_embeddings: ( - None | tuple[torch.Tensor, torch.Tensor] - ) = None, # necessary, but kept here for BC + position_embeddings: None + | (tuple[torch.Tensor, torch.Tensor]) = None, # necessary, but kept here for BC adarms_cond: torch.Tensor | None = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: @@ -410,7 +390,7 @@ class GemmaDecoderLayer(GradientCheckpointingLayer): return outputs -@safe_auto_docstring +@auto_docstring class GemmaPreTrainedModel(PreTrainedModel): config_class = GemmaConfig base_model_prefix = "model" @@ -441,7 +421,7 @@ class GemmaPreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) -@safe_auto_docstring +@auto_docstring class GemmaModel(GemmaPreTrainedModel): def __init__(self, config: GemmaConfig): super().__init__(config) @@ -468,7 +448,7 @@ class GemmaModel(GemmaPreTrainedModel): self.embed_tokens = value @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -540,7 +520,7 @@ class GemmaModel(GemmaPreTrainedModel): # normalized # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 - normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) # noqa: F841 + _normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) # hidden_states = hidden_states * normalizer # decoder layers @@ -586,7 +566,7 @@ class GemmaModel(GemmaPreTrainedModel): class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... -@safe_auto_docstring +@auto_docstring class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -620,7 +600,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -704,7 +684,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): ) -@safe_auto_docstring( +@auto_docstring( custom_intro=""" The Gemma Model transformer with a sequence classification head on top (linear layer). @@ -735,7 +715,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): self.model.embed_tokens = value @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -811,7 +791,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): ) -@safe_auto_docstring +@auto_docstring class GemmaForTokenClassification(GemmaPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -836,7 +816,7 @@ class GemmaForTokenClassification(GemmaPreTrainedModel): self.model.embed_tokens = value @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/paligemma/modeling_paligemma.py b/src/lerobot/policies/pi0_openpi/transformers_replace/models/paligemma/modeling_paligemma.py index feb709b5b..0f7251881 100644 --- a/src/lerobot/policies/pi0_openpi/transformers_replace/models/paligemma/modeling_paligemma.py +++ b/src/lerobot/policies/pi0_openpi/transformers_replace/models/paligemma/modeling_paligemma.py @@ -39,27 +39,8 @@ from .configuration_paligemma import PaliGemmaConfig logger = logging.get_logger(__name__) -# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring -def safe_auto_docstring(func=None, **kwargs): - """Auto docstring decorator that handles Python 3.10+ UnionType gracefully.""" - - def decorator(f): - try: - return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f) - except (AttributeError, TypeError): - # If auto_docstring fails due to UnionType, just return the function unchanged - return f - - if func is None: - # Called with arguments, return the decorator - return decorator - else: - # Called without arguments, apply directly - return decorator(func) - - @dataclass -@safe_auto_docstring( +@auto_docstring( custom_intro=""" Base class for Paligemma outputs, with hidden states and attentions. """ @@ -81,7 +62,7 @@ class PaligemmaModelOutputWithPast(BaseModelOutputWithPast): @dataclass -@safe_auto_docstring( +@auto_docstring( custom_intro=""" Base class for PaliGemma causal language model (or autoregressive) outputs. """ @@ -124,7 +105,7 @@ class PaliGemmaMultiModalProjector(nn.Module): return hidden_states -@safe_auto_docstring +@auto_docstring class PaliGemmaPreTrainedModel(PreTrainedModel): config_class = PaliGemmaConfig base_model_prefix = "" @@ -150,7 +131,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): module.bias.data.zero_() -@safe_auto_docstring( +@auto_docstring( custom_intro=""" The Base Paligemma model which consists of a vision backbone and a language model without language modeling head., """ @@ -277,7 +258,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel): return image_features @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, input_ids: torch.LongTensor = None, @@ -336,7 +317,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel): is_training = token_type_ids is not None and labels is not None - # Replace image id worth PAD if the image token if OOV, to avoid index-errors + # Replace image id with PAD if the image token if OOV, to avoid index-errors if input_ids is not None and self.config.image_token_id >= self.vocab_size: special_image_mask = input_ids == self.config.image_token_id llm_input_ids = input_ids.clone() @@ -409,7 +390,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel): class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... -@safe_auto_docstring( +@auto_docstring( custom_intro=""" The Base Paligemma model which consists of a vision backbone and a language model without language modeling head., """ @@ -450,7 +431,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi def get_image_features(self, pixel_values): return self.model.get_image_features(pixel_values) - # Make modules available conditional class for BC + # Make modules available through conditional class for BC @property def language_model(self): return self.model.language_model @@ -464,7 +445,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi return self.model.multi_modal_projector @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, input_ids: torch.LongTensor = None, diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/check.py b/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/check.py new file mode 100644 index 000000000..d899dc1b9 --- /dev/null +++ b/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/check.py @@ -0,0 +1,5 @@ +import transformers + + +def check_whether_transformers_replace_is_installed_correctly(): + return transformers.__version__ == "4.53.2" diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/modeling_siglip.py b/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/modeling_siglip.py index d0e634472..b3af5cff4 100644 --- a/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/modeling_siglip.py +++ b/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/modeling_siglip.py @@ -37,25 +37,6 @@ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionCo logger = logging.get_logger(__name__) -# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring -def safe_auto_docstring(func=None, **kwargs): - """Auto docstring decorator that handles Python 3.10+ UnionType gracefully.""" - - def decorator(f): - try: - return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f) - except (AttributeError, TypeError): - # If auto_docstring fails due to UnionType, just return the function unchanged - return f - - if func is None: - # Called with arguments, return the decorator - return decorator - else: - # Called without arguments, apply directly - return decorator(func) - - def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf @@ -78,7 +59,7 @@ def _trunc_normal_(tensor, mean, std, a, b): # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. - tensor.uniform_(2 * l - 1, 2 * u - 1) + tensor.uniform_(2 * l - 1, 2 * u - 1) # noqa: E741 # Use inverse cdf transform for normal distribution to get truncated # standard normal @@ -152,7 +133,7 @@ def default_flax_embed_init(tensor): @dataclass -@safe_auto_docstring( +@auto_docstring( custom_intro=""" Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. """ @@ -171,7 +152,7 @@ class SiglipVisionModelOutput(ModelOutput): @dataclass -@safe_auto_docstring( +@auto_docstring( custom_intro=""" Base class for text model's outputs that also contains a pooling of the last hidden states. """ @@ -190,7 +171,7 @@ class SiglipTextModelOutput(ModelOutput): @dataclass -@safe_auto_docstring +@auto_docstring # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip class SiglipOutput(ModelOutput): r""" @@ -502,7 +483,7 @@ class SiglipEncoderLayer(GradientCheckpointingLayer): return outputs -@safe_auto_docstring +@auto_docstring class SiglipPreTrainedModel(PreTrainedModel): config_class = SiglipConfig base_model_prefix = "siglip" @@ -663,7 +644,7 @@ class SiglipTextTransformer(nn.Module): self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, input_ids: torch.Tensor | None = None, @@ -715,7 +696,7 @@ class SiglipTextTransformer(nn.Module): ) -@safe_auto_docstring( +@auto_docstring( custom_intro=""" The text model from SigLIP without any head or projection on top. """ @@ -736,7 +717,7 @@ class SiglipTextModel(SiglipPreTrainedModel): self.text_model.embeddings.token_embedding = value @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, input_ids: torch.Tensor | None = None, @@ -785,7 +766,7 @@ class SiglipVisionTransformer(nn.Module): self.head = SiglipMultiheadAttentionPoolingHead(config) @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, pixel_values, @@ -853,7 +834,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): return hidden_state[:, 0] -@safe_auto_docstring( +@auto_docstring( custom_intro=""" The vision model from SigLIP without any head or projection on top. """ @@ -874,7 +855,7 @@ class SiglipVisionModel(SiglipPreTrainedModel): return self.vision_model.embeddings.patch_embedding @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, pixel_values, @@ -911,7 +892,7 @@ class SiglipVisionModel(SiglipPreTrainedModel): ) -@safe_auto_docstring +@auto_docstring class SiglipModel(SiglipPreTrainedModel): config_class = SiglipConfig @@ -947,7 +928,7 @@ class SiglipModel(SiglipPreTrainedModel): # Initialize weights and apply final processing self.post_init() - @safe_auto_docstring + @auto_docstring def get_text_features( self, input_ids: torch.Tensor | None = None, @@ -995,7 +976,7 @@ class SiglipModel(SiglipPreTrainedModel): return pooled_output - @safe_auto_docstring + @auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor | None = None, @@ -1047,7 +1028,7 @@ class SiglipModel(SiglipPreTrainedModel): return pooled_output @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -1150,7 +1131,7 @@ class SiglipModel(SiglipPreTrainedModel): ) -@safe_auto_docstring( +@auto_docstring( custom_intro=""" SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of the patch tokens) e.g. for ImageNet. @@ -1180,7 +1161,7 @@ class SiglipForImageClassification(SiglipPreTrainedModel): self.post_init() @can_return_tuple - @safe_auto_docstring + @auto_docstring def forward( self, pixel_values: torch.Tensor | None = None, From f840d2e0067e36db0142a6f8858e94aeeea508bd Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 12 Sep 2025 20:06:06 +0200 Subject: [PATCH 05/10] fix(modeling pi0): nit warning message --- src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py | 2 +- src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py index fd040159e..dffbd6621 100644 --- a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py +++ b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py @@ -520,7 +520,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` msg = """transformers_replace is not installed correctly. Please install it with `pip install transformers==4.53.2` and `cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \ -$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")""" +$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")`""" try: from transformers.models.siglip import check diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py index 0bdaaa6f1..5dd3000cb 100644 --- a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py @@ -521,7 +521,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` msg = """transformers_replace is not installed correctly. Please install it with `pip install transformers==4.53.2` and `cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \ -$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")""" +$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")`""" try: from transformers.models.siglip import check From 7a03223693f950f2bf91893d10c231d97f338a96 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 12 Sep 2025 20:19:16 +0200 Subject: [PATCH 06/10] use safeauto_docstring --- .../models/gemma/modeling_gemma.py | 37 ++++++++++---- .../models/paligemma/modeling_paligemma.py | 33 +++++++++--- .../models/siglip/modeling_siglip.py | 51 +++++++++++++------ .../models/gemma/modeling_gemma.py | 37 ++++++++++---- .../models/paligemma/modeling_paligemma.py | 33 +++++++++--- .../models/siglip/modeling_siglip.py | 51 +++++++++++++------ 6 files changed, 178 insertions(+), 64 deletions(-) diff --git a/src/lerobot/policies/pi05_openpi/transformers_replace/models/gemma/modeling_gemma.py b/src/lerobot/policies/pi05_openpi/transformers_replace/models/gemma/modeling_gemma.py index e88051c6e..05066afc5 100644 --- a/src/lerobot/policies/pi05_openpi/transformers_replace/models/gemma/modeling_gemma.py +++ b/src/lerobot/policies/pi05_openpi/transformers_replace/models/gemma/modeling_gemma.py @@ -45,6 +45,25 @@ from .configuration_gemma import GemmaConfig logger = logging.get_logger(__name__) +# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring +def safe_auto_docstring(func=None, **kwargs): + """Auto docstring decorator that handles Python 3.10+ UnionType gracefully.""" + + def decorator(f): + try: + return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f) + except (AttributeError, TypeError): + # If auto_docstring fails due to UnionType, just return the function unchanged + return f + + if func is None: + # Called with arguments, return the decorator + return decorator + else: + # Called without arguments, apply directly + return decorator(func) + + class GemmaRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None): super().__init__() @@ -390,7 +409,7 @@ class GemmaDecoderLayer(GradientCheckpointingLayer): return outputs -@auto_docstring +@safe_auto_docstring class GemmaPreTrainedModel(PreTrainedModel): config_class = GemmaConfig base_model_prefix = "model" @@ -421,7 +440,7 @@ class GemmaPreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) -@auto_docstring +@safe_auto_docstring class GemmaModel(GemmaPreTrainedModel): def __init__(self, config: GemmaConfig): super().__init__(config) @@ -448,7 +467,7 @@ class GemmaModel(GemmaPreTrainedModel): self.embed_tokens = value @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -566,7 +585,7 @@ class GemmaModel(GemmaPreTrainedModel): class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... -@auto_docstring +@safe_auto_docstring class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -600,7 +619,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -684,7 +703,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): ) -@auto_docstring( +@safe_auto_docstring( custom_intro=""" The Gemma Model transformer with a sequence classification head on top (linear layer). @@ -715,7 +734,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): self.model.embed_tokens = value @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -791,7 +810,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): ) -@auto_docstring +@safe_auto_docstring class GemmaForTokenClassification(GemmaPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -816,7 +835,7 @@ class GemmaForTokenClassification(GemmaPreTrainedModel): self.model.embed_tokens = value @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, diff --git a/src/lerobot/policies/pi05_openpi/transformers_replace/models/paligemma/modeling_paligemma.py b/src/lerobot/policies/pi05_openpi/transformers_replace/models/paligemma/modeling_paligemma.py index 0f7251881..b2a36b5ca 100644 --- a/src/lerobot/policies/pi05_openpi/transformers_replace/models/paligemma/modeling_paligemma.py +++ b/src/lerobot/policies/pi05_openpi/transformers_replace/models/paligemma/modeling_paligemma.py @@ -39,8 +39,27 @@ from .configuration_paligemma import PaliGemmaConfig logger = logging.get_logger(__name__) +# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring +def safe_auto_docstring(func=None, **kwargs): + """Auto docstring decorator that handles Python 3.10+ UnionType gracefully.""" + + def decorator(f): + try: + return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f) + except (AttributeError, TypeError): + # If auto_docstring fails due to UnionType, just return the function unchanged + return f + + if func is None: + # Called with arguments, return the decorator + return decorator + else: + # Called without arguments, apply directly + return decorator(func) + + @dataclass -@auto_docstring( +@safe_auto_docstring( custom_intro=""" Base class for Paligemma outputs, with hidden states and attentions. """ @@ -62,7 +81,7 @@ class PaligemmaModelOutputWithPast(BaseModelOutputWithPast): @dataclass -@auto_docstring( +@safe_auto_docstring( custom_intro=""" Base class for PaliGemma causal language model (or autoregressive) outputs. """ @@ -105,7 +124,7 @@ class PaliGemmaMultiModalProjector(nn.Module): return hidden_states -@auto_docstring +@safe_auto_docstring class PaliGemmaPreTrainedModel(PreTrainedModel): config_class = PaliGemmaConfig base_model_prefix = "" @@ -131,7 +150,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): module.bias.data.zero_() -@auto_docstring( +@safe_auto_docstring( custom_intro=""" The Base Paligemma model which consists of a vision backbone and a language model without language modeling head., """ @@ -258,7 +277,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel): return image_features @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.LongTensor = None, @@ -390,7 +409,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel): class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... -@auto_docstring( +@safe_auto_docstring( custom_intro=""" The Base Paligemma model which consists of a vision backbone and a language model without language modeling head., """ @@ -445,7 +464,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi return self.model.multi_modal_projector @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.LongTensor = None, diff --git a/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/modeling_siglip.py b/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/modeling_siglip.py index b3af5cff4..0fc0bba0f 100644 --- a/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/modeling_siglip.py +++ b/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/modeling_siglip.py @@ -37,6 +37,25 @@ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionCo logger = logging.get_logger(__name__) +# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring +def safe_auto_docstring(func=None, **kwargs): + """Auto docstring decorator that handles Python 3.10+ UnionType gracefully.""" + + def decorator(f): + try: + return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f) + except (AttributeError, TypeError): + # If auto_docstring fails due to UnionType, just return the function unchanged + return f + + if func is None: + # Called with arguments, return the decorator + return decorator + else: + # Called without arguments, apply directly + return decorator(func) + + def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf @@ -133,7 +152,7 @@ def default_flax_embed_init(tensor): @dataclass -@auto_docstring( +@safe_auto_docstring( custom_intro=""" Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. """ @@ -152,7 +171,7 @@ class SiglipVisionModelOutput(ModelOutput): @dataclass -@auto_docstring( +@safe_auto_docstring( custom_intro=""" Base class for text model's outputs that also contains a pooling of the last hidden states. """ @@ -171,7 +190,7 @@ class SiglipTextModelOutput(ModelOutput): @dataclass -@auto_docstring +@safe_auto_docstring # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip class SiglipOutput(ModelOutput): r""" @@ -483,7 +502,7 @@ class SiglipEncoderLayer(GradientCheckpointingLayer): return outputs -@auto_docstring +@safe_auto_docstring class SiglipPreTrainedModel(PreTrainedModel): config_class = SiglipConfig base_model_prefix = "siglip" @@ -644,7 +663,7 @@ class SiglipTextTransformer(nn.Module): self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.Tensor | None = None, @@ -696,7 +715,7 @@ class SiglipTextTransformer(nn.Module): ) -@auto_docstring( +@safe_auto_docstring( custom_intro=""" The text model from SigLIP without any head or projection on top. """ @@ -717,7 +736,7 @@ class SiglipTextModel(SiglipPreTrainedModel): self.text_model.embeddings.token_embedding = value @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.Tensor | None = None, @@ -766,7 +785,7 @@ class SiglipVisionTransformer(nn.Module): self.head = SiglipMultiheadAttentionPoolingHead(config) @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, pixel_values, @@ -834,7 +853,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): return hidden_state[:, 0] -@auto_docstring( +@safe_auto_docstring( custom_intro=""" The vision model from SigLIP without any head or projection on top. """ @@ -855,7 +874,7 @@ class SiglipVisionModel(SiglipPreTrainedModel): return self.vision_model.embeddings.patch_embedding @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, pixel_values, @@ -892,7 +911,7 @@ class SiglipVisionModel(SiglipPreTrainedModel): ) -@auto_docstring +@safe_auto_docstring class SiglipModel(SiglipPreTrainedModel): config_class = SiglipConfig @@ -928,7 +947,7 @@ class SiglipModel(SiglipPreTrainedModel): # Initialize weights and apply final processing self.post_init() - @auto_docstring + @safe_auto_docstring def get_text_features( self, input_ids: torch.Tensor | None = None, @@ -976,7 +995,7 @@ class SiglipModel(SiglipPreTrainedModel): return pooled_output - @auto_docstring + @safe_auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor | None = None, @@ -1028,7 +1047,7 @@ class SiglipModel(SiglipPreTrainedModel): return pooled_output @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -1131,7 +1150,7 @@ class SiglipModel(SiglipPreTrainedModel): ) -@auto_docstring( +@safe_auto_docstring( custom_intro=""" SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of the patch tokens) e.g. for ImageNet. @@ -1161,7 +1180,7 @@ class SiglipForImageClassification(SiglipPreTrainedModel): self.post_init() @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, pixel_values: torch.Tensor | None = None, diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py b/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py index e88051c6e..05066afc5 100644 --- a/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py +++ b/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py @@ -45,6 +45,25 @@ from .configuration_gemma import GemmaConfig logger = logging.get_logger(__name__) +# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring +def safe_auto_docstring(func=None, **kwargs): + """Auto docstring decorator that handles Python 3.10+ UnionType gracefully.""" + + def decorator(f): + try: + return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f) + except (AttributeError, TypeError): + # If auto_docstring fails due to UnionType, just return the function unchanged + return f + + if func is None: + # Called with arguments, return the decorator + return decorator + else: + # Called without arguments, apply directly + return decorator(func) + + class GemmaRMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None): super().__init__() @@ -390,7 +409,7 @@ class GemmaDecoderLayer(GradientCheckpointingLayer): return outputs -@auto_docstring +@safe_auto_docstring class GemmaPreTrainedModel(PreTrainedModel): config_class = GemmaConfig base_model_prefix = "model" @@ -421,7 +440,7 @@ class GemmaPreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) -@auto_docstring +@safe_auto_docstring class GemmaModel(GemmaPreTrainedModel): def __init__(self, config: GemmaConfig): super().__init__(config) @@ -448,7 +467,7 @@ class GemmaModel(GemmaPreTrainedModel): self.embed_tokens = value @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -566,7 +585,7 @@ class GemmaModel(GemmaPreTrainedModel): class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... -@auto_docstring +@safe_auto_docstring class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -600,7 +619,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): return self.model @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -684,7 +703,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): ) -@auto_docstring( +@safe_auto_docstring( custom_intro=""" The Gemma Model transformer with a sequence classification head on top (linear layer). @@ -715,7 +734,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): self.model.embed_tokens = value @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -791,7 +810,7 @@ class GemmaForSequenceClassification(GemmaPreTrainedModel): ) -@auto_docstring +@safe_auto_docstring class GemmaForTokenClassification(GemmaPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -816,7 +835,7 @@ class GemmaForTokenClassification(GemmaPreTrainedModel): self.model.embed_tokens = value @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/paligemma/modeling_paligemma.py b/src/lerobot/policies/pi0_openpi/transformers_replace/models/paligemma/modeling_paligemma.py index 0f7251881..b2a36b5ca 100644 --- a/src/lerobot/policies/pi0_openpi/transformers_replace/models/paligemma/modeling_paligemma.py +++ b/src/lerobot/policies/pi0_openpi/transformers_replace/models/paligemma/modeling_paligemma.py @@ -39,8 +39,27 @@ from .configuration_paligemma import PaliGemmaConfig logger = logging.get_logger(__name__) +# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring +def safe_auto_docstring(func=None, **kwargs): + """Auto docstring decorator that handles Python 3.10+ UnionType gracefully.""" + + def decorator(f): + try: + return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f) + except (AttributeError, TypeError): + # If auto_docstring fails due to UnionType, just return the function unchanged + return f + + if func is None: + # Called with arguments, return the decorator + return decorator + else: + # Called without arguments, apply directly + return decorator(func) + + @dataclass -@auto_docstring( +@safe_auto_docstring( custom_intro=""" Base class for Paligemma outputs, with hidden states and attentions. """ @@ -62,7 +81,7 @@ class PaligemmaModelOutputWithPast(BaseModelOutputWithPast): @dataclass -@auto_docstring( +@safe_auto_docstring( custom_intro=""" Base class for PaliGemma causal language model (or autoregressive) outputs. """ @@ -105,7 +124,7 @@ class PaliGemmaMultiModalProjector(nn.Module): return hidden_states -@auto_docstring +@safe_auto_docstring class PaliGemmaPreTrainedModel(PreTrainedModel): config_class = PaliGemmaConfig base_model_prefix = "" @@ -131,7 +150,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): module.bias.data.zero_() -@auto_docstring( +@safe_auto_docstring( custom_intro=""" The Base Paligemma model which consists of a vision backbone and a language model without language modeling head., """ @@ -258,7 +277,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel): return image_features @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.LongTensor = None, @@ -390,7 +409,7 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel): class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... -@auto_docstring( +@safe_auto_docstring( custom_intro=""" The Base Paligemma model which consists of a vision backbone and a language model without language modeling head., """ @@ -445,7 +464,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi return self.model.multi_modal_projector @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.LongTensor = None, diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/modeling_siglip.py b/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/modeling_siglip.py index b3af5cff4..0fc0bba0f 100644 --- a/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/modeling_siglip.py +++ b/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/modeling_siglip.py @@ -37,6 +37,25 @@ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionCo logger = logging.get_logger(__name__) +# Workaround for Python 3.10+ UnionType compatibility with transformers auto_docstring +def safe_auto_docstring(func=None, **kwargs): + """Auto docstring decorator that handles Python 3.10+ UnionType gracefully.""" + + def decorator(f): + try: + return auto_docstring(f, **kwargs) if kwargs else auto_docstring(f) + except (AttributeError, TypeError): + # If auto_docstring fails due to UnionType, just return the function unchanged + return f + + if func is None: + # Called with arguments, return the decorator + return decorator + else: + # Called without arguments, apply directly + return decorator(func) + + def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf @@ -133,7 +152,7 @@ def default_flax_embed_init(tensor): @dataclass -@auto_docstring( +@safe_auto_docstring( custom_intro=""" Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. """ @@ -152,7 +171,7 @@ class SiglipVisionModelOutput(ModelOutput): @dataclass -@auto_docstring( +@safe_auto_docstring( custom_intro=""" Base class for text model's outputs that also contains a pooling of the last hidden states. """ @@ -171,7 +190,7 @@ class SiglipTextModelOutput(ModelOutput): @dataclass -@auto_docstring +@safe_auto_docstring # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip class SiglipOutput(ModelOutput): r""" @@ -483,7 +502,7 @@ class SiglipEncoderLayer(GradientCheckpointingLayer): return outputs -@auto_docstring +@safe_auto_docstring class SiglipPreTrainedModel(PreTrainedModel): config_class = SiglipConfig base_model_prefix = "siglip" @@ -644,7 +663,7 @@ class SiglipTextTransformer(nn.Module): self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.Tensor | None = None, @@ -696,7 +715,7 @@ class SiglipTextTransformer(nn.Module): ) -@auto_docstring( +@safe_auto_docstring( custom_intro=""" The text model from SigLIP without any head or projection on top. """ @@ -717,7 +736,7 @@ class SiglipTextModel(SiglipPreTrainedModel): self.text_model.embeddings.token_embedding = value @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.Tensor | None = None, @@ -766,7 +785,7 @@ class SiglipVisionTransformer(nn.Module): self.head = SiglipMultiheadAttentionPoolingHead(config) @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, pixel_values, @@ -834,7 +853,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): return hidden_state[:, 0] -@auto_docstring( +@safe_auto_docstring( custom_intro=""" The vision model from SigLIP without any head or projection on top. """ @@ -855,7 +874,7 @@ class SiglipVisionModel(SiglipPreTrainedModel): return self.vision_model.embeddings.patch_embedding @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, pixel_values, @@ -892,7 +911,7 @@ class SiglipVisionModel(SiglipPreTrainedModel): ) -@auto_docstring +@safe_auto_docstring class SiglipModel(SiglipPreTrainedModel): config_class = SiglipConfig @@ -928,7 +947,7 @@ class SiglipModel(SiglipPreTrainedModel): # Initialize weights and apply final processing self.post_init() - @auto_docstring + @safe_auto_docstring def get_text_features( self, input_ids: torch.Tensor | None = None, @@ -976,7 +995,7 @@ class SiglipModel(SiglipPreTrainedModel): return pooled_output - @auto_docstring + @safe_auto_docstring def get_image_features( self, pixel_values: torch.FloatTensor | None = None, @@ -1028,7 +1047,7 @@ class SiglipModel(SiglipPreTrainedModel): return pooled_output @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, @@ -1131,7 +1150,7 @@ class SiglipModel(SiglipPreTrainedModel): ) -@auto_docstring( +@safe_auto_docstring( custom_intro=""" SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of the patch tokens) e.g. for ImageNet. @@ -1161,7 +1180,7 @@ class SiglipForImageClassification(SiglipPreTrainedModel): self.post_init() @can_return_tuple - @auto_docstring + @safe_auto_docstring def forward( self, pixel_values: torch.Tensor | None = None, From d1eefd4e972e8488a844d234de69a157efac2fd4 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 12 Sep 2025 20:25:55 +0200 Subject: [PATCH 07/10] fix: remove unused param --- src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py | 2 +- src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py index dffbd6621..ca9d9951d 100644 --- a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py +++ b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py @@ -1253,7 +1253,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy): # Action queue logic for n_action_steps > 1 if len(self._action_queue) == 0: - actions = self.predict_action_chunk(batch, train=False)[:, : self.config.n_action_steps] + actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] # Transpose to get shape (n_action_steps, batch_size, action_dim) self._action_queue.extend(actions.transpose(0, 1)) diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py index 5dd3000cb..a50e83781 100644 --- a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py @@ -1266,7 +1266,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy): # Action queue logic for n_action_steps > 1 if len(self._action_queue) == 0: - actions = self.predict_action_chunk(batch, train=False)[:, : self.config.n_action_steps] + actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] # Transpose to get shape (n_action_steps, batch_size, action_dim) self._action_queue.extend(actions.transpose(0, 1)) From 376cc772ff67b42ca130cc09a6f46650e3878048 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 12 Sep 2025 21:12:48 +0200 Subject: [PATCH 08/10] fix from pretrained --- .../pi05_openpi/modeling_pi05openpi.py | 49 +++++++++---------- .../policies/pi0_openpi/modeling_pi0openpi.py | 45 ++++++++--------- 2 files changed, 42 insertions(+), 52 deletions(-) diff --git a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py index ca9d9951d..eac7dd67c 100644 --- a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py +++ b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py @@ -864,41 +864,35 @@ class PI05OpenPIPolicy(PreTrainedPolicy): @classmethod def from_pretrained( - cls, *args, **kwargs + cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs ): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict """Override the from_pretrained method to handle key remapping and display important disclaimer.""" print( - "⚠️ DISCLAIMER: The PI05OpenPI model is a direct PyTorch port of the OpenPI implementation. \n" + "⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n" " This implementation follows the original OpenPI structure for compatibility. \n" " Original implementation: https://github.com/Physical-Intelligence/openpi" ) + if pretrained_name_or_path is None: + raise ValueError("pretrained_name_or_path is required") - # Store original strict mode - original_strict = kwargs.get("strict", True) - # Temporarily set strict=False to avoid loading issues, we'll handle it manually - kwargs["strict"] = False + # Create default config + config = cls.config_class() - # Call parent from_pretrained with strict=False - model = super().from_pretrained(*args, **kwargs) - - # Extract the pretrained_model_name_or_path from args or kwargs for remapping - if len(args) > 0: - pretrained_model_name_or_path = args[0] - elif "pretrained_model_name_or_path" in kwargs: - pretrained_model_name_or_path = kwargs["pretrained_model_name_or_path"] - else: - return model + # Initialize model without loading weights + # Check if dataset_stats were provided in kwargs + dataset_stats = kwargs.get("dataset_stats") + model = cls(config=config, dataset_stats=dataset_stats) # Now manually load and remap the state dict try: - from transformers.utils import cached_file - # Try to load the pytorch_model.bin or model.safetensors file - print(f"Loading model from: {pretrained_model_name_or_path}") + print(f"Loading model from: {pretrained_name_or_path}") try: + from transformers.utils import cached_file + # Try safetensors first resolved_file = cached_file( - pretrained_model_name_or_path, + pretrained_name_or_path, "model.safetensors", cache_dir=kwargs.get("cache_dir"), force_download=kwargs.get("force_download", False), @@ -914,9 +908,10 @@ class PI05OpenPIPolicy(PreTrainedPolicy): print("✓ Loaded state dict from model.safetensors") except Exception as e: print(f"Could not load state dict from remote files: {e}") + print("Returning model without loading pretrained weights") return model - # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` + # First, fix any pi key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) # Then add "model." prefix for all keys that don't already have it @@ -939,10 +934,10 @@ class PI05OpenPIPolicy(PreTrainedPolicy): print(f"Total keys remapped: {remap_count}") # Load the remapped state dict into the model - missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=original_strict) + missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) if missing_keys: - print(f"⚠️ Missing keys when loading state dict: {len(missing_keys)} keys") + print(f"Missing keys when loading state dict: {len(missing_keys)} keys") if len(missing_keys) <= 5: for key in missing_keys: print(f" - {key}") @@ -952,7 +947,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy): print(f" ... and {len(missing_keys) - 5} more") if unexpected_keys: - print(f"⚠️ Unexpected keys when loading state dict: {len(unexpected_keys)} keys") + print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys") if len(unexpected_keys) <= 5: for key in unexpected_keys: print(f" - {key}") @@ -962,11 +957,11 @@ class PI05OpenPIPolicy(PreTrainedPolicy): print(f" ... and {len(unexpected_keys) - 5} more") if not missing_keys and not unexpected_keys: - print("✅ All keys loaded successfully!") + print("All keys loaded successfully!") except Exception as e: - print(f"⚠️ Warning: Could not remap state dict keys: {e}") - print("Using default loading behavior") + print(f"Warning: Could not remap state dict keys: {e}") + print("Returning model without loading pretrained weights") return model diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py index a50e83781..b9d87f5e0 100644 --- a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py @@ -881,7 +881,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy): @classmethod def from_pretrained( - cls, *args, **kwargs + cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs ): # TODO(pepijn): modify this back so we do not have to add model. prefix to all keys in the state dict """Override the from_pretrained method to handle key remapping and display important disclaimer.""" print( @@ -889,33 +889,27 @@ class PI0OpenPIPolicy(PreTrainedPolicy): " This implementation follows the original OpenPI structure for compatibility. \n" " Original implementation: https://github.com/Physical-Intelligence/openpi" ) + if pretrained_name_or_path is None: + raise ValueError("pretrained_name_or_path is required") - # Store original strict mode - original_strict = kwargs.get("strict", True) - # Temporarily set strict=False to avoid loading issues, we'll handle it manually - kwargs["strict"] = False + # Create default config + config = cls.config_class() - # Call parent from_pretrained with strict=False - model = super().from_pretrained(*args, **kwargs) - - # Extract the pretrained_model_name_or_path from args or kwargs for remapping - if len(args) > 0: - pretrained_model_name_or_path = args[0] - elif "pretrained_model_name_or_path" in kwargs: - pretrained_model_name_or_path = kwargs["pretrained_model_name_or_path"] - else: - return model + # Initialize model without loading weights + # Check if dataset_stats were provided in kwargs + dataset_stats = kwargs.get("dataset_stats") + model = cls(config=config, dataset_stats=dataset_stats) # Now manually load and remap the state dict try: - from transformers.utils import cached_file - # Try to load the pytorch_model.bin or model.safetensors file - print(f"Loading model from: {pretrained_model_name_or_path}") + print(f"Loading model from: {pretrained_name_or_path}") try: + from transformers.utils import cached_file + # Try safetensors first resolved_file = cached_file( - pretrained_model_name_or_path, + pretrained_name_or_path, "model.safetensors", cache_dir=kwargs.get("cache_dir"), force_download=kwargs.get("force_download", False), @@ -931,6 +925,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy): print("✓ Loaded state dict from model.safetensors") except Exception as e: print(f"Could not load state dict from remote files: {e}") + print("Returning model without loading pretrained weights") return model # First, fix any pi key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` @@ -956,10 +951,10 @@ class PI0OpenPIPolicy(PreTrainedPolicy): print(f"Total keys remapped: {remap_count}") # Load the remapped state dict into the model - missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=original_strict) + missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) if missing_keys: - print(f"⚠️ Missing keys when loading state dict: {len(missing_keys)} keys") + print(f"Missing keys when loading state dict: {len(missing_keys)} keys") if len(missing_keys) <= 5: for key in missing_keys: print(f" - {key}") @@ -969,7 +964,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy): print(f" ... and {len(missing_keys) - 5} more") if unexpected_keys: - print(f"⚠️ Unexpected keys when loading state dict: {len(unexpected_keys)} keys") + print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys") if len(unexpected_keys) <= 5: for key in unexpected_keys: print(f" - {key}") @@ -979,11 +974,11 @@ class PI0OpenPIPolicy(PreTrainedPolicy): print(f" ... and {len(unexpected_keys) - 5} more") if not missing_keys and not unexpected_keys: - print("✅ All keys loaded successfully!") + print("All keys loaded successfully!") except Exception as e: - print(f"⚠️ Warning: Could not remap state dict keys: {e}") - print("Using default loading behavior") + print(f"Warning: Could not remap state dict keys: {e}") + print("Returning model without loading pretrained weights") return model From c8163662adb77329a26b1bc166012df512865791 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 12 Sep 2025 21:41:25 +0200 Subject: [PATCH 09/10] add preprocess tests --- .../pi05_openpi/modeling_pi05openpi.py | 1 - .../policies/pi0_openpi/modeling_pi0openpi.py | 1 - test_pi0_original_vs_lerobot.py | 180 +++++++++++++----- 3 files changed, 130 insertions(+), 52 deletions(-) diff --git a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py index eac7dd67c..9ff71152a 100644 --- a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py +++ b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py @@ -961,7 +961,6 @@ class PI05OpenPIPolicy(PreTrainedPolicy): except Exception as e: print(f"Warning: Could not remap state dict keys: {e}") - print("Returning model without loading pretrained weights") return model diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py index b9d87f5e0..120791cc1 100644 --- a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py @@ -978,7 +978,6 @@ class PI0OpenPIPolicy(PreTrainedPolicy): except Exception as e: print(f"Warning: Could not remap state dict keys: {e}") - print("Returning model without loading pretrained weights") return model diff --git a/test_pi0_original_vs_lerobot.py b/test_pi0_original_vs_lerobot.py index e5cdf3dd7..68b62d110 100644 --- a/test_pi0_original_vs_lerobot.py +++ b/test_pi0_original_vs_lerobot.py @@ -3,9 +3,11 @@ import os import torch +from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions. from openpi.models_pytorch.pi0_pytorch import PI0Pytorch +from transformers import AutoTokenizer from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy @@ -54,7 +56,9 @@ class PI0BaseOriginalConfig: def instantiate_lerobot_pi0(from_pretrained: bool = False): if from_pretrained: # Load the policy first - policy = PI0OpenPIPolicy.from_pretrained("pepijn223/pi0_base_fp32") + policy = PI0OpenPIPolicy.from_pretrained( + pretrained_name_or_path="pepijn223/pi0_base_fp32", strict=True + ) # Then reinitialize the normalization with proper stats from lerobot.policies.normalize import Normalize, Unnormalize @@ -153,16 +157,16 @@ def create_dummy_data(): "action": torch.randn( batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device ), - # Create images in [-1, 1] range as expected by both implementations - "observation.images.base_0_rgb": torch.randn( + # Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally) + "observation.images.base_0_rgb": torch.rand( batch_size, 3, 224, 224, dtype=torch.float32, device=device - ).clamp(-1, 1), - "observation.images.left_wrist_0_rgb": torch.randn( + ), + "observation.images.left_wrist_0_rgb": torch.rand( batch_size, 3, 224, 224, dtype=torch.float32, device=device - ).clamp(-1, 1), - "observation.images.right_wrist_0_rgb": torch.randn( + ), + "observation.images.right_wrist_0_rgb": torch.rand( batch_size, 3, 224, 224, dtype=torch.float32, device=device - ).clamp(-1, 1), + ), # Add the task prompt for LeRobot - provide as list with single element to trigger expansion "task": [prompt], } @@ -175,7 +179,7 @@ def extract_lerobot_processed_inputs(lerobot_pi0, batch): lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch) # Get the preprocessed images from LeRobot's internal method - images, img_masks = lerobot_pi0._preprocess_images(batch) + images, img_masks = lerobot_pi0._preprocess_images(batch, train=False) # Create dummy token_ar_mask and token_loss_mask for original implementation token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32) @@ -206,6 +210,72 @@ class PI0Observation: self.token_loss_mask = token_loss_mask +def create_original_observation_with_openpi_preprocessing(batch): + """Create observation object for OpenPI using OpenPI's own preprocessing.""" + batch_size = batch["observation.state"].shape[0] + device = batch["observation.state"].device + + # Create tokenizer for OpenPI (same as LeRobot uses) + tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") + + # Get task description + if "task" in batch: + tasks = batch["task"] + if isinstance(tasks, str): + tasks = [tasks] + elif isinstance(tasks, list) and len(tasks) == 1: + # Expand to batch size + tasks = tasks * batch_size + else: + # Default task if not provided + tasks = ["Pick up the object"] * batch_size + + # Tokenize with max_length padding to match OpenPI's expected format + tokenized = tokenizer( + tasks, + padding="max_length", + padding_side="right", + truncation=True, + max_length=DUMMY_MAX_TOKEN_LEN, + return_tensors="pt", + ) + + lang_tokens = tokenized["input_ids"].to(device) + lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool) + + # Create dummy token_ar_mask and token_loss_mask for OpenPI + token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32) + token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool) + + # Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range) + image_dict = { + "base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0, + "left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0, + "right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0, + } + + # Create image masks (all ones for real images) + image_masks_dict = {} + for key in image_dict: + image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device) + + # Create raw observation object (before preprocessing) + raw_observation = PI0Observation( + state=batch["observation.state"], + images=image_dict, + image_masks=image_masks_dict, + tokenized_prompt=lang_tokens, + tokenized_prompt_mask=lang_masks, + token_ar_mask=token_ar_mask, + token_loss_mask=token_loss_mask, + ) + + # Now use OpenPI's preprocessing + processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False) + + return processed_obs + + def create_original_observation_from_lerobot(lerobot_pi0, batch): """Create observation object compatible with original OpenPI using the exact same inputs as LeRobot.""" _batch_size = batch["observation.state"].shape[0] @@ -251,65 +321,75 @@ def main(): print("Creating dummy data...") batch = create_dummy_data() - print("Creating observation for original PI0 using LeRobot's exact preprocessing...") - pi0_obs = create_original_observation_from_lerobot(lerobot_pi0, batch) + # Test 1: Each model with its own preprocessing (more realistic end-to-end test) + print("\n=== TEST 1: Each model with its own preprocessing ===") + print("Creating observation for OpenPI using OpenPI's own preprocessing...") + pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch) - # Verify both implementations get the same inputs print(f"Task prompt: '{batch['task'][0]}'") - print(f"Tokenized prompt shape: {pi0_obs.tokenized_prompt.shape}") - print(f"Image shapes: {[img.shape for img in pi0_obs.images.values()]}") - print(f"State shape: {pi0_obs.state.shape}") + print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}") + print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}") + print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}") - print("Testing original PI0...") - - # Test training forward pass (returns loss) - print("1. Training forward pass (computing loss):") - original_pi0.train() - original_loss = original_pi0(observation=pi0_obs, actions=batch["action"]) - print(f" Loss shape: {original_loss.shape}, Mean loss: {original_loss.mean().item():.6f}") - - # Test inference (action sampling) with fixed noise for reproducibility - print("2. Inference (action sampling):") + print("Testing OpenPI with own preprocessing...") original_pi0.eval() - - # Create the same noise for both implementations torch.manual_seed(42) # Set seed for reproducibility batch_size = batch["observation.state"].shape[0] noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM) fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE) with torch.no_grad(): - original_actions = original_pi0.sample_actions( - device=DEVICE, observation=pi0_obs, noise=fixed_noise, num_steps=10 + openpi_actions = original_pi0.sample_actions( + device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10 ) - print(f"Original PI0 Actions shape: {original_actions.shape}") - print(f"Original PI0 Actions mean: {original_actions.mean().item():.6f}") - print(f"Original PI0 Actions std: {original_actions.std().item():.6f}") + print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}") + print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}") + print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}") - # Test LeRobot implementation with the same noise - print("\nTesting LeRobot PI0...") + print("Testing LeRobot with own preprocessing...") lerobot_pi0.eval() - - # For LeRobot, we need to modify the batch to force the same noise - # This is more complex since LeRobot generates noise internally torch.manual_seed(42) # Set the same seed with torch.no_grad(): - # lerobot_pi0_actions = lerobot_pi0.select_action(batch) - lerobot_pi0_actions = lerobot_pi0.predict_action_chunk(batch) - print(f"LeRobot actions shape: {lerobot_pi0_actions.shape}") - print(f"LeRobot actions mean: {lerobot_pi0_actions.mean().item():.6f}") - print(f"LeRobot actions std: {lerobot_pi0_actions.std().item():.6f}") + lerobot_actions_own = lerobot_pi0.predict_action_chunk(batch) + print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}") + print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}") + print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}") - print("\nComparing implementations:") - print(f"Original actions shape: {original_actions.shape}") - print(f"LeRobot actions shape: {lerobot_pi0_actions.shape}") + print("\nComparing end-to-end implementations:") + print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}") + print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}") + print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}") - # Compare the first action step (since LeRobot select_action returns a single step) - print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_pi0_actions, original_actions, atol=1e-4)}") - print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_pi0_actions, original_actions, atol=1e-2)}") - print(f"Max absolute difference: {torch.abs(lerobot_pi0_actions - original_actions).max().item():.6f}") + # Test 2: Both models with LeRobot preprocessing (isolates model differences) + print("\n=== TEST 2: Both models with LeRobot preprocessing (model comparison) ===") + print("Creating observation for OpenPI using LeRobot's preprocessing...") + pi0_obs_lerobot = create_original_observation_from_lerobot(lerobot_pi0, batch) - print("\nOriginal PI0 test completed successfully!") + print("Testing OpenPI with LeRobot preprocessing...") + torch.manual_seed(42) # Set seed for reproducibility + with torch.no_grad(): + openpi_actions_lerobot_preproc = original_pi0.sample_actions( + device=DEVICE, observation=pi0_obs_lerobot, noise=fixed_noise, num_steps=10 + ) + print(f"OpenPI (LeRobot preprocessing) Actions shape: {openpi_actions_lerobot_preproc.shape}") + print(f"OpenPI (LeRobot preprocessing) Actions mean: {openpi_actions_lerobot_preproc.mean().item():.6f}") + print(f"OpenPI (LeRobot preprocessing) Actions std: {openpi_actions_lerobot_preproc.std().item():.6f}") + + print("\nComparing models with same preprocessing:") + print( + f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-4)}" + ) + print( + f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions_lerobot_preproc, atol=1e-2)}" + ) + print( + f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions_lerobot_preproc).max().item():.6f}" + ) + + print("\n=== SUMMARY ===") + print("Test 1 compares end-to-end pipelines (each model with its own preprocessing)") + print("Test 2 isolates model differences (both models with LeRobot preprocessing)") + print("Both tests completed successfully!") if __name__ == "__main__": From c5a029a28a7e6aa701d199da1f608b8ee32dba25 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sat, 13 Sep 2025 11:12:54 +0200 Subject: [PATCH 10/10] also compile forward method --- src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py index 120791cc1..549dc0a9b 100644 --- a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py @@ -517,6 +517,8 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` if config.compile_model: torch.set_float32_matmul_precision("high") self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode) + # Also compile the main forward pass used during training + self.forward = torch.compile(self.forward, mode=config.compile_mode) msg = """transformers_replace is not installed correctly. Please install it with `pip install transformers==4.53.2`