feat(smolvla2-runtime): --text_min_new_tokens / --text_temperature CLI debug knobs

The recipe fix (target=${subtask} instead of ${next_subtask}) shifted
the LM head's failure mode from "emit newlines" to "emit EOS at
position 0". On the new ``_tool-good`` checkpoint inference produces
exactly one token (``<end_of_utterance>``, id 49279) and decodes to
empty. That's the chat-pretrained backbone's short-turn EOS prior
not yet being overridden by 2000 steps of fine-tuning supervision.

Expose three knobs so the operator can probe whether the head has
real subtask-token probability mass *under* the EOS argmax without
recompiling or retraining:

  --text_min_new_tokens=N    suppress EOS for the first N tokens
  --text_temperature=T       sample at temperature T
  --text_top_p=P             nucleus filtering at top-p

These are explicitly off-policy (training was greedy / no min-tokens),
so they shouldn't ship in production runs — but they let us tell
whether the model has *learned* subtask prediction (just under EOS)
or hasn't yet. If forcing min_new_tokens=3 with temperature=0.5
produces a sensible subtask, the model is fine and just needs more
training steps to walk EOS down. If it produces gibberish, training
hasn't progressed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-12 21:39:33 +02:00
parent b6fb536460
commit 3a20ea337e
2 changed files with 47 additions and 12 deletions

View File

@@ -365,24 +365,22 @@ class HighLevelSubtaskFwd(InferenceStep):
return None
ctx = _msgs_for_subtask(state)
observation = _maybe_observation(self.observation_provider)
# Match training: greedy argmax, no min_new_tokens, no
# special-token suppression. Earlier experiments forced
# min_new_tokens=5 + sampling because the LM head was
# collapsing to EOS at position 0 — but that turned out to
# be a visual-distribution shift (camera frames being fed
# at the camera's native resolution rather than the
# dataset's recorded resolution), not a head pathology.
# With the camera frame resized to the dataset's
# ``ds_features['observation.images.*']['shape']`` shape,
# the visual prefix is back on-distribution and the same
# greedy decoding that works in ``--no_robot`` dry-run also
# works on the live robot.
# Default: greedy argmax, no min_new_tokens, no special-token
# suppression — matches training. Operator can override via
# ``--text_min_new_tokens=N --text_temperature=T --text_top_p=P``
# on the CLI; useful for under-trained checkpoints whose LM
# head still favours EOS at position 0 (pre-trained chat
# backbone's short-turn prior hasn't been fully overridden
# by the fine-tuning supervision yet).
msg = _generate_with_policy(
self.policy,
ctx,
observation=observation,
state=state,
label="subtask gen",
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
temperature=float(state.get("text_gen_temperature") or 0.0),
top_p=float(state.get("text_gen_top_p") or 1.0),
)
# Diagnostics: surface what the model is *actually* producing
# at chunk boundaries, even when the output gets rejected or

View File

@@ -259,6 +259,35 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
default=None,
help="Stop after N ticks (debug / smoke-test).",
)
p.add_argument(
"--text_min_new_tokens",
type=int,
default=0,
help=(
"Debug knob for under-trained checkpoints: force the LM head "
"to emit at least N non-EOS tokens before EOS is allowed. "
"Use when the head's prior at position 0 still favours EOS "
"(short training run on a chat-pretrained backbone). 3-5 "
"is usually enough to reveal whether the model has real "
"subtask-token mass under the EOS argmax."
),
)
p.add_argument(
"--text_temperature",
type=float,
default=0.0,
help=(
"Sampling temperature for high-level text gen. 0 = greedy "
"argmax (default, matches training). Set 0.3-0.7 with an "
"under-trained checkpoint to escape stuck-at-EOS argmax."
),
)
p.add_argument(
"--text_top_p",
type=float,
default=1.0,
help="Nucleus filtering for high-level text gen.",
)
p.add_argument("-v", "--verbose", action="store_true", help="Enable DEBUG logging.")
return p.parse_args(argv)
@@ -1296,6 +1325,14 @@ def main(argv: list[str] | None = None) -> int:
ctrl_hz=args.ctrl_hz,
high_level_hz=args.high_level_hz,
)
# Stash text-gen knobs on the state dict so the high-level steps
# (which read state) can pick them up and forward them to
# policy.select_message. Letting the operator try
# ``--text_min_new_tokens=5 --text_temperature=0.6`` on an
# under-trained checkpoint without recompiling.
runtime.state["text_gen_min_new_tokens"] = int(getattr(args, "text_min_new_tokens", 0) or 0)
runtime.state["text_gen_temperature"] = float(getattr(args, "text_temperature", 0.0) or 0.0)
runtime.state["text_gen_top_p"] = float(getattr(args, "text_top_p", 1.0) or 1.0)
if args.task:
runtime.set_task(args.task)
# Seed plan/memory/subtask so the first prompt the runtime builds