mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 21:01:26 +00:00
fix(smolvla2): flatten say tool_calls into <say> marker before tokenizing
The chat tokenizer passed assistant `tool_calls` straight to `apply_chat_template`, which renders them as a structured JSON `<tool_call>` block — so the LM head was trained to emit JSON. But the inference parser `_split_plan_and_say` looks for a `<say>...</say>` marker, which the model never saw in training, so the `say` tool never fired at inference. `_flatten_say_tool_calls` is the missing training-time serializer (the one `_split_plan_and_say`'s docstring already assumed existed): it rewrites a `say` tool call into a `<say>...</say>` marker inside the content text before the chat template runs, so the template only tokenizes plain text and the supervised target span trains the model to emit exactly the marker the runtime parses back (Pi 0.5-style flat tool-call serialization). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -347,6 +347,11 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
|
||||
messages, target_indices = self._apply_prompt_dropout(
|
||||
messages, target_indices, sample_idx
|
||||
)
|
||||
# Flatten ``tool_calls`` into a textual ``<say>...</say>`` marker
|
||||
# *before* the chat template sees them, so the model is trained
|
||||
# to emit the same marker the inference parser
|
||||
# (``_split_plan_and_say``) reads back. See ``_flatten_say_tool_calls``.
|
||||
messages = [_flatten_say_tool_calls(m) for m in messages]
|
||||
text_messages = [_strip_lerobot_blocks(m) for m in messages]
|
||||
|
||||
full_ids = tokenizer.apply_chat_template(
|
||||
@@ -508,6 +513,75 @@ def _strip_lerobot_blocks(message: dict[str, Any]) -> dict[str, Any]:
|
||||
return new
|
||||
|
||||
|
||||
def _content_to_text(content: Any) -> str:
|
||||
"""Collapse a message's ``content`` (string or multimodal blocks) to plain text."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
t = block.get("text")
|
||||
if isinstance(t, str):
|
||||
parts.append(t)
|
||||
return "\n".join(parts)
|
||||
return ""
|
||||
|
||||
|
||||
def _flatten_say_tool_calls(message: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Serialize assistant ``say`` tool calls into a textual ``<say>...</say>``
|
||||
marker inside the message content (Pi 0.5-style flat tool-call
|
||||
serialization).
|
||||
|
||||
SmolVLM's chat template would otherwise render ``tool_calls`` as a
|
||||
structured JSON ``<tool_call>`` block, so the LM head learns to emit
|
||||
JSON — but the inference parser ``_split_plan_and_say`` looks for a
|
||||
``<say>...</say>`` marker (``_SAY_RE``). Rewriting the call into the
|
||||
content text *before* ``apply_chat_template`` aligns the two: the
|
||||
template only ever tokenizes plain text, and the supervised target
|
||||
span trains the model to produce the exact marker the runtime reads.
|
||||
|
||||
Messages without ``say`` tool calls are returned unchanged.
|
||||
"""
|
||||
tool_calls = message.get("tool_calls")
|
||||
if not tool_calls:
|
||||
return message
|
||||
|
||||
say_texts: list[str] = []
|
||||
for call in tool_calls:
|
||||
if not isinstance(call, dict):
|
||||
continue
|
||||
fn = call.get("function") or {}
|
||||
if fn.get("name") != "say":
|
||||
continue
|
||||
args = fn.get("arguments")
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
import json # noqa: PLC0415
|
||||
|
||||
args = json.loads(args)
|
||||
except (ValueError, TypeError):
|
||||
args = {}
|
||||
text = args.get("text", "") if isinstance(args, dict) else ""
|
||||
if text:
|
||||
say_texts.append(str(text))
|
||||
|
||||
if not say_texts:
|
||||
# No ``say`` calls (or empty text) — drop the structured calls so
|
||||
# the template doesn't render a stray JSON block, but leave the
|
||||
# content alone.
|
||||
new = dict(message)
|
||||
new.pop("tool_calls", None)
|
||||
return new
|
||||
|
||||
new = dict(message)
|
||||
base = _content_to_text(new.get("content")).strip()
|
||||
marker = "".join(f"<say>{t}</say>" for t in say_texts)
|
||||
new["content"] = f"{base}\n{marker}" if base else marker
|
||||
new.pop("tool_calls", None)
|
||||
return new
|
||||
|
||||
|
||||
def _is_batched_messages(messages: Any) -> bool:
|
||||
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
|
||||
|
||||
@@ -572,3 +646,4 @@ def _as_token_ids(value: Any) -> list[int]:
|
||||
|
||||
# Re-export for tests / introspection
|
||||
strip_lerobot_blocks = _strip_lerobot_blocks
|
||||
flatten_say_tool_calls = _flatten_say_tool_calls
|
||||
|
||||
Reference in New Issue
Block a user