feat(viz): render depth observations as rr.DepthImage in Viridis

log_rerun_data now accepts an optional `features` dict and uses the
`video.is_depth_map=True` info marker to detect depth observations.
Matching 2D arrays are logged as `rr.DepthImage(arr, meter=1.0,
colormap=rr.components.Colormap.Viridis)` and are never JPEG-compressed
(compression is lossy on float32 metric depth).

Detection covers both the namespaced dataset key
(e.g. `observation.depth.front`) and the raw observation keys the robot
emits (`front`, `front_depth`), so it works for both the typed
LeRobotDataset.features dict and the plain robot observation flow.

When `features` is None the previous behaviour is preserved (depth
arrays fall back to the generic `rr.Image` path), so non-depth
recordings and existing call sites are unaffected.

lerobot-record now forwards `dataset.features` so depth keys are picked
up automatically when `--display_data=true`.

Made-with: Cursor
This commit is contained in:
CarolinePascal
2026-04-27 19:35:56 +02:00
parent efad15f600
commit b540fa94a9
3 changed files with 133 additions and 1 deletions

View File

@@ -328,7 +328,10 @@ def record_loop(
if display_data:
log_rerun_data(
observation=obs_processed, action=action_values, compress_images=display_compressed_images
observation=obs_processed,
action=action_values,
compress_images=display_compressed_images,
features=dataset.features if dataset is not None else None,
)
dt_s = time.perf_counter() - start_loop_t

View File

@@ -63,10 +63,38 @@ def _is_scalar(x):
)
def _derive_depth_obs_keys(features: dict[str, dict] | None) -> set[str]:
"""Derive the set of observation keys that correspond to depth-map features.
A feature is considered a depth map when its ``info`` dict carries
``video.is_depth_map=True`` (the marker set by ``hw_to_dataset_features``
and persisted in ``info.json``). For each such feature, we record both
the fully-namespaced dataset key (e.g. ``observation.depth.front``) and
the corresponding raw observation key forms the robot is likely to emit
(``front`` and ``front_depth``) so a single membership check covers all
call sites.
"""
keys: set[str] = set()
if not features:
return keys
depth_prefix = f"{OBS_STR}.depth."
for fk, fv in features.items():
info = fv.get("info") if isinstance(fv, dict) else None
if not isinstance(info, dict) or not info.get("video.is_depth_map", False):
continue
keys.add(fk)
if fk.startswith(depth_prefix):
bare = fk[len(depth_prefix) :]
keys.add(bare)
keys.add(f"{bare}_depth")
return keys
def log_rerun_data(
observation: RobotObservation | None = None,
action: RobotAction | None = None,
compress_images: bool = False,
features: dict[str, dict] | None = None,
) -> None:
"""
Logs observation and action data to Rerun for real-time visualization.
@@ -76,6 +104,10 @@ def log_rerun_data(
- Scalars values (floats, ints) are logged as `rr.Scalars`.
- 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed
from CHW to HWC format, (optionally) compressed to JPEG and logged as `rr.Image` or `rr.EncodedImage`.
- 2D NumPy arrays whose key matches a depth feature in ``features`` (i.e. carrying
``video.is_depth_map=True``) are logged as ``rr.DepthImage`` with the Viridis
colormap and ``meter=1.0`` (depth values are expected in metric meters). Depth
images are never JPEG-compressed regardless of ``compress_images``.
- 1D NumPy arrays are logged as a series of individual scalars, with each element indexed.
- Other multi-dimensional arrays are flattened and logged as individual scalars.
@@ -85,11 +117,16 @@ def log_rerun_data(
observation: An optional dictionary containing observation data to log.
action: An optional dictionary containing action data to log.
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
features: Optional dataset feature spec (e.g. ``LeRobotDataset.features``). When
provided, observation entries matching a depth-map feature are rendered with
``rr.DepthImage`` instead of the generic ``rr.Image`` path.
"""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
depth_obs_keys = _derive_depth_obs_keys(features)
if observation:
for k, v in observation.items():
if v is None:
@@ -100,6 +137,19 @@ def log_rerun_data(
rr.log(key, rr.Scalars(float(v)))
elif isinstance(v, np.ndarray):
arr = v
is_depth = bool(depth_obs_keys) and (k in depth_obs_keys or key in depth_obs_keys)
if is_depth and arr.ndim == 2:
# Viridis-colormapped DepthImage; never JPEG-compress (lossy on float metric depth).
rr.log(
key,
rr.DepthImage(
arr,
meter=1.0,
colormap=rr.components.Colormap.Viridis,
),
static=True,
)
continue
# Convert CHW -> HWC when needed
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))

View File

@@ -43,18 +43,32 @@ def mock_rerun(monkeypatch):
def __init__(self, arr):
self.arr = arr
class DummyDepthImage:
def __init__(self, arr, meter=None, colormap=None, **kwargs):
self.arr = arr
self.meter = meter
self.colormap = colormap
self.kwargs = kwargs
def dummy_log(key, obj=None, **kwargs):
# Accept either positional `obj` or keyword `entity` and record remaining kwargs.
if obj is None and "entity" in kwargs:
obj = kwargs.pop("entity")
calls.append((key, obj, kwargs))
class _Colormap:
Viridis = "viridis"
dummy_components = SimpleNamespace(Colormap=_Colormap)
dummy_rr = SimpleNamespace(
__name__="rerun",
__package__="rerun",
__spec__=SimpleNamespace(name="rerun", submodule_search_locations=None),
Scalars=DummyScalar,
Image=DummyImage,
DepthImage=DummyDepthImage,
components=dummy_components,
log=dummy_log,
init=lambda *a, **k: None,
spawn=lambda *a, **k: None,
@@ -232,3 +246,68 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
a = _obj_for(calls, "action.a")
assert type(a).__name__ == "DummyScalar"
assert a.value == pytest.approx(1.0)
def test_log_rerun_data_routes_depth_to_depth_image(mock_rerun):
"""A 2D float depth obs whose key matches a depth feature must use ``rr.DepthImage``."""
vu, calls = mock_rerun
features = {
"observation.images.front": {
"dtype": "video",
"shape": (480, 640, 3),
"info": {"video.is_depth_map": False},
},
"observation.depth.front": {
"dtype": "video",
"shape": (480, 640),
"info": {"video.is_depth_map": True},
},
}
obs = {
"front": np.zeros((10, 12, 3), dtype=np.uint8),
"front_depth": np.full((10, 12), 0.7, dtype=np.float32),
}
vu.log_rerun_data(observation=obs, features=features)
rgb = _obj_for(calls, "observation.front")
assert type(rgb).__name__ == "DummyImage"
depth = _obj_for(calls, "observation.front_depth")
assert type(depth).__name__ == "DummyDepthImage"
assert depth.arr.shape == (10, 12)
assert depth.meter == pytest.approx(1.0)
assert depth.colormap == "viridis"
assert _kwargs_for(calls, "observation.front_depth").get("static", False) is True
def test_log_rerun_data_depth_skips_compression(mock_rerun):
"""Depth frames must never be JPEG-compressed even when ``compress_images=True``."""
vu, calls = mock_rerun
features = {
"observation.depth.front": {
"dtype": "video",
"shape": (8, 8),
"info": {"video.is_depth_map": True},
},
}
obs = {"front_depth": np.full((8, 8), 0.5, dtype=np.float32)}
vu.log_rerun_data(observation=obs, features=features, compress_images=True)
depth = _obj_for(calls, "observation.front_depth")
assert type(depth).__name__ == "DummyDepthImage"
def test_log_rerun_data_no_features_falls_back_to_image(mock_rerun):
"""Without ``features``, a 2D array still goes through the generic Image path (no depth detection)."""
vu, calls = mock_rerun
obs = {"front_depth": np.zeros((8, 8), dtype=np.float32)}
vu.log_rerun_data(observation=obs)
obj = _obj_for(calls, "observation.front_depth")
assert type(obj).__name__ == "DummyImage"