fix from pretrained

This commit is contained in:
Pepijn
2025-09-17 18:52:32 +02:00
parent 64974c38c2
commit 9461b9f8d5
2 changed files with 67 additions and 11 deletions

View File

@@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
import logging
import math
from collections import deque
from pathlib import Path
from typing import Literal
import torch
@@ -28,10 +30,11 @@ from transformers.models.gemma import modeling_gemma
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
from lerobot.configs.policies import PreTrainedConfig
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.pretrained import PreTrainedPolicy, T
# Helper functions
@@ -865,10 +868,24 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
self.reset()
@classmethod
def from_pretrained(cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs):
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: PreTrainedConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool = True,
**kwargs,
) -> T:
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
print(
"⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
"⚠️ DISCLAIMER: The PI05OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
" This implementation follows the original OpenPI structure for compatibility. \n"
" Original implementation: https://github.com/Physical-Intelligence/openpi"
)
@@ -876,12 +893,23 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
raise ValueError("pretrained_name_or_path is required")
# Use provided config if available, otherwise create default config
config = kwargs.get("config", cls.config_class())
if config is None:
config = PreTrainedConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
# Initialize model without loading weights
# Check if dataset_stats were provided in kwargs
dataset_stats = kwargs.get("dataset_stats")
model = cls(config=config, dataset_stats=dataset_stats)
dataset_stats = kwargs.get("dataset_stats") # TODO(Adil, Pepijn): Remove this with pipeline
model = cls(config, dataset_stats=dataset_stats, **kwargs)
# Now manually load and remap the state dict
try:

View File

@@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
import logging
import math
from collections import deque
from pathlib import Path
from typing import Literal
import torch
@@ -28,10 +30,11 @@ from transformers.models.gemma import modeling_gemma
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
from lerobot.configs.policies import PreTrainedConfig
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.pretrained import PreTrainedPolicy, T
# Helper functions
@@ -882,7 +885,21 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
self.reset()
@classmethod
def from_pretrained(cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs):
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: PreTrainedConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool = True,
**kwargs,
) -> T:
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
print(
"⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
@@ -893,12 +910,23 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
raise ValueError("pretrained_name_or_path is required")
# Use provided config if available, otherwise create default config
config = kwargs.get("config", cls.config_class())
if config is None:
config = PreTrainedConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
# Initialize model without loading weights
# Check if dataset_stats were provided in kwargs
dataset_stats = kwargs.get("dataset_stats")
model = cls(config=config, dataset_stats=dataset_stats)
dataset_stats = kwargs.get("dataset_stats") # TODO(Adil, Pepijn): Remove this with pipeline
model = cls(config, dataset_stats=dataset_stats, **kwargs)
# Now manually load and remap the state dict
try: