feat(smolvla2): inference runtime — select_message + multi-rate REPL

Closes the loop on PR 3: SmolVLA2 can now be queried interactively at
inference, dispatching the same five sub-recipe shapes it was trained
on (action chunks, subtask gen, memory updates, plan/speech on
interjection, VQA on questions).

Modeling fixes + additions
--------------------------

- ``_compute_text_loss``: standard next-token CE shift was missing
  (logits at position t were CE'd against the label at t — identity-
  mapped, learning nothing). Adds ``logits[:, :-1]`` /
  ``labels[:, 1:]`` shift to match HuggingFace ``LlamaForCausalLM``.

- New ``select_message`` on ``SmolVLA2Policy``: AR text generation
  with KV caching, mirroring SmolVLA's ``select_action`` pattern.
  Single prefix forward fills the cache, then per-token forwards
  reuse it. Greedy + top-p nucleus sampling. Returns the decoded
  string with the prompt stripped.

Runtime package — ``src/lerobot/policies/smolvla2/inference/``
-------------------------------------------------------------

- ``triggers.py`` — ``Trigger`` Protocol + ``HzTrigger`` /
  ``EventTrigger`` + ``TickClock``. The whole runtime ticks at
  ``max_rate_hz=50`` and each step gates itself off its own
  cadence.

- ``runtime_state.py`` — runtime state dict factory plus tiny
  helpers (``take_event``, ``set_if_changed``, ``push_log``).
  Stable keys are documented at the top of the module.

- ``steps.py`` — :class:`InferenceStep` base + concrete steps:
  ``LowLevelForward`` / ``DispatchAction`` (action path),
  ``HighLevelSubtaskFwd`` / ``MemoryUpdateFwd`` /
  ``UserInterjectionFwd`` / ``AskVQAFwd`` (text paths),
  ``DispatchToolCalls`` (tool registry → ``Tool.call``). Each
  text step builds a chat-template prompt from current
  ``RuntimeState`` (task / plan / memory / subtask) matching
  what ``smolvla2_hirobot.yaml`` renders during training.
  Includes a tiny ``<say>...</say>`` parser for the
  ``user_interjection_response`` branch's combined plan + speech
  output.

- ``runtime.py`` — :class:`SmolVLA2Runtime` composes the pipeline,
  drives ticks via ``TickClock``, polls a user-supplied
  ``event_collector`` per tick, and prints state-change log lines.

- ``repl.py`` — :class:`StdinReader` non-blocking line reader
  with simple intent classification: ``stop`` / ``quit`` /
  ``exit`` → terminate; ``?`` suffix → ``user_vqa_query`` event;
  first line → set task; other lines → ``user_interjection``.

CLI
---

- ``src/lerobot/scripts/lerobot_smolvla2_runtime.py``: console
  script ``lerobot-smolvla2-runtime`` that loads a checkpoint,
  optionally instantiates ``SayTool`` (pocket-tts), wires up
  ``SmolVLA2Runtime`` + ``StdinReader``, and runs.

  Real-robot wiring (observation_provider / robot_executor) is
  intentionally left as a follow-up — v1 is dry-run / language-
  only so the REPL works without robot hardware.

  Registered in ``pyproject.toml`` ``[project.scripts]``.

Known follow-ups
----------------

- Real-robot integration: today ``LowLevelForward`` only fires when
  an observation_provider is wired. The CLI prints a warning if
  ``--no_robot`` is omitted.
- ``select_message`` runs an extra prefix forward; could share with
  the action path's prefix when both are needed in the same tick.
- Tests: no end-to-end runtime test yet (would need a tiny SmolVLM
  fixture). The components compile and the public surface is
  exercised by the CLI's argument-parsing path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-04-30 22:04:00 +02:00
parent af6d8ebd5b
commit 223cc8a9e2
9 changed files with 1230 additions and 2 deletions

View File

@@ -0,0 +1,196 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""``lerobot-smolvla2-runtime`` — interactive REPL for trained SmolVLA2.
Drives the multi-rate runtime defined in
:mod:`lerobot.policies.smolvla2.inference`. Stdin becomes the user
channel: type a task, then natural-language interjections / questions.
The runtime prints state changes (plan / subtask / memory / vqa /
speech) as they happen.
Examples
--------
Dry run on a checkpoint, no robot connected — useful for sanity-
checking text generation::
uv run lerobot-smolvla2-runtime \\
--policy.path=outputs/train/smolvla2_super_poulain/000020000/pretrained_model \\
--no_robot \\
--task="please clean the kitchen"
With a real robot::
uv run lerobot-smolvla2-runtime \\
--policy.path=... \\
--robot.type=so101 --robot.port=/dev/tty.usbmodem... \\
--tts.voice=alba
Tool dispatch (TTS via ``SayTool``) is enabled by default when
``pocket-tts`` is installed; pass ``--no_tts`` to disable.
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
from typing import Any
logger = logging.getLogger("lerobot.smolvla2.runtime")
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
p = argparse.ArgumentParser(
prog="lerobot-smolvla2-runtime",
description="Interactive REPL runtime for a trained SmolVLA2 checkpoint.",
)
p.add_argument(
"--policy.path",
dest="policy_path",
type=Path,
required=True,
help="Path to a trained SmolVLA2 ``pretrained_model`` directory.",
)
p.add_argument(
"--task",
dest="task",
type=str,
default=None,
help="Initial task. If omitted, the first stdin line is treated as the task.",
)
p.add_argument(
"--no_robot",
action="store_true",
help="Skip robot connection — language-only / dry-run mode.",
)
p.add_argument(
"--no_tts",
action="store_true",
help="Disable the ``say`` tool dispatch.",
)
p.add_argument(
"--tts.voice",
dest="tts_voice",
type=str,
default="alba",
help="Pocket-tts voice name (or path to a .wav for cloning).",
)
p.add_argument(
"--chunk_hz", type=float, default=4.0, help="Action-chunk generation rate."
)
p.add_argument(
"--ctrl_hz", type=float, default=50.0, help="Action dispatch rate."
)
p.add_argument(
"--high_level_hz",
type=float,
default=1.0,
help="High-level subtask generation rate.",
)
p.add_argument(
"--max_ticks",
type=int,
default=None,
help="Stop after N ticks (debug / smoke-test).",
)
p.add_argument("-v", "--verbose", action="store_true", help="Enable DEBUG logging.")
return p.parse_args(argv)
def _load_policy(path: Path): # noqa: ANN202
"""Load a SmolVLA2 checkpoint from ``path``."""
from lerobot.policies.factory import make_policy_from_path # noqa: PLC0415
policy = make_policy_from_path(str(path))
policy.eval()
return policy
def _build_tools(policy_path: Path, no_tts: bool, tts_voice: str) -> dict[str, Any]:
"""Instantiate the tools declared on this dataset/policy."""
if no_tts:
return {}
try:
from lerobot.tools import SayTool # noqa: PLC0415
return {"say": SayTool(voice=tts_voice)}
except Exception as exc: # noqa: BLE001
logger.warning("Could not initialise SayTool (%s) — speech disabled.", exc)
return {}
def main(argv: list[str] | None = None) -> int:
args = _parse_args(argv)
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
)
if not args.policy_path.exists():
print(f"[smolvla2] policy path not found: {args.policy_path}", file=sys.stderr)
return 1
print(f"[smolvla2] loading policy from {args.policy_path}", flush=True)
policy = _load_policy(args.policy_path)
tools = _build_tools(args.policy_path, args.no_tts, args.tts_voice)
if tools:
print(f"[smolvla2] tools loaded: {list(tools)}", flush=True)
# Robot wiring is left as a follow-up — for v1 we run language-only
# / dry-run so REPL development doesn't require a connected robot.
observation_provider = None
robot_executor = None
if not args.no_robot:
print(
"[smolvla2] WARNING: real-robot integration is a follow-up. "
"Running in dry-run mode for now (no actions executed).",
flush=True,
)
from lerobot.policies.smolvla2.inference import ( # noqa: PLC0415
SmolVLA2Runtime,
StdinReader,
)
runtime = SmolVLA2Runtime(
policy=policy,
tools=tools,
observation_provider=observation_provider,
robot_executor=robot_executor,
event_collector=StdinReader().poll,
chunk_hz=args.chunk_hz,
ctrl_hz=args.ctrl_hz,
high_level_hz=args.high_level_hz,
)
if args.task:
runtime.set_task(args.task)
print(
"[smolvla2] runtime ready. Type a task to begin, then any line for "
"interjections, questions ending in '?' for VQA, or 'stop' to exit.",
flush=True,
)
try:
runtime.run(max_ticks=args.max_ticks)
except KeyboardInterrupt:
runtime.stop()
print("\n[smolvla2] interrupted by user", flush=True)
return 0
if __name__ == "__main__":
sys.exit(main())