diff --git a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py index b97fb2acf..b52d70143 100644 --- a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py +++ b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py @@ -14,9 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import builtins import logging import math from collections import deque +from pathlib import Path from typing import Literal import torch @@ -28,10 +30,11 @@ from transformers.models.gemma import modeling_gemma from transformers.models.gemma.modeling_gemma import GemmaForCausalLM from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration +from lerobot.configs.policies import PreTrainedConfig from lerobot.constants import ACTION, OBS_STATE from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig -from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.pretrained import PreTrainedPolicy, T # Helper functions @@ -865,10 +868,24 @@ class PI05OpenPIPolicy(PreTrainedPolicy): self.reset() @classmethod - def from_pretrained(cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs): + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + config: PreTrainedConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + strict: bool = True, + **kwargs, + ) -> T: """Override the from_pretrained method to handle key remapping and display important disclaimer.""" print( - "⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n" + "⚠️ DISCLAIMER: The PI05OpenPI model is a direct PyTorch port of the OpenPI implementation. \n" " This implementation follows the original OpenPI structure for compatibility. \n" " Original implementation: https://github.com/Physical-Intelligence/openpi" ) @@ -876,12 +893,23 @@ class PI05OpenPIPolicy(PreTrainedPolicy): raise ValueError("pretrained_name_or_path is required") # Use provided config if available, otherwise create default config - config = kwargs.get("config", cls.config_class()) + if config is None: + config = PreTrainedConfig.from_pretrained( + pretrained_name_or_path=pretrained_name_or_path, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + **kwargs, + ) # Initialize model without loading weights # Check if dataset_stats were provided in kwargs - dataset_stats = kwargs.get("dataset_stats") - model = cls(config=config, dataset_stats=dataset_stats) + dataset_stats = kwargs.get("dataset_stats") # TODO(Adil, Pepijn): Remove this with pipeline + model = cls(config, dataset_stats=dataset_stats, **kwargs) # Now manually load and remap the state dict try: diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py index db137f025..6fab67118 100644 --- a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py @@ -14,9 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import builtins import logging import math from collections import deque +from pathlib import Path from typing import Literal import torch @@ -28,10 +30,11 @@ from transformers.models.gemma import modeling_gemma from transformers.models.gemma.modeling_gemma import GemmaForCausalLM from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration +from lerobot.configs.policies import PreTrainedConfig from lerobot.constants import ACTION, OBS_STATE from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig -from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.pretrained import PreTrainedPolicy, T # Helper functions @@ -882,7 +885,21 @@ class PI0OpenPIPolicy(PreTrainedPolicy): self.reset() @classmethod - def from_pretrained(cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs): + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + config: PreTrainedConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + strict: bool = True, + **kwargs, + ) -> T: """Override the from_pretrained method to handle key remapping and display important disclaimer.""" print( "⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n" @@ -893,12 +910,23 @@ class PI0OpenPIPolicy(PreTrainedPolicy): raise ValueError("pretrained_name_or_path is required") # Use provided config if available, otherwise create default config - config = kwargs.get("config", cls.config_class()) + if config is None: + config = PreTrainedConfig.from_pretrained( + pretrained_name_or_path=pretrained_name_or_path, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + **kwargs, + ) # Initialize model without loading weights # Check if dataset_stats were provided in kwargs - dataset_stats = kwargs.get("dataset_stats") - model = cls(config=config, dataset_stats=dataset_stats) + dataset_stats = kwargs.get("dataset_stats") # TODO(Adil, Pepijn): Remove this with pipeline + model = cls(config, dataset_stats=dataset_stats, **kwargs) # Now manually load and remap the state dict try: