mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
pi052: opt-in Liger fused kernels (rope + geglu + layer_norm)
Adds ``PI052Config.use_hf_kernels`` (default off). When enabled, ``PI052Policy.__init__`` calls ``apply_liger_kernel_to_paligemma`` before the backbone is built so PaliGemma / Gemma / Siglip layers pick up Liger's fused Triton forwards. Measured at BS=16 / L=512 / H100 80GB with KI+GC on (bench job 22161421, see ``examples/benchmark/bench_pi052_kernels.slurm``): rope only → -2.5% step time geglu only → -2.2% step time layer_norm only → -1.1% step time all three → -4.5% step time, peak_mem unchanged ``cross_entropy`` / ``fused_linear_cross_entropy`` are deliberately skipped — pi052 calls ``F.cross_entropy`` directly and bypasses ``PaliGemmaForConditionalGeneration.forward``, so neither patch fires without invasive model-code changes (left for a follow-up). ``rms_norm`` measured as noise on this workload (GC dominates), so it stays off to keep the patch surface minimal. Requires ``pip install liger-kernel``; falls back to a warning if missing so the default path is unaffected. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -190,6 +190,27 @@ class PI052Config(PI05Config):
|
||||
# commonly cited weight; set 0 to disable entirely.
|
||||
text_ce_z_loss_weight: float = 1e-4
|
||||
|
||||
# Fused kernels (Liger via HF kernels lib) ---------------------------
|
||||
# Patches PaliGemma / Gemma / Siglip ops with Liger Triton kernels
|
||||
# before the model is built. Measured on H100 80GB at BS=16 / L=512
|
||||
# with KI+GC on (bench job 22161421, see
|
||||
# ``examples/benchmark/bench_pi052_kernels.slurm``):
|
||||
#
|
||||
# rope only → −2.5% step time
|
||||
# geglu only → −2.2% step time
|
||||
# layer_norm only → −1.1% step time
|
||||
# all three → −4.5% step time, peak_mem unchanged
|
||||
#
|
||||
# ``cross_entropy`` / ``fused_linear_cross_entropy`` are NOT enabled
|
||||
# — pi052 calls ``F.cross_entropy`` directly and bypasses
|
||||
# ``PaliGemmaForConditionalGeneration.forward``, so neither Liger
|
||||
# patch fires without invasive model-code changes. Reserved for a
|
||||
# follow-up.
|
||||
use_hf_kernels: bool = False
|
||||
"""If True, monkey-patch PaliGemma/Gemma/Siglip layers with Liger's
|
||||
fused Triton kernels (rope + geglu + layer_norm). Off by default;
|
||||
requires ``pip install liger-kernel``."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
# Backbone needs gradients flowing through the text head when
|
||||
|
||||
@@ -55,6 +55,44 @@ from .configuration_pi052 import PI052Config
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_HF_KERNELS_ENABLED = False
|
||||
|
||||
|
||||
def _enable_hf_kernels() -> None:
|
||||
"""Patch PaliGemma / Gemma / Siglip layers with Liger fused kernels.
|
||||
|
||||
Must run BEFORE ``PaliGemmaWithExpertModel`` is built — the patch
|
||||
replaces classes in ``transformers.models.{gemma,paligemma,siglip}``,
|
||||
so any model constructed after this picks up the fused forwards.
|
||||
Idempotent (process-global). ``cross_entropy`` / ``fused_linear_*``
|
||||
are deliberately skipped — pi052 uses ``F.cross_entropy`` directly
|
||||
and never traverses ``PaliGemmaForConditionalGeneration.forward``,
|
||||
so those Liger paths wouldn't fire without model-code changes.
|
||||
See bench job 22161421 in ``examples/benchmark/`` for the numbers.
|
||||
"""
|
||||
global _HF_KERNELS_ENABLED
|
||||
if _HF_KERNELS_ENABLED:
|
||||
return
|
||||
try:
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_paligemma # noqa: PLC0415
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"PI052: use_hf_kernels=True but liger-kernel is not installed; "
|
||||
"skipping. Install with `pip install liger-kernel`."
|
||||
)
|
||||
return
|
||||
apply_liger_kernel_to_paligemma(
|
||||
rope=True,
|
||||
geglu=True,
|
||||
layer_norm=True,
|
||||
rms_norm=False,
|
||||
cross_entropy=False,
|
||||
fused_linear_cross_entropy=False,
|
||||
)
|
||||
_HF_KERNELS_ENABLED = True
|
||||
logger.info("PI052: HF kernels (Liger) enabled — rope, geglu, layer_norm fused.")
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Loss helpers (shared between fused and prefix-only paths)
|
||||
# ----------------------------------------------------------------------
|
||||
@@ -358,6 +396,12 @@ class PI052Policy(PI05Policy):
|
||||
name = "pi052"
|
||||
|
||||
def __init__(self, config: PI052Config, **kwargs: Any) -> None:
|
||||
# Patch ops BEFORE the backbone is built (super().__init__ below
|
||||
# constructs PaliGemmaWithExpertModel which instantiates the
|
||||
# Gemma/Siglip layers we want to swap).
|
||||
if getattr(config, "use_hf_kernels", False):
|
||||
_enable_hf_kernels()
|
||||
|
||||
super().__init__(config, **kwargs)
|
||||
# ``PI05Policy.__init__`` zeroes the PaliGemma ``lm_head`` and
|
||||
# freezes a few terminal layers when ``train_expert_only`` is
|
||||
|
||||
Reference in New Issue
Block a user