chore(smolvla2-runtime): tensor-level obs print for both inference paths

Helper that prints (once per provider lifetime) every
``observation.*`` tensor the policy is about to see, with its shape,
dtype, device, and per-channel min/max/mean/std. Wired into both the
dry-run dataset path and the live-robot path.

Now we can bisect train/inference mismatch *at the tensor level* —
if the same checkpoint produces coherent text on one path's tensors
and ``\n`` on the other's, and the printed tensor stats differ
materially, the bug is in the observation prep, not in the model or
the training distribution.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-12 18:19:18 +02:00
parent 4852b9f952
commit fcdae0ce8e

View File

@@ -263,6 +263,47 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
return p.parse_args(argv)
def _log_obs_tensors_once(label: str, obs: Any, flag: dict) -> None:
"""Print shape / dtype / per-channel stats of every observation tensor
going into the policy, exactly once per provider lifetime.
Used to bisect train/inference mismatches: if the dry-run path
and the robot path produce identifiably different tensors here
(e.g. one is batched twice, one has a different range, one is on
a different device), the LM head's collapse on the live robot is
a tensor-shape bug, not a distribution-shift problem. If the
tensors *do* match byte-for-byte and the head still collapses,
only then is the scene-content OOD hypothesis the right one.
"""
if flag.get("done") or not isinstance(obs, dict):
return
flag["done"] = True
import torch as _torch # noqa: PLC0415
for k, v in obs.items():
if not isinstance(k, str) or not k.startswith("observation."):
continue
if isinstance(v, _torch.Tensor):
try:
stats = (
f"min={float(v.min()):.4f} max={float(v.max()):.4f} "
f"mean={float(v.mean()):.4f} std={float(v.float().std()):.4f}"
)
except Exception: # noqa: BLE001
stats = "(stats unavailable)"
logger.warning(
"obs[%s] %-30s shape=%s dtype=%s device=%s %s",
label,
k,
tuple(v.shape),
v.dtype,
v.device,
stats,
)
else:
logger.warning("obs[%s] %-30s type=%s value=%r", label, k, type(v).__name__, v)
def _load_policy_and_preprocessor(
policy_path: str,
dataset_repo_id: str | None,
@@ -368,6 +409,7 @@ def _build_observation_provider(
)
state = {"cursor": max(0, min(start_frame, len(ds) - 1))}
_logged = {"done": False}
def _provider() -> dict | None:
idx = state["cursor"]
@@ -383,6 +425,8 @@ def _build_observation_provider(
if preprocessor is not None:
sample = preprocessor(sample)
_log_obs_tensors_once("dry-run", sample, _logged)
# Keep only observation keys; the runtime's text path will
# merge these with its own lang_tokens / lang_masks.
observation = {
@@ -649,6 +693,7 @@ def _build_robot_observation_provider(
# head's distribution at position 0 collapses to its dominant
# mode (a memorised ``\n``-only run in this checkpoint).
_resize_logged = {"done": False}
_obs_logged = {"done": False}
target_image_shapes: dict[str, tuple[int, int]] = {}
if ds_features:
for fkey, fmeta in ds_features.items():
@@ -770,6 +815,8 @@ def _build_robot_observation_provider(
return None
obs_tensors = processed if isinstance(processed, dict) else {}
_log_obs_tensors_once("robot", obs_tensors, _obs_logged)
observation = {
k: v
for k, v in obs_tensors.items()