mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
fix(viz): anchor rerun DepthImage colormap to encoder depth range
This commit is contained in:
@@ -39,6 +39,7 @@ from .video_utils import (
|
||||
StreamingVideoEncoder,
|
||||
VideoEncoderConfig,
|
||||
get_safe_default_video_backend,
|
||||
seed_depth_feature_info,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -252,6 +253,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
seed_depth_feature_info(self.meta.features, self._depth_encoder_config)
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(self.meta.video_keys) > 0:
|
||||
streaming_enc = self._build_streaming_encoder(
|
||||
@@ -711,6 +713,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj._camera_encoder_config = camera_encoder_config
|
||||
obj._depth_encoder_config = depth_encoder_config
|
||||
obj._encoder_threads = encoder_threads
|
||||
seed_depth_feature_info(obj.meta.features, depth_encoder_config)
|
||||
|
||||
# Reader is lazily created on first access (write-only mode)
|
||||
obj.reader = None
|
||||
@@ -827,6 +830,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj._depth_encoder_config = depth_encoder_config
|
||||
obj._encoder_threads = encoder_threads
|
||||
obj.root = obj.meta.root
|
||||
seed_depth_feature_info(obj.meta.features, depth_encoder_config)
|
||||
|
||||
# Reader is lazily created on first access (write-only mode)
|
||||
obj.reader = None
|
||||
|
||||
@@ -1355,6 +1355,44 @@ _DEPTH_INFO_KEYS: tuple[str, ...] = (
|
||||
)
|
||||
|
||||
|
||||
def seed_depth_feature_info(
|
||||
features: dict[str, dict],
|
||||
depth_encoder_config: "DepthEncoderConfig | None",
|
||||
) -> None:
|
||||
"""Pre-populate per-feature ``video.<field>`` entries from *depth_encoder_config*.
|
||||
|
||||
``update_video_info`` only runs after the first episode video is encoded,
|
||||
so without this seeding step ``features[key]["info"]`` carries no
|
||||
quantization range until then. Consumers that read the dataset feature
|
||||
spec mid-recording (e.g. the rerun visualizer pinning the depth colormap
|
||||
to ``video.depth_min`` / ``video.depth_max``) would otherwise see no
|
||||
range during episode 1 and re-normalize per frame.
|
||||
|
||||
Stream-derived values written later by :func:`get_video_info` /
|
||||
``update_video_info`` win over these seeds (the merge is
|
||||
``{**existing, **stream_info}``), so callers can safely re-run this on
|
||||
a partially-populated info dict.
|
||||
|
||||
No-op when ``depth_encoder_config`` is ``None`` or no feature is flagged
|
||||
as a depth map.
|
||||
"""
|
||||
if depth_encoder_config is None:
|
||||
return
|
||||
encoder_fields = {
|
||||
f"video.{name}": value for name, value in asdict(depth_encoder_config).items()
|
||||
}
|
||||
for ft in features.values():
|
||||
if ft.get("dtype") != "video":
|
||||
continue
|
||||
info = ft.get("info") or {}
|
||||
if not info.get("video.is_depth_map", False):
|
||||
continue
|
||||
# Only fill fields not already set, so explicit user-provided info is preserved.
|
||||
for k, v in encoder_fields.items():
|
||||
info.setdefault(k, v)
|
||||
ft["info"] = info
|
||||
|
||||
|
||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||
return 1
|
||||
|
||||
@@ -63,8 +63,10 @@ def _is_scalar(x):
|
||||
)
|
||||
|
||||
|
||||
def _derive_depth_obs_keys(features: dict[str, dict] | None) -> set[str]:
|
||||
"""Derive the set of observation keys that correspond to depth-map features.
|
||||
def _derive_depth_obs_ranges(
|
||||
features: dict[str, dict] | None,
|
||||
) -> dict[str, tuple[float, float] | None]:
|
||||
"""Map observation keys of depth features to their ``(depth_min, depth_max)`` range.
|
||||
|
||||
A feature is considered a depth map when its ``info`` dict carries
|
||||
``video.is_depth_map=True`` (the marker set by ``hw_to_dataset_features``
|
||||
@@ -73,21 +75,37 @@ def _derive_depth_obs_keys(features: dict[str, dict] | None) -> set[str]:
|
||||
the corresponding raw observation key forms the robot is likely to emit
|
||||
(``front`` and ``front_depth``) so a single membership check covers all
|
||||
call sites.
|
||||
|
||||
The mapped value is the ``(depth_min, depth_max)`` range stored on the
|
||||
feature (matching the quantization range used at encoding time), or
|
||||
``None`` when the metadata doesn't expose a range — in which case the
|
||||
caller should let Rerun auto-normalize. Anchoring the colormap to a
|
||||
fixed range avoids per-frame re-normalization, which otherwise looks
|
||||
like flicker on near-static scenes.
|
||||
"""
|
||||
keys: set[str] = set()
|
||||
ranges: dict[str, tuple[float, float] | None] = {}
|
||||
if not features:
|
||||
return keys
|
||||
return ranges
|
||||
depth_prefix = f"{OBS_STR}.depth."
|
||||
for fk, fv in features.items():
|
||||
info = fv.get("info") if isinstance(fv, dict) else None
|
||||
if not isinstance(info, dict) or not info.get("video.is_depth_map", False):
|
||||
continue
|
||||
keys.add(fk)
|
||||
depth_min = info.get("video.depth_min")
|
||||
depth_max = info.get("video.depth_max")
|
||||
rng: tuple[float, float] | None = None
|
||||
if (
|
||||
isinstance(depth_min, (int, float))
|
||||
and isinstance(depth_max, (int, float))
|
||||
and depth_max > depth_min
|
||||
):
|
||||
rng = (float(depth_min), float(depth_max))
|
||||
ranges[fk] = rng
|
||||
if fk.startswith(depth_prefix):
|
||||
bare = fk[len(depth_prefix) :]
|
||||
keys.add(bare)
|
||||
keys.add(f"{bare}_depth")
|
||||
return keys
|
||||
ranges[bare] = rng
|
||||
ranges[f"{bare}_depth"] = rng
|
||||
return ranges
|
||||
|
||||
|
||||
def log_rerun_data(
|
||||
@@ -106,8 +124,11 @@ def log_rerun_data(
|
||||
from CHW to HWC format, (optionally) compressed to JPEG and logged as `rr.Image` or `rr.EncodedImage`.
|
||||
- 2D NumPy arrays whose key matches a depth feature in ``features`` (i.e. carrying
|
||||
``video.is_depth_map=True``) are logged as ``rr.DepthImage`` with the Viridis
|
||||
colormap and ``meter=1.0`` (depth values are expected in metric meters). Depth
|
||||
images are never JPEG-compressed regardless of ``compress_images``.
|
||||
colormap and ``meter=1.0`` (depth values are expected in metric meters). When
|
||||
the feature exposes ``video.depth_min`` / ``video.depth_max`` (the encoder
|
||||
quantization range, persisted in ``info.json``), the colormap is anchored to
|
||||
that range via ``depth_range`` to keep the visualization stable across frames.
|
||||
Depth images are never JPEG-compressed regardless of ``compress_images``.
|
||||
- 1D NumPy arrays are logged as a series of individual scalars, with each element indexed.
|
||||
- Other multi-dimensional arrays are flattened and logged as individual scalars.
|
||||
|
||||
@@ -125,7 +146,7 @@ def log_rerun_data(
|
||||
require_package("rerun-sdk", extra="viz", import_name="rerun")
|
||||
import rerun as rr
|
||||
|
||||
depth_obs_keys = _derive_depth_obs_keys(features)
|
||||
depth_obs_ranges = _derive_depth_obs_ranges(features)
|
||||
|
||||
if observation:
|
||||
for k, v in observation.items():
|
||||
@@ -137,18 +158,19 @@ def log_rerun_data(
|
||||
rr.log(key, rr.Scalars(float(v)))
|
||||
elif isinstance(v, np.ndarray):
|
||||
arr = v
|
||||
is_depth = bool(depth_obs_keys) and (k in depth_obs_keys or key in depth_obs_keys)
|
||||
is_depth = bool(depth_obs_ranges) and (k in depth_obs_ranges or key in depth_obs_ranges)
|
||||
if is_depth and arr.ndim == 2:
|
||||
# Viridis-colormapped DepthImage; never JPEG-compress (lossy on float metric depth).
|
||||
rr.log(
|
||||
key,
|
||||
rr.DepthImage(
|
||||
arr,
|
||||
meter=1.0,
|
||||
colormap=rr.components.Colormap.Viridis,
|
||||
),
|
||||
static=True,
|
||||
)
|
||||
# Anchor the colormap to the encoder range when available, so the
|
||||
# visualization doesn't flicker as per-frame min/max drift.
|
||||
depth_range = depth_obs_ranges.get(k) or depth_obs_ranges.get(key)
|
||||
depth_kwargs: dict = {
|
||||
"meter": 1.0,
|
||||
"colormap": rr.components.Colormap.Viridis,
|
||||
}
|
||||
if depth_range is not None:
|
||||
depth_kwargs["depth_range"] = depth_range
|
||||
rr.log(key, rr.DepthImage(arr, **depth_kwargs), static=True)
|
||||
continue
|
||||
# Convert CHW -> HWC when needed
|
||||
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
|
||||
|
||||
@@ -249,7 +249,11 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
|
||||
|
||||
|
||||
def test_log_rerun_data_routes_depth_to_depth_image(mock_rerun):
|
||||
"""A 2D float depth obs whose key matches a depth feature must use ``rr.DepthImage``."""
|
||||
"""A 2D float depth obs whose key matches a depth feature must use ``rr.DepthImage``.
|
||||
|
||||
Without ``video.depth_min``/``video.depth_max`` in the feature info, the
|
||||
visualizer should not pass a ``depth_range`` (Rerun then auto-normalizes).
|
||||
"""
|
||||
vu, calls = mock_rerun
|
||||
|
||||
features = {
|
||||
@@ -279,9 +283,59 @@ def test_log_rerun_data_routes_depth_to_depth_image(mock_rerun):
|
||||
assert depth.arr.shape == (10, 12)
|
||||
assert depth.meter == pytest.approx(1.0)
|
||||
assert depth.colormap == "viridis"
|
||||
# No range available -> Rerun should auto-normalize; we must not pass `depth_range`.
|
||||
assert "depth_range" not in depth.kwargs
|
||||
assert _kwargs_for(calls, "observation.front_depth").get("static", False) is True
|
||||
|
||||
|
||||
def test_log_rerun_data_depth_range_anchored_from_info(mock_rerun):
|
||||
"""When ``video.depth_min``/``depth_max`` are present, ``depth_range`` is forwarded."""
|
||||
vu, calls = mock_rerun
|
||||
|
||||
features = {
|
||||
"observation.depth.front": {
|
||||
"dtype": "video",
|
||||
"shape": (480, 640),
|
||||
"info": {
|
||||
"video.is_depth_map": True,
|
||||
"video.depth_min": 0.05,
|
||||
"video.depth_max": 4.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
obs = {"front_depth": np.full((10, 12), 0.5, dtype=np.float32)}
|
||||
|
||||
vu.log_rerun_data(observation=obs, features=features)
|
||||
|
||||
depth = _obj_for(calls, "observation.front_depth")
|
||||
assert type(depth).__name__ == "DummyDepthImage"
|
||||
assert depth.kwargs.get("depth_range") == (0.05, 4.0)
|
||||
|
||||
|
||||
def test_log_rerun_data_depth_range_ignored_when_invalid(mock_rerun):
|
||||
"""A degenerate range (max <= min, or non-numeric) must be discarded silently."""
|
||||
vu, calls = mock_rerun
|
||||
|
||||
features = {
|
||||
"observation.depth.front": {
|
||||
"dtype": "video",
|
||||
"shape": (480, 640),
|
||||
"info": {
|
||||
"video.is_depth_map": True,
|
||||
"video.depth_min": 1.0,
|
||||
"video.depth_max": 1.0, # degenerate
|
||||
},
|
||||
},
|
||||
}
|
||||
obs = {"front_depth": np.full((10, 12), 0.5, dtype=np.float32)}
|
||||
|
||||
vu.log_rerun_data(observation=obs, features=features)
|
||||
|
||||
depth = _obj_for(calls, "observation.front_depth")
|
||||
assert type(depth).__name__ == "DummyDepthImage"
|
||||
assert "depth_range" not in depth.kwargs
|
||||
|
||||
|
||||
def test_log_rerun_data_depth_skips_compression(mock_rerun):
|
||||
"""Depth frames must never be JPEG-compressed even when ``compress_images=True``."""
|
||||
vu, calls = mock_rerun
|
||||
|
||||
Reference in New Issue
Block a user