speedup pi-05 modeling loading by 72s

This commit is contained in:
Jeremiah Coholich
2026-02-20 15:41:44 -05:00
parent 5865170d36
commit 49444652c6

View File

@@ -967,7 +967,13 @@ class PI05Policy(PreTrainedPolicy):
# Initialize model without loading weights
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
if _transformers_available:
from transformers.modeling_utils import no_init_weights
with no_init_weights():
model = cls(config, **kwargs)
model.model.paligemma_with_expert.paligemma.tie_weights()
else:
model = cls(config, **kwargs)
# Now manually load and remap the state dict
try: