pi052: opt-in Liger fused kernels (rope + geglu + layer_norm)

Adds ``PI052Config.use_hf_kernels`` (default off). When enabled,
``PI052Policy.__init__`` calls ``apply_liger_kernel_to_paligemma``
before the backbone is built so PaliGemma / Gemma / Siglip layers
pick up Liger's fused Triton forwards.

Measured at BS=16 / L=512 / H100 80GB with KI+GC on (bench job
22161421, see ``examples/benchmark/bench_pi052_kernels.slurm``):

  rope only        →  -2.5% step time
  geglu only       →  -2.2% step time
  layer_norm only  →  -1.1% step time
  all three        →  -4.5% step time, peak_mem unchanged

``cross_entropy`` / ``fused_linear_cross_entropy`` are deliberately
skipped — pi052 calls ``F.cross_entropy`` directly and bypasses
``PaliGemmaForConditionalGeneration.forward``, so neither patch
fires without invasive model-code changes (left for a follow-up).
``rms_norm`` measured as noise on this workload (GC dominates),
so it stays off to keep the patch surface minimal.

Requires ``pip install liger-kernel``; falls back to a warning if
missing so the default path is unaffected.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-25 20:44:02 +00:00
parent 2ed6519a93
commit 673cc6b0fe
2 changed files with 65 additions and 0 deletions

View File

@@ -190,6 +190,27 @@ class PI052Config(PI05Config):
# commonly cited weight; set 0 to disable entirely.
text_ce_z_loss_weight: float = 1e-4
# Fused kernels (Liger via HF kernels lib) ---------------------------
# Patches PaliGemma / Gemma / Siglip ops with Liger Triton kernels
# before the model is built. Measured on H100 80GB at BS=16 / L=512
# with KI+GC on (bench job 22161421, see
# ``examples/benchmark/bench_pi052_kernels.slurm``):
#
# rope only → 2.5% step time
# geglu only → 2.2% step time
# layer_norm only → 1.1% step time
# all three → 4.5% step time, peak_mem unchanged
#
# ``cross_entropy`` / ``fused_linear_cross_entropy`` are NOT enabled
# — pi052 calls ``F.cross_entropy`` directly and bypasses
# ``PaliGemmaForConditionalGeneration.forward``, so neither Liger
# patch fires without invasive model-code changes. Reserved for a
# follow-up.
use_hf_kernels: bool = False
"""If True, monkey-patch PaliGemma/Gemma/Siglip layers with Liger's
fused Triton kernels (rope + geglu + layer_norm). Off by default;
requires ``pip install liger-kernel``."""
def __post_init__(self) -> None:
super().__post_init__()
# Backbone needs gradients flowing through the text head when

View File

@@ -55,6 +55,44 @@ from .configuration_pi052 import PI052Config
logger = logging.getLogger(__name__)
_HF_KERNELS_ENABLED = False
def _enable_hf_kernels() -> None:
"""Patch PaliGemma / Gemma / Siglip layers with Liger fused kernels.
Must run BEFORE ``PaliGemmaWithExpertModel`` is built — the patch
replaces classes in ``transformers.models.{gemma,paligemma,siglip}``,
so any model constructed after this picks up the fused forwards.
Idempotent (process-global). ``cross_entropy`` / ``fused_linear_*``
are deliberately skipped — pi052 uses ``F.cross_entropy`` directly
and never traverses ``PaliGemmaForConditionalGeneration.forward``,
so those Liger paths wouldn't fire without model-code changes.
See bench job 22161421 in ``examples/benchmark/`` for the numbers.
"""
global _HF_KERNELS_ENABLED
if _HF_KERNELS_ENABLED:
return
try:
from liger_kernel.transformers import apply_liger_kernel_to_paligemma # noqa: PLC0415
except ImportError:
logger.warning(
"PI052: use_hf_kernels=True but liger-kernel is not installed; "
"skipping. Install with `pip install liger-kernel`."
)
return
apply_liger_kernel_to_paligemma(
rope=True,
geglu=True,
layer_norm=True,
rms_norm=False,
cross_entropy=False,
fused_linear_cross_entropy=False,
)
_HF_KERNELS_ENABLED = True
logger.info("PI052: HF kernels (Liger) enabled — rope, geglu, layer_norm fused.")
# ----------------------------------------------------------------------
# Loss helpers (shared between fused and prefix-only paths)
# ----------------------------------------------------------------------
@@ -358,6 +396,12 @@ class PI052Policy(PI05Policy):
name = "pi052"
def __init__(self, config: PI052Config, **kwargs: Any) -> None:
# Patch ops BEFORE the backbone is built (super().__init__ below
# constructs PaliGemmaWithExpertModel which instantiates the
# Gemma/Siglip layers we want to swap).
if getattr(config, "use_hf_kernels", False):
_enable_hf_kernels()
super().__init__(config, **kwargs)
# ``PI05Policy.__init__`` zeroes the PaliGemma ``lm_head`` and
# freezes a few terminal layers when ``train_expert_only`` is