feat(pi052): train VQA spatial answers in PaliGemma <loc> format

Spatial VQA answers (bbox / keypoint) were trained as pixel-coordinate
JSON, which fights PaliGemma's detection prior and leaks <loc>-token
salad at inference. Convert them to PaliGemma's native <locNNNN>
vocabulary instead so the LM head reuses that prior.

Training side (text_processor_pi052.py): a target turn whose content
parses as a bbox/keypoint answer is rewritten to <loc> text, using the
camera frame's native (H, W) from the observation and the preceding
image block. Non-spatial answers, subtask/memory targets and SmolVLA2
keep their JSON form — the dataset stays backbone-agnostic.

Runtime side (smolvla2/inference/vqa.py): parse_vqa_answer detects
<loc> answers (2 locs -> keypoint, 4 -> bbox), returning normalized
[0,1] coords with a normalized flag; draw_vqa_overlay denormalizes
against the chosen camera frame's pixel size.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-19 20:23:46 +02:00
parent e425dfd624
commit c026aed8f8
5 changed files with 1917 additions and 141 deletions

View File

@@ -36,6 +36,7 @@ Outputs:
from __future__ import annotations
import json
import logging
import os
from dataclasses import dataclass
@@ -234,6 +235,134 @@ def _sample_indices(value: Any, batch_size: int) -> list[int | None]:
return [int(value)] * batch_size
# ---------------------------------------------------------------------------
# VQA spatial answers → PaliGemma <loc> format (PI052 only)
#
# PaliGemma is pre-trained on detection / pointing with a ``<locNNNN>``
# vocabulary (normalized [0, 1023]). The recipe's bbox / keypoint VQA
# answers are stored as JSON with *pixel* coordinates. Training those in
# ``<loc>`` form leverages PaliGemma's prior instead of fighting it (the
# ``<loc>``-token salad). The conversion lives here — not in the dataset
# — so the dataset stays backbone-agnostic (SmolVLA2 keeps the JSON).
# ---------------------------------------------------------------------------
def _camera_image_shapes(observation: dict[str, Any]) -> dict[str, tuple[int, int]]:
"""Map each ``observation.images.*`` key to its native ``(height, width)``.
VQA pixel coordinates are relative to the camera frame's native
resolution. PI052's input pipeline applies no spatial resize before
this step, so the observation image tensors are still at that
resolution — the correct reference for normalizing to ``<loc>``.
"""
shapes: dict[str, tuple[int, int]] = {}
for key, value in (observation or {}).items():
if not (isinstance(key, str) and key.startswith("observation.images.")):
continue
shape = getattr(value, "shape", None)
if shape is None or len(shape) < 2:
continue
shapes[key] = (int(shape[-2]), int(shape[-1])) # (H, W); handles (B,C,H,W)/(C,H,W)
return shapes
def _loc_token(coord: float, dim: int) -> str:
"""PaliGemma ``<locNNNN>`` for pixel ``coord`` on an axis of size ``dim``."""
idx = round(float(coord) / dim * 1023) if dim > 0 else 0
return f"<loc{max(0, min(1023, idx)):04d}>"
def _vqa_answer_to_loc(answer: dict[str, Any], height: int, width: int) -> str | None:
"""Convert a bbox / keypoint VQA answer dict to PaliGemma ``<loc>`` text.
PaliGemma convention: a point is ``<locY><locX> label``; a box is
``<locY0><locX0><locY1><locX1> label`` (y before x, each index in
[0, 1023]). Returns ``None`` for non-spatial answers (count /
attribute / spatial-relation) — those keep their JSON form.
"""
point = answer.get("point")
if isinstance(point, list | tuple) and len(point) == 2 and "point_format" in answer:
try:
x, y = float(point[0]), float(point[1])
except (TypeError, ValueError):
return None
label = str(answer.get("label", "")).strip()
return f"{_loc_token(y, height)}{_loc_token(x, width)} {label}".strip()
detections = answer.get("detections")
if isinstance(detections, list) and detections:
parts: list[str] = []
for det in detections:
if not isinstance(det, dict):
continue
box = det.get("bbox")
if not (isinstance(box, list | tuple) and len(box) == 4):
continue
try:
x1, y1, x2, y2 = (float(v) for v in box)
except (TypeError, ValueError):
continue
label = str(det.get("label", "")).strip()
toks = (
f"{_loc_token(y1, height)}{_loc_token(x1, width)}"
f"{_loc_token(y2, height)}{_loc_token(x2, width)}"
)
parts.append(f"{toks} {label}".strip())
return " ; ".join(parts) if parts else None
return None
def _preceding_image_feature(messages: list[dict[str, Any]], idx: int) -> str | None:
"""Camera ``feature`` of the nearest image block at or before ``idx``."""
for j in range(min(idx, len(messages) - 1), -1, -1):
content = messages[j].get("content")
if not isinstance(content, list):
continue
for block in content:
if isinstance(block, dict) and block.get("type") == "image":
feature = block.get("feature")
if isinstance(feature, str):
return feature
return None
def _messages_vqa_to_loc(
messages: list[dict[str, Any]],
target_indices: list[int],
image_shapes: dict[str, tuple[int, int]] | None,
) -> list[dict[str, Any]]:
"""Rewrite bbox / keypoint VQA *target* answers from JSON to ``<loc>`` text.
Each target turn whose content parses as a spatial VQA answer is
converted, using the camera frame found from the preceding image
block. Non-spatial answers, subtask / memory targets (plain text →
not JSON), and turns with no matching image shape are left untouched.
"""
if not image_shapes or not target_indices:
return messages
out = list(messages)
for idx in target_indices:
if not (0 <= idx < len(out)):
continue
content = out[idx].get("content")
if not isinstance(content, str) or not content.strip():
continue
try:
answer = json.loads(content)
except (ValueError, TypeError):
continue # subtask / memory targets are plain text — skip
if not isinstance(answer, dict):
continue
feature = _preceding_image_feature(out, idx)
if feature is None or feature not in image_shapes:
continue
h, w = image_shapes[feature]
loc_text = _vqa_answer_to_loc(answer, h, w)
if loc_text is not None:
out[idx] = {**out[idx], "content": loc_text}
return out
def _format_messages(
messages: list[dict[str, Any]],
target_indices: list[int] | None = None,
@@ -329,6 +458,9 @@ class PI052TextTokenizerStep(ProcessorStep):
return transition
tokenizer = self._ensure_tokenizer()
# Native camera resolutions — the reference frame for converting
# VQA pixel coordinates to PaliGemma <loc> tokens.
image_shapes = _camera_image_shapes(transition.get(TransitionKey.OBSERVATION) or {})
if _is_batched_messages(messages):
indices_iter = _sample_indices(complementary.get("index"), len(messages))
encoded = [
@@ -339,6 +471,7 @@ class PI052TextTokenizerStep(ProcessorStep):
list(tgt_indices),
complementary,
sample_idx=int(s_idx) if s_idx is not None else None,
image_shapes=image_shapes,
)
for msg, streams, tgt_indices, s_idx in zip(
messages,
@@ -358,6 +491,7 @@ class PI052TextTokenizerStep(ProcessorStep):
list(complementary.get("target_message_indices") or []),
complementary,
sample_idx=sample_idx,
image_shapes=image_shapes,
)
]
@@ -411,6 +545,7 @@ class PI052TextTokenizerStep(ProcessorStep):
target_indices: list[int],
complementary: dict[str, Any],
sample_idx: int | None = None,
image_shapes: dict[str, tuple[int, int]] | None = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor, str]:
# Optional: drop non-target messages per the dropout config.
# Keeps the supervised-target indices stable by re-mapping
@@ -428,6 +563,11 @@ class PI052TextTokenizerStep(ProcessorStep):
sample_idx=sample_idx,
)
# Rewrite bbox / keypoint VQA target answers from JSON to
# PaliGemma <loc> text — done before stripping so the image
# block (camera frame) is still available to normalize against.
messages = _messages_vqa_to_loc(messages, target_indices, image_shapes)
# Flatten ``say`` tool calls into ``<say>...</say>`` text before
# stripping, so the spoken reply is actually tokenized and
# supervised (PaliGemma's flat prompt has no structured calls).

View File

@@ -37,6 +37,7 @@ from __future__ import annotations
import json
import logging
import os
import re
import subprocess
import sys
import time
@@ -50,6 +51,14 @@ logger = logging.getLogger(__name__)
_IMAGE_PREFIX = "observation.images."
# PaliGemma detection / pointing vocabulary. PI052 trains spatial VQA
# answers in this native ``<locNNNN>`` format (index in [0, 1023],
# normalized to the image axis) instead of pixel-coordinate JSON, so the
# answer string the runtime parses can be e.g.
# ``<loc0512><loc0301> blue cube`` (point) or
# ``<loc0100><loc0080><loc0400><loc0360> blue cube`` (box).
_LOC_RE = re.compile(r"<loc(\d{1,4})>")
# Iteration order for shape matching — most specific keys first so an
# answer is classified deterministically.
_SHAPE_ORDER = ("bbox", "keypoint", "count", "attribute", "spatial")
@@ -115,16 +124,74 @@ def prompt_camera_choice(
# ---------------------------------------------------------------------------
def _loc_to_norm(idx: int) -> float:
"""PaliGemma ``<locNNNN>`` index → normalized [0, 1] axis coordinate."""
return max(0.0, min(1023.0, float(idx))) / 1023.0
def parse_loc_answer(answer: str) -> dict | None:
"""Parse a PaliGemma ``<loc>``-format spatial VQA answer.
PI052 trains spatial answers in PaliGemma's native detection
vocabulary: a point is ``<locY><locX> label``, a box is
``<locY0><locX0><locY1><locX1> label``, and multiple boxes are joined
by `` ; ``. Coordinates come back *normalized* ([0, 1]); the overlay
denormalizes them against the chosen camera frame's pixel size.
Returns ``{"kind", "payload", "normalized": True}`` on success
(``payload`` mirrors the JSON shapes so the overlay code is shared),
or ``None`` when the answer carries no ``<loc>`` tokens.
"""
if not answer or "<loc" not in answer:
return None
segments = [seg for seg in answer.split(";") if "<loc" in seg]
points: list[tuple[float, float, str]] = []
boxes: list[tuple[float, float, float, float, str]] = []
for seg in segments:
locs = [int(m) for m in _LOC_RE.findall(seg)]
label = _LOC_RE.sub("", seg).strip()
if len(locs) == 2:
y, x = (_loc_to_norm(v) for v in locs[:2])
points.append((x, y, label))
elif len(locs) >= 4:
y1, x1, y2, x2 = (_loc_to_norm(v) for v in locs[:4])
boxes.append((x1, y1, x2, y2, label))
if boxes:
detections = [
{"label": lbl, "bbox_format": "xyxy", "bbox": [x1, y1, x2, y2]}
for (x1, y1, x2, y2, lbl) in boxes
]
return {"kind": "bbox", "payload": {"detections": detections}, "normalized": True}
if len(points) == 1:
x, y, lbl = points[0]
return {
"kind": "keypoint",
"payload": {"label": lbl, "point_format": "xy", "point": [x, y]},
"normalized": True,
}
if points: # several bare points → treat as detections-as-points
detections = [
{"label": lbl, "bbox_format": "xyxy", "bbox": [x, y, x, y]} for (x, y, lbl) in points
]
return {"kind": "bbox", "payload": {"detections": detections}, "normalized": True}
return None
def parse_vqa_answer(answer: str) -> dict | None:
"""Parse a VQA answer string into ``{"kind", "payload"}``.
``kind`` is one of the ``VQA_ANSWER_SHAPES`` names (``bbox``,
``keypoint``, ``count``, ``attribute``, ``spatial``) or ``"unknown"``
when the JSON doesn't match any known shape. Returns ``None`` when
the answer is not parseable JSON / not a JSON object.
when the JSON doesn't match any known shape. PaliGemma ``<loc>``
spatial answers are detected first (PI052 trains them in that native
format). Returns ``None`` when the answer is neither ``<loc>`` text
nor a parseable JSON object.
"""
if not answer or not answer.strip():
return None
loc_parsed = parse_loc_answer(answer)
if loc_parsed is not None:
return loc_parsed
try:
payload = json.loads(answer)
except (ValueError, TypeError):
@@ -189,7 +256,9 @@ def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
"""Draw ``bbox`` / ``keypoint`` answers onto a copy of ``image``.
Non-spatial answers (``count`` / ``attribute`` / ``spatial`` /
``unknown``) are returned as an unmodified copy.
``unknown``) are returned as an unmodified copy. When ``parsed`` has
``normalized=True`` (PaliGemma ``<loc>`` answers) the [0, 1]
coordinates are scaled to the image's pixel size.
"""
from PIL import ImageDraw # noqa: PLC0415
@@ -197,6 +266,8 @@ def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
kind = parsed.get("kind")
payload = parsed.get("payload") or {}
draw = ImageDraw.Draw(img)
w, h = img.size
sx, sy = (w, h) if parsed.get("normalized") else (1, 1)
if kind == "bbox":
for det in payload.get("detections") or []:
@@ -209,6 +280,8 @@ def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
x1, y1, x2, y2 = (float(v) for v in box)
except (TypeError, ValueError):
continue
x1, x2 = x1 * sx, x2 * sx
y1, y2 = y1 * sy, y2 * sy
draw.rectangle([x1, y1, x2, y2], outline=_BBOX_COLOR, width=3)
label = str(det.get("label", "")).strip()
if label:
@@ -217,7 +290,7 @@ def draw_vqa_overlay(image: Any, parsed: dict) -> Any:
point = payload.get("point")
if isinstance(point, list | tuple) and len(point) == 2:
try:
x, y = float(point[0]), float(point[1])
x, y = float(point[0]) * sx, float(point[1]) * sy
except (TypeError, ValueError):
return img
r = 6