diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 104ec63bf..3929e9ebf 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -967,7 +967,13 @@ class PI05Policy(PreTrainedPolicy): # Initialize model without loading weights # Check if dataset_stats were provided in kwargs - model = cls(config, **kwargs) + if _transformers_available: + from transformers.modeling_utils import no_init_weights + with no_init_weights(): + model = cls(config, **kwargs) + model.model.paligemma_with_expert.paligemma.tie_weights() + else: + model = cls(config, **kwargs) # Now manually load and remap the state dict try: