From 673cc6b0fe81cac0c8d84e2ff65e9749b794a333 Mon Sep 17 00:00:00 2001 From: pepijn Date: Mon, 25 May 2026 20:44:02 +0000 Subject: [PATCH] pi052: opt-in Liger fused kernels (rope + geglu + layer_norm) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds ``PI052Config.use_hf_kernels`` (default off). When enabled, ``PI052Policy.__init__`` calls ``apply_liger_kernel_to_paligemma`` before the backbone is built so PaliGemma / Gemma / Siglip layers pick up Liger's fused Triton forwards. Measured at BS=16 / L=512 / H100 80GB with KI+GC on (bench job 22161421, see ``examples/benchmark/bench_pi052_kernels.slurm``): rope only → -2.5% step time geglu only → -2.2% step time layer_norm only → -1.1% step time all three → -4.5% step time, peak_mem unchanged ``cross_entropy`` / ``fused_linear_cross_entropy`` are deliberately skipped — pi052 calls ``F.cross_entropy`` directly and bypasses ``PaliGemmaForConditionalGeneration.forward``, so neither patch fires without invasive model-code changes (left for a follow-up). ``rms_norm`` measured as noise on this workload (GC dominates), so it stays off to keep the patch surface minimal. Requires ``pip install liger-kernel``; falls back to a warning if missing so the default path is unaffected. Co-authored-by: Cursor --- .../policies/pi052/configuration_pi052.py | 21 +++++++++ src/lerobot/policies/pi052/modeling_pi052.py | 44 +++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/src/lerobot/policies/pi052/configuration_pi052.py b/src/lerobot/policies/pi052/configuration_pi052.py index c6563a04b..79b058dba 100644 --- a/src/lerobot/policies/pi052/configuration_pi052.py +++ b/src/lerobot/policies/pi052/configuration_pi052.py @@ -190,6 +190,27 @@ class PI052Config(PI05Config): # commonly cited weight; set 0 to disable entirely. text_ce_z_loss_weight: float = 1e-4 + # Fused kernels (Liger via HF kernels lib) --------------------------- + # Patches PaliGemma / Gemma / Siglip ops with Liger Triton kernels + # before the model is built. Measured on H100 80GB at BS=16 / L=512 + # with KI+GC on (bench job 22161421, see + # ``examples/benchmark/bench_pi052_kernels.slurm``): + # + # rope only → −2.5% step time + # geglu only → −2.2% step time + # layer_norm only → −1.1% step time + # all three → −4.5% step time, peak_mem unchanged + # + # ``cross_entropy`` / ``fused_linear_cross_entropy`` are NOT enabled + # — pi052 calls ``F.cross_entropy`` directly and bypasses + # ``PaliGemmaForConditionalGeneration.forward``, so neither Liger + # patch fires without invasive model-code changes. Reserved for a + # follow-up. + use_hf_kernels: bool = False + """If True, monkey-patch PaliGemma/Gemma/Siglip layers with Liger's + fused Triton kernels (rope + geglu + layer_norm). Off by default; + requires ``pip install liger-kernel``.""" + def __post_init__(self) -> None: super().__post_init__() # Backbone needs gradients flowing through the text head when diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index af30d5f8a..9b0c66a4c 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -55,6 +55,44 @@ from .configuration_pi052 import PI052Config logger = logging.getLogger(__name__) +_HF_KERNELS_ENABLED = False + + +def _enable_hf_kernels() -> None: + """Patch PaliGemma / Gemma / Siglip layers with Liger fused kernels. + + Must run BEFORE ``PaliGemmaWithExpertModel`` is built — the patch + replaces classes in ``transformers.models.{gemma,paligemma,siglip}``, + so any model constructed after this picks up the fused forwards. + Idempotent (process-global). ``cross_entropy`` / ``fused_linear_*`` + are deliberately skipped — pi052 uses ``F.cross_entropy`` directly + and never traverses ``PaliGemmaForConditionalGeneration.forward``, + so those Liger paths wouldn't fire without model-code changes. + See bench job 22161421 in ``examples/benchmark/`` for the numbers. + """ + global _HF_KERNELS_ENABLED + if _HF_KERNELS_ENABLED: + return + try: + from liger_kernel.transformers import apply_liger_kernel_to_paligemma # noqa: PLC0415 + except ImportError: + logger.warning( + "PI052: use_hf_kernels=True but liger-kernel is not installed; " + "skipping. Install with `pip install liger-kernel`." + ) + return + apply_liger_kernel_to_paligemma( + rope=True, + geglu=True, + layer_norm=True, + rms_norm=False, + cross_entropy=False, + fused_linear_cross_entropy=False, + ) + _HF_KERNELS_ENABLED = True + logger.info("PI052: HF kernels (Liger) enabled — rope, geglu, layer_norm fused.") + + # ---------------------------------------------------------------------- # Loss helpers (shared between fused and prefix-only paths) # ---------------------------------------------------------------------- @@ -358,6 +396,12 @@ class PI052Policy(PI05Policy): name = "pi052" def __init__(self, config: PI052Config, **kwargs: Any) -> None: + # Patch ops BEFORE the backbone is built (super().__init__ below + # constructs PaliGemmaWithExpertModel which instantiates the + # Gemma/Siglip layers we want to swap). + if getattr(config, "use_hf_kernels", False): + _enable_hf_kernels() + super().__init__(config, **kwargs) # ``PI05Policy.__init__`` zeroes the PaliGemma ``lm_head`` and # freezes a few terminal layers when ``train_expert_only`` is