From de64ad3f7e50a2b886e63fa3e3a69e76e5e6dbf9 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Mon, 27 Apr 2026 15:35:44 +0200 Subject: [PATCH] feat(depth): wire StreamingVideoEncoder + writer to depth encoder --- src/lerobot/datasets/dataset_metadata.py | 12 ++-- src/lerobot/datasets/dataset_writer.py | 8 ++- src/lerobot/datasets/feature_utils.py | 18 ++++-- src/lerobot/datasets/lerobot_dataset.py | 21 ++++++- src/lerobot/datasets/video_utils.py | 75 ++++++++++++++++++------ 5 files changed, 104 insertions(+), 30 deletions(-) diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py index f663bb847..f68515b86 100644 --- a/src/lerobot/datasets/dataset_metadata.py +++ b/src/lerobot/datasets/dataset_metadata.py @@ -547,11 +547,15 @@ class LeRobotDatasetMetadata: video_keys = [video_key] if video_key is not None else self.video_keys for key in video_keys: - if not self.features[key].get("info", None): + existing = self.features[key].get("info") or {} + # Repopulate when codec metadata is missing — preserves user-provided + # markers like ``video.is_depth_map`` while still recording stream + # info on the first episode. + if not existing or "video.codec" not in existing: video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0) - self.info.features[key]["info"] = get_video_info( - video_path, camera_encoder_config=camera_encoder_config - ) + stream_info = get_video_info(video_path, camera_encoder_config=camera_encoder_config) + merged = {**existing, **stream_info} + self.info.features[key]["info"] = merged def update_chunk_settings( self, diff --git a/src/lerobot/datasets/dataset_writer.py b/src/lerobot/datasets/dataset_writer.py index 48defab85..46c85d8af 100644 --- a/src/lerobot/datasets/dataset_writer.py +++ b/src/lerobot/datasets/dataset_writer.py @@ -509,7 +509,13 @@ class DatasetWriter: # Update video info (only needed when first episode is encoded) if episode_index == 0: - self._meta.update_video_info(video_key, camera_encoder_config=self._camera_encoder_config) + is_depth_key = video_key in set(self._meta.depth_keys) + cfg_for_info = ( + self._depth_encoder_config + if is_depth_key and self._depth_encoder_config is not None + else self._camera_encoder_config + ) + self._meta.update_video_info(video_key, camera_encoder_config=cfg_for_info) write_info(self._meta.info, self._meta.root) metadata = { diff --git a/src/lerobot/datasets/feature_utils.py b/src/lerobot/datasets/feature_utils.py index 2ab4b0ea6..a2fee042b 100644 --- a/src/lerobot/datasets/feature_utils.py +++ b/src/lerobot/datasets/feature_utils.py @@ -294,10 +294,20 @@ def validate_feature_image_or_video( # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. error_message = "" if isinstance(value, np.ndarray): - actual_shape = value.shape - c, h, w = expected_shape - if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): - error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" + actual_shape = tuple(value.shape) + expected = tuple(expected_shape) + if len(expected) == 2: + # Single-channel features (e.g. depth maps) — accept (H,W), (1,H,W), (H,W,1) + h, w = expected + valid = actual_shape in {(h, w), (1, h, w), (h, w, 1)} + if not valid: + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(h, w)}', '{(1, h, w)}', or '{(h, w, 1)}'.\n" + elif len(expected) == 3: + c, h, w = expected + if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" + else: + error_message += f"The feature '{name}' has an unsupported expected_shape '{expected}'.\n" elif isinstance(value, PILImage.Image): pass else: diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index ce132557c..cf6952e0a 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -259,6 +259,8 @@ class LeRobotDataset(torch.utils.data.Dataset): self._camera_encoder_config, self._encoder_threads, encoder_queue_maxsize, + depth_encoder_config=self._depth_encoder_config, + depth_keys=self.meta.depth_keys, ) self.writer = DatasetWriter( meta=self.meta, @@ -309,12 +311,17 @@ class LeRobotDataset(torch.utils.data.Dataset): camera_encoder_config: VideoEncoderConfig, encoder_threads: int | None, encoder_queue_maxsize: int, + *, + depth_encoder_config: DepthEncoderConfig | None = None, + depth_keys: list[str] | None = None, ) -> StreamingVideoEncoder: return StreamingVideoEncoder( fps=fps, camera_encoder_config=camera_encoder_config, encoder_threads=encoder_threads, queue_maxsize=encoder_queue_maxsize, + depth_encoder_config=depth_encoder_config, + depth_keys=depth_keys, ) # ── Metadata properties ─────────────────────────────────────────── @@ -711,7 +718,12 @@ class LeRobotDataset(torch.utils.data.Dataset): streaming_enc = None if streaming_encoding and len(obj.meta.video_keys) > 0: streaming_enc = cls._build_streaming_encoder( - fps, camera_encoder_config, encoder_threads, encoder_queue_maxsize + fps, + camera_encoder_config, + encoder_threads, + encoder_queue_maxsize, + depth_encoder_config=depth_encoder_config, + depth_keys=obj.meta.depth_keys, ) obj.writer = DatasetWriter( meta=obj.meta, @@ -822,7 +834,12 @@ class LeRobotDataset(torch.utils.data.Dataset): streaming_enc = None if streaming_encoding and len(obj.meta.video_keys) > 0: streaming_enc = cls._build_streaming_encoder( - obj.meta.fps, camera_encoder_config, encoder_threads, encoder_queue_maxsize + obj.meta.fps, + camera_encoder_config, + encoder_threads, + encoder_queue_maxsize, + depth_encoder_config=depth_encoder_config, + depth_keys=obj.meta.depth_keys, ) obj.writer = DatasetWriter( meta=obj.meta, diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 209490db7..babecce1a 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -40,8 +40,11 @@ from PIL import Image from lerobot.datasets.pyav_utils import ( check_video_encoder_config_pyav, + depth_to_video_frame, detect_available_encoders_pyav, decode_depth_frame, + encode_depth_frame_pyav, + decode_depth_frame_pyav, ) from lerobot.datasets.depth_utils import ( quantize_depth, @@ -875,6 +878,7 @@ class _CameraEncoderThread(threading.Thread): frame_queue: queue.Queue, result_queue: queue.Queue, stop_event: threading.Event, + depth_encoder_config: "DepthEncoderConfig | None" = None, ): super().__init__(daemon=True) self.video_path = video_path @@ -885,13 +889,16 @@ class _CameraEncoderThread(threading.Thread): self.frame_queue = frame_queue self.result_queue = result_queue self.stop_event = stop_event + self.depth_encoder_config = depth_encoder_config + def run(self) -> None: from .compute_stats import RunningQuantileStats, auto_downsample_height_width container = None output_stream = None - stats_tracker = RunningQuantileStats() + is_depth = self.depth_encoder_config is not None + stats_tracker = RunningQuantileStats() if not is_depth else None frame_count = 0 try: @@ -909,12 +916,12 @@ class _CameraEncoderThread(threading.Thread): # Sentinel: flush and close break - # Ensure HWC uint8 numpy array + # Ensure HWC (RGB or depth) uint8 (RGB only) numpy array if isinstance(frame_data, np.ndarray): if frame_data.ndim == 3 and frame_data.shape[0] == 3: # CHW -> HWC frame_data = frame_data.transpose(1, 2, 0) - if frame_data.dtype != np.uint8: + if frame_data.dtype != np.uint8 and not is_depth: frame_data = (frame_data * 255).astype(np.uint8) # Open container on first frame (to get width/height) @@ -929,21 +936,25 @@ class _CameraEncoderThread(threading.Thread): output_stream.time_base = Fraction(1, self.fps) # Encode frame with explicit timestamps - pil_img = Image.fromarray(frame_data) - video_frame = av.VideoFrame.from_image(pil_img) + if is_depth: + video_frame = encode_depth_frame_pyav(frame_data, pix_fmt=self.pix_fmt, depth_min=self.depth_encoder_config.depth_min, depth_max=self.depth_encoder_config.depth_max, shift=self.depth_encoder_config.shift, use_log=self.depth_encoder_config.use_log) + else: + pil_img = Image.fromarray(frame_data) + video_frame = av.VideoFrame.from_image(pil_img) video_frame.pts = frame_count video_frame.time_base = Fraction(1, self.fps) packet = output_stream.encode(video_frame) if packet: container.mux(packet) - # Update stats with downsampled frame (per-channel stats like compute_episode_stats) - img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW - img_downsampled = auto_downsample_height_width(img_chw) - # Reshape CHW to (H*W, C) for per-channel stats - channels = img_downsampled.shape[0] - img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels) - stats_tracker.update(img_for_stats) + if not is_depth: + # Update stats with downsampled frame (per-channel stats like compute_episode_stats) + img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW + img_downsampled = auto_downsample_height_width(img_chw) + # Reshape CHW to (H*W, C) for per-channel stats + channels = img_downsampled.shape[0] + img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels) + stats_tracker.update(img_for_stats) frame_count += 1 @@ -958,8 +969,10 @@ class _CameraEncoderThread(threading.Thread): av.logging.restore_default_callback() - # Get stats and put on result queue - if frame_count >= 2: + # Get stats and put on result queue (depth streams skip stats) + if is_depth: + self.result_queue.put(("ok", None)) + elif frame_count >= 2: stats = stats_tracker.get_statistics() self.result_queue.put(("ok", stats)) else: @@ -992,6 +1005,8 @@ class StreamingVideoEncoder: encoder_threads: int | None = None, *, queue_maxsize: int = 30, + depth_encoder_config: "DepthEncoderConfig | None" = None, + depth_keys: list[str] | None = None, ): """ Args: @@ -1002,11 +1017,24 @@ class StreamingVideoEncoder: ``None`` lets the codec decide. queue_maxsize: Max frames to buffer per camera before back-pressure drops frames. + depth_encoder_config: Optional depth encoder configuration applied + to all depth video keys listed in ``depth_keys``. + depth_keys: Video keys (matching the dataset feature names) that + must be encoded as quantized depth maps using + ``depth_encoder_config``. Required when ``depth_encoder_config`` + is provided. """ self.fps = fps self._camera_encoder_config = camera_encoder_config or VideoEncoderConfig() self._encoder_threads = encoder_threads self.queue_maxsize = queue_maxsize + self._depth_encoder_config = depth_encoder_config + self._depth_keys: set[str] = set(depth_keys or []) + if self._depth_keys and self._depth_encoder_config is None: + raise ValueError( + "StreamingVideoEncoder received depth_keys without a depth_encoder_config; " + "either pass a DepthEncoderConfig or remove depth_keys." + ) self._frame_queues: dict[str, queue.Queue] = {} self._result_queues: dict[str, queue.Queue] = {} @@ -1037,19 +1065,28 @@ class StreamingVideoEncoder: temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir)) video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4" - vcodec = self._camera_encoder_config.vcodec - codec_options = self._camera_encoder_config.get_codec_options( - self._encoder_threads, as_strings=True - ) + is_depth_key = video_key in self._depth_keys + encoder_cfg: VideoEncoderConfig + depth_cfg = None + if is_depth_key: + assert self._depth_encoder_config is not None # guaranteed by __init__ + encoder_cfg = self._depth_encoder_config + depth_cfg = self._depth_encoder_config + else: + encoder_cfg = self._camera_encoder_config + + vcodec = encoder_cfg.vcodec + codec_options = encoder_cfg.get_codec_options(self._encoder_threads) encoder_thread = _CameraEncoderThread( video_path=video_path, fps=self.fps, vcodec=vcodec, - pix_fmt=self._camera_encoder_config.pix_fmt, + pix_fmt=encoder_cfg.pix_fmt, codec_options=codec_options, frame_queue=frame_queue, result_queue=result_queue, stop_event=stop_event, + depth_encoder_config=depth_cfg, ) encoder_thread.start()