feat(depth): persist depth metadata + add reader helpers

This commit is contained in:
CarolinePascal
2026-04-26 00:11:38 +02:00
parent df1648c102
commit 2c796d3352
4 changed files with 313 additions and 8 deletions

View File

@@ -23,19 +23,144 @@ from __future__ import annotations
import functools
import logging
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal
import av
import numpy as np
import torch
from lerobot.datasets.depth_utils import (
DEFAULT_DEPTH_MAX,
DEFAULT_DEPTH_MIN,
DEFAULT_DEPTH_SHIFT,
DEFAULT_DEPTH_USE_LOG,
quantize_depth,
dequantize_depth,
)
if TYPE_CHECKING:
from lerobot.datasets.video_utils import VideoEncoderConfig
logger = logging.getLogger(__name__)
# Pixel formats supported by the depth encode/decode helpers below. Both are
# 16-bit-word formats that carry 12 significant bits per sample, matching the
# ``DEPTH_QMAX = 4095`` quantization range.
DEPTH_PIX_FMTS: tuple[str, ...] = ("yuv420p12le", "gray12le")
# Neutral chroma for 12-bit YUV (the midpoint of [0, 4095]). Filling the U/V
# planes with this value keeps the encoder from spending bits on chroma noise
# when only the Y plane carries information.
_NEUTRAL_CHROMA_12BIT: int = 2048
FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
def _write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_value: int | None = None) -> None:
"""Copy ``src`` into a uint16 plane respecting FFmpeg line padding."""
height, width = src.shape
stride_u16 = plane.line_size // np.dtype(np.uint16).itemsize
dst = np.frombuffer(plane, dtype=np.uint16).reshape(height, stride_u16)
if fill_value is not None:
dst.fill(fill_value)
dst[:, :width] = src
def encode_depth_frame_pyav(
depth: np.ndarray | torch.Tensor,
*,
pix_fmt: str = "yuv420p12le",
depth_min: float = DEFAULT_DEPTH_MIN,
depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG,
input_unit: Literal["auto", "m", "mm"] = "auto",
) -> av.VideoFrame:
"""Quantize depth and pack it into a 12-bit PyAV video frame.
Args:
depth: Depth frame to encode (H, W). Unit handling follows
:func:`lerobot.datasets.depth_utils.quantize_depth`.
pix_fmt: Target pixel format. Must be one of :data:`DEPTH_PIX_FMTS`.
depth_min, depth_max, shift, use_log, input_unit: Forwarded to
:func:`quantize_depth`.
Returns:
An :class:`av.VideoFrame` in ``pix_fmt`` with quantized depth in the
luminance plane.
"""
if pix_fmt not in DEPTH_PIX_FMTS:
raise ValueError(f"Unsupported depth pix_fmt={pix_fmt!r}; expected one of {DEPTH_PIX_FMTS}")
quantized_depth = quantize_depth(
depth,
depth_min=depth_min,
depth_max=depth_max,
shift=shift,
use_log=use_log,
input_unit=input_unit,
)
if quantized_depth.ndim != 2:
raise ValueError(f"depth must be a 2D frame; got shape {quantized_depth.shape}")
quantized_depth = np.ascontiguousarray(quantized_depth, dtype=np.uint16)
height, width = quantized_depth.shape
if pix_fmt == "gray12le":
frame = av.VideoFrame(width=width, height=height, format="gray12le")
_write_u16_plane(frame.planes[0], quantized_depth)
return frame
if height % 2 != 0 or width % 2 != 0:
raise ValueError("yuv420p12le requires even H and W")
frame = av.VideoFrame(width=width, height=height, format="yuv420p12le")
_write_u16_plane(frame.planes[0], quantized_depth)
neutral_chroma = np.full((height // 2, width // 2), _NEUTRAL_CHROMA_12BIT, dtype=np.uint16)
_write_u16_plane(frame.planes[1], neutral_chroma, fill_value=_NEUTRAL_CHROMA_12BIT)
_write_u16_plane(frame.planes[2], neutral_chroma, fill_value=_NEUTRAL_CHROMA_12BIT)
return frame
def decode_depth_frame_pyav(
frame: av.VideoFrame | list[av.VideoFrame],
*,
depth_min: float = DEFAULT_DEPTH_MIN,
depth_max: float = DEFAULT_DEPTH_MAX,
shift: float = DEFAULT_DEPTH_SHIFT,
use_log: bool = DEFAULT_DEPTH_USE_LOG,
return_quantized: bool = False,
output_unit: Literal["m", "mm"] = "m",
) -> np.ndarray:
"""Decode one or many depth video frames to quantized or metric depth.
Args:
frame: A single depth frame or a list of depth frames.
depth_min, depth_max, shift, use_log: Forwarded to
:func:`dequantize_depth`.
return_quantized: If ``True``, return raw 12-bit quanta as ``uint16``.
output_unit: Unit for dequantized output (``"m"`` or ``"mm"``).
Returns:
``(H, W)`` array for a single frame, or ``(N, H, W)`` for a list.
"""
frames = frame if isinstance(frame, list) else [frame]
quantized = np.stack([f.reformat(format="gray12le").to_ndarray() for f in frames]).astype(np.uint16, copy=False)
if return_quantized:
return quantized[0] if len(frames) == 1 else quantized
decoded = dequantize_depth(
quantized,
depth_min=depth_min,
depth_max=depth_max,
shift=shift,
use_log=use_log,
output_unit=output_unit,
)
return decoded[0] if len(frames) == 1 else decoded
@functools.cache
def get_codec(vcodec: str) -> av.codec.Codec | None:
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
@@ -46,7 +171,7 @@ def get_codec(vcodec: str) -> av.codec.Codec | None:
@functools.cache
def _get_codec_options_by_name(vcodec: str) -> dict[str, av.option.Option]:
def _get_codec_video_formats(vcodec: str) -> dict[str, av.option.Option]:
"""Private-option name → PyAV ``Option`` for *vcodec* (empty if unavailable)."""
codec = get_codec(vcodec)
if codec is None: