mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
chore(smolvla2-runtime): tensor-level obs print for both inference paths
Helper that prints (once per provider lifetime) every ``observation.*`` tensor the policy is about to see, with its shape, dtype, device, and per-channel min/max/mean/std. Wired into both the dry-run dataset path and the live-robot path. Now we can bisect train/inference mismatch *at the tensor level* — if the same checkpoint produces coherent text on one path's tensors and ``\n`` on the other's, and the printed tensor stats differ materially, the bug is in the observation prep, not in the model or the training distribution. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -263,6 +263,47 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||
return p.parse_args(argv)
|
||||
|
||||
|
||||
def _log_obs_tensors_once(label: str, obs: Any, flag: dict) -> None:
|
||||
"""Print shape / dtype / per-channel stats of every observation tensor
|
||||
going into the policy, exactly once per provider lifetime.
|
||||
|
||||
Used to bisect train/inference mismatches: if the dry-run path
|
||||
and the robot path produce identifiably different tensors here
|
||||
(e.g. one is batched twice, one has a different range, one is on
|
||||
a different device), the LM head's collapse on the live robot is
|
||||
a tensor-shape bug, not a distribution-shift problem. If the
|
||||
tensors *do* match byte-for-byte and the head still collapses,
|
||||
only then is the scene-content OOD hypothesis the right one.
|
||||
"""
|
||||
if flag.get("done") or not isinstance(obs, dict):
|
||||
return
|
||||
flag["done"] = True
|
||||
import torch as _torch # noqa: PLC0415
|
||||
|
||||
for k, v in obs.items():
|
||||
if not isinstance(k, str) or not k.startswith("observation."):
|
||||
continue
|
||||
if isinstance(v, _torch.Tensor):
|
||||
try:
|
||||
stats = (
|
||||
f"min={float(v.min()):.4f} max={float(v.max()):.4f} "
|
||||
f"mean={float(v.mean()):.4f} std={float(v.float().std()):.4f}"
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
stats = "(stats unavailable)"
|
||||
logger.warning(
|
||||
"obs[%s] %-30s shape=%s dtype=%s device=%s %s",
|
||||
label,
|
||||
k,
|
||||
tuple(v.shape),
|
||||
v.dtype,
|
||||
v.device,
|
||||
stats,
|
||||
)
|
||||
else:
|
||||
logger.warning("obs[%s] %-30s type=%s value=%r", label, k, type(v).__name__, v)
|
||||
|
||||
|
||||
def _load_policy_and_preprocessor(
|
||||
policy_path: str,
|
||||
dataset_repo_id: str | None,
|
||||
@@ -368,6 +409,7 @@ def _build_observation_provider(
|
||||
)
|
||||
|
||||
state = {"cursor": max(0, min(start_frame, len(ds) - 1))}
|
||||
_logged = {"done": False}
|
||||
|
||||
def _provider() -> dict | None:
|
||||
idx = state["cursor"]
|
||||
@@ -383,6 +425,8 @@ def _build_observation_provider(
|
||||
if preprocessor is not None:
|
||||
sample = preprocessor(sample)
|
||||
|
||||
_log_obs_tensors_once("dry-run", sample, _logged)
|
||||
|
||||
# Keep only observation keys; the runtime's text path will
|
||||
# merge these with its own lang_tokens / lang_masks.
|
||||
observation = {
|
||||
@@ -649,6 +693,7 @@ def _build_robot_observation_provider(
|
||||
# head's distribution at position 0 collapses to its dominant
|
||||
# mode (a memorised ``\n``-only run in this checkpoint).
|
||||
_resize_logged = {"done": False}
|
||||
_obs_logged = {"done": False}
|
||||
target_image_shapes: dict[str, tuple[int, int]] = {}
|
||||
if ds_features:
|
||||
for fkey, fmeta in ds_features.items():
|
||||
@@ -770,6 +815,8 @@ def _build_robot_observation_provider(
|
||||
return None
|
||||
obs_tensors = processed if isinstance(processed, dict) else {}
|
||||
|
||||
_log_obs_tensors_once("robot", obs_tensors, _obs_logged)
|
||||
|
||||
observation = {
|
||||
k: v
|
||||
for k, v in obs_tensors.items()
|
||||
|
||||
Reference in New Issue
Block a user