pi052(debug): drop misleading inference/parity dump from text preds

The first-token parity check re-tokenized the decoded (stripped) inference
string, so the leading-space SentencePiece variant always mismatched the
training argmax — a false "DIVERGED" alarm. Remove the autoregressive
inference print and parity comparison (and the now-dead per-sample
select_message generation), keeping only the prompt, ground-truth target,
and teacher-forced argmax accuracy.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-06-04 13:32:44 +00:00
parent 23419026d5
commit e660a51e78
2 changed files with 13 additions and 77 deletions

View File

@@ -1009,60 +1009,11 @@ class PI052Policy(PI05Policy):
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
preds = text_logits.argmax(dim=-1)
# Train/inference parity check — run select_message on the
# *same* prompt prefix (the language up to but not including
# the supervised span) and capture the auto-regressive
# generation. The first generated token MUST match the
# training-side argmax at the prompt-end position (both are
# ``argmax lm_head(h_last_prompt)`` over identical context);
# any divergence is a parity bug (mask, dtype, KI routing
# difference). Later tokens can diverge because training
# uses teacher forcing while inference free-runs.
inference_outputs: list[dict[str, Any]] = []
for s in range(n):
row_labels = sub_labels[s]
sup_pos = (row_labels != -100).nonzero(as_tuple=True)[0]
if sup_pos.numel() == 0:
inference_outputs.append({"first_token": None, "decoded": ""})
continue
first_sup = int(sup_pos[0].item())
# Build a single-sample batch by *truncating* the token
# sequence to the prompt-only portion (length == first_sup),
# not by zero-masking. ``select_message`` reads the
# prompt-end hidden state via ``vlm_out[:, -1:]`` — the
# *last position* of the prefix — so a padded sequence
# would make it read a padding-token hidden state
# (PaliGemma's prior on those happens to be ``<loc>``,
# which would falsely flag a parity diverge). The real
# runtime feeds ``tokenizer(prompt)`` without padding,
# so we mirror that here.
prompt_tokens = sub[OBS_LANGUAGE_TOKENS][s : s + 1, :first_sup]
prompt_mask_orig = sub[OBS_LANGUAGE_ATTENTION_MASK][s : s + 1, :first_sup]
inf_batch: dict[str, Any] = {
OBS_LANGUAGE_TOKENS: prompt_tokens,
OBS_LANGUAGE_ATTENTION_MASK: prompt_mask_orig,
}
for k, v in sub.items():
if isinstance(k, str) and k.startswith("observation.images."):
inf_batch[k] = v[s : s + 1]
if "observation.state" in batch and torch.is_tensor(batch["observation.state"]):
inf_batch["observation.state"] = batch["observation.state"][s : s + 1]
try:
# Tight budget — we just want to see the model's
# opening continuation, not the full sequence.
decoded = self.select_message(
inf_batch, max_new_tokens=24, temperature=0.0, top_p=1.0
)
except Exception as exc: # noqa: BLE001
decoded = f"<inference failed: {type(exc).__name__}: {exc}>"
inference_outputs.append({"first_sup_pos": first_sup, "decoded": decoded})
return {
"input_ids": lang_tokens.detach().cpu(),
"attention_mask": lang_masks.detach().cpu(),
"labels": sub_labels.detach().cpu(),
"predictions": preds.detach().cpu(),
"inference": inference_outputs,
}
finally:
if was_training:

View File

@@ -224,7 +224,6 @@ def _print_debug_text_predictions(
labels = debug["labels"]
preds = debug["predictions"]
attn = debug["attention_mask"]
inference = debug.get("inference") or []
n = ids.shape[0]
print(
@@ -251,7 +250,6 @@ def _print_debug_text_predictions(
# Training-side teacher-forced argmax on the same prompt+target.
n_sup = n_ok = 0
first_sup_pred: int | None = None
teacher_chars: list[int] = []
for i in range(1, real):
label = sl[i]
@@ -259,8 +257,6 @@ def _print_debug_text_predictions(
continue
n_sup += 1
pred = int(sp[i - 1])
if first_sup_pred is None:
first_sup_pred = pred
teacher_chars.append(pred)
if label == pred:
n_ok += 1
@@ -272,28 +268,6 @@ def _print_debug_text_predictions(
f" training argmax (teacher-fed) : {teacher_text!r} acc={n_ok}/{n_sup}={acc:.1%}",
flush=True,
)
# Inference-side autoregressive output from the same prompt prefix.
inf_entry = inference[s] if s < len(inference) else None
if inf_entry:
inf_decoded = inf_entry.get("decoded", "")
print(f" inference (autoregressive) : {inf_decoded!r}", flush=True)
# First-token parity: training-side argmax at the prompt-end
# position MUST equal inference's first generated token —
# both compute argmax(lm_head(h_last_prompt)) on identical
# context. Any divergence signals a training↔inference bug.
if first_sup_pred is not None and inf_decoded and not inf_decoded.startswith("<inference"):
inf_ids = tokenizer(inf_decoded, add_special_tokens=False)["input_ids"]
if inf_ids:
inf_first = int(inf_ids[0])
match = inf_first == first_sup_pred
print(
f" first-token parity : "
f"train={first_sup_pred} ({tokenizer.decode([first_sup_pred])!r}) "
f"vs infer={inf_first} ({tokenizer.decode([inf_first])!r}) "
f"{'✓ MATCH' if match else '✗ DIVERGED — training/inference mismatch'}",
flush=True,
)
print("=" * 60 + "\n", flush=True)
@@ -381,15 +355,26 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
# We set find_unused_parameters=True to handle models with conditional computation
if accelerator is None:
from accelerate.utils import DistributedDataParallelKwargs
from datetime import timedelta
from accelerate.utils import DistributedDataParallelKwargs, InitProcessGroupKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# Bump the c10d store-get / barrier timeout so the rank-0-only
# ``make_dataset`` block below doesn't trigger a barrier crash on
# large datasets. Default is 10 min (``store->get`` 600 s); a
# 32 k-episode v3 dataset (e.g. ``robocasa_pretrain_human300_v4``)
# spends >13 min on rank 0 building the episode/frame index
# while ranks 1-N idle at ``wait_for_everyone()`` and crash with
# ``DistBackendError: ... wait timeout after 600000ms``. 2 h is
# plenty of headroom; fast paths are unaffected.
ipg_kwargs = InitProcessGroupKwargs(timeout=timedelta(hours=2))
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
force_cpu = cfg.trainable_config.device == "cpu"
accelerator = Accelerator(
step_scheduler_with_optimizer=False,
kwargs_handlers=[ddp_kwargs],
kwargs_handlers=[ddp_kwargs, ipg_kwargs],
cpu=force_cpu,
)