diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index e1f7b0156..39603cff6 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -61,6 +61,79 @@ from .wall_x.configuration_wall_x import WallXConfig from .xvla.configuration_xvla import XVLAConfig +def _restore_pi052_pretrained_state( + preprocessor: PolicyProcessorPipeline, + postprocessor: PolicyProcessorPipeline, + pretrained_path: str, +) -> None: + """Transplant saved stateful blobs from a pi052 checkpoint into fresh pipelines. + + pi052's preprocessor includes steps whose constructor args don't + JSON-roundtrip (``RenderMessagesStep.recipe`` is a Python object, + ``ActionTokenizerProcessorStep.action_tokenizer_name`` is a + fitted-tokenizer path that may not exist at eval time). We rebuild + those pipelines fresh from ``config.recipe_path`` and then walk + over the saved ``policy_{pre,post}processor.json`` files to find + each step's ``state_file`` reference and load the bytes back into + the corresponding fresh step. Today that's only the + NormalizerProcessorStep / UnnormalizerProcessorStep (the action / + state quantile stats), but the loop is generic so any future + stateful step picks up its blob automatically. + + Pairing is by ``registry_name`` AND position so a benign reorder + on the saved side surfaces a warning rather than silently feeding + the wrong tensors into the wrong step. + """ + import json # noqa: PLC0415 + import logging # noqa: PLC0415 + from pathlib import Path # noqa: PLC0415 + + from safetensors.torch import load_file # noqa: PLC0415 + + base = Path(pretrained_path) + if not base.exists(): + return + + log = logging.getLogger(__name__) + + for pipeline, config_filename in [ + (preprocessor, f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"), + (postprocessor, f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"), + ]: + config_path = base / config_filename + if not config_path.exists(): + continue + saved = json.loads(config_path.read_text()) + + for idx, (saved_step, fresh_step) in enumerate( + zip(saved.get("steps", []), pipeline.steps, strict=False) + ): + state_file = saved_step.get("state_file") + if not state_file: + continue + saved_name = saved_step.get("registry_name") + fresh_name = getattr(type(fresh_step), "_registry_name", None) + if saved_name and fresh_name and saved_name != fresh_name: + log.warning( + "PI052 state restore: %s step %d registry name mismatch " + "(saved=%s, fresh=%s); skipping %s", + config_filename, idx, saved_name, fresh_name, state_file, + ) + continue + state_path = base / state_file + if not state_path.exists(): + log.warning( + "PI052 state restore: %s missing at %s; %s left at fresh init", + state_file, base, fresh_name, + ) + continue + fresh_step.load_state_dict(load_file(str(state_path))) + log.info( + "PI052 state restore: loaded %s into %s (step %d)", + state_file, fresh_name, idx, + ) + + def _reconnect_relative_absolute_steps( preprocessor: PolicyProcessorPipeline, postprocessor: PolicyProcessorPipeline ) -> None: @@ -277,6 +350,29 @@ def make_pre_post_processors( NotImplementedError: If a processor factory is not implemented for the given policy configuration type. """ + if pretrained_path and getattr(policy_cfg, "type", None) == "pi052": + # pi052 pipelines don't roundtrip through the saved + # ``policy_preprocessor.json``: ``RenderMessagesStep`` holds a + # Python ``TrainingRecipe`` (not JSON-serializable; saved as + # ``{}``) and ``ActionTokenizerProcessorStep`` saves a host-only + # FAST tokenizer path. Generic ``from_pretrained`` then dies + # with ``RenderMessagesStep.__init__() missing 1 required + # positional argument: 'recipe'`` (job 22164494). + # + # Mirror ``lerobot_pi052_runtime``'s bootstrap: build pipelines + # fresh from ``config.recipe_path`` and transplant the saved + # stateful blobs (normalizer stats) from the checkpoint dir. + from .pi052.processor_pi052 import make_pi052_pre_post_processors + + preprocessor, postprocessor = make_pi052_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + dataset_repo_id=kwargs.get("dataset_repo_id"), + ) + _restore_pi052_pretrained_state(preprocessor, postprocessor, pretrained_path) + _reconnect_relative_absolute_steps(preprocessor, postprocessor) + return preprocessor, postprocessor + if pretrained_path: # TODO(Steven): Temporary patch, implement correctly the processors for Gr00t if isinstance(policy_cfg, GrootConfig): diff --git a/src/lerobot/policies/pi052/configuration_pi052.py b/src/lerobot/policies/pi052/configuration_pi052.py index f18433f07..d2725891f 100644 --- a/src/lerobot/policies/pi052/configuration_pi052.py +++ b/src/lerobot/policies/pi052/configuration_pi052.py @@ -197,6 +197,14 @@ class PI052Config(PI05Config): # at -4.5% step time on H100 (bench job 22161421); peak memory # unchanged. ``fused_linear_cross_entropy`` ships separately via # ``_shifted_lin_ce`` / ``_fast_lin_ce``. + use_hf_kernels: bool = True + """Deprecated. Liger HF kernels are patched unconditionally by + ``_enable_hf_kernels`` — this field is retained as a no-op for + backward compatibility with checkpoints saved before commit + d70c8104 (which still serialize ``use_hf_kernels: true`` into + ``config.json``). Loading those configs would otherwise raise + ``DecodingError: The fields use_hf_kernels are not valid for + PI052Config`` (job 22164492). Remove in a future major bump.""" def __post_init__(self) -> None: super().__post_init__()