diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index beaaeb576..a8fc24714 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -328,7 +328,10 @@ def record_loop( if display_data: log_rerun_data( - observation=obs_processed, action=action_values, compress_images=display_compressed_images + observation=obs_processed, + action=action_values, + compress_images=display_compressed_images, + features=dataset.features if dataset is not None else None, ) dt_s = time.perf_counter() - start_loop_t diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index d9d5bf6b5..70377bbef 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -63,10 +63,38 @@ def _is_scalar(x): ) +def _derive_depth_obs_keys(features: dict[str, dict] | None) -> set[str]: + """Derive the set of observation keys that correspond to depth-map features. + + A feature is considered a depth map when its ``info`` dict carries + ``video.is_depth_map=True`` (the marker set by ``hw_to_dataset_features`` + and persisted in ``info.json``). For each such feature, we record both + the fully-namespaced dataset key (e.g. ``observation.depth.front``) and + the corresponding raw observation key forms the robot is likely to emit + (``front`` and ``front_depth``) so a single membership check covers all + call sites. + """ + keys: set[str] = set() + if not features: + return keys + depth_prefix = f"{OBS_STR}.depth." + for fk, fv in features.items(): + info = fv.get("info") if isinstance(fv, dict) else None + if not isinstance(info, dict) or not info.get("video.is_depth_map", False): + continue + keys.add(fk) + if fk.startswith(depth_prefix): + bare = fk[len(depth_prefix) :] + keys.add(bare) + keys.add(f"{bare}_depth") + return keys + + def log_rerun_data( observation: RobotObservation | None = None, action: RobotAction | None = None, compress_images: bool = False, + features: dict[str, dict] | None = None, ) -> None: """ Logs observation and action data to Rerun for real-time visualization. @@ -76,6 +104,10 @@ def log_rerun_data( - Scalars values (floats, ints) are logged as `rr.Scalars`. - 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed from CHW to HWC format, (optionally) compressed to JPEG and logged as `rr.Image` or `rr.EncodedImage`. + - 2D NumPy arrays whose key matches a depth feature in ``features`` (i.e. carrying + ``video.is_depth_map=True``) are logged as ``rr.DepthImage`` with the Viridis + colormap and ``meter=1.0`` (depth values are expected in metric meters). Depth + images are never JPEG-compressed regardless of ``compress_images``. - 1D NumPy arrays are logged as a series of individual scalars, with each element indexed. - Other multi-dimensional arrays are flattened and logged as individual scalars. @@ -85,11 +117,16 @@ def log_rerun_data( observation: An optional dictionary containing observation data to log. action: An optional dictionary containing action data to log. compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality. + features: Optional dataset feature spec (e.g. ``LeRobotDataset.features``). When + provided, observation entries matching a depth-map feature are rendered with + ``rr.DepthImage`` instead of the generic ``rr.Image`` path. """ require_package("rerun-sdk", extra="viz", import_name="rerun") import rerun as rr + depth_obs_keys = _derive_depth_obs_keys(features) + if observation: for k, v in observation.items(): if v is None: @@ -100,6 +137,19 @@ def log_rerun_data( rr.log(key, rr.Scalars(float(v))) elif isinstance(v, np.ndarray): arr = v + is_depth = bool(depth_obs_keys) and (k in depth_obs_keys or key in depth_obs_keys) + if is_depth and arr.ndim == 2: + # Viridis-colormapped DepthImage; never JPEG-compress (lossy on float metric depth). + rr.log( + key, + rr.DepthImage( + arr, + meter=1.0, + colormap=rr.components.Colormap.Viridis, + ), + static=True, + ) + continue # Convert CHW -> HWC when needed if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4): arr = np.transpose(arr, (1, 2, 0)) diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py index 63ff76c77..ec2c6fc1d 100644 --- a/tests/utils/test_visualization_utils.py +++ b/tests/utils/test_visualization_utils.py @@ -43,18 +43,32 @@ def mock_rerun(monkeypatch): def __init__(self, arr): self.arr = arr + class DummyDepthImage: + def __init__(self, arr, meter=None, colormap=None, **kwargs): + self.arr = arr + self.meter = meter + self.colormap = colormap + self.kwargs = kwargs + def dummy_log(key, obj=None, **kwargs): # Accept either positional `obj` or keyword `entity` and record remaining kwargs. if obj is None and "entity" in kwargs: obj = kwargs.pop("entity") calls.append((key, obj, kwargs)) + class _Colormap: + Viridis = "viridis" + + dummy_components = SimpleNamespace(Colormap=_Colormap) + dummy_rr = SimpleNamespace( __name__="rerun", __package__="rerun", __spec__=SimpleNamespace(name="rerun", submodule_search_locations=None), Scalars=DummyScalar, Image=DummyImage, + DepthImage=DummyDepthImage, + components=dummy_components, log=dummy_log, init=lambda *a, **k: None, spawn=lambda *a, **k: None, @@ -232,3 +246,68 @@ def test_log_rerun_data_kwargs_only(mock_rerun): a = _obj_for(calls, "action.a") assert type(a).__name__ == "DummyScalar" assert a.value == pytest.approx(1.0) + + +def test_log_rerun_data_routes_depth_to_depth_image(mock_rerun): + """A 2D float depth obs whose key matches a depth feature must use ``rr.DepthImage``.""" + vu, calls = mock_rerun + + features = { + "observation.images.front": { + "dtype": "video", + "shape": (480, 640, 3), + "info": {"video.is_depth_map": False}, + }, + "observation.depth.front": { + "dtype": "video", + "shape": (480, 640), + "info": {"video.is_depth_map": True}, + }, + } + obs = { + "front": np.zeros((10, 12, 3), dtype=np.uint8), + "front_depth": np.full((10, 12), 0.7, dtype=np.float32), + } + + vu.log_rerun_data(observation=obs, features=features) + + rgb = _obj_for(calls, "observation.front") + assert type(rgb).__name__ == "DummyImage" + + depth = _obj_for(calls, "observation.front_depth") + assert type(depth).__name__ == "DummyDepthImage" + assert depth.arr.shape == (10, 12) + assert depth.meter == pytest.approx(1.0) + assert depth.colormap == "viridis" + assert _kwargs_for(calls, "observation.front_depth").get("static", False) is True + + +def test_log_rerun_data_depth_skips_compression(mock_rerun): + """Depth frames must never be JPEG-compressed even when ``compress_images=True``.""" + vu, calls = mock_rerun + + features = { + "observation.depth.front": { + "dtype": "video", + "shape": (8, 8), + "info": {"video.is_depth_map": True}, + }, + } + obs = {"front_depth": np.full((8, 8), 0.5, dtype=np.float32)} + + vu.log_rerun_data(observation=obs, features=features, compress_images=True) + + depth = _obj_for(calls, "observation.front_depth") + assert type(depth).__name__ == "DummyDepthImage" + + +def test_log_rerun_data_no_features_falls_back_to_image(mock_rerun): + """Without ``features``, a 2D array still goes through the generic Image path (no depth detection).""" + vu, calls = mock_rerun + + obs = {"front_depth": np.zeros((8, 8), dtype=np.float32)} + + vu.log_rerun_data(observation=obs) + + obj = _obj_for(calls, "observation.front_depth") + assert type(obj).__name__ == "DummyImage"