fix(smolvla2): flatten say tool_calls into <say> marker before tokenizing

The chat tokenizer passed assistant `tool_calls` straight to
`apply_chat_template`, which renders them as a structured JSON
`<tool_call>` block — so the LM head was trained to emit JSON. But the
inference parser `_split_plan_and_say` looks for a `<say>...</say>`
marker, which the model never saw in training, so the `say` tool never
fired at inference.

`_flatten_say_tool_calls` is the missing training-time serializer (the
one `_split_plan_and_say`'s docstring already assumed existed): it
rewrites a `say` tool call into a `<say>...</say>` marker inside the
content text before the chat template runs, so the template only
tokenizes plain text and the supervised target span trains the model to
emit exactly the marker the runtime parses back (Pi 0.5-style flat
tool-call serialization).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-18 10:47:31 +02:00
parent 5e3b9ba82c
commit bfb8cfb432
2 changed files with 152 additions and 0 deletions

View File

@@ -347,6 +347,11 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
messages, target_indices = self._apply_prompt_dropout(
messages, target_indices, sample_idx
)
# Flatten ``tool_calls`` into a textual ``<say>...</say>`` marker
# *before* the chat template sees them, so the model is trained
# to emit the same marker the inference parser
# (``_split_plan_and_say``) reads back. See ``_flatten_say_tool_calls``.
messages = [_flatten_say_tool_calls(m) for m in messages]
text_messages = [_strip_lerobot_blocks(m) for m in messages]
full_ids = tokenizer.apply_chat_template(
@@ -508,6 +513,75 @@ def _strip_lerobot_blocks(message: dict[str, Any]) -> dict[str, Any]:
return new
def _content_to_text(content: Any) -> str:
"""Collapse a message's ``content`` (string or multimodal blocks) to plain text."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
t = block.get("text")
if isinstance(t, str):
parts.append(t)
return "\n".join(parts)
return ""
def _flatten_say_tool_calls(message: dict[str, Any]) -> dict[str, Any]:
"""Serialize assistant ``say`` tool calls into a textual ``<say>...</say>``
marker inside the message content (Pi 0.5-style flat tool-call
serialization).
SmolVLM's chat template would otherwise render ``tool_calls`` as a
structured JSON ``<tool_call>`` block, so the LM head learns to emit
JSON — but the inference parser ``_split_plan_and_say`` looks for a
``<say>...</say>`` marker (``_SAY_RE``). Rewriting the call into the
content text *before* ``apply_chat_template`` aligns the two: the
template only ever tokenizes plain text, and the supervised target
span trains the model to produce the exact marker the runtime reads.
Messages without ``say`` tool calls are returned unchanged.
"""
tool_calls = message.get("tool_calls")
if not tool_calls:
return message
say_texts: list[str] = []
for call in tool_calls:
if not isinstance(call, dict):
continue
fn = call.get("function") or {}
if fn.get("name") != "say":
continue
args = fn.get("arguments")
if isinstance(args, str):
try:
import json # noqa: PLC0415
args = json.loads(args)
except (ValueError, TypeError):
args = {}
text = args.get("text", "") if isinstance(args, dict) else ""
if text:
say_texts.append(str(text))
if not say_texts:
# No ``say`` calls (or empty text) — drop the structured calls so
# the template doesn't render a stray JSON block, but leave the
# content alone.
new = dict(message)
new.pop("tool_calls", None)
return new
new = dict(message)
base = _content_to_text(new.get("content")).strip()
marker = "".join(f"<say>{t}</say>" for t in say_texts)
new["content"] = f"{base}\n{marker}" if base else marker
new.pop("tool_calls", None)
return new
def _is_batched_messages(messages: Any) -> bool:
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
@@ -572,3 +646,4 @@ def _as_token_ids(value: Any) -> list[int]:
# Re-export for tests / introspection
strip_lerobot_blocks = _strip_lerobot_blocks
flatten_say_tool_calls = _flatten_say_tool_calls