mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
fix from pretrained
This commit is contained in:
@@ -14,9 +14,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
@@ -28,10 +30,11 @@ from transformers.models.gemma import modeling_gemma
|
||||
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
|
||||
|
||||
# Helper functions
|
||||
@@ -865,10 +868,24 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
self.reset()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs):
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: PreTrainedConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool = True,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
|
||||
print(
|
||||
"⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
|
||||
"⚠️ DISCLAIMER: The PI05OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
|
||||
" This implementation follows the original OpenPI structure for compatibility. \n"
|
||||
" Original implementation: https://github.com/Physical-Intelligence/openpi"
|
||||
)
|
||||
@@ -876,12 +893,23 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
raise ValueError("pretrained_name_or_path is required")
|
||||
|
||||
# Use provided config if available, otherwise create default config
|
||||
config = kwargs.get("config", cls.config_class())
|
||||
if config is None:
|
||||
config = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Initialize model without loading weights
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
dataset_stats = kwargs.get("dataset_stats")
|
||||
model = cls(config=config, dataset_stats=dataset_stats)
|
||||
dataset_stats = kwargs.get("dataset_stats") # TODO(Adil, Pepijn): Remove this with pipeline
|
||||
model = cls(config, dataset_stats=dataset_stats, **kwargs)
|
||||
|
||||
# Now manually load and remap the state dict
|
||||
try:
|
||||
|
||||
@@ -14,9 +14,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
@@ -28,10 +30,11 @@ from transformers.models.gemma import modeling_gemma
|
||||
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
|
||||
|
||||
# Helper functions
|
||||
@@ -882,7 +885,21 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
self.reset()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_name_or_path: str, strict: bool = True, *args, **kwargs):
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: PreTrainedConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool = True,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Override the from_pretrained method to handle key remapping and display important disclaimer."""
|
||||
print(
|
||||
"⚠️ DISCLAIMER: The PI0OpenPI model is a direct PyTorch port of the OpenPI implementation. \n"
|
||||
@@ -893,12 +910,23 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
raise ValueError("pretrained_name_or_path is required")
|
||||
|
||||
# Use provided config if available, otherwise create default config
|
||||
config = kwargs.get("config", cls.config_class())
|
||||
if config is None:
|
||||
config = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Initialize model without loading weights
|
||||
# Check if dataset_stats were provided in kwargs
|
||||
dataset_stats = kwargs.get("dataset_stats")
|
||||
model = cls(config=config, dataset_stats=dataset_stats)
|
||||
dataset_stats = kwargs.get("dataset_stats") # TODO(Adil, Pepijn): Remove this with pipeline
|
||||
model = cls(config, dataset_stats=dataset_stats, **kwargs)
|
||||
|
||||
# Now manually load and remap the state dict
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user