mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 04:11:24 +00:00
339 lines
12 KiB
Python
339 lines
12 KiB
Python
|
|
#!/usr/bin/env python
|
||
|
|
|
||
|
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||
|
|
#
|
||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
# you may not use this file except in compliance with the License.
|
||
|
|
# You may obtain a copy of the License at
|
||
|
|
#
|
||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
#
|
||
|
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
# See the License for the specific language governing permissions and
|
||
|
|
# limitations under the License.
|
||
|
|
|
||
|
|
"""Benchmark ``PI052Policy.forward + backward`` on a single GPU.
|
||
|
|
|
||
|
|
Compares the new SDPA attention path against the eager baseline by
|
||
|
|
monkeypatching ``sdpa_attention_forward`` before the first model
|
||
|
|
forward — so both runs share identical Q/K/V plumbing and only the
|
||
|
|
attention kernel differs. Reports steps/sec and peak GPU memory.
|
||
|
|
|
||
|
|
SLURM-only:
|
||
|
|
|
||
|
|
sbatch examples/benchmark/bench_pi052_step.slurm
|
||
|
|
|
||
|
|
Or one-off:
|
||
|
|
|
||
|
|
srun --partition=hopper-prod --qos=high --gpus=1 --time=15 \\
|
||
|
|
python examples/benchmark/bench_pi052_step.py --attn sdpa --batch-size 8
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import gc
|
||
|
|
import math
|
||
|
|
import os
|
||
|
|
import time
|
||
|
|
|
||
|
|
import torch
|
||
|
|
|
||
|
|
|
||
|
|
def _maybe_patch_eager() -> None:
|
||
|
|
"""Swap ``sdpa_attention_forward`` for the original eager forward.
|
||
|
|
|
||
|
|
Must be called BEFORE PI052Policy is instantiated — the layer
|
||
|
|
compute functions resolve the symbol at call time (module-level
|
||
|
|
lookup), so this patch covers both pi05 and pi052 KI paths."""
|
||
|
|
from transformers.models.gemma import modeling_gemma
|
||
|
|
|
||
|
|
from lerobot.policies.pi05 import modeling_pi05
|
||
|
|
|
||
|
|
modeling_pi05.sdpa_attention_forward = modeling_gemma.eager_attention_forward
|
||
|
|
|
||
|
|
|
||
|
|
_LIGER_SUBKERNELS = ("rope", "rms_norm", "geglu", "layer_norm")
|
||
|
|
|
||
|
|
|
||
|
|
def _maybe_patch_liger(spec: str) -> dict:
|
||
|
|
"""Globally patch PaliGemma/Gemma/Siglip modules with Liger Triton kernels.
|
||
|
|
|
||
|
|
Must be called BEFORE PI052Policy is instantiated — Liger replaces
|
||
|
|
classes inside ``transformers.models.{gemma,gemma2,siglip,paligemma}``,
|
||
|
|
so any model built after the call picks up the fused forwards.
|
||
|
|
|
||
|
|
``spec`` is a comma-separated subset of {rope, rms_norm, geglu,
|
||
|
|
layer_norm} (also ``all`` and ``none``). ``cross_entropy`` and
|
||
|
|
``fused_linear_cross_entropy`` are intentionally skipped — pi052's
|
||
|
|
losses use ``F.cross_entropy`` directly (not ``nn.CrossEntropyLoss``)
|
||
|
|
and never traverse ``PaliGemmaForConditionalGeneration.forward``,
|
||
|
|
so neither patch would fire without invasive model-code changes.
|
||
|
|
"""
|
||
|
|
enabled = dict.fromkeys(_LIGER_SUBKERNELS, False)
|
||
|
|
if spec in ("", "none"):
|
||
|
|
return enabled
|
||
|
|
tokens = [t.strip() for t in spec.split(",") if t.strip()]
|
||
|
|
if tokens == ["all"]:
|
||
|
|
enabled = dict.fromkeys(_LIGER_SUBKERNELS, True)
|
||
|
|
else:
|
||
|
|
for t in tokens:
|
||
|
|
if t not in enabled:
|
||
|
|
raise SystemExit(f"Unknown liger subkernel: {t!r}. Choose from {_LIGER_SUBKERNELS} or 'all'.")
|
||
|
|
enabled[t] = True
|
||
|
|
|
||
|
|
from liger_kernel.transformers import apply_liger_kernel_to_paligemma
|
||
|
|
|
||
|
|
apply_liger_kernel_to_paligemma(
|
||
|
|
rope=enabled["rope"],
|
||
|
|
rms_norm=enabled["rms_norm"],
|
||
|
|
geglu=enabled["geglu"],
|
||
|
|
layer_norm=enabled["layer_norm"],
|
||
|
|
cross_entropy=False,
|
||
|
|
fused_linear_cross_entropy=False,
|
||
|
|
)
|
||
|
|
return enabled
|
||
|
|
|
||
|
|
|
||
|
|
def _maybe_patch_flex() -> None:
|
||
|
|
"""Swap ``sdpa_attention_forward`` for a FlexAttention-backed forward.
|
||
|
|
|
||
|
|
Experimental: builds a per-call ``score_mod`` from the additive
|
||
|
|
mask and dispatches to a compiled ``flex_attention`` kernel.
|
||
|
|
|
||
|
|
Known issue on torch 2.7.1: dynamo errors out with
|
||
|
|
``FlexAttentionHigherOrderVariable() has no type`` when the
|
||
|
|
``score_mod`` closure captures a per-call bias tensor. A proper
|
||
|
|
port needs ``create_block_mask(mask_mod, ...)`` plumbed at the
|
||
|
|
PI05Pytorch.forward level so a BlockMask object can be passed
|
||
|
|
down to the layer compute, not a per-call closure. Left as
|
||
|
|
future work; keep this stub for benchmark experimentation."""
|
||
|
|
import torch
|
||
|
|
from torch.nn.attention.flex_attention import flex_attention
|
||
|
|
|
||
|
|
from lerobot.policies.pi05 import modeling_pi05
|
||
|
|
|
||
|
|
compiled_flex = torch.compile(flex_attention, dynamic=True)
|
||
|
|
|
||
|
|
def flex_forward(module, query, key, value, attention_mask, scaling, dropout=0.0):
|
||
|
|
n_rep = module.num_key_value_groups
|
||
|
|
if n_rep > 1:
|
||
|
|
key = key.repeat_interleave(n_rep, dim=1)
|
||
|
|
value = value.repeat_interleave(n_rep, dim=1)
|
||
|
|
|
||
|
|
bias = attention_mask # (B, 1, Lq, Lk) additive
|
||
|
|
|
||
|
|
def score_mod(score, b, h, q_idx, kv_idx):
|
||
|
|
return score + bias[b, 0, q_idx, kv_idx]
|
||
|
|
|
||
|
|
attn_output = compiled_flex(query, key, value, score_mod=score_mod, scale=scaling)
|
||
|
|
return attn_output.transpose(1, 2).contiguous(), None
|
||
|
|
|
||
|
|
modeling_pi05.sdpa_attention_forward = flex_forward
|
||
|
|
|
||
|
|
|
||
|
|
def _build_policy(args, device: torch.device):
|
||
|
|
"""Random-init PI052Policy at production-relevant shapes."""
|
||
|
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||
|
|
from lerobot.policies.pi052.configuration_pi052 import PI052Config
|
||
|
|
from lerobot.policies.pi052.modeling_pi052 import PI052Policy
|
||
|
|
|
||
|
|
# Production has ``unfreeze_lm_head=True`` + ``text_loss_weight>0``,
|
||
|
|
# which flips ``train_expert_only=False`` in __post_init__ and
|
||
|
|
# makes the whole PaliGemma + Gemma-expert stack trainable. We
|
||
|
|
# mirror that here so the optimizer-state count reflects reality;
|
||
|
|
# the loss path still goes through ``PI05Policy.forward`` because
|
||
|
|
# ``text_labels`` / FAST tokens are absent from the synthetic batch
|
||
|
|
# (see ``PI052Policy.forward`` early-return).
|
||
|
|
config = PI052Config(
|
||
|
|
max_action_dim=args.action_dim,
|
||
|
|
max_state_dim=args.state_dim,
|
||
|
|
dtype=args.dtype,
|
||
|
|
knowledge_insulation=args.knowledge_insulation,
|
||
|
|
text_loss_weight=1e-3 if args.train_full else 0.0,
|
||
|
|
flow_loss_weight=1.0,
|
||
|
|
enable_fast_action_loss=False,
|
||
|
|
unfreeze_lm_head=args.train_full,
|
||
|
|
tokenizer_max_length=args.lang_tokens,
|
||
|
|
device="cuda",
|
||
|
|
compile_model=args.compile_model,
|
||
|
|
compile_mode=args.compile_mode,
|
||
|
|
)
|
||
|
|
config.input_features = {
|
||
|
|
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(args.state_dim,)),
|
||
|
|
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||
|
|
}
|
||
|
|
config.output_features = {
|
||
|
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(args.action_dim,)),
|
||
|
|
}
|
||
|
|
policy = PI052Policy(config)
|
||
|
|
policy.to(device)
|
||
|
|
if args.gradient_checkpointing:
|
||
|
|
policy.model.gradient_checkpointing_enable()
|
||
|
|
policy.train()
|
||
|
|
return policy, config
|
||
|
|
|
||
|
|
|
||
|
|
def _build_batch(args, config, device: torch.device) -> dict:
|
||
|
|
"""Synthetic batch matching the training-loop input contract."""
|
||
|
|
from lerobot.utils.constants import (
|
||
|
|
ACTION,
|
||
|
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||
|
|
OBS_LANGUAGE_TOKENS,
|
||
|
|
)
|
||
|
|
|
||
|
|
B = args.batch_size
|
||
|
|
L = args.lang_tokens
|
||
|
|
return {
|
||
|
|
OBS_LANGUAGE_TOKENS: torch.randint(0, 250000, (B, L), device=device),
|
||
|
|
OBS_LANGUAGE_ATTENTION_MASK: torch.ones(B, L, dtype=torch.bool, device=device),
|
||
|
|
"observation.images.base_0_rgb": torch.rand(B, 3, 224, 224, device=device),
|
||
|
|
"observation.images.base_0_rgb_padding_mask": torch.ones(B, dtype=torch.bool, device=device),
|
||
|
|
"observation.state": torch.randn(B, args.state_dim, device=device),
|
||
|
|
ACTION: torch.randn(B, config.chunk_size, args.action_dim, device=device),
|
||
|
|
"action_is_pad": torch.zeros(B, config.chunk_size, dtype=torch.bool, device=device),
|
||
|
|
"task": ["bench task"] * B,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def _step(policy, batch, optimizer=None) -> torch.Tensor:
|
||
|
|
loss, _ = policy.forward(batch)
|
||
|
|
loss.backward()
|
||
|
|
if optimizer is not None:
|
||
|
|
optimizer.step()
|
||
|
|
optimizer.zero_grad(set_to_none=True)
|
||
|
|
else:
|
||
|
|
for p in policy.parameters():
|
||
|
|
if p.grad is not None:
|
||
|
|
p.grad = None
|
||
|
|
return loss.detach()
|
||
|
|
|
||
|
|
|
||
|
|
def main() -> int:
|
||
|
|
parser = argparse.ArgumentParser()
|
||
|
|
parser.add_argument("--attn", choices=["sdpa", "eager", "flex"], default="sdpa")
|
||
|
|
parser.add_argument(
|
||
|
|
"--kernels",
|
||
|
|
default="none",
|
||
|
|
help=(
|
||
|
|
"Liger sub-kernels to enable, comma-separated. Choose from "
|
||
|
|
f"{_LIGER_SUBKERNELS} or use 'all' / 'none' (default). Applied "
|
||
|
|
"via apply_liger_kernel_to_paligemma() BEFORE model build."
|
||
|
|
),
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--compile",
|
||
|
|
dest="compile_model",
|
||
|
|
action="store_true",
|
||
|
|
help="Set policy.config.compile_model=True (torch.compile the forward).",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--compile-mode",
|
||
|
|
default="default",
|
||
|
|
help="torch.compile mode (default | reduce-overhead | max-autotune).",
|
||
|
|
)
|
||
|
|
parser.add_argument("--batch-size", type=int, default=8)
|
||
|
|
parser.add_argument("--warmup", type=int, default=8)
|
||
|
|
parser.add_argument("--steps", type=int, default=40)
|
||
|
|
parser.add_argument("--lang-tokens", type=int, default=512)
|
||
|
|
parser.add_argument("--dtype", choices=["bfloat16", "float32"], default="bfloat16")
|
||
|
|
parser.add_argument("--action-dim", type=int, default=14)
|
||
|
|
parser.add_argument("--state-dim", type=int, default=14)
|
||
|
|
parser.add_argument("--knowledge-insulation", action="store_true", default=True)
|
||
|
|
parser.add_argument(
|
||
|
|
"--gradient-checkpointing",
|
||
|
|
dest="gradient_checkpointing",
|
||
|
|
action=argparse.BooleanOptionalAction,
|
||
|
|
default=True,
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--optimizer",
|
||
|
|
choices=["none", "adamw", "adamw_fused"],
|
||
|
|
default="adamw_fused",
|
||
|
|
help=(
|
||
|
|
"Whether to include an AdamW step in the timed iteration. "
|
||
|
|
"'none' mirrors the fwd+bwd-only original bench; 'adamw' / "
|
||
|
|
"'adamw_fused' add the realistic ~2x param-bytes optimizer "
|
||
|
|
"state and ``optimizer.step()`` cost."
|
||
|
|
),
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--train-full",
|
||
|
|
action=argparse.BooleanOptionalAction,
|
||
|
|
default=True,
|
||
|
|
help=(
|
||
|
|
"Mirror production: unfreeze the PaliGemma backbone (full "
|
||
|
|
"~3B trainable params) instead of training only the 300M "
|
||
|
|
"action expert."
|
||
|
|
),
|
||
|
|
)
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
if not torch.cuda.is_available():
|
||
|
|
raise SystemExit("Benchmark requires CUDA; submit via slurm (srun/sbatch).")
|
||
|
|
|
||
|
|
if args.attn == "eager":
|
||
|
|
_maybe_patch_eager()
|
||
|
|
elif args.attn == "flex":
|
||
|
|
_maybe_patch_flex()
|
||
|
|
|
||
|
|
liger_flags = _maybe_patch_liger(args.kernels)
|
||
|
|
|
||
|
|
device = torch.device("cuda")
|
||
|
|
torch.cuda.reset_peak_memory_stats()
|
||
|
|
|
||
|
|
policy, config = _build_policy(args, device)
|
||
|
|
batch = _build_batch(args, config, device)
|
||
|
|
|
||
|
|
optimizer = None
|
||
|
|
trainable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||
|
|
if args.optimizer != "none":
|
||
|
|
trainable = [p for p in policy.parameters() if p.requires_grad]
|
||
|
|
optimizer = torch.optim.AdamW(
|
||
|
|
trainable, lr=5e-5, fused=(args.optimizer == "adamw_fused")
|
||
|
|
)
|
||
|
|
|
||
|
|
for _ in range(args.warmup):
|
||
|
|
_step(policy, batch, optimizer)
|
||
|
|
torch.cuda.synchronize()
|
||
|
|
|
||
|
|
starter = torch.cuda.Event(enable_timing=True)
|
||
|
|
ender = torch.cuda.Event(enable_timing=True)
|
||
|
|
starter.record()
|
||
|
|
for _ in range(args.steps):
|
||
|
|
_step(policy, batch, optimizer)
|
||
|
|
ender.record()
|
||
|
|
torch.cuda.synchronize()
|
||
|
|
total_ms = starter.elapsed_time(ender)
|
||
|
|
step_ms = total_ms / args.steps
|
||
|
|
peak_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||
|
|
optim_gb = 0.0
|
||
|
|
if optimizer is not None:
|
||
|
|
for st in optimizer.state.values():
|
||
|
|
for v in st.values():
|
||
|
|
if torch.is_tensor(v):
|
||
|
|
optim_gb += v.numel() * v.element_size() / (1024**3)
|
||
|
|
|
||
|
|
liger_on = ",".join(k for k, v in liger_flags.items() if v) or "none"
|
||
|
|
name = (
|
||
|
|
f"{args.attn:>5} | BS={args.batch_size} | L={args.lang_tokens} | "
|
||
|
|
f"KI={args.knowledge_insulation} | GC={args.gradient_checkpointing} | "
|
||
|
|
f"compile={args.compile_model} | liger={liger_on} | opt={args.optimizer} | dtype={args.dtype}"
|
||
|
|
)
|
||
|
|
print(
|
||
|
|
f"{name}\n step_ms={step_ms:.1f} steps/sec={1000.0 / step_ms:.3f} "
|
||
|
|
f"peak_mem={peak_gb:.2f} GiB optim_state={optim_gb:.2f} GiB "
|
||
|
|
f"trainable_params={trainable_params / 1e9:.2f}B"
|
||
|
|
)
|
||
|
|
|
||
|
|
del policy, batch
|
||
|
|
gc.collect()
|
||
|
|
torch.cuda.empty_cache()
|
||
|
|
return 0
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
raise SystemExit(main())
|