fix(viz): anchor rerun DepthImage colormap to encoder depth range

This commit is contained in:
CarolinePascal
2026-04-27 20:15:40 +02:00
parent b540fa94a9
commit f43bf75f9b
4 changed files with 140 additions and 22 deletions

View File

@@ -39,6 +39,7 @@ from .video_utils import (
StreamingVideoEncoder,
VideoEncoderConfig,
get_safe_default_video_backend,
seed_depth_feature_info,
)
logger = logging.getLogger(__name__)
@@ -252,6 +253,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
DeprecationWarning,
stacklevel=2,
)
seed_depth_feature_info(self.meta.features, self._depth_encoder_config)
streaming_enc = None
if streaming_encoding and len(self.meta.video_keys) > 0:
streaming_enc = self._build_streaming_encoder(
@@ -711,6 +713,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj._camera_encoder_config = camera_encoder_config
obj._depth_encoder_config = depth_encoder_config
obj._encoder_threads = encoder_threads
seed_depth_feature_info(obj.meta.features, depth_encoder_config)
# Reader is lazily created on first access (write-only mode)
obj.reader = None
@@ -827,6 +830,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj._depth_encoder_config = depth_encoder_config
obj._encoder_threads = encoder_threads
obj.root = obj.meta.root
seed_depth_feature_info(obj.meta.features, depth_encoder_config)
# Reader is lazily created on first access (write-only mode)
obj.reader = None

View File

@@ -1355,6 +1355,44 @@ _DEPTH_INFO_KEYS: tuple[str, ...] = (
)
def seed_depth_feature_info(
features: dict[str, dict],
depth_encoder_config: "DepthEncoderConfig | None",
) -> None:
"""Pre-populate per-feature ``video.<field>`` entries from *depth_encoder_config*.
``update_video_info`` only runs after the first episode video is encoded,
so without this seeding step ``features[key]["info"]`` carries no
quantization range until then. Consumers that read the dataset feature
spec mid-recording (e.g. the rerun visualizer pinning the depth colormap
to ``video.depth_min`` / ``video.depth_max``) would otherwise see no
range during episode 1 and re-normalize per frame.
Stream-derived values written later by :func:`get_video_info` /
``update_video_info`` win over these seeds (the merge is
``{**existing, **stream_info}``), so callers can safely re-run this on
a partially-populated info dict.
No-op when ``depth_encoder_config`` is ``None`` or no feature is flagged
as a depth map.
"""
if depth_encoder_config is None:
return
encoder_fields = {
f"video.{name}": value for name, value in asdict(depth_encoder_config).items()
}
for ft in features.values():
if ft.get("dtype") != "video":
continue
info = ft.get("info") or {}
if not info.get("video.is_depth_map", False):
continue
# Only fill fields not already set, so explicit user-provided info is preserved.
for k, v in encoder_fields.items():
info.setdefault(k, v)
ft["info"] = info
def get_video_pixel_channels(pix_fmt: str) -> int:
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
return 1

View File

@@ -63,8 +63,10 @@ def _is_scalar(x):
)
def _derive_depth_obs_keys(features: dict[str, dict] | None) -> set[str]:
"""Derive the set of observation keys that correspond to depth-map features.
def _derive_depth_obs_ranges(
features: dict[str, dict] | None,
) -> dict[str, tuple[float, float] | None]:
"""Map observation keys of depth features to their ``(depth_min, depth_max)`` range.
A feature is considered a depth map when its ``info`` dict carries
``video.is_depth_map=True`` (the marker set by ``hw_to_dataset_features``
@@ -73,21 +75,37 @@ def _derive_depth_obs_keys(features: dict[str, dict] | None) -> set[str]:
the corresponding raw observation key forms the robot is likely to emit
(``front`` and ``front_depth``) so a single membership check covers all
call sites.
The mapped value is the ``(depth_min, depth_max)`` range stored on the
feature (matching the quantization range used at encoding time), or
``None`` when the metadata doesn't expose a range — in which case the
caller should let Rerun auto-normalize. Anchoring the colormap to a
fixed range avoids per-frame re-normalization, which otherwise looks
like flicker on near-static scenes.
"""
keys: set[str] = set()
ranges: dict[str, tuple[float, float] | None] = {}
if not features:
return keys
return ranges
depth_prefix = f"{OBS_STR}.depth."
for fk, fv in features.items():
info = fv.get("info") if isinstance(fv, dict) else None
if not isinstance(info, dict) or not info.get("video.is_depth_map", False):
continue
keys.add(fk)
depth_min = info.get("video.depth_min")
depth_max = info.get("video.depth_max")
rng: tuple[float, float] | None = None
if (
isinstance(depth_min, (int, float))
and isinstance(depth_max, (int, float))
and depth_max > depth_min
):
rng = (float(depth_min), float(depth_max))
ranges[fk] = rng
if fk.startswith(depth_prefix):
bare = fk[len(depth_prefix) :]
keys.add(bare)
keys.add(f"{bare}_depth")
return keys
ranges[bare] = rng
ranges[f"{bare}_depth"] = rng
return ranges
def log_rerun_data(
@@ -106,8 +124,11 @@ def log_rerun_data(
from CHW to HWC format, (optionally) compressed to JPEG and logged as `rr.Image` or `rr.EncodedImage`.
- 2D NumPy arrays whose key matches a depth feature in ``features`` (i.e. carrying
``video.is_depth_map=True``) are logged as ``rr.DepthImage`` with the Viridis
colormap and ``meter=1.0`` (depth values are expected in metric meters). Depth
images are never JPEG-compressed regardless of ``compress_images``.
colormap and ``meter=1.0`` (depth values are expected in metric meters). When
the feature exposes ``video.depth_min`` / ``video.depth_max`` (the encoder
quantization range, persisted in ``info.json``), the colormap is anchored to
that range via ``depth_range`` to keep the visualization stable across frames.
Depth images are never JPEG-compressed regardless of ``compress_images``.
- 1D NumPy arrays are logged as a series of individual scalars, with each element indexed.
- Other multi-dimensional arrays are flattened and logged as individual scalars.
@@ -125,7 +146,7 @@ def log_rerun_data(
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
depth_obs_keys = _derive_depth_obs_keys(features)
depth_obs_ranges = _derive_depth_obs_ranges(features)
if observation:
for k, v in observation.items():
@@ -137,18 +158,19 @@ def log_rerun_data(
rr.log(key, rr.Scalars(float(v)))
elif isinstance(v, np.ndarray):
arr = v
is_depth = bool(depth_obs_keys) and (k in depth_obs_keys or key in depth_obs_keys)
is_depth = bool(depth_obs_ranges) and (k in depth_obs_ranges or key in depth_obs_ranges)
if is_depth and arr.ndim == 2:
# Viridis-colormapped DepthImage; never JPEG-compress (lossy on float metric depth).
rr.log(
key,
rr.DepthImage(
arr,
meter=1.0,
colormap=rr.components.Colormap.Viridis,
),
static=True,
)
# Anchor the colormap to the encoder range when available, so the
# visualization doesn't flicker as per-frame min/max drift.
depth_range = depth_obs_ranges.get(k) or depth_obs_ranges.get(key)
depth_kwargs: dict = {
"meter": 1.0,
"colormap": rr.components.Colormap.Viridis,
}
if depth_range is not None:
depth_kwargs["depth_range"] = depth_range
rr.log(key, rr.DepthImage(arr, **depth_kwargs), static=True)
continue
# Convert CHW -> HWC when needed
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):

View File

@@ -249,7 +249,11 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
def test_log_rerun_data_routes_depth_to_depth_image(mock_rerun):
"""A 2D float depth obs whose key matches a depth feature must use ``rr.DepthImage``."""
"""A 2D float depth obs whose key matches a depth feature must use ``rr.DepthImage``.
Without ``video.depth_min``/``video.depth_max`` in the feature info, the
visualizer should not pass a ``depth_range`` (Rerun then auto-normalizes).
"""
vu, calls = mock_rerun
features = {
@@ -279,9 +283,59 @@ def test_log_rerun_data_routes_depth_to_depth_image(mock_rerun):
assert depth.arr.shape == (10, 12)
assert depth.meter == pytest.approx(1.0)
assert depth.colormap == "viridis"
# No range available -> Rerun should auto-normalize; we must not pass `depth_range`.
assert "depth_range" not in depth.kwargs
assert _kwargs_for(calls, "observation.front_depth").get("static", False) is True
def test_log_rerun_data_depth_range_anchored_from_info(mock_rerun):
"""When ``video.depth_min``/``depth_max`` are present, ``depth_range`` is forwarded."""
vu, calls = mock_rerun
features = {
"observation.depth.front": {
"dtype": "video",
"shape": (480, 640),
"info": {
"video.is_depth_map": True,
"video.depth_min": 0.05,
"video.depth_max": 4.0,
},
},
}
obs = {"front_depth": np.full((10, 12), 0.5, dtype=np.float32)}
vu.log_rerun_data(observation=obs, features=features)
depth = _obj_for(calls, "observation.front_depth")
assert type(depth).__name__ == "DummyDepthImage"
assert depth.kwargs.get("depth_range") == (0.05, 4.0)
def test_log_rerun_data_depth_range_ignored_when_invalid(mock_rerun):
"""A degenerate range (max <= min, or non-numeric) must be discarded silently."""
vu, calls = mock_rerun
features = {
"observation.depth.front": {
"dtype": "video",
"shape": (480, 640),
"info": {
"video.is_depth_map": True,
"video.depth_min": 1.0,
"video.depth_max": 1.0, # degenerate
},
},
}
obs = {"front_depth": np.full((10, 12), 0.5, dtype=np.float32)}
vu.log_rerun_data(observation=obs, features=features)
depth = _obj_for(calls, "observation.front_depth")
assert type(depth).__name__ == "DummyDepthImage"
assert "depth_range" not in depth.kwargs
def test_log_rerun_data_depth_skips_compression(mock_rerun):
"""Depth frames must never be JPEG-compressed even when ``compress_images=True``."""
vu, calls = mock_rerun