mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
feat(smolvla2): inference runtime — select_message + multi-rate REPL
Closes the loop on PR 3: SmolVLA2 can now be queried interactively at inference, dispatching the same five sub-recipe shapes it was trained on (action chunks, subtask gen, memory updates, plan/speech on interjection, VQA on questions). Modeling fixes + additions -------------------------- - ``_compute_text_loss``: standard next-token CE shift was missing (logits at position t were CE'd against the label at t — identity- mapped, learning nothing). Adds ``logits[:, :-1]`` / ``labels[:, 1:]`` shift to match HuggingFace ``LlamaForCausalLM``. - New ``select_message`` on ``SmolVLA2Policy``: AR text generation with KV caching, mirroring SmolVLA's ``select_action`` pattern. Single prefix forward fills the cache, then per-token forwards reuse it. Greedy + top-p nucleus sampling. Returns the decoded string with the prompt stripped. Runtime package — ``src/lerobot/policies/smolvla2/inference/`` ------------------------------------------------------------- - ``triggers.py`` — ``Trigger`` Protocol + ``HzTrigger`` / ``EventTrigger`` + ``TickClock``. The whole runtime ticks at ``max_rate_hz=50`` and each step gates itself off its own cadence. - ``runtime_state.py`` — runtime state dict factory plus tiny helpers (``take_event``, ``set_if_changed``, ``push_log``). Stable keys are documented at the top of the module. - ``steps.py`` — :class:`InferenceStep` base + concrete steps: ``LowLevelForward`` / ``DispatchAction`` (action path), ``HighLevelSubtaskFwd`` / ``MemoryUpdateFwd`` / ``UserInterjectionFwd`` / ``AskVQAFwd`` (text paths), ``DispatchToolCalls`` (tool registry → ``Tool.call``). Each text step builds a chat-template prompt from current ``RuntimeState`` (task / plan / memory / subtask) matching what ``smolvla2_hirobot.yaml`` renders during training. Includes a tiny ``<say>...</say>`` parser for the ``user_interjection_response`` branch's combined plan + speech output. - ``runtime.py`` — :class:`SmolVLA2Runtime` composes the pipeline, drives ticks via ``TickClock``, polls a user-supplied ``event_collector`` per tick, and prints state-change log lines. - ``repl.py`` — :class:`StdinReader` non-blocking line reader with simple intent classification: ``stop`` / ``quit`` / ``exit`` → terminate; ``?`` suffix → ``user_vqa_query`` event; first line → set task; other lines → ``user_interjection``. CLI --- - ``src/lerobot/scripts/lerobot_smolvla2_runtime.py``: console script ``lerobot-smolvla2-runtime`` that loads a checkpoint, optionally instantiates ``SayTool`` (pocket-tts), wires up ``SmolVLA2Runtime`` + ``StdinReader``, and runs. Real-robot wiring (observation_provider / robot_executor) is intentionally left as a follow-up — v1 is dry-run / language- only so the REPL works without robot hardware. Registered in ``pyproject.toml`` ``[project.scripts]``. Known follow-ups ---------------- - Real-robot integration: today ``LowLevelForward`` only fires when an observation_provider is wired. The CLI prints a warning if ``--no_robot`` is omitted. - ``select_message`` runs an extra prefix forward; could share with the action path's prefix when both are needed in the same tick. - Tests: no end-to-end runtime test yet (would need a tiny SmolVLM fixture). The components compile and the public surface is exercised by the CLI's argument-parsing path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
196
src/lerobot/scripts/lerobot_smolvla2_runtime.py
Normal file
196
src/lerobot/scripts/lerobot_smolvla2_runtime.py
Normal file
@@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""``lerobot-smolvla2-runtime`` — interactive REPL for trained SmolVLA2.
|
||||
|
||||
Drives the multi-rate runtime defined in
|
||||
:mod:`lerobot.policies.smolvla2.inference`. Stdin becomes the user
|
||||
channel: type a task, then natural-language interjections / questions.
|
||||
The runtime prints state changes (plan / subtask / memory / vqa /
|
||||
speech) as they happen.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
Dry run on a checkpoint, no robot connected — useful for sanity-
|
||||
checking text generation::
|
||||
|
||||
uv run lerobot-smolvla2-runtime \\
|
||||
--policy.path=outputs/train/smolvla2_super_poulain/000020000/pretrained_model \\
|
||||
--no_robot \\
|
||||
--task="please clean the kitchen"
|
||||
|
||||
With a real robot::
|
||||
|
||||
uv run lerobot-smolvla2-runtime \\
|
||||
--policy.path=... \\
|
||||
--robot.type=so101 --robot.port=/dev/tty.usbmodem... \\
|
||||
--tts.voice=alba
|
||||
|
||||
Tool dispatch (TTS via ``SayTool``) is enabled by default when
|
||||
``pocket-tts`` is installed; pass ``--no_tts`` to disable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger("lerobot.smolvla2.runtime")
|
||||
|
||||
|
||||
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(
|
||||
prog="lerobot-smolvla2-runtime",
|
||||
description="Interactive REPL runtime for a trained SmolVLA2 checkpoint.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--policy.path",
|
||||
dest="policy_path",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to a trained SmolVLA2 ``pretrained_model`` directory.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--task",
|
||||
dest="task",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Initial task. If omitted, the first stdin line is treated as the task.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--no_robot",
|
||||
action="store_true",
|
||||
help="Skip robot connection — language-only / dry-run mode.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--no_tts",
|
||||
action="store_true",
|
||||
help="Disable the ``say`` tool dispatch.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--tts.voice",
|
||||
dest="tts_voice",
|
||||
type=str,
|
||||
default="alba",
|
||||
help="Pocket-tts voice name (or path to a .wav for cloning).",
|
||||
)
|
||||
p.add_argument(
|
||||
"--chunk_hz", type=float, default=4.0, help="Action-chunk generation rate."
|
||||
)
|
||||
p.add_argument(
|
||||
"--ctrl_hz", type=float, default=50.0, help="Action dispatch rate."
|
||||
)
|
||||
p.add_argument(
|
||||
"--high_level_hz",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="High-level subtask generation rate.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--max_ticks",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Stop after N ticks (debug / smoke-test).",
|
||||
)
|
||||
p.add_argument("-v", "--verbose", action="store_true", help="Enable DEBUG logging.")
|
||||
return p.parse_args(argv)
|
||||
|
||||
|
||||
def _load_policy(path: Path): # noqa: ANN202
|
||||
"""Load a SmolVLA2 checkpoint from ``path``."""
|
||||
from lerobot.policies.factory import make_policy_from_path # noqa: PLC0415
|
||||
|
||||
policy = make_policy_from_path(str(path))
|
||||
policy.eval()
|
||||
return policy
|
||||
|
||||
|
||||
def _build_tools(policy_path: Path, no_tts: bool, tts_voice: str) -> dict[str, Any]:
|
||||
"""Instantiate the tools declared on this dataset/policy."""
|
||||
if no_tts:
|
||||
return {}
|
||||
try:
|
||||
from lerobot.tools import SayTool # noqa: PLC0415
|
||||
|
||||
return {"say": SayTool(voice=tts_voice)}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("Could not initialise SayTool (%s) — speech disabled.", exc)
|
||||
return {}
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
args = _parse_args(argv)
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(message)s",
|
||||
)
|
||||
|
||||
if not args.policy_path.exists():
|
||||
print(f"[smolvla2] policy path not found: {args.policy_path}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print(f"[smolvla2] loading policy from {args.policy_path}", flush=True)
|
||||
policy = _load_policy(args.policy_path)
|
||||
|
||||
tools = _build_tools(args.policy_path, args.no_tts, args.tts_voice)
|
||||
if tools:
|
||||
print(f"[smolvla2] tools loaded: {list(tools)}", flush=True)
|
||||
|
||||
# Robot wiring is left as a follow-up — for v1 we run language-only
|
||||
# / dry-run so REPL development doesn't require a connected robot.
|
||||
observation_provider = None
|
||||
robot_executor = None
|
||||
if not args.no_robot:
|
||||
print(
|
||||
"[smolvla2] WARNING: real-robot integration is a follow-up. "
|
||||
"Running in dry-run mode for now (no actions executed).",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
from lerobot.policies.smolvla2.inference import ( # noqa: PLC0415
|
||||
SmolVLA2Runtime,
|
||||
StdinReader,
|
||||
)
|
||||
|
||||
runtime = SmolVLA2Runtime(
|
||||
policy=policy,
|
||||
tools=tools,
|
||||
observation_provider=observation_provider,
|
||||
robot_executor=robot_executor,
|
||||
event_collector=StdinReader().poll,
|
||||
chunk_hz=args.chunk_hz,
|
||||
ctrl_hz=args.ctrl_hz,
|
||||
high_level_hz=args.high_level_hz,
|
||||
)
|
||||
if args.task:
|
||||
runtime.set_task(args.task)
|
||||
print(
|
||||
"[smolvla2] runtime ready. Type a task to begin, then any line for "
|
||||
"interjections, questions ending in '?' for VQA, or 'stop' to exit.",
|
||||
flush=True,
|
||||
)
|
||||
try:
|
||||
runtime.run(max_ticks=args.max_ticks)
|
||||
except KeyboardInterrupt:
|
||||
runtime.stop()
|
||||
print("\n[smolvla2] interrupted by user", flush=True)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user