mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 02:11:25 +00:00
pi052: make `lerobot-eval` work on saved checkpoints
pi052's preprocessor pipelines don't roundtrip through the saved
``policy_preprocessor.json``: ``RenderMessagesStep`` holds a
``TrainingRecipe`` Python object (not JSON-serializable, saved as
``{}``) and ``ActionTokenizerProcessorStep`` saves the fitted FAST
tokenizer's host-only ``~/.cache/lerobot/fast_tokenizers/...`` path.
``PolicyProcessorPipeline.from_pretrained`` then dies with
``RenderMessagesStep.__init__() missing 1 required positional
argument: 'recipe'`` (job 22164494).
The pi052 training path was workable because the recipe-aware steps
were built directly; the runtime path
(``lerobot.scripts.lerobot_pi052_runtime``) sidesteps the loader by
passing ``pretrained_path=None`` to ``make_pre_post_processors`` and
building fresh from ``config.recipe_path``. The standard
``lerobot-eval`` entry point had no such escape hatch.
Two surgical fixes:
* ``factory.make_pre_post_processors``: when ``policy_cfg.type ==
"pi052"`` AND ``pretrained_path`` is set, bypass the generic
``PolicyProcessorPipeline.from_pretrained`` call. Build the
pipelines fresh via ``make_pi052_pre_post_processors`` (same
bootstrap the runtime uses) and transplant the saved stateful
blobs from each step's ``state_file`` reference in the saved JSON
(today: NormalizerProcessorStep + UnnormalizerProcessorStep
quantile stats). Pairing is by ``registry_name`` AND position so
a benign reorder logs a warning instead of silently mis-loading.
* ``PI052Config.use_hf_kernels``: re-add as a deprecated no-op
field. The flag was removed in d70c8104 (Liger kernels became
unconditional), but checkpoints saved before that commit
serialize ``use_hf_kernels: true`` into ``config.json``. Without
this field draccus rejects the load with ``DecodingError: The
fields use_hf_kernels are not valid for PI052Config`` (job
22164492). Mark for removal in a future major bump.
Together these let an external ``lerobot-eval --policy.path=<pi052
checkpoint>`` invocation evaluate the model against any env.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -61,6 +61,79 @@ from .wall_x.configuration_wall_x import WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig
|
||||
|
||||
|
||||
def _restore_pi052_pretrained_state(
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
pretrained_path: str,
|
||||
) -> None:
|
||||
"""Transplant saved stateful blobs from a pi052 checkpoint into fresh pipelines.
|
||||
|
||||
pi052's preprocessor includes steps whose constructor args don't
|
||||
JSON-roundtrip (``RenderMessagesStep.recipe`` is a Python object,
|
||||
``ActionTokenizerProcessorStep.action_tokenizer_name`` is a
|
||||
fitted-tokenizer path that may not exist at eval time). We rebuild
|
||||
those pipelines fresh from ``config.recipe_path`` and then walk
|
||||
over the saved ``policy_{pre,post}processor.json`` files to find
|
||||
each step's ``state_file`` reference and load the bytes back into
|
||||
the corresponding fresh step. Today that's only the
|
||||
NormalizerProcessorStep / UnnormalizerProcessorStep (the action /
|
||||
state quantile stats), but the loop is generic so any future
|
||||
stateful step picks up its blob automatically.
|
||||
|
||||
Pairing is by ``registry_name`` AND position so a benign reorder
|
||||
on the saved side surfaces a warning rather than silently feeding
|
||||
the wrong tensors into the wrong step.
|
||||
"""
|
||||
import json # noqa: PLC0415
|
||||
import logging # noqa: PLC0415
|
||||
from pathlib import Path # noqa: PLC0415
|
||||
|
||||
from safetensors.torch import load_file # noqa: PLC0415
|
||||
|
||||
base = Path(pretrained_path)
|
||||
if not base.exists():
|
||||
return
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
for pipeline, config_filename in [
|
||||
(preprocessor, f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"),
|
||||
(postprocessor, f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"),
|
||||
]:
|
||||
config_path = base / config_filename
|
||||
if not config_path.exists():
|
||||
continue
|
||||
saved = json.loads(config_path.read_text())
|
||||
|
||||
for idx, (saved_step, fresh_step) in enumerate(
|
||||
zip(saved.get("steps", []), pipeline.steps, strict=False)
|
||||
):
|
||||
state_file = saved_step.get("state_file")
|
||||
if not state_file:
|
||||
continue
|
||||
saved_name = saved_step.get("registry_name")
|
||||
fresh_name = getattr(type(fresh_step), "_registry_name", None)
|
||||
if saved_name and fresh_name and saved_name != fresh_name:
|
||||
log.warning(
|
||||
"PI052 state restore: %s step %d registry name mismatch "
|
||||
"(saved=%s, fresh=%s); skipping %s",
|
||||
config_filename, idx, saved_name, fresh_name, state_file,
|
||||
)
|
||||
continue
|
||||
state_path = base / state_file
|
||||
if not state_path.exists():
|
||||
log.warning(
|
||||
"PI052 state restore: %s missing at %s; %s left at fresh init",
|
||||
state_file, base, fresh_name,
|
||||
)
|
||||
continue
|
||||
fresh_step.load_state_dict(load_file(str(state_path)))
|
||||
log.info(
|
||||
"PI052 state restore: loaded %s into %s (step %d)",
|
||||
state_file, fresh_name, idx,
|
||||
)
|
||||
|
||||
|
||||
def _reconnect_relative_absolute_steps(
|
||||
preprocessor: PolicyProcessorPipeline, postprocessor: PolicyProcessorPipeline
|
||||
) -> None:
|
||||
@@ -277,6 +350,29 @@ def make_pre_post_processors(
|
||||
NotImplementedError: If a processor factory is not implemented for the given
|
||||
policy configuration type.
|
||||
"""
|
||||
if pretrained_path and getattr(policy_cfg, "type", None) == "pi052":
|
||||
# pi052 pipelines don't roundtrip through the saved
|
||||
# ``policy_preprocessor.json``: ``RenderMessagesStep`` holds a
|
||||
# Python ``TrainingRecipe`` (not JSON-serializable; saved as
|
||||
# ``{}``) and ``ActionTokenizerProcessorStep`` saves a host-only
|
||||
# FAST tokenizer path. Generic ``from_pretrained`` then dies
|
||||
# with ``RenderMessagesStep.__init__() missing 1 required
|
||||
# positional argument: 'recipe'`` (job 22164494).
|
||||
#
|
||||
# Mirror ``lerobot_pi052_runtime``'s bootstrap: build pipelines
|
||||
# fresh from ``config.recipe_path`` and transplant the saved
|
||||
# stateful blobs (normalizer stats) from the checkpoint dir.
|
||||
from .pi052.processor_pi052 import make_pi052_pre_post_processors
|
||||
|
||||
preprocessor, postprocessor = make_pi052_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_repo_id=kwargs.get("dataset_repo_id"),
|
||||
)
|
||||
_restore_pi052_pretrained_state(preprocessor, postprocessor, pretrained_path)
|
||||
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
if pretrained_path:
|
||||
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
|
||||
if isinstance(policy_cfg, GrootConfig):
|
||||
|
||||
@@ -197,6 +197,14 @@ class PI052Config(PI05Config):
|
||||
# at -4.5% step time on H100 (bench job 22161421); peak memory
|
||||
# unchanged. ``fused_linear_cross_entropy`` ships separately via
|
||||
# ``_shifted_lin_ce`` / ``_fast_lin_ce``.
|
||||
use_hf_kernels: bool = True
|
||||
"""Deprecated. Liger HF kernels are patched unconditionally by
|
||||
``_enable_hf_kernels`` — this field is retained as a no-op for
|
||||
backward compatibility with checkpoints saved before commit
|
||||
d70c8104 (which still serialize ``use_hf_kernels: true`` into
|
||||
``config.json``). Loading those configs would otherwise raise
|
||||
``DecodingError: The fields use_hf_kernels are not valid for
|
||||
PI052Config`` (job 22164492). Remove in a future major bump."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
|
||||
Reference in New Issue
Block a user