From d70c810416c4deaad24cf71ad65681f25ad33e9f Mon Sep 17 00:00:00 2001 From: pepijn Date: Tue, 26 May 2026 11:47:49 +0000 Subject: [PATCH] =?UTF-8?q?pi052:=20drop=20``use=5Fhf=5Fkernels``=20flag?= =?UTF-8?q?=20=E2=80=94=20always=20patch=20Liger=20kernels?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The flag gated a process-global, idempotent Liger patch that swaps in fused Triton rope / geglu / layer_norm kernels (~4.5 % step time on H100, bench job 22161421). Since liger-kernel is now a hard dependency of the loss path (``_shifted_lin_ce`` / ``_fast_lin_ce`` in ``modeling_pi052``), gating the same dep behind an opt-in flag was redundant — every pi052 run pulls the wheel in either way. * ``PI052Policy.__init__`` calls ``_enable_hf_kernels()`` unconditionally; the function still degrades gracefully if the wheel happens to be missing (logs a warning, returns). * Drop ``PI052Config.use_hf_kernels``; the bench numbers and the ``fused_linear_cross_entropy`` pointer to ``_shifted_lin_ce`` / ``_fast_lin_ce`` are kept as comments next to the docstring. * Update the warning + ``_shifted_lin_ce`` lazy-import comment to drop stale ``use_hf_kernels`` / ``reduce-overhead`` references. Co-authored-by: Cursor --- .../policies/pi052/configuration_pi052.py | 31 +++++-------------- src/lerobot/policies/pi052/modeling_pi052.py | 19 ++++++------ 2 files changed, 17 insertions(+), 33 deletions(-) diff --git a/src/lerobot/policies/pi052/configuration_pi052.py b/src/lerobot/policies/pi052/configuration_pi052.py index 2b02a2baa..f18433f07 100644 --- a/src/lerobot/policies/pi052/configuration_pi052.py +++ b/src/lerobot/policies/pi052/configuration_pi052.py @@ -190,30 +190,13 @@ class PI052Config(PI05Config): # commonly cited weight; set 0 to disable entirely. text_ce_z_loss_weight: float = 1e-4 - # Fused kernels (Liger via HF kernels lib) --------------------------- - # Patches PaliGemma / Gemma / Siglip ops with Liger Triton kernels - # before the model is built. Measured on H100 80GB at BS=16 / L=512 - # with KI+GC on (bench job 22161421, see - # ``examples/benchmark/bench_pi052_kernels.slurm``): - # - # rope only → −2.5% step time - # geglu only → −2.2% step time - # layer_norm only → −1.1% step time - # all three → −4.5% step time, peak_mem unchanged - # - # ``fused_linear_cross_entropy`` is now wired directly into the - # pi052 forward via ``_shifted_lin_ce`` / ``_fast_lin_ce`` (see - # ``modeling_pi052``). The kernel takes ``(hidden_states, - # lm_head.weight, labels)`` and computes matmul + softmax + CE in - # fused Triton blocks, never materialising the (B, T, 257k) logits - # tensor. Saves ~10 GB activation memory per CE branch and ~30 % - # step time on the dual-CE pi052 recipe (text + FAST). Removing the - # ``.any().item()`` sync also lets ``compile_mode=reduce-overhead`` - # capture full CUDA graphs over the loss path. - use_hf_kernels: bool = False - """If True, monkey-patch PaliGemma/Gemma/Siglip layers with Liger's - fused Triton kernels (rope + geglu + layer_norm). Off by default; - requires ``pip install liger-kernel``.""" + # Liger Triton kernels (rope + geglu + layer_norm) are now patched + # unconditionally at model build time — see ``_enable_hf_kernels`` + # in ``modeling_pi052``. The patch is process-global, idempotent + # and degrades gracefully if ``liger-kernel`` is missing. Measured + # at -4.5% step time on H100 (bench job 22161421); peak memory + # unchanged. ``fused_linear_cross_entropy`` ships separately via + # ``_shifted_lin_ce`` / ``_fast_lin_ce``. def __post_init__(self) -> None: super().__post_init__() diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 5f942753c..cb491b7ed 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -77,8 +77,9 @@ def _enable_hf_kernels() -> None: from liger_kernel.transformers import apply_liger_kernel_to_paligemma # noqa: PLC0415 except ImportError: logger.warning( - "PI052: use_hf_kernels=True but liger-kernel is not installed; " - "skipping. Install with `pip install liger-kernel`." + "PI052: liger-kernel is not installed; skipping fused Triton " + "kernels (rope/geglu/layer_norm). Install with " + "``pip install liger-kernel`` for a ~4.5%% step speedup." ) return apply_liger_kernel_to_paligemma( @@ -126,15 +127,14 @@ def _shifted_lin_ce( * Shift convention identical to the eager version — hidden at position ``t`` predicts label at ``t+1``; ``ignore_index=-100``. * No ``.any().item()`` sync — Liger returns 0.0 cleanly when - every label is ignored, keeping the graph capturable for - ``compile_mode=reduce-overhead`` (CUDA graphs). + every label is ignored. * ``z_loss_weight`` maps directly to Liger's ``lse_square_scale`` (same ``z²·w`` formula on per-position logsumexp). Setting it to 0 disables the z-loss term at zero cost. """ # Liger is imported lazily so the module still imports on machines - # without liger-kernel; the call site only ever runs after - # use_hf_kernels / training has selected the Liger path. + # without liger-kernel — the call site only fires from the training + # forward, which always pulls in the kernel. from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: PLC0415 LigerFusedLinearCrossEntropyLoss, ) @@ -434,9 +434,10 @@ class PI052Policy(PI05Policy): def __init__(self, config: PI052Config, **kwargs: Any) -> None: # Patch ops BEFORE the backbone is built (super().__init__ below # constructs PaliGemmaWithExpertModel which instantiates the - # Gemma/Siglip layers we want to swap). - if getattr(config, "use_hf_kernels", False): - _enable_hf_kernels() + # Gemma/Siglip layers we want to swap). Always-on — the patch + # is process-global / idempotent and degrades gracefully if + # liger-kernel is missing. + _enable_hf_kernels() super().__init__(config, **kwargs) # ``PI05Policy.__init__`` zeroes the PaliGemma ``lm_head`` and