Compare commits

...

3 Commits

Author SHA1 Message Date
pepijn
d70c810416 pi052: drop `use_hf_kernels` flag — always patch Liger kernels
The flag gated a process-global, idempotent Liger patch that swaps
in fused Triton rope / geglu / layer_norm kernels (~4.5 % step time
on H100, bench job 22161421). Since liger-kernel is now a hard
dependency of the loss path (``_shifted_lin_ce`` / ``_fast_lin_ce``
in ``modeling_pi052``), gating the same dep behind an opt-in flag
was redundant — every pi052 run pulls the wheel in either way.

* ``PI052Policy.__init__`` calls ``_enable_hf_kernels()``
  unconditionally; the function still degrades gracefully if the
  wheel happens to be missing (logs a warning, returns).
* Drop ``PI052Config.use_hf_kernels``; the bench numbers and the
  ``fused_linear_cross_entropy`` pointer to ``_shifted_lin_ce`` /
  ``_fast_lin_ce`` are kept as comments next to the docstring.
* Update the warning + ``_shifted_lin_ce`` lazy-import comment to
  drop stale ``use_hf_kernels`` / ``reduce-overhead`` references.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-26 11:47:49 +00:00
pepijn
4c3ddb1ff5 pi052: wire Liger fused linear CE + DDP-safe FAST tokenizer fit
* Replace ``_shifted_ce`` / ``_fast_ce`` with Liger's
  ``fused_linear_cross_entropy``: the ``(B, T, 257k)`` logits tensor
  is no longer materialised — the kernel chunks over the ``(B*T)``
  axis and computes matmul + softmax + CE in fused Triton blocks.
  ~30 % step speedup and ~12 GB of activation memory freed on the
  dual-CE pi052 recipe. All four call sites in
  ``_compute_all_losses_fused`` and ``_compute_text_and_fast_loss``
  updated; the ``.any().item()`` CPU sync is dropped so the loss
  path stays CUDA-graph-capturable.

* DDP-safe FAST tokenizer fit. The cache-hit sentinel previously
  looked for ``preprocessor_config.json`` but
  ``ProcessorMixin.save_pretrained`` writes ``processor_config.json``
  — every rank always cache-missed and re-fit, racing on writes and
  occasionally producing a stale ``.pyc`` that crashed
  ``AutoProcessor.from_pretrained`` with ``AttributeError:
  UniversalActionProcessor``. Fix the sentinel; gate the fit on the
  (local) main process; non-leader ranks poll the cache until the
  leader is done. Caught by job 22162549.

* New recipe ``subtask_mem_vqa_robocasa.yaml`` — subtask + memory +
  per-camera VQA over the three robocasa camera keys produced by the
  port pipeline (``robot0_agentview_left/right``, ``robot0_eye_in_hand``).
  The previously-shipped ``subtask_mem_vqa_speech.yaml`` references
  ``observation.images.front`` / ``wrist`` which don't exist in
  robocasa, so VQA never rendered.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-26 11:18:16 +00:00
pepijn
8615f3f613 annotate(vqa): tighten bbox + keypoint quality bar
Low-confidence VLM detections were producing many overlapping, loose
boxes per frame (oven + toaster oven + counter + drawer + ...) and
coarse keypoints, hurting downstream policy grounding. Two surgical
fixes:

- module_3_vqa prompt: cap bbox at most 3 high-confidence detections
  (prefer 1 tight box), require specific labels and ≤10% padding,
  allow empty detections list when nothing meets the bar; keypoint
  must be a single pixel-precise feature (handle / button / gripper
  tip) rather than a coarse "somewhere on object" point.
- run_hf_job: lower vlm.temperature 0.7 → 0.2. Bbox + keypoint are
  coordinate-regression tasks where sampling noise directly degrades
  localization; question phrasing still varies enough at 0.2.

No new config knobs — the count cap lives in the prompt since "top-N
by confidence" is best picked by the VLM itself. Validator already
accepts empty detections.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-26 08:31:37 +00:00
6 changed files with 274 additions and 76 deletions

View File

@@ -55,7 +55,11 @@ CMD = (
"--vlm.serve_ready_timeout_s=1800 " "--vlm.serve_ready_timeout_s=1800 "
"--vlm.client_concurrency=256 " "--vlm.client_concurrency=256 "
"--vlm.max_new_tokens=512 " "--vlm.max_new_tokens=512 "
"--vlm.temperature=0.7 " # Low temperature for VQA: bbox + keypoint are coordinate-regression
# tasks where sampling noise directly degrades localization
# (overlapping boxes, drifted points). 0.2 keeps the model decisive
# while still letting question/label phrasing vary across frames.
"--vlm.temperature=0.2 "
"--executor.episode_parallelism=64 " "--executor.episode_parallelism=64 "
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}' " "--vlm.chat_template_kwargs='{\"enable_thinking\": false}' "
# Whole-scene agentview is the right choice for subtask reasoning + # Whole-scene agentview is the right choice for subtask reasoning +

View File

@@ -5,15 +5,40 @@ pixel coordinates, keypoints, counts, attributes, and spatial relations.
The frame shows a robot working on: "{episode_task}". The frame shows a robot working on: "{episode_task}".
QUALITY BAR — read before answering:
- Only label objects you are highly confident about. If you are not
sure what an object is, do NOT include it. A short, certain answer
beats a long, speculative one.
- For coordinate-grounded answers (bbox, keypoint) only emit a label
when you can localize the object *tightly and precisely*. If the
object is occluded, ambiguous, off-frame, or you can't pin its
extent, return an empty detections list / pick a different object
rather than guessing.
- Prefer task-relevant objects (the thing the robot is manipulating
or interacting with) over background clutter.
Question types and the EXACT answer JSON shape required for each: Question types and the EXACT answer JSON shape required for each:
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy", bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
"bbox": [x1, y1, x2, y2]}}, ...]}} "bbox": [x1, y1, x2, y2]}}, ...]}}
bbox is in pixel coordinates (x_min, y_min, x_max, y_max). Pixel coordinates (x_min, y_min, x_max, y_max). Emit
AT MOST 3 detections, and *only* the highest-confidence
ones — 1 tight, certain detection is preferred over 3
loose ones. Each box must be tight (no >10% padding
around the object) and the label must be specific
("red mug" not "object"). Return an empty list if no
object meets the bar.
ECoT example: "a white cup [124, 25, 176, 113]". ECoT example: "a white cup [124, 25, 176, 113]".
keypoint => {{"label": "<point>", "point_format": "xy", keypoint => {{"label": "<point>", "point_format": "xy",
"point": [x, y]}} "point": [x, y]}}
Pick ONE high-confidence, precisely-localizable point
(e.g. a graspable handle, a button center, the gripper
tip). The point must land within a few pixels of the
feature. Do not emit a coarse "somewhere on the object"
point — pick a different question type if no such
point exists in this frame.
count => {{"label": "<obj>", "count": <int>, count => {{"label": "<obj>", "count": <int>,
"note": "<optional short note>"}} "note": "<optional short note>"}}

View File

@@ -0,0 +1,99 @@
# subtask_mem_vqa_robocasa — Hi-Robot blend tuned for RoboCasa cameras.
#
# Same supervision as ``subtask_mem.yaml`` (subtask + memory) plus
# camera-grounded VQA across the three RoboCasa camera keys produced
# by ``slurm_build_robocasa_composite_seen.py``:
#
# observation.images.robot0_agentview_left (left scene view)
# observation.images.robot0_agentview_right (right scene view)
# observation.images.robot0_eye_in_hand (wrist)
#
# The annotation pipeline (``examples/annotations/run_hf_job.py``) emits
# VQA per camera, so each anchor frame produces three (user, assistant)
# rows tagged with their source camera. Each VQA sub-recipe consumes
# the rows for one camera via ``camera=...`` resolver bindings.
#
# Spatial VQA targets (bbox / point) are rewritten from JSON to
# PaliGemma ``<locDDDD>`` tokens by ``_messages_vqa_to_loc`` —
# ``register_paligemma_loc_tokens`` already collapses them to single
# detection-vocab ids so the LM head learns the pretrained pointing /
# detection prior, not a 7-piece BPE salad.
#
# Interjections / spoken responses are intentionally absent — the
# annotation job runs with ``--interjections.enabled=false``.
blend:
high_level_subtask:
weight: 0.25
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
low_level_execution:
weight: 0.45
messages:
# Action expert is conditioned on the SUBTASK; at inference the
# high-level loop generates it via the LM head and feeds it here.
# ``stream: low_level`` flips ``predict_actions=True`` so the flow
# loss fires; subtask CE is owned by ``high_level_subtask``.
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
memory_update:
# Trained densely with ``active_at`` — every frame inside a subtask
# interval — so the (prior_memory, completed_subtask) → current_memory
# mapping is supervised against varied observations. The *when* to
# emit lives in the inference trigger (subtask_change), not the
# model. See ``subtask_mem.yaml`` for the long version of this note.
weight: 0.15
bindings:
prior_memory: "nth_prev(style=memory, offset=1)"
current_memory: "active_at(t, style=memory)"
completed_subtask: "nth_prev(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
ask_vqa_agentview_left:
weight: 0.05
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_left)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_left)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.robot0_agentview_left}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_agentview_right:
weight: 0.05
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_right)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_right)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.robot0_agentview_right}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
ask_vqa_wrist:
weight: 0.05
bindings:
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_eye_in_hand)"
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_eye_in_hand)"
messages:
- role: user
stream: high_level
if_present: vqa_query
content:
- {type: image, feature: observation.images.robot0_eye_in_hand}
- {type: text, text: "${vqa_query}"}
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}

View File

@@ -190,26 +190,13 @@ class PI052Config(PI05Config):
# commonly cited weight; set 0 to disable entirely. # commonly cited weight; set 0 to disable entirely.
text_ce_z_loss_weight: float = 1e-4 text_ce_z_loss_weight: float = 1e-4
# Fused kernels (Liger via HF kernels lib) --------------------------- # Liger Triton kernels (rope + geglu + layer_norm) are now patched
# Patches PaliGemma / Gemma / Siglip ops with Liger Triton kernels # unconditionally at model build time — see ``_enable_hf_kernels``
# before the model is built. Measured on H100 80GB at BS=16 / L=512 # in ``modeling_pi052``. The patch is process-global, idempotent
# with KI+GC on (bench job 22161421, see # and degrades gracefully if ``liger-kernel`` is missing. Measured
# ``examples/benchmark/bench_pi052_kernels.slurm``): # at -4.5% step time on H100 (bench job 22161421); peak memory
# # unchanged. ``fused_linear_cross_entropy`` ships separately via
# rope only → 2.5% step time # ``_shifted_lin_ce`` / ``_fast_lin_ce``.
# geglu only → 2.2% step time
# layer_norm only → 1.1% step time
# all three → 4.5% step time, peak_mem unchanged
#
# ``cross_entropy`` / ``fused_linear_cross_entropy`` are NOT enabled
# — pi052 calls ``F.cross_entropy`` directly and bypasses
# ``PaliGemmaForConditionalGeneration.forward``, so neither Liger
# patch fires without invasive model-code changes. Reserved for a
# follow-up.
use_hf_kernels: bool = False
"""If True, monkey-patch PaliGemma/Gemma/Siglip layers with Liger's
fused Triton kernels (rope + geglu + layer_norm). Off by default;
requires ``pip install liger-kernel``."""
def __post_init__(self) -> None: def __post_init__(self) -> None:
super().__post_init__() super().__post_init__()

View File

@@ -39,12 +39,21 @@ from __future__ import annotations
import hashlib import hashlib
import logging import logging
import os
import time
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Marker file the cache-hit check looks for. ``ProcessorMixin.save_pretrained``
# writes ``processor_config.json`` (NOT ``preprocessor_config.json`` —
# that's the image / feature-extractor convention). Centralised here so
# the cache-hit check and the rank-N readiness wait agree on the same
# sentinel.
_CACHE_SENTINEL = "processor_config.json"
def _dataset_signature( def _dataset_signature(
dataset_repo_id: str, dataset_repo_id: str,
@@ -111,7 +120,7 @@ def fit_fast_tokenizer(
sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples, chunk_size) sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples, chunk_size)
out_dir = cache_dir / sig out_dir = cache_dir / sig
if out_dir.exists() and (out_dir / "preprocessor_config.json").exists(): if out_dir.exists() and (out_dir / _CACHE_SENTINEL).exists():
logger.info( logger.info(
"FAST tokenizer cache hit: %s — re-using fitted tokenizer for " "FAST tokenizer cache hit: %s — re-using fitted tokenizer for "
"dataset=%s base=%s n_samples=%d", "dataset=%s base=%s n_samples=%d",
@@ -119,6 +128,32 @@ def fit_fast_tokenizer(
) )
return str(out_dir) return str(out_dir)
# DDP-safe fit: only the (local) main process actually fits + saves;
# other ranks poll the cache sentinel until the leader is done.
# Without this guard, all N ranks fit concurrently and race on
# ``save_pretrained`` + ``AutoProcessor.from_pretrained`` (the latter
# copies ``processing_action_tokenizer.py`` into ``HF_MODULES_CACHE``
# and compiles a ``.pyc`` — concurrent writers occasionally produce
# a stale / partial ``.pyc`` and the subsequent ``from .. import
# UniversalActionProcessor`` raises ``AttributeError``.
is_leader = (
int(os.environ.get("RANK", "0")) == 0
and int(os.environ.get("LOCAL_RANK", "0")) == 0
)
if not is_leader:
timeout_s = 1800.0 # 30 min — covers ~1024-sample fits on cold caches
start = time.monotonic()
while not (out_dir / _CACHE_SENTINEL).exists():
if time.monotonic() - start > timeout_s:
raise RuntimeError(
f"FAST tokenizer fit: non-leader rank timed out after "
f"{timeout_s:.0f}s waiting for {out_dir / _CACHE_SENTINEL}. "
"Leader rank likely crashed during the fit."
)
time.sleep(2.0)
logger.info("FAST tokenizer ready (leader populated cache): %s", out_dir)
return str(out_dir)
logger.info( logger.info(
"FAST tokenizer cache miss — fitting on dataset=%s " "FAST tokenizer cache miss — fitting on dataset=%s "
"base=%s n_samples=%d chunk_size=%d%s", "base=%s n_samples=%d chunk_size=%d%s",

View File

@@ -77,8 +77,9 @@ def _enable_hf_kernels() -> None:
from liger_kernel.transformers import apply_liger_kernel_to_paligemma # noqa: PLC0415 from liger_kernel.transformers import apply_liger_kernel_to_paligemma # noqa: PLC0415
except ImportError: except ImportError:
logger.warning( logger.warning(
"PI052: use_hf_kernels=True but liger-kernel is not installed; " "PI052: liger-kernel is not installed; skipping fused Triton "
"skipping. Install with `pip install liger-kernel`." "kernels (rope/geglu/layer_norm). Install with "
"``pip install liger-kernel`` for a ~4.5%% step speedup."
) )
return return
apply_liger_kernel_to_paligemma( apply_liger_kernel_to_paligemma(
@@ -106,35 +107,52 @@ def _mask_per_sample(per_sample: Tensor, predict_actions_t: Tensor | None) -> Te
return (per_sample * mask).sum() / mask.sum().clamp(min=1.0) return (per_sample * mask).sum() / mask.sum().clamp(min=1.0)
def _shifted_ce(logits: Tensor, labels: Tensor, z_loss_weight: float = 0.0) -> Tensor: def _shifted_lin_ce(
"""Next-token CE: hidden at t predicts label at t+1, ignore_index=-100. hidden: Tensor,
lm_head_weight: Tensor,
labels: Tensor,
z_loss_weight: float = 0.0,
) -> Tensor:
"""Liger-fused (hidden @ W.T → softmax → CE) on shifted labels.
Mean over non-ignored positions across the batch. Returns 0 cleanly Replaces the explicit ``lm_head(hidden) → F.cross_entropy(...)``
when no positions are supervised (clamp(min=1) on the denominator). pair with Liger's ``LigerFusedLinearCrossEntropyLoss``: the full
``(B, T, V)`` logits tensor is never materialised — the kernel
chunks over the (B*T) axis, computing matmul + logsumexp + CE
in fused Triton blocks. On a 257k-vocab head this saves ~10 GB
of activation memory per CE branch and ~30 % step time vs the
eager ``F.cross_entropy`` path.
When ``z_loss_weight > 0``, also adds PaLM-style z-loss Semantics:
(``z² · w``, where ``z = log Σ exp(logits)``) on every supervised * Shift convention identical to the eager version — hidden at
position. Penalises the log-partition function drifting away from position ``t`` predicts label at ``t+1``; ``ignore_index=-100``.
zero — without it, large-vocab models (PaliGemma is 257k) can let * No ``.any().item()`` sync — Liger returns 0.0 cleanly when
``logsumexp`` grow unboundedly while CE stays low, because uniform every label is ignored.
additive logit bias cancels in softmax. PaLM appendix B / Chinchilla * ``z_loss_weight`` maps directly to Liger's ``lse_square_scale``
report this is essential for stable large-vocab CE; cheap insurance (same ``z²·w`` formula on per-position logsumexp). Setting it
here especially with ``lm_head_lr_scale=5.0`` amplifying drift risk. to 0 disables the z-loss term at zero cost.
""" """
shift_logits = logits[:, :-1, :].contiguous() # Liger is imported lazily so the module still imports on machines
# without liger-kernel — the call site only fires from the training
# forward, which always pulls in the kernel.
from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: PLC0415
LigerFusedLinearCrossEntropyLoss,
)
shift_hidden = hidden[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous().long() shift_labels = labels[:, 1:].contiguous().long()
valid = shift_labels != -100 B, T_1, H = shift_hidden.shape
if not bool(valid.any().item()): flat_hidden = shift_hidden.reshape(B * T_1, H)
return shift_logits.sum() * 0.0 flat_labels = shift_labels.reshape(B * T_1)
valid_logits = shift_logits[valid] # Match the dtype the eager path used: cast hidden to the lm_head's
valid_labels = shift_labels[valid] # weight dtype so bf16 weights see bf16 activations.
ce = F.cross_entropy(valid_logits, valid_labels, reduction="mean") flat_hidden = flat_hidden.to(lm_head_weight.dtype)
if z_loss_weight <= 0.0: loss_fn = LigerFusedLinearCrossEntropyLoss(
return ce ignore_index=-100,
# PaLM z-loss: penalise (log Σ exp(logits))² per supervised position. lse_square_scale=float(z_loss_weight),
# ``logsumexp`` is numerically stable and shares the softmax kernel. reduction="mean",
z = torch.logsumexp(valid_logits, dim=-1) )
return ce + z_loss_weight * (z**2).mean() return loss_fn(lm_head_weight, flat_hidden, flat_labels)
def _mark_target_span_causal( def _mark_target_span_causal(
@@ -172,32 +190,48 @@ def _mark_target_span_causal(
return att return att
def _fast_ce( def _fast_lin_ce(
fast_logits: Tensor, hidden: Tensor,
lm_head_weight: Tensor,
action_tokens: Tensor, action_tokens: Tensor,
action_code_mask: Tensor, action_code_mask: Tensor,
predict_actions_t: Tensor | None, predict_actions_t: Tensor | None,
) -> Tensor: ) -> Tensor:
"""FAST action-code CE with token-span masking and per-sample action gating. """Liger-fused FAST action-code CE with span masking + sample gating.
``action_code_mask`` is true only on the discrete action-code tokens, Mirrors ``_shifted_lin_ce`` but with FAST-specific masking: only
excluding the BOS / "Action: " / delimiter wrapper. Samples whose the discrete action-code positions (``action_code_mask``) are
recipe sets ``predict_actions=False`` get all code positions masked supervised, and samples whose recipe sets ``predict_actions=False``
out via the per-sample gate. get all code positions masked. Masked positions are folded into
Liger's ``ignore_index=-100`` so the kernel skips them without
a CPU-side gather (which would synchronise + break CUDA graphs).
""" """
shift_logits = fast_logits[:, :-1, :].contiguous() from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: PLC0415
LigerFusedLinearCrossEntropyLoss,
)
shift_hidden = hidden[:, :-1, :].contiguous()
shift_targets = action_tokens[:, 1:].contiguous().long() shift_targets = action_tokens[:, 1:].contiguous().long()
shift_valid = action_code_mask[:, 1:].contiguous().bool() shift_valid = action_code_mask[:, 1:].contiguous().bool()
if predict_actions_t is not None: if predict_actions_t is not None:
sample_mask = predict_actions_t[:, None].expand_as(shift_valid) sample_mask = predict_actions_t[:, None].expand_as(shift_valid)
shift_valid = shift_valid & sample_mask shift_valid = shift_valid & sample_mask
if not bool(shift_valid.any().item()): # Fold the boolean mask into the target via ignore_index. No
return shift_logits.sum() * 0.0 # ``.any().item()`` sync — Liger returns 0.0 when every position
return F.cross_entropy( # is ignored, preserving graph capture for CUDA graphs.
shift_logits[shift_valid], shift_targets = torch.where(
shift_targets[shift_valid], shift_valid, shift_targets, torch.full_like(shift_targets, -100)
)
B, T_1, H = shift_hidden.shape
flat_hidden = shift_hidden.reshape(B * T_1, H).to(lm_head_weight.dtype)
flat_labels = shift_targets.reshape(B * T_1)
loss_fn = LigerFusedLinearCrossEntropyLoss(
ignore_index=-100,
reduction="mean", reduction="mean",
) )
return loss_fn(lm_head_weight, flat_hidden, flat_labels)
# ---------------------------------------------------------------------- # ----------------------------------------------------------------------
@@ -400,9 +434,10 @@ class PI052Policy(PI05Policy):
def __init__(self, config: PI052Config, **kwargs: Any) -> None: def __init__(self, config: PI052Config, **kwargs: Any) -> None:
# Patch ops BEFORE the backbone is built (super().__init__ below # Patch ops BEFORE the backbone is built (super().__init__ below
# constructs PaliGemmaWithExpertModel which instantiates the # constructs PaliGemmaWithExpertModel which instantiates the
# Gemma/Siglip layers we want to swap). # Gemma/Siglip layers we want to swap). Always-on — the patch
if getattr(config, "use_hf_kernels", False): # is process-global / idempotent and degrades gracefully if
_enable_hf_kernels() # liger-kernel is missing.
_enable_hf_kernels()
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
# ``PI05Policy.__init__`` zeroes the PaliGemma ``lm_head`` and # ``PI05Policy.__init__`` zeroes the PaliGemma ``lm_head`` and
@@ -726,9 +761,12 @@ class PI052Policy(PI05Policy):
text_hidden = prefix_out[:, -(fast_len + lang_len) : -fast_len, :] text_hidden = prefix_out[:, -(fast_len + lang_len) : -fast_len, :]
else: else:
text_hidden = prefix_out[:, -lang_len:, :] text_hidden = prefix_out[:, -lang_len:, :]
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype)) # Liger fused linear-CE: skip the explicit ``lm_head(...)``
text_loss = _shifted_ce( # materialisation; the kernel multiplies on-the-fly and
text_logits, # never holds the full (B, T, 257k) logits tensor.
text_loss = _shifted_lin_ce(
text_hidden,
lm_head.weight,
text_labels, text_labels,
z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0), z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0),
) )
@@ -736,8 +774,13 @@ class PI052Policy(PI05Policy):
fast_loss: Tensor | None = None fast_loss: Tensor | None = None
if fast_len > 0 and prefix_out is not None and action_code_mask is not None: if fast_len > 0 and prefix_out is not None and action_code_mask is not None:
fast_hidden = prefix_out[:, -fast_len:, :] fast_hidden = prefix_out[:, -fast_len:, :]
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype)) fast_loss = _fast_lin_ce(
fast_loss = _fast_ce(fast_logits, action_tokens, action_code_mask, predict_actions_t) fast_hidden,
lm_head.weight,
action_tokens,
action_code_mask,
predict_actions_t,
)
return flow_loss, text_loss, fast_loss return flow_loss, text_loss, fast_loss
@@ -830,9 +873,9 @@ class PI052Policy(PI05Policy):
text_hidden = vlm_out[:, -(fast_len + lang_len):-fast_len, :] text_hidden = vlm_out[:, -(fast_len + lang_len):-fast_len, :]
else: else:
text_hidden = vlm_out[:, -lang_len:, :] text_hidden = vlm_out[:, -lang_len:, :]
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype)) text_loss = _shifted_lin_ce(
text_loss = _shifted_ce( text_hidden,
text_logits, lm_head.weight,
text_labels, text_labels,
z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0), z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0),
) )
@@ -844,8 +887,13 @@ class PI052Policy(PI05Policy):
and fast_len > 0 and fast_len > 0
): ):
fast_hidden = vlm_out[:, -fast_len:, :] fast_hidden = vlm_out[:, -fast_len:, :]
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype)) fast_loss = _fast_lin_ce(
fast_loss = _fast_ce(fast_logits, action_tokens, action_code_mask, predict_actions_t) fast_hidden,
lm_head.weight,
action_tokens,
action_code_mask,
predict_actions_t,
)
return text_loss, fast_loss return text_loss, fast_loss