mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 21:01:26 +00:00
feat(smolvla2-runtime): --text_min_new_tokens / --text_temperature CLI debug knobs
The recipe fix (target=${subtask} instead of ${next_subtask}) shifted
the LM head's failure mode from "emit newlines" to "emit EOS at
position 0". On the new ``_tool-good`` checkpoint inference produces
exactly one token (``<end_of_utterance>``, id 49279) and decodes to
empty. That's the chat-pretrained backbone's short-turn EOS prior
not yet being overridden by 2000 steps of fine-tuning supervision.
Expose three knobs so the operator can probe whether the head has
real subtask-token probability mass *under* the EOS argmax without
recompiling or retraining:
--text_min_new_tokens=N suppress EOS for the first N tokens
--text_temperature=T sample at temperature T
--text_top_p=P nucleus filtering at top-p
These are explicitly off-policy (training was greedy / no min-tokens),
so they shouldn't ship in production runs — but they let us tell
whether the model has *learned* subtask prediction (just under EOS)
or hasn't yet. If forcing min_new_tokens=3 with temperature=0.5
produces a sensible subtask, the model is fine and just needs more
training steps to walk EOS down. If it produces gibberish, training
hasn't progressed.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -365,24 +365,22 @@ class HighLevelSubtaskFwd(InferenceStep):
|
||||
return None
|
||||
ctx = _msgs_for_subtask(state)
|
||||
observation = _maybe_observation(self.observation_provider)
|
||||
# Match training: greedy argmax, no min_new_tokens, no
|
||||
# special-token suppression. Earlier experiments forced
|
||||
# min_new_tokens=5 + sampling because the LM head was
|
||||
# collapsing to EOS at position 0 — but that turned out to
|
||||
# be a visual-distribution shift (camera frames being fed
|
||||
# at the camera's native resolution rather than the
|
||||
# dataset's recorded resolution), not a head pathology.
|
||||
# With the camera frame resized to the dataset's
|
||||
# ``ds_features['observation.images.*']['shape']`` shape,
|
||||
# the visual prefix is back on-distribution and the same
|
||||
# greedy decoding that works in ``--no_robot`` dry-run also
|
||||
# works on the live robot.
|
||||
# Default: greedy argmax, no min_new_tokens, no special-token
|
||||
# suppression — matches training. Operator can override via
|
||||
# ``--text_min_new_tokens=N --text_temperature=T --text_top_p=P``
|
||||
# on the CLI; useful for under-trained checkpoints whose LM
|
||||
# head still favours EOS at position 0 (pre-trained chat
|
||||
# backbone's short-turn prior hasn't been fully overridden
|
||||
# by the fine-tuning supervision yet).
|
||||
msg = _generate_with_policy(
|
||||
self.policy,
|
||||
ctx,
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="subtask gen",
|
||||
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
|
||||
temperature=float(state.get("text_gen_temperature") or 0.0),
|
||||
top_p=float(state.get("text_gen_top_p") or 1.0),
|
||||
)
|
||||
# Diagnostics: surface what the model is *actually* producing
|
||||
# at chunk boundaries, even when the output gets rejected or
|
||||
|
||||
@@ -259,6 +259,35 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||
default=None,
|
||||
help="Stop after N ticks (debug / smoke-test).",
|
||||
)
|
||||
p.add_argument(
|
||||
"--text_min_new_tokens",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Debug knob for under-trained checkpoints: force the LM head "
|
||||
"to emit at least N non-EOS tokens before EOS is allowed. "
|
||||
"Use when the head's prior at position 0 still favours EOS "
|
||||
"(short training run on a chat-pretrained backbone). 3-5 "
|
||||
"is usually enough to reveal whether the model has real "
|
||||
"subtask-token mass under the EOS argmax."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--text_temperature",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help=(
|
||||
"Sampling temperature for high-level text gen. 0 = greedy "
|
||||
"argmax (default, matches training). Set 0.3-0.7 with an "
|
||||
"under-trained checkpoint to escape stuck-at-EOS argmax."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--text_top_p",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Nucleus filtering for high-level text gen.",
|
||||
)
|
||||
p.add_argument("-v", "--verbose", action="store_true", help="Enable DEBUG logging.")
|
||||
return p.parse_args(argv)
|
||||
|
||||
@@ -1296,6 +1325,14 @@ def main(argv: list[str] | None = None) -> int:
|
||||
ctrl_hz=args.ctrl_hz,
|
||||
high_level_hz=args.high_level_hz,
|
||||
)
|
||||
# Stash text-gen knobs on the state dict so the high-level steps
|
||||
# (which read state) can pick them up and forward them to
|
||||
# policy.select_message. Letting the operator try
|
||||
# ``--text_min_new_tokens=5 --text_temperature=0.6`` on an
|
||||
# under-trained checkpoint without recompiling.
|
||||
runtime.state["text_gen_min_new_tokens"] = int(getattr(args, "text_min_new_tokens", 0) or 0)
|
||||
runtime.state["text_gen_temperature"] = float(getattr(args, "text_temperature", 0.0) or 0.0)
|
||||
runtime.state["text_gen_top_p"] = float(getattr(args, "text_top_p", 1.0) or 1.0)
|
||||
if args.task:
|
||||
runtime.set_task(args.task)
|
||||
# Seed plan/memory/subtask so the first prompt the runtime builds
|
||||
|
||||
Reference in New Issue
Block a user