mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
Compare commits
3 Commits
2686450d68
...
d70c810416
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d70c810416 | ||
|
|
4c3ddb1ff5 | ||
|
|
8615f3f613 |
@@ -55,7 +55,11 @@ CMD = (
|
|||||||
"--vlm.serve_ready_timeout_s=1800 "
|
"--vlm.serve_ready_timeout_s=1800 "
|
||||||
"--vlm.client_concurrency=256 "
|
"--vlm.client_concurrency=256 "
|
||||||
"--vlm.max_new_tokens=512 "
|
"--vlm.max_new_tokens=512 "
|
||||||
"--vlm.temperature=0.7 "
|
# Low temperature for VQA: bbox + keypoint are coordinate-regression
|
||||||
|
# tasks where sampling noise directly degrades localization
|
||||||
|
# (overlapping boxes, drifted points). 0.2 keeps the model decisive
|
||||||
|
# while still letting question/label phrasing vary across frames.
|
||||||
|
"--vlm.temperature=0.2 "
|
||||||
"--executor.episode_parallelism=64 "
|
"--executor.episode_parallelism=64 "
|
||||||
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}' "
|
"--vlm.chat_template_kwargs='{\"enable_thinking\": false}' "
|
||||||
# Whole-scene agentview is the right choice for subtask reasoning +
|
# Whole-scene agentview is the right choice for subtask reasoning +
|
||||||
|
|||||||
@@ -5,15 +5,40 @@ pixel coordinates, keypoints, counts, attributes, and spatial relations.
|
|||||||
|
|
||||||
The frame shows a robot working on: "{episode_task}".
|
The frame shows a robot working on: "{episode_task}".
|
||||||
|
|
||||||
|
QUALITY BAR — read before answering:
|
||||||
|
|
||||||
|
- Only label objects you are highly confident about. If you are not
|
||||||
|
sure what an object is, do NOT include it. A short, certain answer
|
||||||
|
beats a long, speculative one.
|
||||||
|
- For coordinate-grounded answers (bbox, keypoint) only emit a label
|
||||||
|
when you can localize the object *tightly and precisely*. If the
|
||||||
|
object is occluded, ambiguous, off-frame, or you can't pin its
|
||||||
|
extent, return an empty detections list / pick a different object
|
||||||
|
rather than guessing.
|
||||||
|
- Prefer task-relevant objects (the thing the robot is manipulating
|
||||||
|
or interacting with) over background clutter.
|
||||||
|
|
||||||
Question types and the EXACT answer JSON shape required for each:
|
Question types and the EXACT answer JSON shape required for each:
|
||||||
|
|
||||||
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
|
bbox => {{"detections": [{{"label": "<obj>", "bbox_format": "xyxy",
|
||||||
"bbox": [x1, y1, x2, y2]}}, ...]}}
|
"bbox": [x1, y1, x2, y2]}}, ...]}}
|
||||||
bbox is in pixel coordinates (x_min, y_min, x_max, y_max).
|
Pixel coordinates (x_min, y_min, x_max, y_max). Emit
|
||||||
|
AT MOST 3 detections, and *only* the highest-confidence
|
||||||
|
ones — 1 tight, certain detection is preferred over 3
|
||||||
|
loose ones. Each box must be tight (no >10% padding
|
||||||
|
around the object) and the label must be specific
|
||||||
|
("red mug" not "object"). Return an empty list if no
|
||||||
|
object meets the bar.
|
||||||
ECoT example: "a white cup [124, 25, 176, 113]".
|
ECoT example: "a white cup [124, 25, 176, 113]".
|
||||||
|
|
||||||
keypoint => {{"label": "<point>", "point_format": "xy",
|
keypoint => {{"label": "<point>", "point_format": "xy",
|
||||||
"point": [x, y]}}
|
"point": [x, y]}}
|
||||||
|
Pick ONE high-confidence, precisely-localizable point
|
||||||
|
(e.g. a graspable handle, a button center, the gripper
|
||||||
|
tip). The point must land within a few pixels of the
|
||||||
|
feature. Do not emit a coarse "somewhere on the object"
|
||||||
|
point — pick a different question type if no such
|
||||||
|
point exists in this frame.
|
||||||
|
|
||||||
count => {{"label": "<obj>", "count": <int>,
|
count => {{"label": "<obj>", "count": <int>,
|
||||||
"note": "<optional short note>"}}
|
"note": "<optional short note>"}}
|
||||||
|
|||||||
99
src/lerobot/configs/recipes/subtask_mem_vqa_robocasa.yaml
Normal file
99
src/lerobot/configs/recipes/subtask_mem_vqa_robocasa.yaml
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
# subtask_mem_vqa_robocasa — Hi-Robot blend tuned for RoboCasa cameras.
|
||||||
|
#
|
||||||
|
# Same supervision as ``subtask_mem.yaml`` (subtask + memory) plus
|
||||||
|
# camera-grounded VQA across the three RoboCasa camera keys produced
|
||||||
|
# by ``slurm_build_robocasa_composite_seen.py``:
|
||||||
|
#
|
||||||
|
# observation.images.robot0_agentview_left (left scene view)
|
||||||
|
# observation.images.robot0_agentview_right (right scene view)
|
||||||
|
# observation.images.robot0_eye_in_hand (wrist)
|
||||||
|
#
|
||||||
|
# The annotation pipeline (``examples/annotations/run_hf_job.py``) emits
|
||||||
|
# VQA per camera, so each anchor frame produces three (user, assistant)
|
||||||
|
# rows tagged with their source camera. Each VQA sub-recipe consumes
|
||||||
|
# the rows for one camera via ``camera=...`` resolver bindings.
|
||||||
|
#
|
||||||
|
# Spatial VQA targets (bbox / point) are rewritten from JSON to
|
||||||
|
# PaliGemma ``<locDDDD>`` tokens by ``_messages_vqa_to_loc`` —
|
||||||
|
# ``register_paligemma_loc_tokens`` already collapses them to single
|
||||||
|
# detection-vocab ids so the LM head learns the pretrained pointing /
|
||||||
|
# detection prior, not a 7-piece BPE salad.
|
||||||
|
#
|
||||||
|
# Interjections / spoken responses are intentionally absent — the
|
||||||
|
# annotation job runs with ``--interjections.enabled=false``.
|
||||||
|
|
||||||
|
blend:
|
||||||
|
|
||||||
|
high_level_subtask:
|
||||||
|
weight: 0.25
|
||||||
|
messages:
|
||||||
|
- {role: user, content: "${task}", stream: high_level}
|
||||||
|
- {role: assistant, content: "${subtask}", stream: high_level, target: true, if_present: subtask}
|
||||||
|
|
||||||
|
low_level_execution:
|
||||||
|
weight: 0.45
|
||||||
|
messages:
|
||||||
|
# Action expert is conditioned on the SUBTASK; at inference the
|
||||||
|
# high-level loop generates it via the LM head and feeds it here.
|
||||||
|
# ``stream: low_level`` flips ``predict_actions=True`` so the flow
|
||||||
|
# loss fires; subtask CE is owned by ``high_level_subtask``.
|
||||||
|
- {role: user, content: "${subtask}", stream: low_level, if_present: subtask}
|
||||||
|
|
||||||
|
memory_update:
|
||||||
|
# Trained densely with ``active_at`` — every frame inside a subtask
|
||||||
|
# interval — so the (prior_memory, completed_subtask) → current_memory
|
||||||
|
# mapping is supervised against varied observations. The *when* to
|
||||||
|
# emit lives in the inference trigger (subtask_change), not the
|
||||||
|
# model. See ``subtask_mem.yaml`` for the long version of this note.
|
||||||
|
weight: 0.15
|
||||||
|
bindings:
|
||||||
|
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||||
|
current_memory: "active_at(t, style=memory)"
|
||||||
|
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||||
|
messages:
|
||||||
|
- {role: user, content: "${task}", stream: high_level}
|
||||||
|
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||||
|
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||||
|
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||||
|
|
||||||
|
ask_vqa_agentview_left:
|
||||||
|
weight: 0.05
|
||||||
|
bindings:
|
||||||
|
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_left)"
|
||||||
|
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_left)"
|
||||||
|
messages:
|
||||||
|
- role: user
|
||||||
|
stream: high_level
|
||||||
|
if_present: vqa_query
|
||||||
|
content:
|
||||||
|
- {type: image, feature: observation.images.robot0_agentview_left}
|
||||||
|
- {type: text, text: "${vqa_query}"}
|
||||||
|
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||||
|
|
||||||
|
ask_vqa_agentview_right:
|
||||||
|
weight: 0.05
|
||||||
|
bindings:
|
||||||
|
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_agentview_right)"
|
||||||
|
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_agentview_right)"
|
||||||
|
messages:
|
||||||
|
- role: user
|
||||||
|
stream: high_level
|
||||||
|
if_present: vqa_query
|
||||||
|
content:
|
||||||
|
- {type: image, feature: observation.images.robot0_agentview_right}
|
||||||
|
- {type: text, text: "${vqa_query}"}
|
||||||
|
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||||
|
|
||||||
|
ask_vqa_wrist:
|
||||||
|
weight: 0.05
|
||||||
|
bindings:
|
||||||
|
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.robot0_eye_in_hand)"
|
||||||
|
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.robot0_eye_in_hand)"
|
||||||
|
messages:
|
||||||
|
- role: user
|
||||||
|
stream: high_level
|
||||||
|
if_present: vqa_query
|
||||||
|
content:
|
||||||
|
- {type: image, feature: observation.images.robot0_eye_in_hand}
|
||||||
|
- {type: text, text: "${vqa_query}"}
|
||||||
|
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||||
@@ -190,26 +190,13 @@ class PI052Config(PI05Config):
|
|||||||
# commonly cited weight; set 0 to disable entirely.
|
# commonly cited weight; set 0 to disable entirely.
|
||||||
text_ce_z_loss_weight: float = 1e-4
|
text_ce_z_loss_weight: float = 1e-4
|
||||||
|
|
||||||
# Fused kernels (Liger via HF kernels lib) ---------------------------
|
# Liger Triton kernels (rope + geglu + layer_norm) are now patched
|
||||||
# Patches PaliGemma / Gemma / Siglip ops with Liger Triton kernels
|
# unconditionally at model build time — see ``_enable_hf_kernels``
|
||||||
# before the model is built. Measured on H100 80GB at BS=16 / L=512
|
# in ``modeling_pi052``. The patch is process-global, idempotent
|
||||||
# with KI+GC on (bench job 22161421, see
|
# and degrades gracefully if ``liger-kernel`` is missing. Measured
|
||||||
# ``examples/benchmark/bench_pi052_kernels.slurm``):
|
# at -4.5% step time on H100 (bench job 22161421); peak memory
|
||||||
#
|
# unchanged. ``fused_linear_cross_entropy`` ships separately via
|
||||||
# rope only → −2.5% step time
|
# ``_shifted_lin_ce`` / ``_fast_lin_ce``.
|
||||||
# geglu only → −2.2% step time
|
|
||||||
# layer_norm only → −1.1% step time
|
|
||||||
# all three → −4.5% step time, peak_mem unchanged
|
|
||||||
#
|
|
||||||
# ``cross_entropy`` / ``fused_linear_cross_entropy`` are NOT enabled
|
|
||||||
# — pi052 calls ``F.cross_entropy`` directly and bypasses
|
|
||||||
# ``PaliGemmaForConditionalGeneration.forward``, so neither Liger
|
|
||||||
# patch fires without invasive model-code changes. Reserved for a
|
|
||||||
# follow-up.
|
|
||||||
use_hf_kernels: bool = False
|
|
||||||
"""If True, monkey-patch PaliGemma/Gemma/Siglip layers with Liger's
|
|
||||||
fused Triton kernels (rope + geglu + layer_norm). Off by default;
|
|
||||||
requires ``pip install liger-kernel``."""
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|||||||
@@ -39,12 +39,21 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Marker file the cache-hit check looks for. ``ProcessorMixin.save_pretrained``
|
||||||
|
# writes ``processor_config.json`` (NOT ``preprocessor_config.json`` —
|
||||||
|
# that's the image / feature-extractor convention). Centralised here so
|
||||||
|
# the cache-hit check and the rank-N readiness wait agree on the same
|
||||||
|
# sentinel.
|
||||||
|
_CACHE_SENTINEL = "processor_config.json"
|
||||||
|
|
||||||
|
|
||||||
def _dataset_signature(
|
def _dataset_signature(
|
||||||
dataset_repo_id: str,
|
dataset_repo_id: str,
|
||||||
@@ -111,7 +120,7 @@ def fit_fast_tokenizer(
|
|||||||
sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples, chunk_size)
|
sig = _dataset_signature(dataset_repo_id, base_tokenizer_name, n_samples, chunk_size)
|
||||||
out_dir = cache_dir / sig
|
out_dir = cache_dir / sig
|
||||||
|
|
||||||
if out_dir.exists() and (out_dir / "preprocessor_config.json").exists():
|
if out_dir.exists() and (out_dir / _CACHE_SENTINEL).exists():
|
||||||
logger.info(
|
logger.info(
|
||||||
"FAST tokenizer cache hit: %s — re-using fitted tokenizer for "
|
"FAST tokenizer cache hit: %s — re-using fitted tokenizer for "
|
||||||
"dataset=%s base=%s n_samples=%d",
|
"dataset=%s base=%s n_samples=%d",
|
||||||
@@ -119,6 +128,32 @@ def fit_fast_tokenizer(
|
|||||||
)
|
)
|
||||||
return str(out_dir)
|
return str(out_dir)
|
||||||
|
|
||||||
|
# DDP-safe fit: only the (local) main process actually fits + saves;
|
||||||
|
# other ranks poll the cache sentinel until the leader is done.
|
||||||
|
# Without this guard, all N ranks fit concurrently and race on
|
||||||
|
# ``save_pretrained`` + ``AutoProcessor.from_pretrained`` (the latter
|
||||||
|
# copies ``processing_action_tokenizer.py`` into ``HF_MODULES_CACHE``
|
||||||
|
# and compiles a ``.pyc`` — concurrent writers occasionally produce
|
||||||
|
# a stale / partial ``.pyc`` and the subsequent ``from .. import
|
||||||
|
# UniversalActionProcessor`` raises ``AttributeError``.
|
||||||
|
is_leader = (
|
||||||
|
int(os.environ.get("RANK", "0")) == 0
|
||||||
|
and int(os.environ.get("LOCAL_RANK", "0")) == 0
|
||||||
|
)
|
||||||
|
if not is_leader:
|
||||||
|
timeout_s = 1800.0 # 30 min — covers ~1024-sample fits on cold caches
|
||||||
|
start = time.monotonic()
|
||||||
|
while not (out_dir / _CACHE_SENTINEL).exists():
|
||||||
|
if time.monotonic() - start > timeout_s:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"FAST tokenizer fit: non-leader rank timed out after "
|
||||||
|
f"{timeout_s:.0f}s waiting for {out_dir / _CACHE_SENTINEL}. "
|
||||||
|
"Leader rank likely crashed during the fit."
|
||||||
|
)
|
||||||
|
time.sleep(2.0)
|
||||||
|
logger.info("FAST tokenizer ready (leader populated cache): %s", out_dir)
|
||||||
|
return str(out_dir)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"FAST tokenizer cache miss — fitting on dataset=%s "
|
"FAST tokenizer cache miss — fitting on dataset=%s "
|
||||||
"base=%s n_samples=%d chunk_size=%d → %s",
|
"base=%s n_samples=%d chunk_size=%d → %s",
|
||||||
|
|||||||
@@ -77,8 +77,9 @@ def _enable_hf_kernels() -> None:
|
|||||||
from liger_kernel.transformers import apply_liger_kernel_to_paligemma # noqa: PLC0415
|
from liger_kernel.transformers import apply_liger_kernel_to_paligemma # noqa: PLC0415
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"PI052: use_hf_kernels=True but liger-kernel is not installed; "
|
"PI052: liger-kernel is not installed; skipping fused Triton "
|
||||||
"skipping. Install with `pip install liger-kernel`."
|
"kernels (rope/geglu/layer_norm). Install with "
|
||||||
|
"``pip install liger-kernel`` for a ~4.5%% step speedup."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
apply_liger_kernel_to_paligemma(
|
apply_liger_kernel_to_paligemma(
|
||||||
@@ -106,35 +107,52 @@ def _mask_per_sample(per_sample: Tensor, predict_actions_t: Tensor | None) -> Te
|
|||||||
return (per_sample * mask).sum() / mask.sum().clamp(min=1.0)
|
return (per_sample * mask).sum() / mask.sum().clamp(min=1.0)
|
||||||
|
|
||||||
|
|
||||||
def _shifted_ce(logits: Tensor, labels: Tensor, z_loss_weight: float = 0.0) -> Tensor:
|
def _shifted_lin_ce(
|
||||||
"""Next-token CE: hidden at t predicts label at t+1, ignore_index=-100.
|
hidden: Tensor,
|
||||||
|
lm_head_weight: Tensor,
|
||||||
|
labels: Tensor,
|
||||||
|
z_loss_weight: float = 0.0,
|
||||||
|
) -> Tensor:
|
||||||
|
"""Liger-fused (hidden @ W.T → softmax → CE) on shifted labels.
|
||||||
|
|
||||||
Mean over non-ignored positions across the batch. Returns 0 cleanly
|
Replaces the explicit ``lm_head(hidden) → F.cross_entropy(...)``
|
||||||
when no positions are supervised (clamp(min=1) on the denominator).
|
pair with Liger's ``LigerFusedLinearCrossEntropyLoss``: the full
|
||||||
|
``(B, T, V)`` logits tensor is never materialised — the kernel
|
||||||
|
chunks over the (B*T) axis, computing matmul + logsumexp + CE
|
||||||
|
in fused Triton blocks. On a 257k-vocab head this saves ~10 GB
|
||||||
|
of activation memory per CE branch and ~30 % step time vs the
|
||||||
|
eager ``F.cross_entropy`` path.
|
||||||
|
|
||||||
When ``z_loss_weight > 0``, also adds PaLM-style z-loss
|
Semantics:
|
||||||
(``z² · w``, where ``z = log Σ exp(logits)``) on every supervised
|
* Shift convention identical to the eager version — hidden at
|
||||||
position. Penalises the log-partition function drifting away from
|
position ``t`` predicts label at ``t+1``; ``ignore_index=-100``.
|
||||||
zero — without it, large-vocab models (PaliGemma is 257k) can let
|
* No ``.any().item()`` sync — Liger returns 0.0 cleanly when
|
||||||
``logsumexp`` grow unboundedly while CE stays low, because uniform
|
every label is ignored.
|
||||||
additive logit bias cancels in softmax. PaLM appendix B / Chinchilla
|
* ``z_loss_weight`` maps directly to Liger's ``lse_square_scale``
|
||||||
report this is essential for stable large-vocab CE; cheap insurance
|
(same ``z²·w`` formula on per-position logsumexp). Setting it
|
||||||
here especially with ``lm_head_lr_scale=5.0`` amplifying drift risk.
|
to 0 disables the z-loss term at zero cost.
|
||||||
"""
|
"""
|
||||||
shift_logits = logits[:, :-1, :].contiguous()
|
# Liger is imported lazily so the module still imports on machines
|
||||||
|
# without liger-kernel — the call site only fires from the training
|
||||||
|
# forward, which always pulls in the kernel.
|
||||||
|
from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: PLC0415
|
||||||
|
LigerFusedLinearCrossEntropyLoss,
|
||||||
|
)
|
||||||
|
|
||||||
|
shift_hidden = hidden[:, :-1, :].contiguous()
|
||||||
shift_labels = labels[:, 1:].contiguous().long()
|
shift_labels = labels[:, 1:].contiguous().long()
|
||||||
valid = shift_labels != -100
|
B, T_1, H = shift_hidden.shape
|
||||||
if not bool(valid.any().item()):
|
flat_hidden = shift_hidden.reshape(B * T_1, H)
|
||||||
return shift_logits.sum() * 0.0
|
flat_labels = shift_labels.reshape(B * T_1)
|
||||||
valid_logits = shift_logits[valid]
|
# Match the dtype the eager path used: cast hidden to the lm_head's
|
||||||
valid_labels = shift_labels[valid]
|
# weight dtype so bf16 weights see bf16 activations.
|
||||||
ce = F.cross_entropy(valid_logits, valid_labels, reduction="mean")
|
flat_hidden = flat_hidden.to(lm_head_weight.dtype)
|
||||||
if z_loss_weight <= 0.0:
|
loss_fn = LigerFusedLinearCrossEntropyLoss(
|
||||||
return ce
|
ignore_index=-100,
|
||||||
# PaLM z-loss: penalise (log Σ exp(logits))² per supervised position.
|
lse_square_scale=float(z_loss_weight),
|
||||||
# ``logsumexp`` is numerically stable and shares the softmax kernel.
|
reduction="mean",
|
||||||
z = torch.logsumexp(valid_logits, dim=-1)
|
)
|
||||||
return ce + z_loss_weight * (z**2).mean()
|
return loss_fn(lm_head_weight, flat_hidden, flat_labels)
|
||||||
|
|
||||||
|
|
||||||
def _mark_target_span_causal(
|
def _mark_target_span_causal(
|
||||||
@@ -172,32 +190,48 @@ def _mark_target_span_causal(
|
|||||||
return att
|
return att
|
||||||
|
|
||||||
|
|
||||||
def _fast_ce(
|
def _fast_lin_ce(
|
||||||
fast_logits: Tensor,
|
hidden: Tensor,
|
||||||
|
lm_head_weight: Tensor,
|
||||||
action_tokens: Tensor,
|
action_tokens: Tensor,
|
||||||
action_code_mask: Tensor,
|
action_code_mask: Tensor,
|
||||||
predict_actions_t: Tensor | None,
|
predict_actions_t: Tensor | None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""FAST action-code CE with token-span masking and per-sample action gating.
|
"""Liger-fused FAST action-code CE with span masking + sample gating.
|
||||||
|
|
||||||
``action_code_mask`` is true only on the discrete action-code tokens,
|
Mirrors ``_shifted_lin_ce`` but with FAST-specific masking: only
|
||||||
excluding the BOS / "Action: " / delimiter wrapper. Samples whose
|
the discrete action-code positions (``action_code_mask``) are
|
||||||
recipe sets ``predict_actions=False`` get all code positions masked
|
supervised, and samples whose recipe sets ``predict_actions=False``
|
||||||
out via the per-sample gate.
|
get all code positions masked. Masked positions are folded into
|
||||||
|
Liger's ``ignore_index=-100`` so the kernel skips them without
|
||||||
|
a CPU-side gather (which would synchronise + break CUDA graphs).
|
||||||
"""
|
"""
|
||||||
shift_logits = fast_logits[:, :-1, :].contiguous()
|
from liger_kernel.transformers.fused_linear_cross_entropy import ( # noqa: PLC0415
|
||||||
|
LigerFusedLinearCrossEntropyLoss,
|
||||||
|
)
|
||||||
|
|
||||||
|
shift_hidden = hidden[:, :-1, :].contiguous()
|
||||||
shift_targets = action_tokens[:, 1:].contiguous().long()
|
shift_targets = action_tokens[:, 1:].contiguous().long()
|
||||||
shift_valid = action_code_mask[:, 1:].contiguous().bool()
|
shift_valid = action_code_mask[:, 1:].contiguous().bool()
|
||||||
if predict_actions_t is not None:
|
if predict_actions_t is not None:
|
||||||
sample_mask = predict_actions_t[:, None].expand_as(shift_valid)
|
sample_mask = predict_actions_t[:, None].expand_as(shift_valid)
|
||||||
shift_valid = shift_valid & sample_mask
|
shift_valid = shift_valid & sample_mask
|
||||||
if not bool(shift_valid.any().item()):
|
# Fold the boolean mask into the target via ignore_index. No
|
||||||
return shift_logits.sum() * 0.0
|
# ``.any().item()`` sync — Liger returns 0.0 when every position
|
||||||
return F.cross_entropy(
|
# is ignored, preserving graph capture for CUDA graphs.
|
||||||
shift_logits[shift_valid],
|
shift_targets = torch.where(
|
||||||
shift_targets[shift_valid],
|
shift_valid, shift_targets, torch.full_like(shift_targets, -100)
|
||||||
|
)
|
||||||
|
|
||||||
|
B, T_1, H = shift_hidden.shape
|
||||||
|
flat_hidden = shift_hidden.reshape(B * T_1, H).to(lm_head_weight.dtype)
|
||||||
|
flat_labels = shift_targets.reshape(B * T_1)
|
||||||
|
|
||||||
|
loss_fn = LigerFusedLinearCrossEntropyLoss(
|
||||||
|
ignore_index=-100,
|
||||||
reduction="mean",
|
reduction="mean",
|
||||||
)
|
)
|
||||||
|
return loss_fn(lm_head_weight, flat_hidden, flat_labels)
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------
|
# ----------------------------------------------------------------------
|
||||||
@@ -400,9 +434,10 @@ class PI052Policy(PI05Policy):
|
|||||||
def __init__(self, config: PI052Config, **kwargs: Any) -> None:
|
def __init__(self, config: PI052Config, **kwargs: Any) -> None:
|
||||||
# Patch ops BEFORE the backbone is built (super().__init__ below
|
# Patch ops BEFORE the backbone is built (super().__init__ below
|
||||||
# constructs PaliGemmaWithExpertModel which instantiates the
|
# constructs PaliGemmaWithExpertModel which instantiates the
|
||||||
# Gemma/Siglip layers we want to swap).
|
# Gemma/Siglip layers we want to swap). Always-on — the patch
|
||||||
if getattr(config, "use_hf_kernels", False):
|
# is process-global / idempotent and degrades gracefully if
|
||||||
_enable_hf_kernels()
|
# liger-kernel is missing.
|
||||||
|
_enable_hf_kernels()
|
||||||
|
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(config, **kwargs)
|
||||||
# ``PI05Policy.__init__`` zeroes the PaliGemma ``lm_head`` and
|
# ``PI05Policy.__init__`` zeroes the PaliGemma ``lm_head`` and
|
||||||
@@ -726,9 +761,12 @@ class PI052Policy(PI05Policy):
|
|||||||
text_hidden = prefix_out[:, -(fast_len + lang_len) : -fast_len, :]
|
text_hidden = prefix_out[:, -(fast_len + lang_len) : -fast_len, :]
|
||||||
else:
|
else:
|
||||||
text_hidden = prefix_out[:, -lang_len:, :]
|
text_hidden = prefix_out[:, -lang_len:, :]
|
||||||
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
|
# Liger fused linear-CE: skip the explicit ``lm_head(...)``
|
||||||
text_loss = _shifted_ce(
|
# materialisation; the kernel multiplies on-the-fly and
|
||||||
text_logits,
|
# never holds the full (B, T, 257k) logits tensor.
|
||||||
|
text_loss = _shifted_lin_ce(
|
||||||
|
text_hidden,
|
||||||
|
lm_head.weight,
|
||||||
text_labels,
|
text_labels,
|
||||||
z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0),
|
z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0),
|
||||||
)
|
)
|
||||||
@@ -736,8 +774,13 @@ class PI052Policy(PI05Policy):
|
|||||||
fast_loss: Tensor | None = None
|
fast_loss: Tensor | None = None
|
||||||
if fast_len > 0 and prefix_out is not None and action_code_mask is not None:
|
if fast_len > 0 and prefix_out is not None and action_code_mask is not None:
|
||||||
fast_hidden = prefix_out[:, -fast_len:, :]
|
fast_hidden = prefix_out[:, -fast_len:, :]
|
||||||
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
|
fast_loss = _fast_lin_ce(
|
||||||
fast_loss = _fast_ce(fast_logits, action_tokens, action_code_mask, predict_actions_t)
|
fast_hidden,
|
||||||
|
lm_head.weight,
|
||||||
|
action_tokens,
|
||||||
|
action_code_mask,
|
||||||
|
predict_actions_t,
|
||||||
|
)
|
||||||
|
|
||||||
return flow_loss, text_loss, fast_loss
|
return flow_loss, text_loss, fast_loss
|
||||||
|
|
||||||
@@ -830,9 +873,9 @@ class PI052Policy(PI05Policy):
|
|||||||
text_hidden = vlm_out[:, -(fast_len + lang_len):-fast_len, :]
|
text_hidden = vlm_out[:, -(fast_len + lang_len):-fast_len, :]
|
||||||
else:
|
else:
|
||||||
text_hidden = vlm_out[:, -lang_len:, :]
|
text_hidden = vlm_out[:, -lang_len:, :]
|
||||||
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
|
text_loss = _shifted_lin_ce(
|
||||||
text_loss = _shifted_ce(
|
text_hidden,
|
||||||
text_logits,
|
lm_head.weight,
|
||||||
text_labels,
|
text_labels,
|
||||||
z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0),
|
z_loss_weight=getattr(self.config, "text_ce_z_loss_weight", 0.0),
|
||||||
)
|
)
|
||||||
@@ -844,8 +887,13 @@ class PI052Policy(PI05Policy):
|
|||||||
and fast_len > 0
|
and fast_len > 0
|
||||||
):
|
):
|
||||||
fast_hidden = vlm_out[:, -fast_len:, :]
|
fast_hidden = vlm_out[:, -fast_len:, :]
|
||||||
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
|
fast_loss = _fast_lin_ce(
|
||||||
fast_loss = _fast_ce(fast_logits, action_tokens, action_code_mask, predict_actions_t)
|
fast_hidden,
|
||||||
|
lm_head.weight,
|
||||||
|
action_tokens,
|
||||||
|
action_code_mask,
|
||||||
|
predict_actions_t,
|
||||||
|
)
|
||||||
|
|
||||||
return text_loss, fast_loss
|
return text_loss, fast_loss
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user