From 2c796d3352e4a2cfb1bdf08d120c82da2d5bd244 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Sun, 26 Apr 2026 00:11:38 +0200 Subject: [PATCH] feat(depth): persist depth metadata + add reader helpers --- src/lerobot/datasets/dataset_metadata.py | 14 +++ src/lerobot/datasets/pyav_utils.py | 129 +++++++++++++++++++- src/lerobot/datasets/video_utils.py | 148 ++++++++++++++++++++++- tests/datasets/test_dataset_metadata.py | 30 +++++ 4 files changed, 313 insertions(+), 8 deletions(-) diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py index 57b967ac5..f663bb847 100644 --- a/src/lerobot/datasets/dataset_metadata.py +++ b/src/lerobot/datasets/dataset_metadata.py @@ -313,6 +313,20 @@ class LeRobotDatasetMetadata: """Keys to access visual modalities stored as videos.""" return [key for key, ft in self.features.items() if ft["dtype"] == "video"] + @property + def depth_keys(self) -> list[str]: + """Keys to access depth-map modalities stored as videos. + + A depth video key is a feature whose ``info`` dict carries + ``"video.is_depth_map": True`` (set either at creation time by the user + or after the first encoded episode by :meth:`update_video_info`). + """ + return [ + key + for key, ft in self.features.items() + if ft["dtype"] == "video" and ft.get("info", {}).get("video.is_depth_map", False) + ] + @property def camera_keys(self) -> list[str]: """Keys to access visual modalities (regardless of their storage method).""" diff --git a/src/lerobot/datasets/pyav_utils.py b/src/lerobot/datasets/pyav_utils.py index c13c66b89..949f9b1d7 100644 --- a/src/lerobot/datasets/pyav_utils.py +++ b/src/lerobot/datasets/pyav_utils.py @@ -23,19 +23,144 @@ from __future__ import annotations import functools import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import av +import numpy as np +import torch + +from lerobot.datasets.depth_utils import ( + DEFAULT_DEPTH_MAX, + DEFAULT_DEPTH_MIN, + DEFAULT_DEPTH_SHIFT, + DEFAULT_DEPTH_USE_LOG, + quantize_depth, + dequantize_depth, +) if TYPE_CHECKING: from lerobot.datasets.video_utils import VideoEncoderConfig logger = logging.getLogger(__name__) +# Pixel formats supported by the depth encode/decode helpers below. Both are +# 16-bit-word formats that carry 12 significant bits per sample, matching the +# ``DEPTH_QMAX = 4095`` quantization range. +DEPTH_PIX_FMTS: tuple[str, ...] = ("yuv420p12le", "gray12le") + +# Neutral chroma for 12-bit YUV (the midpoint of [0, 4095]). Filling the U/V +# planes with this value keeps the encoder from spending bits on chroma noise +# when only the Y plane carries information. +_NEUTRAL_CHROMA_12BIT: int = 2048 + FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE") FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64") +def _write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_value: int | None = None) -> None: + """Copy ``src`` into a uint16 plane respecting FFmpeg line padding.""" + height, width = src.shape + stride_u16 = plane.line_size // np.dtype(np.uint16).itemsize + dst = np.frombuffer(plane, dtype=np.uint16).reshape(height, stride_u16) + if fill_value is not None: + dst.fill(fill_value) + dst[:, :width] = src + + +def encode_depth_frame_pyav( + depth: np.ndarray | torch.Tensor, + *, + pix_fmt: str = "yuv420p12le", + depth_min: float = DEFAULT_DEPTH_MIN, + depth_max: float = DEFAULT_DEPTH_MAX, + shift: float = DEFAULT_DEPTH_SHIFT, + use_log: bool = DEFAULT_DEPTH_USE_LOG, + input_unit: Literal["auto", "m", "mm"] = "auto", +) -> av.VideoFrame: + """Quantize depth and pack it into a 12-bit PyAV video frame. + + Args: + depth: Depth frame to encode (H, W). Unit handling follows + :func:`lerobot.datasets.depth_utils.quantize_depth`. + pix_fmt: Target pixel format. Must be one of :data:`DEPTH_PIX_FMTS`. + depth_min, depth_max, shift, use_log, input_unit: Forwarded to + :func:`quantize_depth`. + + Returns: + An :class:`av.VideoFrame` in ``pix_fmt`` with quantized depth in the + luminance plane. + """ + if pix_fmt not in DEPTH_PIX_FMTS: + raise ValueError(f"Unsupported depth pix_fmt={pix_fmt!r}; expected one of {DEPTH_PIX_FMTS}") + + quantized_depth = quantize_depth( + depth, + depth_min=depth_min, + depth_max=depth_max, + shift=shift, + use_log=use_log, + input_unit=input_unit, + ) + if quantized_depth.ndim != 2: + raise ValueError(f"depth must be a 2D frame; got shape {quantized_depth.shape}") + + quantized_depth = np.ascontiguousarray(quantized_depth, dtype=np.uint16) + height, width = quantized_depth.shape + + if pix_fmt == "gray12le": + frame = av.VideoFrame(width=width, height=height, format="gray12le") + _write_u16_plane(frame.planes[0], quantized_depth) + return frame + + if height % 2 != 0 or width % 2 != 0: + raise ValueError("yuv420p12le requires even H and W") + + frame = av.VideoFrame(width=width, height=height, format="yuv420p12le") + _write_u16_plane(frame.planes[0], quantized_depth) + neutral_chroma = np.full((height // 2, width // 2), _NEUTRAL_CHROMA_12BIT, dtype=np.uint16) + _write_u16_plane(frame.planes[1], neutral_chroma, fill_value=_NEUTRAL_CHROMA_12BIT) + _write_u16_plane(frame.planes[2], neutral_chroma, fill_value=_NEUTRAL_CHROMA_12BIT) + return frame + + +def decode_depth_frame_pyav( + frame: av.VideoFrame | list[av.VideoFrame], + *, + depth_min: float = DEFAULT_DEPTH_MIN, + depth_max: float = DEFAULT_DEPTH_MAX, + shift: float = DEFAULT_DEPTH_SHIFT, + use_log: bool = DEFAULT_DEPTH_USE_LOG, + return_quantized: bool = False, + output_unit: Literal["m", "mm"] = "m", +) -> np.ndarray: + """Decode one or many depth video frames to quantized or metric depth. + + Args: + frame: A single depth frame or a list of depth frames. + depth_min, depth_max, shift, use_log: Forwarded to + :func:`dequantize_depth`. + return_quantized: If ``True``, return raw 12-bit quanta as ``uint16``. + output_unit: Unit for dequantized output (``"m"`` or ``"mm"``). + + Returns: + ``(H, W)`` array for a single frame, or ``(N, H, W)`` for a list. + """ + frames = frame if isinstance(frame, list) else [frame] + quantized = np.stack([f.reformat(format="gray12le").to_ndarray() for f in frames]).astype(np.uint16, copy=False) + if return_quantized: + return quantized[0] if len(frames) == 1 else quantized + + decoded = dequantize_depth( + quantized, + depth_min=depth_min, + depth_max=depth_max, + shift=shift, + use_log=use_log, + output_unit=output_unit, + ) + return decoded[0] if len(frames) == 1 else decoded + + @functools.cache def get_codec(vcodec: str) -> av.codec.Codec | None: """PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable.""" @@ -46,7 +171,7 @@ def get_codec(vcodec: str) -> av.codec.Codec | None: @functools.cache -def _get_codec_options_by_name(vcodec: str) -> dict[str, av.option.Option]: +def _get_codec_video_formats(vcodec: str) -> dict[str, av.option.Option]: """Private-option name → PyAV ``Option`` for *vcodec* (empty if unavailable).""" codec = get_codec(vcodec) if codec is None: diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index da58c8aa8..0e930b5f3 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -41,6 +41,7 @@ from PIL import Image from lerobot.datasets.pyav_utils import ( check_video_encoder_config_pyav, detect_available_encoders_pyav, + decode_depth_frame, ) from lerobot.datasets.depth_utils import ( quantize_depth, @@ -103,6 +104,12 @@ class VideoEncoderConfig: video_backend: str = "pyav" extra_options: dict[str, Any] = field(default_factory=dict) + # Class-level marker persisted to ``info.json`` (via ``asdict``) so the + # reader can tell depth datasets from RGB ones without a separate dispatch + # path. ``init=False`` keeps it out of CLI/constructor surface; subclasses + # flip the default (see :class:`DepthEncoderConfig`). + is_depth_map: bool = field(default=False, init=False) + def __post_init__(self) -> None: self.resolve_vcodec() @@ -553,6 +560,121 @@ def decode_video_frames_torchcodec( return closest_frames +def decode_depth_frames( + video_path: Path | str, + timestamps: list[float], + tolerance_s: float, + *, + depth_min: float = DEFAULT_DEPTH_MIN, + depth_max: float = DEFAULT_DEPTH_MAX, + shift: float = DEFAULT_DEPTH_SHIFT, + use_log: bool = DEFAULT_DEPTH_USE_LOG, + return_quantized: bool = False, + log_loaded_timestamps: bool = False, +) -> torch.Tensor: + """Decode depth-map frames at the requested timestamps using PyAV. + + Mirrors the timestamp-tolerance / closest-frame contract of + :func:`decode_video_frames` but operates entirely through PyAV (the + ``torchvision`` and ``torchcodec`` backends don't currently round-trip + 12-bit pixel formats reliably). + + Each decoded frame is reformatted to ``gray12le`` so the same path + handles ``yuv420p12le`` (HEVC default) and ``gray12le`` (ffv1 archive) + sources transparently. + + Args: + video_path: Path to a depth video produced with a + :class:`DepthEncoderConfig`. + timestamps: Frame timestamps to retrieve, in seconds. + tolerance_s: Maximum allowed deviation between the queried and the + actually-decoded timestamps. + depth_min, depth_max, shift, use_log: Parameters used at quantization + time. Should match :func:`info_to_depth_kwargs` extracted from + ``info.json`` for the source dataset. + return_quantized: If ``True``, skip the dequantization step and + return raw 12-bit ``uint16`` quanta. + log_loaded_timestamps: Debug logging. + + Returns: + ``torch.Tensor`` of shape ``(N, H, W)``: + + * ``dtype=torch.float32`` (metric depth, default) + * ``dtype=torch.uint16`` when ``return_quantized=True``. + + Raises: + FrameTimestampError: If a query timestamp can't be matched within + *tolerance_s*, or if no frames are decoded. + """ + video_path_str = str(video_path) + first_ts = min(timestamps) + last_ts = max(timestamps) + + loaded_frames: list[np.ndarray] = [] + loaded_ts: list[float] = [] + + av.logging.set_level(av.logging.WARNING) + with av.open(video_path_str, "r") as container: + try: + stream = container.streams.video[0] + except IndexError as e: + raise FrameTimestampError(f"No video stream in {video_path_str}") from e + + # Seek to the keyframe at-or-before first_ts (PyAV doesn't do + # accurate seek, so we still iterate forward to the requested range). + seek_pts = int(first_ts / stream.time_base) + container.seek(seek_pts, stream=stream, any_frame=False, backward=True) + + for frame in container.decode(stream): + if frame.pts is None: + continue + current_ts = float(frame.pts * stream.time_base) + if log_loaded_timestamps: + logger.info(f"depth frame loaded at timestamp={current_ts:.4f}") + loaded_frames.append( + decode_depth_frame( + frame, + depth_min=depth_min, + depth_max=depth_max, + shift=shift, + use_log=use_log, + return_quantized=True, + ) + ) + loaded_ts.append(current_ts) + if current_ts >= last_ts: + break + + av.logging.restore_default_callback() + + if not loaded_frames: + raise FrameTimestampError( + f"No depth frames decoded from {video_path_str} for timestamps {timestamps}" + ) + + query_ts = torch.tensor(timestamps) + loaded_ts_t = torch.tensor(loaded_ts) + dist = torch.cdist(query_ts[:, None], loaded_ts_t[:, None], p=1) + min_, argmin_ = dist.min(1) + + is_within_tol = min_ < tolerance_s + if not is_within_tol.all(): + raise FrameTimestampError( + f"One or several query timestamps violate the tolerance " + f"({min_[~is_within_tol]} > {tolerance_s=})." + f"\nqueried timestamps: {query_ts}" + f"\nloaded timestamps: {loaded_ts_t}" + f"\nvideo: {video_path_str}" + ) + + closest = np.stack([loaded_frames[i] for i in argmin_]) # (N, H, W) uint16 + quantized = torch.from_numpy(closest) + + if return_quantized: + return quantized + return dequantize_depth(quantized, depth_min, depth_max, shift, use_log) + + def encode_video_frames( imgs_dir: Path | str, video_path: Path | str, @@ -1128,13 +1250,13 @@ def get_audio_info(video_path: Path | str) -> dict: def get_video_info( video_path: Path | str, - camera_encoder_config: "VideoEncoderConfig | None" = None, + video_encoder_config: "VideoEncoderConfig | None" = None, ) -> dict: """Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``. Args: video_path: Path to the encoded video file to probe. - camera_encoder_config: If provided, record the exact encoder settings used to encode this + video_encoder_config: If provided, record the exact encoder settings used to encode this video. Stream-derived values take precedence — encoder fields are only written for keys not already populated from the video file itself. """ @@ -1154,7 +1276,6 @@ def get_video_info( video_info["video.width"] = video_stream.width video_info["video.codec"] = video_stream.codec.canonical_name video_info["video.pix_fmt"] = video_stream.pix_fmt - video_info["video.is_depth_map"] = False # Calculate fps from r_frame_rate video_info["video.fps"] = int(video_stream.base_rate) @@ -1168,14 +1289,29 @@ def get_video_info( # Adding audio stream information video_info.update(**get_audio_info(video_path)) - # Add additional encoder configuration if provided - if camera_encoder_config is not None: - for field_name, field_value in asdict(camera_encoder_config).items(): + # Add additional encoder configuration if provided (no override of stream-derived values) + # Depth related fields flow naturally through this path. + if video_encoder_config is not None: + for field_name, field_value in asdict(video_encoder_config).items(): video_info.setdefault(f"video.{field_name}", field_value) + # Fallback case where no encoder config is provided or the video is not a depth map. + video_info.setdefault("video.is_depth_map", False) + return video_info +# ─── Depth metadata helpers (reader side) ──────────────────────────── + + +_DEPTH_INFO_KEYS: tuple[str, ...] = ( + "video.depth_min", + "video.depth_max", + "video.shift", + "video.use_log", +) + + def get_video_pixel_channels(pix_fmt: str) -> int: if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt: return 1 diff --git a/tests/datasets/test_dataset_metadata.py b/tests/datasets/test_dataset_metadata.py index 6c784c90b..ab67e6ce3 100644 --- a/tests/datasets/test_dataset_metadata.py +++ b/tests/datasets/test_dataset_metadata.py @@ -142,6 +142,36 @@ def test_create_without_videos_has_no_video_path(tmp_path): assert meta.video_keys == [] +def test_depth_keys_property_filters_by_marker(tmp_path): + """``depth_keys`` selects only video features carrying ``video.is_depth_map=True``.""" + features = { + **SIMPLE_FEATURES, + "observation.images.cam": { + "dtype": "video", + "shape": (64, 96, 3), + "names": ["height", "width", "channels"], + "info": None, + }, + "observation.depth.cam": { + "dtype": "video", + "shape": (64, 96), + "names": ["height", "width"], + "info": {"video.is_depth_map": True}, + }, + } + meta = LeRobotDatasetMetadata.create( + repo_id="test/depth_keys", fps=DEFAULT_FPS, features=features, root=tmp_path / "depth_keys" + ) + + assert set(meta.video_keys) == {"observation.images.cam", "observation.depth.cam"} + assert meta.depth_keys == ["observation.depth.cam"] + +def test_depth_keys_empty_when_no_marker(tmp_path): + meta = LeRobotDatasetMetadata.create( + repo_id="test/no_depth", fps=DEFAULT_FPS, features=VIDEO_FEATURES, root=tmp_path / "no_depth" + ) + assert meta.depth_keys == [] + def test_create_raises_on_existing_directory(tmp_path): """create() raises if root directory already exists.""" root = tmp_path / "existing"