From f43bf75f9b72b518d8fa81ea978fd41dbaf4cdd2 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Mon, 27 Apr 2026 20:15:40 +0200 Subject: [PATCH] fix(viz): anchor rerun DepthImage colormap to encoder depth range --- src/lerobot/datasets/lerobot_dataset.py | 4 ++ src/lerobot/datasets/video_utils.py | 38 ++++++++++++++ src/lerobot/utils/visualization_utils.py | 64 ++++++++++++++++-------- tests/utils/test_visualization_utils.py | 56 ++++++++++++++++++++- 4 files changed, 140 insertions(+), 22 deletions(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index cf6952e0a..76ff2b706 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -39,6 +39,7 @@ from .video_utils import ( StreamingVideoEncoder, VideoEncoderConfig, get_safe_default_video_backend, + seed_depth_feature_info, ) logger = logging.getLogger(__name__) @@ -252,6 +253,7 @@ class LeRobotDataset(torch.utils.data.Dataset): DeprecationWarning, stacklevel=2, ) + seed_depth_feature_info(self.meta.features, self._depth_encoder_config) streaming_enc = None if streaming_encoding and len(self.meta.video_keys) > 0: streaming_enc = self._build_streaming_encoder( @@ -711,6 +713,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj._camera_encoder_config = camera_encoder_config obj._depth_encoder_config = depth_encoder_config obj._encoder_threads = encoder_threads + seed_depth_feature_info(obj.meta.features, depth_encoder_config) # Reader is lazily created on first access (write-only mode) obj.reader = None @@ -827,6 +830,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj._depth_encoder_config = depth_encoder_config obj._encoder_threads = encoder_threads obj.root = obj.meta.root + seed_depth_feature_info(obj.meta.features, depth_encoder_config) # Reader is lazily created on first access (write-only mode) obj.reader = None diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index babecce1a..d5be59dd6 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -1355,6 +1355,44 @@ _DEPTH_INFO_KEYS: tuple[str, ...] = ( ) +def seed_depth_feature_info( + features: dict[str, dict], + depth_encoder_config: "DepthEncoderConfig | None", +) -> None: + """Pre-populate per-feature ``video.`` entries from *depth_encoder_config*. + + ``update_video_info`` only runs after the first episode video is encoded, + so without this seeding step ``features[key]["info"]`` carries no + quantization range until then. Consumers that read the dataset feature + spec mid-recording (e.g. the rerun visualizer pinning the depth colormap + to ``video.depth_min`` / ``video.depth_max``) would otherwise see no + range during episode 1 and re-normalize per frame. + + Stream-derived values written later by :func:`get_video_info` / + ``update_video_info`` win over these seeds (the merge is + ``{**existing, **stream_info}``), so callers can safely re-run this on + a partially-populated info dict. + + No-op when ``depth_encoder_config`` is ``None`` or no feature is flagged + as a depth map. + """ + if depth_encoder_config is None: + return + encoder_fields = { + f"video.{name}": value for name, value in asdict(depth_encoder_config).items() + } + for ft in features.values(): + if ft.get("dtype") != "video": + continue + info = ft.get("info") or {} + if not info.get("video.is_depth_map", False): + continue + # Only fill fields not already set, so explicit user-provided info is preserved. + for k, v in encoder_fields.items(): + info.setdefault(k, v) + ft["info"] = info + + def get_video_pixel_channels(pix_fmt: str) -> int: if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt: return 1 diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index 70377bbef..9695758e9 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -63,8 +63,10 @@ def _is_scalar(x): ) -def _derive_depth_obs_keys(features: dict[str, dict] | None) -> set[str]: - """Derive the set of observation keys that correspond to depth-map features. +def _derive_depth_obs_ranges( + features: dict[str, dict] | None, +) -> dict[str, tuple[float, float] | None]: + """Map observation keys of depth features to their ``(depth_min, depth_max)`` range. A feature is considered a depth map when its ``info`` dict carries ``video.is_depth_map=True`` (the marker set by ``hw_to_dataset_features`` @@ -73,21 +75,37 @@ def _derive_depth_obs_keys(features: dict[str, dict] | None) -> set[str]: the corresponding raw observation key forms the robot is likely to emit (``front`` and ``front_depth``) so a single membership check covers all call sites. + + The mapped value is the ``(depth_min, depth_max)`` range stored on the + feature (matching the quantization range used at encoding time), or + ``None`` when the metadata doesn't expose a range — in which case the + caller should let Rerun auto-normalize. Anchoring the colormap to a + fixed range avoids per-frame re-normalization, which otherwise looks + like flicker on near-static scenes. """ - keys: set[str] = set() + ranges: dict[str, tuple[float, float] | None] = {} if not features: - return keys + return ranges depth_prefix = f"{OBS_STR}.depth." for fk, fv in features.items(): info = fv.get("info") if isinstance(fv, dict) else None if not isinstance(info, dict) or not info.get("video.is_depth_map", False): continue - keys.add(fk) + depth_min = info.get("video.depth_min") + depth_max = info.get("video.depth_max") + rng: tuple[float, float] | None = None + if ( + isinstance(depth_min, (int, float)) + and isinstance(depth_max, (int, float)) + and depth_max > depth_min + ): + rng = (float(depth_min), float(depth_max)) + ranges[fk] = rng if fk.startswith(depth_prefix): bare = fk[len(depth_prefix) :] - keys.add(bare) - keys.add(f"{bare}_depth") - return keys + ranges[bare] = rng + ranges[f"{bare}_depth"] = rng + return ranges def log_rerun_data( @@ -106,8 +124,11 @@ def log_rerun_data( from CHW to HWC format, (optionally) compressed to JPEG and logged as `rr.Image` or `rr.EncodedImage`. - 2D NumPy arrays whose key matches a depth feature in ``features`` (i.e. carrying ``video.is_depth_map=True``) are logged as ``rr.DepthImage`` with the Viridis - colormap and ``meter=1.0`` (depth values are expected in metric meters). Depth - images are never JPEG-compressed regardless of ``compress_images``. + colormap and ``meter=1.0`` (depth values are expected in metric meters). When + the feature exposes ``video.depth_min`` / ``video.depth_max`` (the encoder + quantization range, persisted in ``info.json``), the colormap is anchored to + that range via ``depth_range`` to keep the visualization stable across frames. + Depth images are never JPEG-compressed regardless of ``compress_images``. - 1D NumPy arrays are logged as a series of individual scalars, with each element indexed. - Other multi-dimensional arrays are flattened and logged as individual scalars. @@ -125,7 +146,7 @@ def log_rerun_data( require_package("rerun-sdk", extra="viz", import_name="rerun") import rerun as rr - depth_obs_keys = _derive_depth_obs_keys(features) + depth_obs_ranges = _derive_depth_obs_ranges(features) if observation: for k, v in observation.items(): @@ -137,18 +158,19 @@ def log_rerun_data( rr.log(key, rr.Scalars(float(v))) elif isinstance(v, np.ndarray): arr = v - is_depth = bool(depth_obs_keys) and (k in depth_obs_keys or key in depth_obs_keys) + is_depth = bool(depth_obs_ranges) and (k in depth_obs_ranges or key in depth_obs_ranges) if is_depth and arr.ndim == 2: # Viridis-colormapped DepthImage; never JPEG-compress (lossy on float metric depth). - rr.log( - key, - rr.DepthImage( - arr, - meter=1.0, - colormap=rr.components.Colormap.Viridis, - ), - static=True, - ) + # Anchor the colormap to the encoder range when available, so the + # visualization doesn't flicker as per-frame min/max drift. + depth_range = depth_obs_ranges.get(k) or depth_obs_ranges.get(key) + depth_kwargs: dict = { + "meter": 1.0, + "colormap": rr.components.Colormap.Viridis, + } + if depth_range is not None: + depth_kwargs["depth_range"] = depth_range + rr.log(key, rr.DepthImage(arr, **depth_kwargs), static=True) continue # Convert CHW -> HWC when needed if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4): diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py index ec2c6fc1d..bf397651d 100644 --- a/tests/utils/test_visualization_utils.py +++ b/tests/utils/test_visualization_utils.py @@ -249,7 +249,11 @@ def test_log_rerun_data_kwargs_only(mock_rerun): def test_log_rerun_data_routes_depth_to_depth_image(mock_rerun): - """A 2D float depth obs whose key matches a depth feature must use ``rr.DepthImage``.""" + """A 2D float depth obs whose key matches a depth feature must use ``rr.DepthImage``. + + Without ``video.depth_min``/``video.depth_max`` in the feature info, the + visualizer should not pass a ``depth_range`` (Rerun then auto-normalizes). + """ vu, calls = mock_rerun features = { @@ -279,9 +283,59 @@ def test_log_rerun_data_routes_depth_to_depth_image(mock_rerun): assert depth.arr.shape == (10, 12) assert depth.meter == pytest.approx(1.0) assert depth.colormap == "viridis" + # No range available -> Rerun should auto-normalize; we must not pass `depth_range`. + assert "depth_range" not in depth.kwargs assert _kwargs_for(calls, "observation.front_depth").get("static", False) is True +def test_log_rerun_data_depth_range_anchored_from_info(mock_rerun): + """When ``video.depth_min``/``depth_max`` are present, ``depth_range`` is forwarded.""" + vu, calls = mock_rerun + + features = { + "observation.depth.front": { + "dtype": "video", + "shape": (480, 640), + "info": { + "video.is_depth_map": True, + "video.depth_min": 0.05, + "video.depth_max": 4.0, + }, + }, + } + obs = {"front_depth": np.full((10, 12), 0.5, dtype=np.float32)} + + vu.log_rerun_data(observation=obs, features=features) + + depth = _obj_for(calls, "observation.front_depth") + assert type(depth).__name__ == "DummyDepthImage" + assert depth.kwargs.get("depth_range") == (0.05, 4.0) + + +def test_log_rerun_data_depth_range_ignored_when_invalid(mock_rerun): + """A degenerate range (max <= min, or non-numeric) must be discarded silently.""" + vu, calls = mock_rerun + + features = { + "observation.depth.front": { + "dtype": "video", + "shape": (480, 640), + "info": { + "video.is_depth_map": True, + "video.depth_min": 1.0, + "video.depth_max": 1.0, # degenerate + }, + }, + } + obs = {"front_depth": np.full((10, 12), 0.5, dtype=np.float32)} + + vu.log_rerun_data(observation=obs, features=features) + + depth = _obj_for(calls, "observation.front_depth") + assert type(depth).__name__ == "DummyDepthImage" + assert "depth_range" not in depth.kwargs + + def test_log_rerun_data_depth_skips_compression(mock_rerun): """Depth frames must never be JPEG-compressed even when ``compress_images=True``.""" vu, calls = mock_rerun