feat(depth): wire StreamingVideoEncoder + writer to depth encoder

This commit is contained in:
CarolinePascal
2026-04-27 15:35:44 +02:00
parent d777359662
commit de64ad3f7e
5 changed files with 104 additions and 30 deletions

View File

@@ -547,11 +547,15 @@ class LeRobotDatasetMetadata:
video_keys = [video_key] if video_key is not None else self.video_keys
for key in video_keys:
if not self.features[key].get("info", None):
existing = self.features[key].get("info") or {}
# Repopulate when codec metadata is missing — preserves user-provided
# markers like ``video.is_depth_map`` while still recording stream
# info on the first episode.
if not existing or "video.codec" not in existing:
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
self.info.features[key]["info"] = get_video_info(
video_path, camera_encoder_config=camera_encoder_config
)
stream_info = get_video_info(video_path, camera_encoder_config=camera_encoder_config)
merged = {**existing, **stream_info}
self.info.features[key]["info"] = merged
def update_chunk_settings(
self,

View File

@@ -509,7 +509,13 @@ class DatasetWriter:
# Update video info (only needed when first episode is encoded)
if episode_index == 0:
self._meta.update_video_info(video_key, camera_encoder_config=self._camera_encoder_config)
is_depth_key = video_key in set(self._meta.depth_keys)
cfg_for_info = (
self._depth_encoder_config
if is_depth_key and self._depth_encoder_config is not None
else self._camera_encoder_config
)
self._meta.update_video_info(video_key, camera_encoder_config=cfg_for_info)
write_info(self._meta.info, self._meta.root)
metadata = {

View File

@@ -294,10 +294,20 @@ def validate_feature_image_or_video(
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = value.shape
c, h, w = expected_shape
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
actual_shape = tuple(value.shape)
expected = tuple(expected_shape)
if len(expected) == 2:
# Single-channel features (e.g. depth maps) — accept (H,W), (1,H,W), (H,W,1)
h, w = expected
valid = actual_shape in {(h, w), (1, h, w), (h, w, 1)}
if not valid:
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(h, w)}', '{(1, h, w)}', or '{(h, w, 1)}'.\n"
elif len(expected) == 3:
c, h, w = expected
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
else:
error_message += f"The feature '{name}' has an unsupported expected_shape '{expected}'.\n"
elif isinstance(value, PILImage.Image):
pass
else:

View File

@@ -259,6 +259,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._camera_encoder_config,
self._encoder_threads,
encoder_queue_maxsize,
depth_encoder_config=self._depth_encoder_config,
depth_keys=self.meta.depth_keys,
)
self.writer = DatasetWriter(
meta=self.meta,
@@ -309,12 +311,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
camera_encoder_config: VideoEncoderConfig,
encoder_threads: int | None,
encoder_queue_maxsize: int,
*,
depth_encoder_config: DepthEncoderConfig | None = None,
depth_keys: list[str] | None = None,
) -> StreamingVideoEncoder:
return StreamingVideoEncoder(
fps=fps,
camera_encoder_config=camera_encoder_config,
encoder_threads=encoder_threads,
queue_maxsize=encoder_queue_maxsize,
depth_encoder_config=depth_encoder_config,
depth_keys=depth_keys,
)
# ── Metadata properties ───────────────────────────────────────────
@@ -711,7 +718,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
streaming_enc = None
if streaming_encoding and len(obj.meta.video_keys) > 0:
streaming_enc = cls._build_streaming_encoder(
fps, camera_encoder_config, encoder_threads, encoder_queue_maxsize
fps,
camera_encoder_config,
encoder_threads,
encoder_queue_maxsize,
depth_encoder_config=depth_encoder_config,
depth_keys=obj.meta.depth_keys,
)
obj.writer = DatasetWriter(
meta=obj.meta,
@@ -822,7 +834,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
streaming_enc = None
if streaming_encoding and len(obj.meta.video_keys) > 0:
streaming_enc = cls._build_streaming_encoder(
obj.meta.fps, camera_encoder_config, encoder_threads, encoder_queue_maxsize
obj.meta.fps,
camera_encoder_config,
encoder_threads,
encoder_queue_maxsize,
depth_encoder_config=depth_encoder_config,
depth_keys=obj.meta.depth_keys,
)
obj.writer = DatasetWriter(
meta=obj.meta,

View File

@@ -40,8 +40,11 @@ from PIL import Image
from lerobot.datasets.pyav_utils import (
check_video_encoder_config_pyav,
depth_to_video_frame,
detect_available_encoders_pyav,
decode_depth_frame,
encode_depth_frame_pyav,
decode_depth_frame_pyav,
)
from lerobot.datasets.depth_utils import (
quantize_depth,
@@ -875,6 +878,7 @@ class _CameraEncoderThread(threading.Thread):
frame_queue: queue.Queue,
result_queue: queue.Queue,
stop_event: threading.Event,
depth_encoder_config: "DepthEncoderConfig | None" = None,
):
super().__init__(daemon=True)
self.video_path = video_path
@@ -885,13 +889,16 @@ class _CameraEncoderThread(threading.Thread):
self.frame_queue = frame_queue
self.result_queue = result_queue
self.stop_event = stop_event
self.depth_encoder_config = depth_encoder_config
def run(self) -> None:
from .compute_stats import RunningQuantileStats, auto_downsample_height_width
container = None
output_stream = None
stats_tracker = RunningQuantileStats()
is_depth = self.depth_encoder_config is not None
stats_tracker = RunningQuantileStats() if not is_depth else None
frame_count = 0
try:
@@ -909,12 +916,12 @@ class _CameraEncoderThread(threading.Thread):
# Sentinel: flush and close
break
# Ensure HWC uint8 numpy array
# Ensure HWC (RGB or depth) uint8 (RGB only) numpy array
if isinstance(frame_data, np.ndarray):
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
# CHW -> HWC
frame_data = frame_data.transpose(1, 2, 0)
if frame_data.dtype != np.uint8:
if frame_data.dtype != np.uint8 and not is_depth:
frame_data = (frame_data * 255).astype(np.uint8)
# Open container on first frame (to get width/height)
@@ -929,21 +936,25 @@ class _CameraEncoderThread(threading.Thread):
output_stream.time_base = Fraction(1, self.fps)
# Encode frame with explicit timestamps
pil_img = Image.fromarray(frame_data)
video_frame = av.VideoFrame.from_image(pil_img)
if is_depth:
video_frame = encode_depth_frame_pyav(frame_data, pix_fmt=self.pix_fmt, depth_min=self.depth_encoder_config.depth_min, depth_max=self.depth_encoder_config.depth_max, shift=self.depth_encoder_config.shift, use_log=self.depth_encoder_config.use_log)
else:
pil_img = Image.fromarray(frame_data)
video_frame = av.VideoFrame.from_image(pil_img)
video_frame.pts = frame_count
video_frame.time_base = Fraction(1, self.fps)
packet = output_stream.encode(video_frame)
if packet:
container.mux(packet)
# Update stats with downsampled frame (per-channel stats like compute_episode_stats)
img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW
img_downsampled = auto_downsample_height_width(img_chw)
# Reshape CHW to (H*W, C) for per-channel stats
channels = img_downsampled.shape[0]
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
stats_tracker.update(img_for_stats)
if not is_depth:
# Update stats with downsampled frame (per-channel stats like compute_episode_stats)
img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW
img_downsampled = auto_downsample_height_width(img_chw)
# Reshape CHW to (H*W, C) for per-channel stats
channels = img_downsampled.shape[0]
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
stats_tracker.update(img_for_stats)
frame_count += 1
@@ -958,8 +969,10 @@ class _CameraEncoderThread(threading.Thread):
av.logging.restore_default_callback()
# Get stats and put on result queue
if frame_count >= 2:
# Get stats and put on result queue (depth streams skip stats)
if is_depth:
self.result_queue.put(("ok", None))
elif frame_count >= 2:
stats = stats_tracker.get_statistics()
self.result_queue.put(("ok", stats))
else:
@@ -992,6 +1005,8 @@ class StreamingVideoEncoder:
encoder_threads: int | None = None,
*,
queue_maxsize: int = 30,
depth_encoder_config: "DepthEncoderConfig | None" = None,
depth_keys: list[str] | None = None,
):
"""
Args:
@@ -1002,11 +1017,24 @@ class StreamingVideoEncoder:
``None`` lets the codec decide.
queue_maxsize: Max frames to buffer per camera before
back-pressure drops frames.
depth_encoder_config: Optional depth encoder configuration applied
to all depth video keys listed in ``depth_keys``.
depth_keys: Video keys (matching the dataset feature names) that
must be encoded as quantized depth maps using
``depth_encoder_config``. Required when ``depth_encoder_config``
is provided.
"""
self.fps = fps
self._camera_encoder_config = camera_encoder_config or VideoEncoderConfig()
self._encoder_threads = encoder_threads
self.queue_maxsize = queue_maxsize
self._depth_encoder_config = depth_encoder_config
self._depth_keys: set[str] = set(depth_keys or [])
if self._depth_keys and self._depth_encoder_config is None:
raise ValueError(
"StreamingVideoEncoder received depth_keys without a depth_encoder_config; "
"either pass a DepthEncoderConfig or remove depth_keys."
)
self._frame_queues: dict[str, queue.Queue] = {}
self._result_queues: dict[str, queue.Queue] = {}
@@ -1037,19 +1065,28 @@ class StreamingVideoEncoder:
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
vcodec = self._camera_encoder_config.vcodec
codec_options = self._camera_encoder_config.get_codec_options(
self._encoder_threads, as_strings=True
)
is_depth_key = video_key in self._depth_keys
encoder_cfg: VideoEncoderConfig
depth_cfg = None
if is_depth_key:
assert self._depth_encoder_config is not None # guaranteed by __init__
encoder_cfg = self._depth_encoder_config
depth_cfg = self._depth_encoder_config
else:
encoder_cfg = self._camera_encoder_config
vcodec = encoder_cfg.vcodec
codec_options = encoder_cfg.get_codec_options(self._encoder_threads)
encoder_thread = _CameraEncoderThread(
video_path=video_path,
fps=self.fps,
vcodec=vcodec,
pix_fmt=self._camera_encoder_config.pix_fmt,
pix_fmt=encoder_cfg.pix_fmt,
codec_options=codec_options,
frame_queue=frame_queue,
result_queue=result_queue,
stop_event=stop_event,
depth_encoder_config=depth_cfg,
)
encoder_thread.start()