mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
feat(depth): persist depth metadata + add reader helpers
This commit is contained in:
@@ -313,6 +313,20 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
||||
|
||||
@property
|
||||
def depth_keys(self) -> list[str]:
|
||||
"""Keys to access depth-map modalities stored as videos.
|
||||
|
||||
A depth video key is a feature whose ``info`` dict carries
|
||||
``"video.is_depth_map": True`` (set either at creation time by the user
|
||||
or after the first encoded episode by :meth:`update_video_info`).
|
||||
"""
|
||||
return [
|
||||
key
|
||||
for key, ft in self.features.items()
|
||||
if ft["dtype"] == "video" and ft.get("info", {}).get("video.is_depth_map", False)
|
||||
]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
|
||||
@@ -23,19 +23,144 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.depth_utils import (
|
||||
DEFAULT_DEPTH_MAX,
|
||||
DEFAULT_DEPTH_MIN,
|
||||
DEFAULT_DEPTH_SHIFT,
|
||||
DEFAULT_DEPTH_USE_LOG,
|
||||
quantize_depth,
|
||||
dequantize_depth,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.datasets.video_utils import VideoEncoderConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pixel formats supported by the depth encode/decode helpers below. Both are
|
||||
# 16-bit-word formats that carry 12 significant bits per sample, matching the
|
||||
# ``DEPTH_QMAX = 4095`` quantization range.
|
||||
DEPTH_PIX_FMTS: tuple[str, ...] = ("yuv420p12le", "gray12le")
|
||||
|
||||
# Neutral chroma for 12-bit YUV (the midpoint of [0, 4095]). Filling the U/V
|
||||
# planes with this value keeps the encoder from spending bits on chroma noise
|
||||
# when only the Y plane carries information.
|
||||
_NEUTRAL_CHROMA_12BIT: int = 2048
|
||||
|
||||
FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
|
||||
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
|
||||
|
||||
|
||||
def _write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_value: int | None = None) -> None:
|
||||
"""Copy ``src`` into a uint16 plane respecting FFmpeg line padding."""
|
||||
height, width = src.shape
|
||||
stride_u16 = plane.line_size // np.dtype(np.uint16).itemsize
|
||||
dst = np.frombuffer(plane, dtype=np.uint16).reshape(height, stride_u16)
|
||||
if fill_value is not None:
|
||||
dst.fill(fill_value)
|
||||
dst[:, :width] = src
|
||||
|
||||
|
||||
def encode_depth_frame_pyav(
|
||||
depth: np.ndarray | torch.Tensor,
|
||||
*,
|
||||
pix_fmt: str = "yuv420p12le",
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
input_unit: Literal["auto", "m", "mm"] = "auto",
|
||||
) -> av.VideoFrame:
|
||||
"""Quantize depth and pack it into a 12-bit PyAV video frame.
|
||||
|
||||
Args:
|
||||
depth: Depth frame to encode (H, W). Unit handling follows
|
||||
:func:`lerobot.datasets.depth_utils.quantize_depth`.
|
||||
pix_fmt: Target pixel format. Must be one of :data:`DEPTH_PIX_FMTS`.
|
||||
depth_min, depth_max, shift, use_log, input_unit: Forwarded to
|
||||
:func:`quantize_depth`.
|
||||
|
||||
Returns:
|
||||
An :class:`av.VideoFrame` in ``pix_fmt`` with quantized depth in the
|
||||
luminance plane.
|
||||
"""
|
||||
if pix_fmt not in DEPTH_PIX_FMTS:
|
||||
raise ValueError(f"Unsupported depth pix_fmt={pix_fmt!r}; expected one of {DEPTH_PIX_FMTS}")
|
||||
|
||||
quantized_depth = quantize_depth(
|
||||
depth,
|
||||
depth_min=depth_min,
|
||||
depth_max=depth_max,
|
||||
shift=shift,
|
||||
use_log=use_log,
|
||||
input_unit=input_unit,
|
||||
)
|
||||
if quantized_depth.ndim != 2:
|
||||
raise ValueError(f"depth must be a 2D frame; got shape {quantized_depth.shape}")
|
||||
|
||||
quantized_depth = np.ascontiguousarray(quantized_depth, dtype=np.uint16)
|
||||
height, width = quantized_depth.shape
|
||||
|
||||
if pix_fmt == "gray12le":
|
||||
frame = av.VideoFrame(width=width, height=height, format="gray12le")
|
||||
_write_u16_plane(frame.planes[0], quantized_depth)
|
||||
return frame
|
||||
|
||||
if height % 2 != 0 or width % 2 != 0:
|
||||
raise ValueError("yuv420p12le requires even H and W")
|
||||
|
||||
frame = av.VideoFrame(width=width, height=height, format="yuv420p12le")
|
||||
_write_u16_plane(frame.planes[0], quantized_depth)
|
||||
neutral_chroma = np.full((height // 2, width // 2), _NEUTRAL_CHROMA_12BIT, dtype=np.uint16)
|
||||
_write_u16_plane(frame.planes[1], neutral_chroma, fill_value=_NEUTRAL_CHROMA_12BIT)
|
||||
_write_u16_plane(frame.planes[2], neutral_chroma, fill_value=_NEUTRAL_CHROMA_12BIT)
|
||||
return frame
|
||||
|
||||
|
||||
def decode_depth_frame_pyav(
|
||||
frame: av.VideoFrame | list[av.VideoFrame],
|
||||
*,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
return_quantized: bool = False,
|
||||
output_unit: Literal["m", "mm"] = "m",
|
||||
) -> np.ndarray:
|
||||
"""Decode one or many depth video frames to quantized or metric depth.
|
||||
|
||||
Args:
|
||||
frame: A single depth frame or a list of depth frames.
|
||||
depth_min, depth_max, shift, use_log: Forwarded to
|
||||
:func:`dequantize_depth`.
|
||||
return_quantized: If ``True``, return raw 12-bit quanta as ``uint16``.
|
||||
output_unit: Unit for dequantized output (``"m"`` or ``"mm"``).
|
||||
|
||||
Returns:
|
||||
``(H, W)`` array for a single frame, or ``(N, H, W)`` for a list.
|
||||
"""
|
||||
frames = frame if isinstance(frame, list) else [frame]
|
||||
quantized = np.stack([f.reformat(format="gray12le").to_ndarray() for f in frames]).astype(np.uint16, copy=False)
|
||||
if return_quantized:
|
||||
return quantized[0] if len(frames) == 1 else quantized
|
||||
|
||||
decoded = dequantize_depth(
|
||||
quantized,
|
||||
depth_min=depth_min,
|
||||
depth_max=depth_max,
|
||||
shift=shift,
|
||||
use_log=use_log,
|
||||
output_unit=output_unit,
|
||||
)
|
||||
return decoded[0] if len(frames) == 1 else decoded
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_codec(vcodec: str) -> av.codec.Codec | None:
|
||||
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
|
||||
@@ -46,7 +171,7 @@ def get_codec(vcodec: str) -> av.codec.Codec | None:
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_codec_options_by_name(vcodec: str) -> dict[str, av.option.Option]:
|
||||
def _get_codec_video_formats(vcodec: str) -> dict[str, av.option.Option]:
|
||||
"""Private-option name → PyAV ``Option`` for *vcodec* (empty if unavailable)."""
|
||||
codec = get_codec(vcodec)
|
||||
if codec is None:
|
||||
|
||||
@@ -41,6 +41,7 @@ from PIL import Image
|
||||
from lerobot.datasets.pyav_utils import (
|
||||
check_video_encoder_config_pyav,
|
||||
detect_available_encoders_pyav,
|
||||
decode_depth_frame,
|
||||
)
|
||||
from lerobot.datasets.depth_utils import (
|
||||
quantize_depth,
|
||||
@@ -103,6 +104,12 @@ class VideoEncoderConfig:
|
||||
video_backend: str = "pyav"
|
||||
extra_options: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Class-level marker persisted to ``info.json`` (via ``asdict``) so the
|
||||
# reader can tell depth datasets from RGB ones without a separate dispatch
|
||||
# path. ``init=False`` keeps it out of CLI/constructor surface; subclasses
|
||||
# flip the default (see :class:`DepthEncoderConfig`).
|
||||
is_depth_map: bool = field(default=False, init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.resolve_vcodec()
|
||||
|
||||
@@ -553,6 +560,121 @@ def decode_video_frames_torchcodec(
|
||||
return closest_frames
|
||||
|
||||
|
||||
def decode_depth_frames(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
*,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
return_quantized: bool = False,
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Decode depth-map frames at the requested timestamps using PyAV.
|
||||
|
||||
Mirrors the timestamp-tolerance / closest-frame contract of
|
||||
:func:`decode_video_frames` but operates entirely through PyAV (the
|
||||
``torchvision`` and ``torchcodec`` backends don't currently round-trip
|
||||
12-bit pixel formats reliably).
|
||||
|
||||
Each decoded frame is reformatted to ``gray12le`` so the same path
|
||||
handles ``yuv420p12le`` (HEVC default) and ``gray12le`` (ffv1 archive)
|
||||
sources transparently.
|
||||
|
||||
Args:
|
||||
video_path: Path to a depth video produced with a
|
||||
:class:`DepthEncoderConfig`.
|
||||
timestamps: Frame timestamps to retrieve, in seconds.
|
||||
tolerance_s: Maximum allowed deviation between the queried and the
|
||||
actually-decoded timestamps.
|
||||
depth_min, depth_max, shift, use_log: Parameters used at quantization
|
||||
time. Should match :func:`info_to_depth_kwargs` extracted from
|
||||
``info.json`` for the source dataset.
|
||||
return_quantized: If ``True``, skip the dequantization step and
|
||||
return raw 12-bit ``uint16`` quanta.
|
||||
log_loaded_timestamps: Debug logging.
|
||||
|
||||
Returns:
|
||||
``torch.Tensor`` of shape ``(N, H, W)``:
|
||||
|
||||
* ``dtype=torch.float32`` (metric depth, default)
|
||||
* ``dtype=torch.uint16`` when ``return_quantized=True``.
|
||||
|
||||
Raises:
|
||||
FrameTimestampError: If a query timestamp can't be matched within
|
||||
*tolerance_s*, or if no frames are decoded.
|
||||
"""
|
||||
video_path_str = str(video_path)
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
|
||||
loaded_frames: list[np.ndarray] = []
|
||||
loaded_ts: list[float] = []
|
||||
|
||||
av.logging.set_level(av.logging.WARNING)
|
||||
with av.open(video_path_str, "r") as container:
|
||||
try:
|
||||
stream = container.streams.video[0]
|
||||
except IndexError as e:
|
||||
raise FrameTimestampError(f"No video stream in {video_path_str}") from e
|
||||
|
||||
# Seek to the keyframe at-or-before first_ts (PyAV doesn't do
|
||||
# accurate seek, so we still iterate forward to the requested range).
|
||||
seek_pts = int(first_ts / stream.time_base)
|
||||
container.seek(seek_pts, stream=stream, any_frame=False, backward=True)
|
||||
|
||||
for frame in container.decode(stream):
|
||||
if frame.pts is None:
|
||||
continue
|
||||
current_ts = float(frame.pts * stream.time_base)
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"depth frame loaded at timestamp={current_ts:.4f}")
|
||||
loaded_frames.append(
|
||||
decode_depth_frame(
|
||||
frame,
|
||||
depth_min=depth_min,
|
||||
depth_max=depth_max,
|
||||
shift=shift,
|
||||
use_log=use_log,
|
||||
return_quantized=True,
|
||||
)
|
||||
)
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
break
|
||||
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
if not loaded_frames:
|
||||
raise FrameTimestampError(
|
||||
f"No depth frames decoded from {video_path_str} for timestamps {timestamps}"
|
||||
)
|
||||
|
||||
query_ts = torch.tensor(timestamps)
|
||||
loaded_ts_t = torch.tensor(loaded_ts)
|
||||
dist = torch.cdist(query_ts[:, None], loaded_ts_t[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps violate the tolerance "
|
||||
f"({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts_t}"
|
||||
f"\nvideo: {video_path_str}"
|
||||
)
|
||||
|
||||
closest = np.stack([loaded_frames[i] for i in argmin_]) # (N, H, W) uint16
|
||||
quantized = torch.from_numpy(closest)
|
||||
|
||||
if return_quantized:
|
||||
return quantized
|
||||
return dequantize_depth(quantized, depth_min, depth_max, shift, use_log)
|
||||
|
||||
|
||||
def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
@@ -1128,13 +1250,13 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
|
||||
def get_video_info(
|
||||
video_path: Path | str,
|
||||
camera_encoder_config: "VideoEncoderConfig | None" = None,
|
||||
video_encoder_config: "VideoEncoderConfig | None" = None,
|
||||
) -> dict:
|
||||
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
|
||||
|
||||
Args:
|
||||
video_path: Path to the encoded video file to probe.
|
||||
camera_encoder_config: If provided, record the exact encoder settings used to encode this
|
||||
video_encoder_config: If provided, record the exact encoder settings used to encode this
|
||||
video. Stream-derived values take precedence — encoder fields are only written for keys
|
||||
not already populated from the video file itself.
|
||||
"""
|
||||
@@ -1154,7 +1276,6 @@ def get_video_info(
|
||||
video_info["video.width"] = video_stream.width
|
||||
video_info["video.codec"] = video_stream.codec.canonical_name
|
||||
video_info["video.pix_fmt"] = video_stream.pix_fmt
|
||||
video_info["video.is_depth_map"] = False
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
video_info["video.fps"] = int(video_stream.base_rate)
|
||||
@@ -1168,14 +1289,29 @@ def get_video_info(
|
||||
# Adding audio stream information
|
||||
video_info.update(**get_audio_info(video_path))
|
||||
|
||||
# Add additional encoder configuration if provided
|
||||
if camera_encoder_config is not None:
|
||||
for field_name, field_value in asdict(camera_encoder_config).items():
|
||||
# Add additional encoder configuration if provided (no override of stream-derived values)
|
||||
# Depth related fields flow naturally through this path.
|
||||
if video_encoder_config is not None:
|
||||
for field_name, field_value in asdict(video_encoder_config).items():
|
||||
video_info.setdefault(f"video.{field_name}", field_value)
|
||||
|
||||
# Fallback case where no encoder config is provided or the video is not a depth map.
|
||||
video_info.setdefault("video.is_depth_map", False)
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
# ─── Depth metadata helpers (reader side) ────────────────────────────
|
||||
|
||||
|
||||
_DEPTH_INFO_KEYS: tuple[str, ...] = (
|
||||
"video.depth_min",
|
||||
"video.depth_max",
|
||||
"video.shift",
|
||||
"video.use_log",
|
||||
)
|
||||
|
||||
|
||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||
return 1
|
||||
|
||||
@@ -142,6 +142,36 @@ def test_create_without_videos_has_no_video_path(tmp_path):
|
||||
assert meta.video_keys == []
|
||||
|
||||
|
||||
def test_depth_keys_property_filters_by_marker(tmp_path):
|
||||
"""``depth_keys`` selects only video features carrying ``video.is_depth_map=True``."""
|
||||
features = {
|
||||
**SIMPLE_FEATURES,
|
||||
"observation.images.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
},
|
||||
"observation.depth.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96),
|
||||
"names": ["height", "width"],
|
||||
"info": {"video.is_depth_map": True},
|
||||
},
|
||||
}
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/depth_keys", fps=DEFAULT_FPS, features=features, root=tmp_path / "depth_keys"
|
||||
)
|
||||
|
||||
assert set(meta.video_keys) == {"observation.images.cam", "observation.depth.cam"}
|
||||
assert meta.depth_keys == ["observation.depth.cam"]
|
||||
|
||||
def test_depth_keys_empty_when_no_marker(tmp_path):
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/no_depth", fps=DEFAULT_FPS, features=VIDEO_FEATURES, root=tmp_path / "no_depth"
|
||||
)
|
||||
assert meta.depth_keys == []
|
||||
|
||||
def test_create_raises_on_existing_directory(tmp_path):
|
||||
"""create() raises if root directory already exists."""
|
||||
root = tmp_path / "existing"
|
||||
|
||||
Reference in New Issue
Block a user