mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 04:11:24 +00:00
Compare commits
14 Commits
016799dfa1
...
feat/depth
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4445849b86 | ||
|
|
f43bf75f9b | ||
|
|
b540fa94a9 | ||
|
|
efad15f600 | ||
|
|
407d1882a2 | ||
|
|
0d6e4f3bad | ||
|
|
536b29d963 | ||
|
|
2744e26593 | ||
|
|
de64ad3f7e | ||
|
|
d777359662 | ||
|
|
5d0a20bd9c | ||
|
|
2c796d3352 | ||
|
|
df1648c102 | ||
|
|
3bd96a4346 |
@@ -133,6 +133,9 @@ class RealSenseCamera(Camera):
|
||||
|
||||
self.rs_pipeline: rs.pipeline | None = None
|
||||
self.rs_profile: rs.pipeline_profile | None = None
|
||||
# Meters per uint16 unit on the depth stream. Queried from the device
|
||||
# at connect() time. Typical D-series value is 0.001 (= 1 mm/unit).
|
||||
self.depth_scale: float | None = None
|
||||
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
@@ -190,6 +193,17 @@ class RealSenseCamera(Camera):
|
||||
) from e
|
||||
|
||||
self._configure_capture_settings()
|
||||
|
||||
# Query depth scale (meters per uint16 unit) when depth is enabled so
|
||||
# consumers can convert the raw z16 stream to metric distances.
|
||||
if self.use_depth and self.rs_profile is not None:
|
||||
try:
|
||||
depth_sensor = self.rs_profile.get_device().first_depth_sensor()
|
||||
self.depth_scale = float(depth_sensor.get_depth_scale())
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"{self}: failed to query depth scale ({e}); falling back to 0.001 m/unit.")
|
||||
self.depth_scale = 0.001
|
||||
|
||||
self._start_read_thread()
|
||||
|
||||
# NOTE(Steven/Caroline): Enforcing at least one second of warmup as RS cameras need a bit of time before the first read. If we don't wait, the first read from the warmup will raise.
|
||||
@@ -532,7 +546,6 @@ class RealSenseCamera(Camera):
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
@@ -575,7 +588,6 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
@@ -611,6 +623,78 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
@check_if_not_connected
|
||||
def async_read_depth(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""Read the latest depth frame asynchronously, in metric meters.
|
||||
|
||||
Mirrors :meth:`async_read` but returns the depth stream rather than the
|
||||
color stream. Output is ``np.uint16`` of shape ``(H, W)``.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
|
||||
the background read thread is not running.
|
||||
TimeoutError: If no frame becomes available within ``timeout_ms``.
|
||||
"""
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(
|
||||
f"{self}: cannot read depth — camera was configured with use_depth=False."
|
||||
)
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for depth frame from camera {self} after {timeout_ms} ms."
|
||||
)
|
||||
|
||||
with self.frame_lock:
|
||||
depth_frame = self.latest_depth_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if depth_frame is None:
|
||||
raise RuntimeError(f"Internal error: Event set but no depth frame available for {self}.")
|
||||
|
||||
return depth_frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest_depth(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent depth frame in metric meters (peeking).
|
||||
|
||||
Non-blocking counterpart of :meth:`read_latest` for the depth stream.
|
||||
Output is ``np.float32`` of shape ``(H, W)`` in meters.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
|
||||
no depth frame has been captured yet.
|
||||
TimeoutError: If the latest depth frame is older than ``max_age_ms``.
|
||||
"""
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(
|
||||
f"{self}: cannot read depth — camera was configured with use_depth=False."
|
||||
)
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
depth_frame = self.latest_depth_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if depth_frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any depth frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest depth frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return depth_frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera, stops the pipeline, and cleans up resources.
|
||||
@@ -634,6 +718,8 @@ class RealSenseCamera(Camera):
|
||||
self.rs_pipeline = None
|
||||
self.rs_profile = None
|
||||
|
||||
self.depth_scale = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_color_frame = None
|
||||
self.latest_depth_frame = None
|
||||
|
||||
@@ -49,9 +49,11 @@ from .sampler import EpisodeAwareSampler
|
||||
from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
from .video_utils import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
VideoEncodingManager,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
)
|
||||
|
||||
# NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.)
|
||||
@@ -67,9 +69,11 @@ __all__ = [
|
||||
"LeRobotDatasetMetadata",
|
||||
"MultiLeRobotDataset",
|
||||
"StreamingLeRobotDataset",
|
||||
"DepthEncoderConfig",
|
||||
"VideoEncoderConfig",
|
||||
"VideoEncodingManager",
|
||||
"camera_encoder_defaults",
|
||||
"depth_encoder_defaults",
|
||||
"add_features",
|
||||
"aggregate_datasets",
|
||||
"aggregate_pipeline_dataset_features",
|
||||
|
||||
@@ -313,6 +313,20 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
||||
|
||||
@property
|
||||
def depth_keys(self) -> list[str]:
|
||||
"""Keys to access depth-map modalities stored as videos.
|
||||
|
||||
A depth video key is a feature whose ``info`` dict carries
|
||||
``"video.is_depth_map": True`` (set either at creation time by the user
|
||||
or after the first encoded episode by :meth:`update_video_info`).
|
||||
"""
|
||||
return [
|
||||
key
|
||||
for key, ft in self.features.items()
|
||||
if ft["dtype"] == "video" and ft.get("info", {}).get("video.is_depth_map", False)
|
||||
]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
@@ -533,11 +547,15 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
video_keys = [video_key] if video_key is not None else self.video_keys
|
||||
for key in video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
existing = self.features[key].get("info") or {}
|
||||
# Repopulate when codec metadata is missing — preserves user-provided
|
||||
# markers like ``video.is_depth_map`` while still recording stream
|
||||
# info on the first episode.
|
||||
if not existing or "video.codec" not in existing:
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info.features[key]["info"] = get_video_info(
|
||||
video_path, camera_encoder_config=camera_encoder_config
|
||||
)
|
||||
stream_info = get_video_info(video_path, camera_encoder_config=camera_encoder_config)
|
||||
merged = {**existing, **stream_info}
|
||||
self.info.features[key]["info"] = merged
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
|
||||
@@ -32,7 +32,13 @@ from .io_utils import (
|
||||
hf_transform_to_torch,
|
||||
load_nested_dataset,
|
||||
)
|
||||
from .video_utils import decode_video_frames
|
||||
from .video_utils import decode_depth_frames, decode_video_frames
|
||||
from .depth_utils import (
|
||||
DEFAULT_DEPTH_MIN,
|
||||
DEFAULT_DEPTH_MAX,
|
||||
DEFAULT_DEPTH_SHIFT,
|
||||
DEFAULT_DEPTH_USE_LOG,
|
||||
)
|
||||
|
||||
|
||||
class DatasetReader:
|
||||
@@ -237,17 +243,31 @@ class DatasetReader:
|
||||
"""
|
||||
ep = self._meta.episodes[ep_idx]
|
||||
|
||||
depth_keys = set(self._meta.depth_keys)
|
||||
|
||||
def _decode_single(vid_key: str, query_ts: list[float]) -> tuple[str, torch.Tensor]:
|
||||
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames(
|
||||
video_path,
|
||||
shifted_query_ts,
|
||||
self._tolerance_s,
|
||||
self._video_backend,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
if vid_key in depth_keys:
|
||||
feature_info = self._meta.features[vid_key].get("info") or {}
|
||||
frames = decode_depth_frames(
|
||||
video_path,
|
||||
shifted_query_ts,
|
||||
self._tolerance_s,
|
||||
depth_min=feature_info.get("video.depth_min", DEFAULT_DEPTH_MIN),
|
||||
depth_max=feature_info.get("video.depth_max", DEFAULT_DEPTH_MAX),
|
||||
shift=feature_info.get("video.shift", DEFAULT_DEPTH_SHIFT),
|
||||
use_log=feature_info.get("video.use_log", DEFAULT_DEPTH_USE_LOG),
|
||||
)
|
||||
else:
|
||||
frames = decode_video_frames(
|
||||
video_path,
|
||||
shifted_query_ts,
|
||||
self._tolerance_s,
|
||||
self._video_backend,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
return vid_key, frames.squeeze(0)
|
||||
|
||||
items = list(query_timestamps.items())
|
||||
|
||||
@@ -46,16 +46,19 @@ from .io_utils import (
|
||||
write_info,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_DEPTH_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from .video_utils import (
|
||||
DepthEncoderConfig,
|
||||
StreamingVideoEncoder,
|
||||
VideoEncoderConfig,
|
||||
concatenate_video_files,
|
||||
encode_video_frames,
|
||||
get_video_duration_in_s,
|
||||
is_depth_feature,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -100,6 +103,7 @@ class DatasetWriter:
|
||||
batch_encoding_size: int,
|
||||
streaming_encoder: StreamingVideoEncoder | None = None,
|
||||
initial_frames: int = 0,
|
||||
depth_encoder_config: DepthEncoderConfig | None = None,
|
||||
):
|
||||
"""Initialize the writer with metadata, codec, and encoder config.
|
||||
|
||||
@@ -115,14 +119,19 @@ class DatasetWriter:
|
||||
streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder`
|
||||
for real-time encoding. ``None`` disables streaming mode.
|
||||
initial_frames: Starting frame count (non-zero when resuming).
|
||||
depth_encoder_config: Optional depth-map encoder config used in
|
||||
place of ``camera_encoder_config`` for keys present in
|
||||
``meta.depth_keys``.
|
||||
"""
|
||||
self._meta = meta
|
||||
self._root = root
|
||||
self._camera_encoder_config = camera_encoder_config
|
||||
self._depth_encoder_config = depth_encoder_config
|
||||
self._encoder_threads = encoder_threads
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._streaming_encoder = streaming_encoder
|
||||
|
||||
|
||||
# Writer state
|
||||
self.image_writer: AsyncImageWriter | None = None
|
||||
self.episode_buffer: dict = self._create_episode_buffer()
|
||||
@@ -142,8 +151,16 @@ class DatasetWriter:
|
||||
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
|
||||
return ep_buffer
|
||||
|
||||
def _is_depth_image_key(self, image_key: str) -> bool:
|
||||
"""Whether *image_key* is a depth feature stored as per-frame images."""
|
||||
ft = self._meta.features.get(image_key)
|
||||
if ft is None or ft.get("dtype") != "image":
|
||||
return False
|
||||
return is_depth_feature(ft.get("info") or {})
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
path_template = DEFAULT_DEPTH_PATH if self._is_depth_image_key(image_key) else DEFAULT_IMAGE_PATH
|
||||
fpath = path_template.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self._root / fpath
|
||||
@@ -502,7 +519,13 @@ class DatasetWriter:
|
||||
|
||||
# Update video info (only needed when first episode is encoded)
|
||||
if episode_index == 0:
|
||||
self._meta.update_video_info(video_key, camera_encoder_config=self._camera_encoder_config)
|
||||
is_depth_key = video_key in set(self._meta.depth_keys)
|
||||
cfg_for_info = (
|
||||
self._depth_encoder_config
|
||||
if is_depth_key and self._depth_encoder_config is not None
|
||||
else self._camera_encoder_config
|
||||
)
|
||||
self._meta.update_video_info(video_key, camera_encoder_config=cfg_for_info)
|
||||
write_info(self._meta.info, self._meta.root)
|
||||
|
||||
metadata = {
|
||||
|
||||
189
src/lerobot/datasets/depth_utils.py
Normal file
189
src/lerobot/datasets/depth_utils.py
Normal file
@@ -0,0 +1,189 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Depth encoding/decoding helpers for :class:`VideoEncoderConfig`.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
|
||||
DEPTH_QUANT_BITS: int = 12
|
||||
DEPTH_QMAX: int = (1 << DEPTH_QUANT_BITS) - 1 # 4095
|
||||
_MM_PER_METRE: float = 1000.0
|
||||
_UINT16_MAX: int = 65535
|
||||
|
||||
DEFAULT_DEPTH_MIN: float = 0.01
|
||||
DEFAULT_DEPTH_MAX: float = 10.0
|
||||
DEFAULT_DEPTH_SHIFT: float = 3.5
|
||||
DEFAULT_DEPTH_USE_LOG: bool = True
|
||||
|
||||
|
||||
def _validate_log_quant_params(depth_min: float, shift: float) -> None:
|
||||
"""Ensure ``log(depth_min + shift)`` is finite."""
|
||||
if depth_min + shift <= 0:
|
||||
raise ValueError(
|
||||
f"depth_min + shift must be positive for logarithmic quantization, "
|
||||
f"got depth_min={depth_min} + shift={shift} = {depth_min + shift}"
|
||||
)
|
||||
|
||||
|
||||
def _depth_input_to_float32_and_unit(
|
||||
depth: NDArray[np.uint16] | NDArray[np.floating] | torch.Tensor,
|
||||
input_unit: Literal["auto", "m", "mm"],
|
||||
) -> tuple[NDArray[np.float32], Literal["m", "mm"]]:
|
||||
"""Depth as float32 in the chosen unit, plus the resolved unit."""
|
||||
if isinstance(depth, torch.Tensor):
|
||||
t = depth.detach().cpu()
|
||||
arr = t.numpy()
|
||||
is_floating = t.is_floating_point()
|
||||
else:
|
||||
arr = np.asarray(depth)
|
||||
is_floating = np.issubdtype(arr.dtype, np.floating)
|
||||
|
||||
resolved_unit: Literal["m", "mm"]
|
||||
if input_unit == "auto":
|
||||
resolved_unit = "m" if is_floating else "mm"
|
||||
else:
|
||||
resolved_unit = input_unit
|
||||
|
||||
# Convert to float32 to keep typing consistency
|
||||
return np.asarray(arr, dtype=np.float32, order="K"), resolved_unit
|
||||
|
||||
|
||||
def quantize_depth(
|
||||
depth: NDArray[np.uint16] | NDArray[np.floating] | torch.Tensor,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
*,
|
||||
input_unit: Literal["auto", "m", "mm"] = "auto",
|
||||
) -> NDArray[np.uint16]:
|
||||
"""Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``).
|
||||
|
||||
Depth maps are packed into 12-bit integer frames so they fit in standard
|
||||
high-bit-depth pixel formats (e.g. ``yuv420p12le`` / ``gray12le``)
|
||||
and can be encoded by widely supported video codecs (HEVC Main 12, ffv1).
|
||||
Logarithmic quantization is the default because it allocates more quanta
|
||||
to near-range depth, which matches the (1/depth) error profile of typical
|
||||
depth sensors. Math is ported from BEHAVIOR-1K's ``obs_utils.py``.
|
||||
|
||||
**Input units**:
|
||||
|
||||
- ``input_unit="auto"`` (default): infer from dtype (floating = m, non-floating = mm).
|
||||
- ``input_unit="mm"``: interpret input values as millimetres.
|
||||
- ``input_unit="m"``: interpret input values as metres.
|
||||
|
||||
Quantization math runs in the **resolved input unit**.
|
||||
|
||||
``depth_min``, ``depth_max``, and ``shift`` are always in **metres**.
|
||||
|
||||
Args:
|
||||
depth: Depth map; ``torch.Tensor`` is moved to CPU for conversion.
|
||||
depth_min: Depth (metres) at quantum ``0``.
|
||||
depth_max: Depth (metres) at quantum :data:`DEPTH_QMAX`.
|
||||
shift: Depth shift (metres); used in log mode. Must satisfy ``depth_min + shift > 0``.
|
||||
use_log: If ``True`` (default), quantize in log space.
|
||||
input_unit: Input unit policy (``"auto"``, ``"mm"``, ``"m"``).
|
||||
|
||||
Returns:
|
||||
``numpy.ndarray``, ``dtype=uint16``, same shape as ``depth``, values in
|
||||
``[0, DEPTH_QMAX]``.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``input_unit`` is not ``"auto"``, ``"mm"``, or ``"m"``.
|
||||
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
|
||||
"""
|
||||
if input_unit not in ("auto", "m", "mm"):
|
||||
raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}")
|
||||
|
||||
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
|
||||
depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE)
|
||||
depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE)
|
||||
shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE)
|
||||
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
log_min = math.log(float(depth_min_u + shift_u))
|
||||
log_max = math.log(float(depth_max_u + shift_u))
|
||||
norm = (np.log(depth_f + shift_u) - log_min) / (log_max - log_min)
|
||||
else:
|
||||
norm = (depth_f - depth_min_u) / (depth_max_u - depth_min_u)
|
||||
|
||||
out = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX)
|
||||
return out.astype(np.uint16, copy=False)
|
||||
|
||||
|
||||
def dequantize_depth(
|
||||
quantized: NDArray[np.uint16] | torch.Tensor,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
*,
|
||||
output_unit: Literal["m", "mm"] = "mm",
|
||||
) -> NDArray[np.uint16] | NDArray[np.float32]:
|
||||
"""Inverse of :func:`quantize_depth`.
|
||||
|
||||
Tuning arguments **must match** :func:`quantize_depth`.
|
||||
|
||||
Decoding inverts the same normalized code mapping as :func:`quantize_depth`
|
||||
using ``depth_min`` / ``depth_max`` / ``shift`` (in metres), then returns
|
||||
the requested output unit.
|
||||
|
||||
Args:
|
||||
quantized: 12-bit codes ``[0, DEPTH_QMAX]``, ``dtype=uint16``.
|
||||
depth_min, depth_max, shift, use_log: Same as :func:`quantize_depth` (metres).
|
||||
output_unit: ``\"mm\"`` returns ``uint16`` millimetres (``rint``, clip
|
||||
``[0, 65535]``). ``\"m\"`` returns ``float32`` metres in
|
||||
``[depth_min, depth_max]``.
|
||||
|
||||
Returns:
|
||||
Depth map in the requested unit and dtype.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
|
||||
ValueError: If ``output_unit`` is not ``\"m\"`` or ``\"mm\"``.
|
||||
"""
|
||||
if output_unit not in ("m", "mm"):
|
||||
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
|
||||
|
||||
if isinstance(quantized, torch.Tensor):
|
||||
quantized = quantized.detach().cpu().numpy()
|
||||
q = np.asarray(quantized, dtype=np.uint16, order="K")
|
||||
norm = q.astype(np.float32, copy=False) / DEPTH_QMAX
|
||||
|
||||
depth_min_mm = np.float32(depth_min * _MM_PER_METRE)
|
||||
depth_max_mm = np.float32(depth_max * _MM_PER_METRE)
|
||||
shift_mm = np.float32(shift * _MM_PER_METRE)
|
||||
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
log_min = math.log(float(depth_min_mm + shift_mm))
|
||||
log_max = math.log(float(depth_max_mm + shift_mm))
|
||||
depth_mm = np.exp(norm * (log_max - log_min) + log_min) - shift_mm
|
||||
else:
|
||||
depth_mm = norm * (depth_max_mm - depth_min_mm) + depth_min_mm
|
||||
|
||||
depth_mm = np.clip(depth_mm, depth_min_mm, depth_max_mm).astype(np.float32, copy=False)
|
||||
if output_unit == "m":
|
||||
return (depth_mm / np.float32(_MM_PER_METRE)).astype(np.float32, copy=False)
|
||||
mm = np.rint(depth_mm).clip(0, _UINT16_MAX)
|
||||
return mm.astype(np.uint16, copy=False)
|
||||
@@ -294,10 +294,20 @@ def validate_feature_image_or_video(
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c, h, w = expected_shape
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
actual_shape = tuple(value.shape)
|
||||
expected = tuple(expected_shape)
|
||||
if len(expected) == 2:
|
||||
# Single-channel features (e.g. depth maps) — accept (H,W), (1,H,W), (H,W,1)
|
||||
h, w = expected
|
||||
valid = actual_shape in {(h, w), (1, h, w), (h, w, 1)}
|
||||
if not valid:
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(h, w)}', '{(1, h, w)}', or '{(h, w, 1)}'.\n"
|
||||
elif len(expected) == 3:
|
||||
c, h, w = expected
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
else:
|
||||
error_message += f"The feature '{name}' has an unsupported expected_shape '{expected}'.\n"
|
||||
elif isinstance(value, PILImage.Image):
|
||||
pass
|
||||
else:
|
||||
|
||||
@@ -41,15 +41,56 @@ def safe_stop_image_writer(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim != 3:
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||
# Single-channel dtypes that PIL natively maps to the matching mode
|
||||
# (``uint8`` → ``L``, ``uint16`` → ``I;16``, ``float32`` → ``F``).
|
||||
GRAYSCALE_DTYPES: tuple[np.dtype, ...] = (
|
||||
np.dtype("uint8"),
|
||||
np.dtype("uint16"),
|
||||
np.dtype("float32"),
|
||||
)
|
||||
|
||||
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
"""Convert a NumPy array to a PIL Image, preserving precision for grayscale.
|
||||
|
||||
Behaviour by shape:
|
||||
|
||||
- ``(H, W)`` or ``(1, H, W)`` / ``(H, W, 1)``: single-channel grayscale.
|
||||
The native dtype is preserved using the matching PIL mode
|
||||
(``L`` / ``I;16`` / ``F``). This is the path used for raw depth maps (no rescaling, clamping, or downcasting)
|
||||
- ``(3, H, W)`` / ``(H, W, 3)``: RGB. Channels-first inputs are transposed
|
||||
to channels-last. Float inputs in ``[0, 1]`` are scaled to ``uint8``
|
||||
(existing behaviour, gated by ``range_check``).
|
||||
|
||||
Other shapes / channel counts raise ``NotImplementedError`` or
|
||||
``ValueError``.
|
||||
"""
|
||||
if image_array.ndim not in (2, 3):
|
||||
raise ValueError(
|
||||
f"The array has {image_array.ndim} dimensions, but 2 or 3 is expected for an image."
|
||||
)
|
||||
|
||||
# Squeeze 3D single-channel inputs to 2D so depth maps work whether the
|
||||
# caller emits (H, W), (1, H, W), or (H, W, 1).
|
||||
if image_array.ndim == 3:
|
||||
if image_array.shape[0] == 1:
|
||||
image_array = image_array[0]
|
||||
elif image_array.shape[-1] == 1:
|
||||
image_array = image_array[..., 0]
|
||||
|
||||
if image_array.ndim == 2:
|
||||
if image_array.dtype not in GRAYSCALE_DTYPES:
|
||||
raise ValueError(
|
||||
f"Unsupported single-channel image dtype: {image_array.dtype}. "
|
||||
f"Supported dtypes: {sorted(str(d) for d in GRAYSCALE_DTYPES)}."
|
||||
)
|
||||
|
||||
return PIL.Image.fromarray(np.ascontiguousarray(image_array))
|
||||
|
||||
# 3D path: must be RGB (3 channels), channels-first or channels-last.
|
||||
if image_array.shape[0] == 3:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
image_array = image_array.transpose(1, 2, 0)
|
||||
|
||||
elif image_array.shape[-1] != 3:
|
||||
raise NotImplementedError(
|
||||
f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
|
||||
@@ -71,13 +112,28 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
|
||||
return PIL.Image.fromarray(image_array)
|
||||
|
||||
|
||||
def save_kwargs_for_path(fpath: Path, compress_level: int) -> dict:
|
||||
"""Pick the right format-specific kwargs for :meth:`PIL.Image.Image.save`.
|
||||
|
||||
PNG uses ``compress_level`` (0–9, zlib). TIFF uses ``compression`` (raw) for lossless raw depth maps.
|
||||
"""
|
||||
suffix = Path(fpath).suffix.lower()
|
||||
if suffix == ".png":
|
||||
return {"compress_level": compress_level}
|
||||
if suffix in (".tif", ".tiff"):
|
||||
return {"compression": "raw"}
|
||||
return {}
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):
|
||||
"""
|
||||
Saves a NumPy array or PIL Image to a file.
|
||||
|
||||
This function handles both NumPy arrays and PIL Image objects, converting
|
||||
the former to a PIL Image before saving. It includes error handling for
|
||||
the save operation.
|
||||
the save operation. The output format is inferred from the *fpath*
|
||||
extension: ``.png`` → PNG with ``compress_level``, ``.tiff`` / ``.tif``
|
||||
→ lossless raw depth maps (TIFF).
|
||||
|
||||
Args:
|
||||
image (np.ndarray | PIL.Image.Image): The image data to save.
|
||||
@@ -101,7 +157,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
|
||||
img = image
|
||||
else:
|
||||
raise TypeError(f"Unsupported image type: {type(image)}")
|
||||
img.save(fpath, compress_level=compress_level)
|
||||
img.save(fpath, **save_kwargs_for_path(Path(fpath), compress_level))
|
||||
except Exception as e:
|
||||
logger.error("Error writing image %s: %s", fpath, e)
|
||||
|
||||
|
||||
@@ -35,9 +35,11 @@ from .utils import (
|
||||
is_valid_version,
|
||||
)
|
||||
from .video_utils import (
|
||||
DepthEncoderConfig,
|
||||
StreamingVideoEncoder,
|
||||
VideoEncoderConfig,
|
||||
get_safe_default_video_backend,
|
||||
seed_depth_feature_info,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -59,6 +61,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return_uint8: bool = False,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
depth_encoder_config: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -207,6 +210,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
self._camera_encoder_config = camera_encoder_config
|
||||
self._depth_encoder_config = depth_encoder_config
|
||||
self._encoder_threads = encoder_threads
|
||||
|
||||
if self._requested_root is not None:
|
||||
@@ -249,6 +253,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
seed_depth_feature_info(self.meta.features, self._depth_encoder_config)
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(self.meta.video_keys) > 0:
|
||||
streaming_enc = self._build_streaming_encoder(
|
||||
@@ -256,11 +261,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self._camera_encoder_config,
|
||||
self._encoder_threads,
|
||||
encoder_queue_maxsize,
|
||||
depth_encoder_config=self._depth_encoder_config,
|
||||
depth_keys=self.meta.depth_keys,
|
||||
)
|
||||
self.writer = DatasetWriter(
|
||||
meta=self.meta,
|
||||
root=self.root,
|
||||
camera_encoder_config=self._camera_encoder_config,
|
||||
depth_encoder_config=self._depth_encoder_config,
|
||||
encoder_threads=self._encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -305,12 +313,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
camera_encoder_config: VideoEncoderConfig,
|
||||
encoder_threads: int | None,
|
||||
encoder_queue_maxsize: int,
|
||||
*,
|
||||
depth_encoder_config: DepthEncoderConfig | None = None,
|
||||
depth_keys: list[str] | None = None,
|
||||
) -> StreamingVideoEncoder:
|
||||
return StreamingVideoEncoder(
|
||||
fps=fps,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
encoder_threads=encoder_threads,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
depth_encoder_config=depth_encoder_config,
|
||||
depth_keys=depth_keys,
|
||||
)
|
||||
|
||||
# ── Metadata properties ───────────────────────────────────────────
|
||||
@@ -626,6 +639,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
depth_encoder_config: DepthEncoderConfig | None = None,
|
||||
metadata_buffer_size: int = 10,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -697,7 +711,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj._return_uint8 = False
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._camera_encoder_config = camera_encoder_config
|
||||
obj._depth_encoder_config = depth_encoder_config
|
||||
obj._encoder_threads = encoder_threads
|
||||
seed_depth_feature_info(obj.meta.features, depth_encoder_config)
|
||||
|
||||
# Reader is lazily created on first access (write-only mode)
|
||||
obj.reader = None
|
||||
@@ -705,12 +721,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
fps, camera_encoder_config, encoder_threads, encoder_queue_maxsize
|
||||
fps,
|
||||
camera_encoder_config,
|
||||
encoder_threads,
|
||||
encoder_queue_maxsize,
|
||||
depth_encoder_config=depth_encoder_config,
|
||||
depth_keys=obj.meta.depth_keys,
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
depth_encoder_config=depth_encoder_config,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -734,6 +756,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
depth_encoder_config: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
@@ -804,8 +827,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
obj._camera_encoder_config = camera_encoder_config
|
||||
obj._depth_encoder_config = depth_encoder_config
|
||||
obj._encoder_threads = encoder_threads
|
||||
obj.root = obj.meta.root
|
||||
seed_depth_feature_info(obj.meta.features, depth_encoder_config)
|
||||
|
||||
# Reader is lazily created on first access (write-only mode)
|
||||
obj.reader = None
|
||||
@@ -813,12 +838,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
obj.meta.fps, camera_encoder_config, encoder_threads, encoder_queue_maxsize
|
||||
obj.meta.fps,
|
||||
camera_encoder_config,
|
||||
encoder_threads,
|
||||
encoder_queue_maxsize,
|
||||
depth_encoder_config=depth_encoder_config,
|
||||
depth_keys=obj.meta.depth_keys,
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
depth_encoder_config=depth_encoder_config,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
|
||||
@@ -23,19 +23,144 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.depth_utils import (
|
||||
DEFAULT_DEPTH_MAX,
|
||||
DEFAULT_DEPTH_MIN,
|
||||
DEFAULT_DEPTH_SHIFT,
|
||||
DEFAULT_DEPTH_USE_LOG,
|
||||
quantize_depth,
|
||||
dequantize_depth,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.datasets.video_utils import VideoEncoderConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pixel formats supported by the depth encode/decode helpers below. Both are
|
||||
# 16-bit-word formats that carry 12 significant bits per sample, matching the
|
||||
# ``DEPTH_QMAX = 4095`` quantization range.
|
||||
DEPTH_PIX_FMTS: tuple[str, ...] = ("yuv420p12le", "gray12le")
|
||||
|
||||
# Neutral chroma for 12-bit YUV (the midpoint of [0, 4095]). Filling the U/V
|
||||
# planes with this value keeps the encoder from spending bits on chroma noise
|
||||
# when only the Y plane carries information.
|
||||
_NEUTRAL_CHROMA_12BIT: int = 2048
|
||||
|
||||
FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
|
||||
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
|
||||
|
||||
|
||||
def _write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_value: int | None = None) -> None:
|
||||
"""Copy ``src`` into a uint16 plane respecting FFmpeg line padding."""
|
||||
height, width = src.shape
|
||||
stride_u16 = plane.line_size // np.dtype(np.uint16).itemsize
|
||||
dst = np.frombuffer(plane, dtype=np.uint16).reshape(height, stride_u16)
|
||||
if fill_value is not None:
|
||||
dst.fill(fill_value)
|
||||
dst[:, :width] = src
|
||||
|
||||
|
||||
def encode_depth_frame_pyav(
|
||||
depth: np.ndarray | torch.Tensor,
|
||||
*,
|
||||
pix_fmt: str = "yuv420p12le",
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
input_unit: Literal["auto", "m", "mm"] = "auto",
|
||||
) -> av.VideoFrame:
|
||||
"""Quantize depth and pack it into a 12-bit PyAV video frame.
|
||||
|
||||
Args:
|
||||
depth: Depth frame to encode (H, W). Unit handling follows
|
||||
:func:`lerobot.datasets.depth_utils.quantize_depth`.
|
||||
pix_fmt: Target pixel format. Must be one of :data:`DEPTH_PIX_FMTS`.
|
||||
depth_min, depth_max, shift, use_log, input_unit: Forwarded to
|
||||
:func:`quantize_depth`.
|
||||
|
||||
Returns:
|
||||
An :class:`av.VideoFrame` in ``pix_fmt`` with quantized depth in the
|
||||
luminance plane.
|
||||
"""
|
||||
if pix_fmt not in DEPTH_PIX_FMTS:
|
||||
raise ValueError(f"Unsupported depth pix_fmt={pix_fmt!r}; expected one of {DEPTH_PIX_FMTS}")
|
||||
|
||||
quantized_depth = quantize_depth(
|
||||
depth,
|
||||
depth_min=depth_min,
|
||||
depth_max=depth_max,
|
||||
shift=shift,
|
||||
use_log=use_log,
|
||||
input_unit=input_unit,
|
||||
)
|
||||
if quantized_depth.ndim != 2:
|
||||
raise ValueError(f"depth must be a 2D frame; got shape {quantized_depth.shape}")
|
||||
|
||||
quantized_depth = np.ascontiguousarray(quantized_depth, dtype=np.uint16)
|
||||
height, width = quantized_depth.shape
|
||||
|
||||
if pix_fmt == "gray12le":
|
||||
frame = av.VideoFrame(width=width, height=height, format="gray12le")
|
||||
_write_u16_plane(frame.planes[0], quantized_depth)
|
||||
return frame
|
||||
|
||||
if height % 2 != 0 or width % 2 != 0:
|
||||
raise ValueError("yuv420p12le requires even H and W")
|
||||
|
||||
frame = av.VideoFrame(width=width, height=height, format="yuv420p12le")
|
||||
_write_u16_plane(frame.planes[0], quantized_depth)
|
||||
neutral_chroma = np.full((height // 2, width // 2), _NEUTRAL_CHROMA_12BIT, dtype=np.uint16)
|
||||
_write_u16_plane(frame.planes[1], neutral_chroma, fill_value=_NEUTRAL_CHROMA_12BIT)
|
||||
_write_u16_plane(frame.planes[2], neutral_chroma, fill_value=_NEUTRAL_CHROMA_12BIT)
|
||||
return frame
|
||||
|
||||
|
||||
def decode_depth_frame_pyav(
|
||||
frame: av.VideoFrame | list[av.VideoFrame],
|
||||
*,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
return_quantized: bool = False,
|
||||
output_unit: Literal["m", "mm"] = "m",
|
||||
) -> np.ndarray:
|
||||
"""Decode one or many depth video frames to quantized or metric depth.
|
||||
|
||||
Args:
|
||||
frame: A single depth frame or a list of depth frames.
|
||||
depth_min, depth_max, shift, use_log: Forwarded to
|
||||
:func:`dequantize_depth`.
|
||||
return_quantized: If ``True``, return raw 12-bit quanta as ``uint16``.
|
||||
output_unit: Unit for dequantized output (``"m"`` or ``"mm"``).
|
||||
|
||||
Returns:
|
||||
``(H, W)`` array for a single frame, or ``(N, H, W)`` for a list.
|
||||
"""
|
||||
frames = frame if isinstance(frame, list) else [frame]
|
||||
quantized = np.stack([f.reformat(format="gray12le").to_ndarray() for f in frames]).astype(np.uint16, copy=False)
|
||||
if return_quantized:
|
||||
return quantized[0] if len(frames) == 1 else quantized
|
||||
|
||||
decoded = dequantize_depth(
|
||||
quantized,
|
||||
depth_min=depth_min,
|
||||
depth_max=depth_max,
|
||||
shift=shift,
|
||||
use_log=use_log,
|
||||
output_unit=output_unit,
|
||||
)
|
||||
return decoded[0] if len(frames) == 1 else decoded
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_codec(vcodec: str) -> av.codec.Codec | None:
|
||||
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
|
||||
@@ -46,7 +171,7 @@ def get_codec(vcodec: str) -> av.codec.Codec | None:
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_codec_options_by_name(vcodec: str) -> dict[str, av.option.Option]:
|
||||
def _get_codec_video_formats(vcodec: str) -> dict[str, av.option.Option]:
|
||||
"""Private-option name → PyAV ``Option`` for *vcodec* (empty if unavailable)."""
|
||||
codec = get_codec(vcodec)
|
||||
if codec is None:
|
||||
|
||||
@@ -93,6 +93,10 @@ DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||
# Depth maps live alongside images on disk but use TIFF instead of PNG: PNG
|
||||
# cannot natively round-trip float32, and several common loaders silently
|
||||
# downcast 16-bit grayscale.
|
||||
DEFAULT_DEPTH_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.tiff"
|
||||
|
||||
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
|
||||
@@ -17,6 +17,7 @@ import contextlib
|
||||
import glob
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import queue
|
||||
import shutil
|
||||
import tempfile
|
||||
@@ -39,7 +40,19 @@ from PIL import Image
|
||||
|
||||
from lerobot.datasets.pyav_utils import (
|
||||
check_video_encoder_config_pyav,
|
||||
depth_to_video_frame,
|
||||
detect_available_encoders_pyav,
|
||||
decode_depth_frame,
|
||||
encode_depth_frame_pyav,
|
||||
decode_depth_frame_pyav,
|
||||
)
|
||||
from lerobot.datasets.depth_utils import (
|
||||
quantize_depth,
|
||||
dequantize_depth,
|
||||
DEFAULT_DEPTH_MIN,
|
||||
DEFAULT_DEPTH_MAX,
|
||||
DEFAULT_DEPTH_SHIFT,
|
||||
DEFAULT_DEPTH_USE_LOG,
|
||||
)
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
@@ -56,7 +69,7 @@ HW_ENCODERS = [
|
||||
"h264_qsv", # Intel Quick Sync
|
||||
]
|
||||
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS)
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "ffv1", "auto"} | set(HW_ENCODERS)
|
||||
|
||||
LIBSVTAV1_DEFAULT_PRESET: int = 12
|
||||
|
||||
@@ -94,6 +107,12 @@ class VideoEncoderConfig:
|
||||
video_backend: str = "pyav"
|
||||
extra_options: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Class-level marker persisted to ``info.json`` (via ``asdict``) so the
|
||||
# reader can tell depth datasets from RGB ones without a separate dispatch
|
||||
# path. ``init=False`` keeps it out of CLI/constructor surface; subclasses
|
||||
# flip the default (see :class:`DepthEncoderConfig`).
|
||||
is_depth_map: bool = field(default=False, init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.resolve_vcodec()
|
||||
|
||||
@@ -122,6 +141,12 @@ class VideoEncoderConfig:
|
||||
also silently rewritten to ``libsvtav1`` so encoding never hard-fails on
|
||||
a host missing the requested encoder.
|
||||
"""
|
||||
# Backward compatibility: older datasets persist ``vcodec="av1"`` in
|
||||
# ``info.json``. Rewrite to the canonical encoder name *before* the
|
||||
# validation check below so loading those datasets keeps working.
|
||||
if self.vcodec == "av1":
|
||||
self.vcodec = "libsvtav1"
|
||||
|
||||
if self.vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{self.vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if self.vcodec == "auto":
|
||||
@@ -191,6 +216,10 @@ class VideoEncoderConfig:
|
||||
elif self.vcodec == "h264_qsv":
|
||||
set_if("global_quality", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
elif self.vcodec == "ffv1":
|
||||
# Lossless intra-frame codec. ``crf``/``preset``/``fast_decode``
|
||||
# are not meaningful.
|
||||
set_if("threads", encoder_threads)
|
||||
else:
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
@@ -203,6 +232,60 @@ class VideoEncoderConfig:
|
||||
return opts
|
||||
|
||||
|
||||
@dataclass
|
||||
class DepthEncoderConfig(VideoEncoderConfig):
|
||||
"""Encoder configuration for depth-map streams.
|
||||
|
||||
Inherits the full :class:`VideoEncoderConfig` surface (codec, GOP, CRF,
|
||||
preset, ``extra_options``…) and adds the four parameters of the depth
|
||||
quantization pipeline (:func:`quantize_depth`). Inheritance — rather
|
||||
than composition — keeps the CLI flat: ``--dataset.depth_encoder_config.<field>``
|
||||
works identically to its RGB counterpart.
|
||||
|
||||
Defaults flip ``vcodec`` to ``"hevc"`` (Main 12 profile) and ``pix_fmt``
|
||||
to ``"yuv420p12le"``, the most widely available 12-bit pixel format.
|
||||
For archive-grade lossless storage use ``vcodec="ffv1"`` together with
|
||||
``pix_fmt="gray12le"`` (and clear ``crf``/``preset`` to ``None`` since
|
||||
``ffv1`` doesn't expose those tuning knobs).
|
||||
|
||||
The :attr:`is_depth_map` marker is class-fixed to ``True`` (``init=False``,
|
||||
so it's hidden from CLI and constructor args) and is what the reader
|
||||
side keys on to tell depth datasets from RGB ones.
|
||||
|
||||
Attributes:
|
||||
depth_min: Minimum depth in physical units (e.g. metres) represented
|
||||
by quantum ``0``.
|
||||
depth_max: Maximum depth represented by quantum :data:`DEPTH_QMAX`.
|
||||
shift: Pre-log offset for numerical stability near zero.
|
||||
use_log: ``True`` for logarithmic quantization (default; matches
|
||||
sensor error profile), ``False`` for linear.
|
||||
"""
|
||||
|
||||
vcodec: str = "hevc"
|
||||
pix_fmt: str = "yuv420p12le"
|
||||
|
||||
depth_min: float = DEFAULT_DEPTH_MIN
|
||||
depth_max: float = DEFAULT_DEPTH_MAX
|
||||
shift: float = DEFAULT_DEPTH_SHIFT
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG
|
||||
|
||||
# Class invariant — kept out of ``__init__`` (and CLI) but persisted
|
||||
# via ``asdict`` into ``info.json`` for the reader to detect depth.
|
||||
is_depth_map: bool = field(default=True, init=False)
|
||||
|
||||
def quantize(self, depth: torch.Tensor | np.ndarray) -> torch.Tensor:
|
||||
"""Apply :func:`quantize_depth` bound to this config's parameters."""
|
||||
return quantize_depth(depth, self.depth_min, self.depth_max, self.shift, self.use_log)
|
||||
|
||||
def dequantize(self, quantized: torch.Tensor | np.ndarray) -> torch.Tensor:
|
||||
"""Apply :func:`dequantize_depth` bound to this config's parameters."""
|
||||
return dequantize_depth(quantized, self.depth_min, self.depth_max, self.shift, self.use_log)
|
||||
|
||||
|
||||
def depth_encoder_defaults() -> DepthEncoderConfig:
|
||||
"""Return a :class:`DepthEncoderConfig` with depth-camera defaults."""
|
||||
return DepthEncoderConfig()
|
||||
|
||||
def camera_encoder_defaults() -> VideoEncoderConfig:
|
||||
"""Return a :class:`VideoEncoderConfig` with RGB-camera defaults."""
|
||||
return VideoEncoderConfig()
|
||||
@@ -486,6 +569,121 @@ def decode_video_frames_torchcodec(
|
||||
return closest_frames
|
||||
|
||||
|
||||
def decode_depth_frames(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
*,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
return_quantized: bool = False,
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Decode depth-map frames at the requested timestamps using PyAV.
|
||||
|
||||
Mirrors the timestamp-tolerance / closest-frame contract of
|
||||
:func:`decode_video_frames` but operates entirely through PyAV (the
|
||||
``torchvision`` and ``torchcodec`` backends don't currently round-trip
|
||||
12-bit pixel formats reliably).
|
||||
|
||||
Each decoded frame is reformatted to ``gray12le`` so the same path
|
||||
handles ``yuv420p12le`` (HEVC default) and ``gray12le`` (ffv1 archive)
|
||||
sources transparently.
|
||||
|
||||
Args:
|
||||
video_path: Path to a depth video produced with a
|
||||
:class:`DepthEncoderConfig`.
|
||||
timestamps: Frame timestamps to retrieve, in seconds.
|
||||
tolerance_s: Maximum allowed deviation between the queried and the
|
||||
actually-decoded timestamps.
|
||||
depth_min, depth_max, shift, use_log: Parameters used at quantization
|
||||
time. Should match :func:`info_to_depth_kwargs` extracted from
|
||||
``info.json`` for the source dataset.
|
||||
return_quantized: If ``True``, skip the dequantization step and
|
||||
return raw 12-bit ``uint16`` quanta.
|
||||
log_loaded_timestamps: Debug logging.
|
||||
|
||||
Returns:
|
||||
``torch.Tensor`` of shape ``(N, H, W)``:
|
||||
|
||||
* ``dtype=torch.float32`` (metric depth, default)
|
||||
* ``dtype=torch.uint16`` when ``return_quantized=True``.
|
||||
|
||||
Raises:
|
||||
FrameTimestampError: If a query timestamp can't be matched within
|
||||
*tolerance_s*, or if no frames are decoded.
|
||||
"""
|
||||
video_path_str = str(video_path)
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
|
||||
loaded_frames: list[np.ndarray] = []
|
||||
loaded_ts: list[float] = []
|
||||
|
||||
av.logging.set_level(av.logging.WARNING)
|
||||
with av.open(video_path_str, "r") as container:
|
||||
try:
|
||||
stream = container.streams.video[0]
|
||||
except IndexError as e:
|
||||
raise FrameTimestampError(f"No video stream in {video_path_str}") from e
|
||||
|
||||
# Seek to the keyframe at-or-before first_ts (PyAV doesn't do
|
||||
# accurate seek, so we still iterate forward to the requested range).
|
||||
seek_pts = int(first_ts / stream.time_base)
|
||||
container.seek(seek_pts, stream=stream, any_frame=False, backward=True)
|
||||
|
||||
for frame in container.decode(stream):
|
||||
if frame.pts is None:
|
||||
continue
|
||||
current_ts = float(frame.pts * stream.time_base)
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"depth frame loaded at timestamp={current_ts:.4f}")
|
||||
loaded_frames.append(
|
||||
decode_depth_frame(
|
||||
frame,
|
||||
depth_min=depth_min,
|
||||
depth_max=depth_max,
|
||||
shift=shift,
|
||||
use_log=use_log,
|
||||
return_quantized=True,
|
||||
)
|
||||
)
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
break
|
||||
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
if not loaded_frames:
|
||||
raise FrameTimestampError(
|
||||
f"No depth frames decoded from {video_path_str} for timestamps {timestamps}"
|
||||
)
|
||||
|
||||
query_ts = torch.tensor(timestamps)
|
||||
loaded_ts_t = torch.tensor(loaded_ts)
|
||||
dist = torch.cdist(query_ts[:, None], loaded_ts_t[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps violate the tolerance "
|
||||
f"({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts_t}"
|
||||
f"\nvideo: {video_path_str}"
|
||||
)
|
||||
|
||||
closest = np.stack([loaded_frames[i] for i in argmin_]) # (N, H, W) uint16
|
||||
quantized = torch.from_numpy(closest)
|
||||
|
||||
if return_quantized:
|
||||
return quantized
|
||||
return dequantize_depth(quantized, depth_min, depth_max, shift, use_log)
|
||||
|
||||
|
||||
def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
@@ -680,6 +878,7 @@ class _CameraEncoderThread(threading.Thread):
|
||||
frame_queue: queue.Queue,
|
||||
result_queue: queue.Queue,
|
||||
stop_event: threading.Event,
|
||||
depth_encoder_config: "DepthEncoderConfig | None" = None,
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
self.video_path = video_path
|
||||
@@ -690,13 +889,16 @@ class _CameraEncoderThread(threading.Thread):
|
||||
self.frame_queue = frame_queue
|
||||
self.result_queue = result_queue
|
||||
self.stop_event = stop_event
|
||||
self.depth_encoder_config = depth_encoder_config
|
||||
|
||||
|
||||
def run(self) -> None:
|
||||
from .compute_stats import RunningQuantileStats, auto_downsample_height_width
|
||||
|
||||
container = None
|
||||
output_stream = None
|
||||
stats_tracker = RunningQuantileStats()
|
||||
is_depth = self.depth_encoder_config is not None
|
||||
stats_tracker = RunningQuantileStats() if not is_depth else None
|
||||
frame_count = 0
|
||||
|
||||
try:
|
||||
@@ -714,12 +916,12 @@ class _CameraEncoderThread(threading.Thread):
|
||||
# Sentinel: flush and close
|
||||
break
|
||||
|
||||
# Ensure HWC uint8 numpy array
|
||||
# Ensure HWC (RGB or depth) uint8 (RGB only) numpy array
|
||||
if isinstance(frame_data, np.ndarray):
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
|
||||
# CHW -> HWC
|
||||
frame_data = frame_data.transpose(1, 2, 0)
|
||||
if frame_data.dtype != np.uint8:
|
||||
if frame_data.dtype != np.uint8 and not is_depth:
|
||||
frame_data = (frame_data * 255).astype(np.uint8)
|
||||
|
||||
# Open container on first frame (to get width/height)
|
||||
@@ -734,21 +936,25 @@ class _CameraEncoderThread(threading.Thread):
|
||||
output_stream.time_base = Fraction(1, self.fps)
|
||||
|
||||
# Encode frame with explicit timestamps
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
if is_depth:
|
||||
video_frame = encode_depth_frame_pyav(frame_data, pix_fmt=self.pix_fmt, depth_min=self.depth_encoder_config.depth_min, depth_max=self.depth_encoder_config.depth_max, shift=self.depth_encoder_config.shift, use_log=self.depth_encoder_config.use_log)
|
||||
else:
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
video_frame.pts = frame_count
|
||||
video_frame.time_base = Fraction(1, self.fps)
|
||||
packet = output_stream.encode(video_frame)
|
||||
if packet:
|
||||
container.mux(packet)
|
||||
|
||||
# Update stats with downsampled frame (per-channel stats like compute_episode_stats)
|
||||
img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW
|
||||
img_downsampled = auto_downsample_height_width(img_chw)
|
||||
# Reshape CHW to (H*W, C) for per-channel stats
|
||||
channels = img_downsampled.shape[0]
|
||||
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
|
||||
stats_tracker.update(img_for_stats)
|
||||
if not is_depth:
|
||||
# Update stats with downsampled frame (per-channel stats like compute_episode_stats)
|
||||
img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW
|
||||
img_downsampled = auto_downsample_height_width(img_chw)
|
||||
# Reshape CHW to (H*W, C) for per-channel stats
|
||||
channels = img_downsampled.shape[0]
|
||||
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
|
||||
stats_tracker.update(img_for_stats)
|
||||
|
||||
frame_count += 1
|
||||
|
||||
@@ -763,8 +969,10 @@ class _CameraEncoderThread(threading.Thread):
|
||||
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
# Get stats and put on result queue
|
||||
if frame_count >= 2:
|
||||
# Get stats and put on result queue (depth streams skip stats)
|
||||
if is_depth:
|
||||
self.result_queue.put(("ok", None))
|
||||
elif frame_count >= 2:
|
||||
stats = stats_tracker.get_statistics()
|
||||
self.result_queue.put(("ok", stats))
|
||||
else:
|
||||
@@ -797,6 +1005,8 @@ class StreamingVideoEncoder:
|
||||
encoder_threads: int | None = None,
|
||||
*,
|
||||
queue_maxsize: int = 30,
|
||||
depth_encoder_config: "DepthEncoderConfig | None" = None,
|
||||
depth_keys: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -807,11 +1017,24 @@ class StreamingVideoEncoder:
|
||||
``None`` lets the codec decide.
|
||||
queue_maxsize: Max frames to buffer per camera before
|
||||
back-pressure drops frames.
|
||||
depth_encoder_config: Optional depth encoder configuration applied
|
||||
to all depth video keys listed in ``depth_keys``.
|
||||
depth_keys: Video keys (matching the dataset feature names) that
|
||||
must be encoded as quantized depth maps using
|
||||
``depth_encoder_config``. Required when ``depth_encoder_config``
|
||||
is provided.
|
||||
"""
|
||||
self.fps = fps
|
||||
self._camera_encoder_config = camera_encoder_config or VideoEncoderConfig()
|
||||
self._encoder_threads = encoder_threads
|
||||
self.queue_maxsize = queue_maxsize
|
||||
self._depth_encoder_config = depth_encoder_config
|
||||
self._depth_keys: set[str] = set(depth_keys or [])
|
||||
if self._depth_keys and self._depth_encoder_config is None:
|
||||
raise ValueError(
|
||||
"StreamingVideoEncoder received depth_keys without a depth_encoder_config; "
|
||||
"either pass a DepthEncoderConfig or remove depth_keys."
|
||||
)
|
||||
|
||||
self._frame_queues: dict[str, queue.Queue] = {}
|
||||
self._result_queues: dict[str, queue.Queue] = {}
|
||||
@@ -842,19 +1065,28 @@ class StreamingVideoEncoder:
|
||||
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
|
||||
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
||||
|
||||
vcodec = self._camera_encoder_config.vcodec
|
||||
codec_options = self._camera_encoder_config.get_codec_options(
|
||||
self._encoder_threads, as_strings=True
|
||||
)
|
||||
is_depth_key = video_key in self._depth_keys
|
||||
encoder_cfg: VideoEncoderConfig
|
||||
depth_cfg = None
|
||||
if is_depth_key:
|
||||
assert self._depth_encoder_config is not None # guaranteed by __init__
|
||||
encoder_cfg = self._depth_encoder_config
|
||||
depth_cfg = self._depth_encoder_config
|
||||
else:
|
||||
encoder_cfg = self._camera_encoder_config
|
||||
|
||||
vcodec = encoder_cfg.vcodec
|
||||
codec_options = encoder_cfg.get_codec_options(self._encoder_threads)
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=self.fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=self._camera_encoder_config.pix_fmt,
|
||||
pix_fmt=encoder_cfg.pix_fmt,
|
||||
codec_options=codec_options,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
depth_encoder_config=depth_cfg,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
@@ -1061,13 +1293,13 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
|
||||
def get_video_info(
|
||||
video_path: Path | str,
|
||||
camera_encoder_config: "VideoEncoderConfig | None" = None,
|
||||
video_encoder_config: "VideoEncoderConfig | None" = None,
|
||||
) -> dict:
|
||||
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
|
||||
|
||||
Args:
|
||||
video_path: Path to the encoded video file to probe.
|
||||
camera_encoder_config: If provided, record the exact encoder settings used to encode this
|
||||
video_encoder_config: If provided, record the exact encoder settings used to encode this
|
||||
video. Stream-derived values take precedence — encoder fields are only written for keys
|
||||
not already populated from the video file itself.
|
||||
"""
|
||||
@@ -1087,7 +1319,6 @@ def get_video_info(
|
||||
video_info["video.width"] = video_stream.width
|
||||
video_info["video.codec"] = video_stream.codec.canonical_name
|
||||
video_info["video.pix_fmt"] = video_stream.pix_fmt
|
||||
video_info["video.is_depth_map"] = False
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
video_info["video.fps"] = int(video_stream.base_rate)
|
||||
@@ -1101,14 +1332,67 @@ def get_video_info(
|
||||
# Adding audio stream information
|
||||
video_info.update(**get_audio_info(video_path))
|
||||
|
||||
# Add additional encoder configuration if provided
|
||||
if camera_encoder_config is not None:
|
||||
for field_name, field_value in asdict(camera_encoder_config).items():
|
||||
# Add additional encoder configuration if provided (no override of stream-derived values)
|
||||
# Depth related fields flow naturally through this path.
|
||||
if video_encoder_config is not None:
|
||||
for field_name, field_value in asdict(video_encoder_config).items():
|
||||
video_info.setdefault(f"video.{field_name}", field_value)
|
||||
|
||||
# Fallback case where no encoder config is provided or the video is not a depth map.
|
||||
video_info.setdefault("video.is_depth_map", False)
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
# ─── Depth metadata helpers (reader side) ────────────────────────────
|
||||
|
||||
|
||||
_DEPTH_INFO_KEYS: tuple[str, ...] = (
|
||||
"video.depth_min",
|
||||
"video.depth_max",
|
||||
"video.shift",
|
||||
"video.use_log",
|
||||
)
|
||||
|
||||
|
||||
def seed_depth_feature_info(
|
||||
features: dict[str, dict],
|
||||
depth_encoder_config: "DepthEncoderConfig | None",
|
||||
) -> None:
|
||||
"""Pre-populate per-feature ``video.<field>`` entries from *depth_encoder_config*.
|
||||
|
||||
``update_video_info`` only runs after the first episode video is encoded,
|
||||
so without this seeding step ``features[key]["info"]`` carries no
|
||||
quantization range until then. Consumers that read the dataset feature
|
||||
spec mid-recording (e.g. the rerun visualizer pinning the depth colormap
|
||||
to ``video.depth_min`` / ``video.depth_max``) would otherwise see no
|
||||
range during episode 1 and re-normalize per frame.
|
||||
|
||||
Stream-derived values written later by :func:`get_video_info` /
|
||||
``update_video_info`` win over these seeds (the merge is
|
||||
``{**existing, **stream_info}``), so callers can safely re-run this on
|
||||
a partially-populated info dict.
|
||||
|
||||
No-op when ``depth_encoder_config`` is ``None`` or no feature is flagged
|
||||
as a depth map.
|
||||
"""
|
||||
if depth_encoder_config is None:
|
||||
return
|
||||
encoder_fields = {
|
||||
f"video.{name}": value for name, value in asdict(depth_encoder_config).items()
|
||||
}
|
||||
for ft in features.values():
|
||||
if ft.get("dtype") != "video":
|
||||
continue
|
||||
info = ft.get("info") or {}
|
||||
if not info.get("video.is_depth_map", False):
|
||||
continue
|
||||
# Only fill fields not already set, so explicit user-provided info is preserved.
|
||||
for k, v in encoder_fields.items():
|
||||
info.setdefault(k, v)
|
||||
ft["info"] = info
|
||||
|
||||
|
||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||
return 1
|
||||
|
||||
@@ -68,9 +68,16 @@ class SOFollower(Robot):
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
features: dict[str, tuple] = {}
|
||||
for cam in self.cameras:
|
||||
cam_cfg = self.config.cameras[cam]
|
||||
features[cam] = (cam_cfg.height, cam_cfg.width, 3)
|
||||
# Cameras with a depth stream (e.g. RealSense with use_depth=True) also
|
||||
# emit a 2D depth feature; hw_to_dataset_features routes 2D shapes to
|
||||
# ``observation.depth.<bare>`` with the depth-map marker.
|
||||
if getattr(cam_cfg, "use_depth", False):
|
||||
features[f"{cam}_depth"] = (cam_cfg.height, cam_cfg.width)
|
||||
return features
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
@@ -190,6 +197,14 @@ class SOFollower(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
# Cameras with a depth stream populate a sibling ``<cam>_depth`` key
|
||||
# (consumed by hw_to_dataset_features / build_dataset_frame).
|
||||
if getattr(self.config.cameras[cam_key], "use_depth", False):
|
||||
start = time.perf_counter()
|
||||
obs_dict[f"{cam_key}_depth"] = cam.read_latest_depth()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key} depth: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
|
||||
@@ -104,10 +104,12 @@ from lerobot.common.control_utils import (
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.datasets import (
|
||||
DepthEncoderConfig,
|
||||
LeRobotDataset,
|
||||
VideoEncodingManager,
|
||||
aggregate_pipeline_dataset_features,
|
||||
create_initial_features,
|
||||
depth_encoder_defaults,
|
||||
safe_stop_image_writer,
|
||||
)
|
||||
from lerobot.processor import (
|
||||
@@ -326,7 +328,10 @@ def record_loop(
|
||||
|
||||
if display_data:
|
||||
log_rerun_data(
|
||||
observation=obs_processed, action=action_values, compress_images=display_compressed_images
|
||||
observation=obs_processed,
|
||||
action=action_values,
|
||||
compress_images=display_compressed_images,
|
||||
features=dataset.features if dataset is not None else None,
|
||||
)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
@@ -399,6 +404,7 @@ def record(
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
camera_encoder_config=cfg.dataset.camera_encoder_config,
|
||||
depth_encoder_config=cfg.dataset.depth_encoder_config,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
@@ -428,6 +434,7 @@ def record(
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
camera_encoder_config=cfg.dataset.camera_encoder_config,
|
||||
depth_encoder_config=cfg.dataset.depth_encoder_config,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
|
||||
@@ -86,11 +86,24 @@ def hw_to_dataset_features(
|
||||
}
|
||||
|
||||
for key, shape in cam_fts.items():
|
||||
features[f"{prefix}.images.{key}"] = {
|
||||
"dtype": "video" if use_video else "image",
|
||||
"shape": shape,
|
||||
"names": ["height", "width", "channels"],
|
||||
}
|
||||
if len(shape) == 2:
|
||||
# Single-channel feature (e.g. depth map). The hardware-side key is
|
||||
# expected to use a "_depth" suffix to disambiguate from its color
|
||||
# counterpart; we strip it so the dataset feature is published as
|
||||
# ``{prefix}.depth.<bare>`` and aligned with ``observation.images.<bare>``.
|
||||
bare = key.removesuffix("_depth") if key.endswith("_depth") else key
|
||||
features[f"{prefix}.depth.{bare}"] = {
|
||||
"dtype": "video" if use_video else "image",
|
||||
"shape": shape,
|
||||
"names": ["height", "width"],
|
||||
"info": {"video.is_depth_map": True},
|
||||
}
|
||||
else:
|
||||
features[f"{prefix}.images.{key}"] = {
|
||||
"dtype": "video" if use_video else "image",
|
||||
"shape": shape,
|
||||
"names": ["height", "width", "channels"],
|
||||
}
|
||||
|
||||
_validate_feature_names(features)
|
||||
return features
|
||||
@@ -120,7 +133,14 @@ def build_dataset_frame(
|
||||
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
|
||||
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||
if key.startswith(f"{prefix}.depth."):
|
||||
bare = key.removeprefix(f"{prefix}.depth.")
|
||||
# Hardware emits depth values under "<bare>_depth" to disambiguate
|
||||
# from the color stream stored at "<bare>" — fall back to the bare
|
||||
# name when the producer already uses dataset-style keys.
|
||||
frame[key] = values.get(f"{bare}_depth", values.get(bare))
|
||||
else:
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
@@ -63,10 +63,56 @@ def _is_scalar(x):
|
||||
)
|
||||
|
||||
|
||||
def _derive_depth_obs_ranges(
|
||||
features: dict[str, dict] | None,
|
||||
) -> dict[str, tuple[float, float] | None]:
|
||||
"""Map observation keys of depth features to their ``(depth_min, depth_max)`` range.
|
||||
|
||||
A feature is considered a depth map when its ``info`` dict carries
|
||||
``video.is_depth_map=True`` (the marker set by ``hw_to_dataset_features``
|
||||
and persisted in ``info.json``). For each such feature, we record both
|
||||
the fully-namespaced dataset key (e.g. ``observation.depth.front``) and
|
||||
the corresponding raw observation key forms the robot is likely to emit
|
||||
(``front`` and ``front_depth``) so a single membership check covers all
|
||||
call sites.
|
||||
|
||||
The mapped value is the ``(depth_min, depth_max)`` range stored on the
|
||||
feature (matching the quantization range used at encoding time), or
|
||||
``None`` when the metadata doesn't expose a range — in which case the
|
||||
caller should let Rerun auto-normalize. Anchoring the colormap to a
|
||||
fixed range avoids per-frame re-normalization, which otherwise looks
|
||||
like flicker on near-static scenes.
|
||||
"""
|
||||
ranges: dict[str, tuple[float, float] | None] = {}
|
||||
if not features:
|
||||
return ranges
|
||||
depth_prefix = f"{OBS_STR}.depth."
|
||||
for fk, fv in features.items():
|
||||
info = fv.get("info") if isinstance(fv, dict) else None
|
||||
if not isinstance(info, dict) or not info.get("video.is_depth_map", False):
|
||||
continue
|
||||
depth_min = info.get("video.depth_min")
|
||||
depth_max = info.get("video.depth_max")
|
||||
rng: tuple[float, float] | None = None
|
||||
if (
|
||||
isinstance(depth_min, (int, float))
|
||||
and isinstance(depth_max, (int, float))
|
||||
and depth_max > depth_min
|
||||
):
|
||||
rng = (float(depth_min), float(depth_max))
|
||||
ranges[fk] = rng
|
||||
if fk.startswith(depth_prefix):
|
||||
bare = fk[len(depth_prefix) :]
|
||||
ranges[bare] = rng
|
||||
ranges[f"{bare}_depth"] = rng
|
||||
return ranges
|
||||
|
||||
|
||||
def log_rerun_data(
|
||||
observation: RobotObservation | None = None,
|
||||
action: RobotAction | None = None,
|
||||
compress_images: bool = False,
|
||||
features: dict[str, dict] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Logs observation and action data to Rerun for real-time visualization.
|
||||
@@ -76,6 +122,13 @@ def log_rerun_data(
|
||||
- Scalars values (floats, ints) are logged as `rr.Scalars`.
|
||||
- 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed
|
||||
from CHW to HWC format, (optionally) compressed to JPEG and logged as `rr.Image` or `rr.EncodedImage`.
|
||||
- 2D NumPy arrays whose key matches a depth feature in ``features`` (i.e. carrying
|
||||
``video.is_depth_map=True``) are logged as ``rr.DepthImage`` with the Viridis
|
||||
colormap and ``meter=1.0`` (depth values are expected in metric meters). When
|
||||
the feature exposes ``video.depth_min`` / ``video.depth_max`` (the encoder
|
||||
quantization range, persisted in ``info.json``), the colormap is anchored to
|
||||
that range via ``depth_range`` to keep the visualization stable across frames.
|
||||
Depth images are never JPEG-compressed regardless of ``compress_images``.
|
||||
- 1D NumPy arrays are logged as a series of individual scalars, with each element indexed.
|
||||
- Other multi-dimensional arrays are flattened and logged as individual scalars.
|
||||
|
||||
@@ -85,11 +138,16 @@ def log_rerun_data(
|
||||
observation: An optional dictionary containing observation data to log.
|
||||
action: An optional dictionary containing action data to log.
|
||||
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
|
||||
features: Optional dataset feature spec (e.g. ``LeRobotDataset.features``). When
|
||||
provided, observation entries matching a depth-map feature are rendered with
|
||||
``rr.DepthImage`` instead of the generic ``rr.Image`` path.
|
||||
"""
|
||||
|
||||
require_package("rerun-sdk", extra="viz", import_name="rerun")
|
||||
import rerun as rr
|
||||
|
||||
depth_obs_ranges = _derive_depth_obs_ranges(features)
|
||||
|
||||
if observation:
|
||||
for k, v in observation.items():
|
||||
if v is None:
|
||||
@@ -100,6 +158,20 @@ def log_rerun_data(
|
||||
rr.log(key, rr.Scalars(float(v)))
|
||||
elif isinstance(v, np.ndarray):
|
||||
arr = v
|
||||
is_depth = bool(depth_obs_ranges) and (k in depth_obs_ranges or key in depth_obs_ranges)
|
||||
if is_depth and arr.ndim == 2:
|
||||
# Viridis-colormapped DepthImage; never JPEG-compress (lossy on float metric depth).
|
||||
# Anchor the colormap to the encoder range when available, so the
|
||||
# visualization doesn't flicker as per-frame min/max drift.
|
||||
depth_range = depth_obs_ranges.get(k) or depth_obs_ranges.get(key)
|
||||
depth_kwargs: dict = {
|
||||
"meter": 1.0,
|
||||
"colormap": rr.components.Colormap.Viridis,
|
||||
}
|
||||
if depth_range is not None:
|
||||
depth_kwargs["depth_range"] = depth_range
|
||||
rr.log(key, rr.DepthImage(arr, **depth_kwargs), static=True)
|
||||
continue
|
||||
# Convert CHW -> HWC when needed
|
||||
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
|
||||
arr = np.transpose(arr, (1, 2, 0))
|
||||
|
||||
@@ -202,6 +202,31 @@ def test_read_latest_too_old():
|
||||
_ = camera.read_latest(max_age_ms=0) # immediately too old
|
||||
|
||||
|
||||
def test_async_read_depth_without_use_depth_raises():
|
||||
"""``async_read_depth`` must reject cameras configured without ``use_depth=True``."""
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", warmup_s=0)
|
||||
with RealSenseCamera(config) as camera, pytest.raises(RuntimeError, match="use_depth=False"):
|
||||
_ = camera.async_read_depth()
|
||||
|
||||
|
||||
def test_read_latest_depth_without_use_depth_raises():
|
||||
"""``read_latest_depth`` must reject cameras configured without ``use_depth=True``."""
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", warmup_s=0)
|
||||
with RealSenseCamera(config) as camera, pytest.raises(RuntimeError, match="use_depth=False"):
|
||||
_ = camera.read_latest_depth()
|
||||
|
||||
|
||||
def test_depth_to_meters_uses_depth_scale():
|
||||
"""``_depth_to_meters`` must scale uint16 raw depth into float32 metric meters."""
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", warmup_s=0)
|
||||
camera = RealSenseCamera(config)
|
||||
camera.depth_scale = 0.001 # typical D-series scale (1 mm/unit)
|
||||
raw = np.array([[0, 1000, 2500], [4095, 65535, 0]], dtype=np.uint16)
|
||||
meters = camera._depth_to_meters(raw)
|
||||
assert meters.dtype == np.float32
|
||||
np.testing.assert_allclose(meters, raw.astype(np.float32) * 0.001)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"rotation",
|
||||
[
|
||||
|
||||
@@ -142,6 +142,36 @@ def test_create_without_videos_has_no_video_path(tmp_path):
|
||||
assert meta.video_keys == []
|
||||
|
||||
|
||||
def test_depth_keys_property_filters_by_marker(tmp_path):
|
||||
"""``depth_keys`` selects only video features carrying ``video.is_depth_map=True``."""
|
||||
features = {
|
||||
**SIMPLE_FEATURES,
|
||||
"observation.images.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
},
|
||||
"observation.depth.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96),
|
||||
"names": ["height", "width"],
|
||||
"info": {"video.is_depth_map": True},
|
||||
},
|
||||
}
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/depth_keys", fps=DEFAULT_FPS, features=features, root=tmp_path / "depth_keys"
|
||||
)
|
||||
|
||||
assert set(meta.video_keys) == {"observation.images.cam", "observation.depth.cam"}
|
||||
assert meta.depth_keys == ["observation.depth.cam"]
|
||||
|
||||
def test_depth_keys_empty_when_no_marker(tmp_path):
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/no_depth", fps=DEFAULT_FPS, features=VIDEO_FEATURES, root=tmp_path / "no_depth"
|
||||
)
|
||||
assert meta.depth_keys == []
|
||||
|
||||
def test_create_raises_on_existing_directory(tmp_path):
|
||||
"""create() raises if root directory already exists."""
|
||||
root = tmp_path / "existing"
|
||||
|
||||
@@ -1483,7 +1483,8 @@ def test_valid_video_codecs_constant():
|
||||
assert "auto" in VALID_VIDEO_CODECS
|
||||
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
|
||||
assert "h264_nvenc" in VALID_VIDEO_CODECS
|
||||
assert len(VALID_VIDEO_CODECS) == 10
|
||||
assert "ffv1" in VALID_VIDEO_CODECS
|
||||
assert len(VALID_VIDEO_CODECS) == 11
|
||||
|
||||
|
||||
def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
|
||||
|
||||
@@ -93,9 +93,32 @@ def test_image_array_to_pil_image_pytorch_format(img_array_factory):
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_single_channel(img_array_factory):
|
||||
# Single-channel inputs are routed to grayscale mode for raw depth maps.
|
||||
img_array = img_array_factory(channels=1)
|
||||
with pytest.raises(NotImplementedError):
|
||||
image_array_to_pil_image(img_array)
|
||||
result_image = image_array_to_pil_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "L"
|
||||
assert np.array_equal(np.array(result_image), img_array.squeeze(-1))
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_single_channel_uint16(img_array_factory):
|
||||
img_array = img_array_factory(channels=1, dtype=np.uint16)
|
||||
result_image = image_array_to_pil_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "I;16"
|
||||
# Bit-perfect: no rescaling, no clipping.
|
||||
assert np.array_equal(np.array(result_image), img_array.squeeze(-1))
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_single_channel_float32(img_array_factory):
|
||||
img_array = img_array_factory(channels=1, dtype=np.float32)
|
||||
result_image = image_array_to_pil_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "F"
|
||||
assert np.array_equal(np.array(result_image), img_array.squeeze(-1))
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_4_channels(img_array_factory):
|
||||
@@ -141,6 +164,28 @@ def test_write_image_image(tmp_path, img_factory):
|
||||
assert np.array_equal(image_pil, saved_image)
|
||||
|
||||
|
||||
def test_write_image_tiff_uint16_bitperfect(tmp_path):
|
||||
"""16-bit grayscale TIFF round-trips bit-perfectly (raw depth maps)."""
|
||||
image_array = np.random.randint(0, 65535, size=(32, 48), dtype=np.uint16)
|
||||
fpath = tmp_path / "depth.tiff"
|
||||
write_image(image_array, fpath)
|
||||
assert fpath.exists()
|
||||
saved = np.array(Image.open(fpath))
|
||||
assert saved.dtype == np.uint16
|
||||
assert np.array_equal(saved, image_array)
|
||||
|
||||
|
||||
def test_write_image_tiff_float32_bitperfect(tmp_path):
|
||||
"""Float32 TIFF round-trips bit-perfectly (metric depth in meters)."""
|
||||
image_array = np.random.uniform(0.05, 4.0, size=(32, 48)).astype(np.float32)
|
||||
fpath = tmp_path / "depth.tiff"
|
||||
write_image(image_array, fpath)
|
||||
assert fpath.exists()
|
||||
saved = np.array(Image.open(fpath))
|
||||
assert saved.dtype == np.float32
|
||||
assert np.array_equal(saved, image_array)
|
||||
|
||||
|
||||
def test_write_image_exception(tmp_path):
|
||||
image_array = "invalid data"
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
|
||||
@@ -436,6 +436,37 @@ def test_add_frame_works_in_write_mode(tmp_path):
|
||||
dataset.add_frame(_make_frame()) # should not raise
|
||||
|
||||
|
||||
# ── Depth-feature plumbing ───────────────────────────────────────────
|
||||
|
||||
|
||||
_DEPTH_FEATURES = {
|
||||
**SIMPLE_FEATURES,
|
||||
"observation.depth": {
|
||||
"dtype": "video",
|
||||
"shape": (32, 32),
|
||||
"names": ["height", "width"],
|
||||
"info": {"video.is_depth_map": True},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_create_with_depth_streaming_succeeds(tmp_path):
|
||||
"""A depth dataset with streaming_encoding=True is created in write mode."""
|
||||
from lerobot.datasets.video_utils import DepthEncoderConfig
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID,
|
||||
fps=DEFAULT_FPS,
|
||||
features=_DEPTH_FEATURES,
|
||||
root=tmp_path / "depth_ds",
|
||||
depth_encoder_config=DepthEncoderConfig(),
|
||||
streaming_encoding=True,
|
||||
)
|
||||
assert isinstance(dataset.writer, DatasetWriter)
|
||||
assert dataset.meta.depth_keys == ["observation.depth"]
|
||||
assert dataset._depth_encoder_config is not None
|
||||
|
||||
|
||||
# ── Resume mode ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -311,6 +311,18 @@ class TestEncoderDetection:
|
||||
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
|
||||
assert "h264_nvenc" in VALID_VIDEO_CODECS
|
||||
|
||||
def test_av1_alias_resolves_to_libsvtav1(self):
|
||||
"""Older datasets persist ``vcodec="av1"``; backward-compat alias must keep them loadable."""
|
||||
cfg = VideoEncoderConfig(vcodec="av1")
|
||||
assert cfg.vcodec == "libsvtav1"
|
||||
|
||||
def test_av1_alias_persisted_after_resolve(self):
|
||||
"""Repeated calls to ``resolve_vcodec`` should be idempotent (alias only fires once)."""
|
||||
cfg = VideoEncoderConfig(vcodec="av1")
|
||||
assert cfg.vcodec == "libsvtav1"
|
||||
cfg.resolve_vcodec()
|
||||
assert cfg.vcodec == "libsvtav1"
|
||||
|
||||
|
||||
ARTIFACTS = Path(__file__).parent.parent / "fixtures" / "artifacts" / "videos"
|
||||
|
||||
|
||||
77
tests/utils/test_feature_utils.py
Normal file
77
tests/utils/test_feature_utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit tests for ``lerobot.utils.feature_utils``."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
|
||||
|
||||
def test_hw_to_dataset_features_routes_3d_shape_to_images():
|
||||
hw = {"front": (480, 640, 3)}
|
||||
out = hw_to_dataset_features(hw, OBS_STR, use_video=True)
|
||||
|
||||
assert "observation.images.front" in out
|
||||
assert out["observation.images.front"]["dtype"] == "video"
|
||||
assert out["observation.images.front"]["shape"] == (480, 640, 3)
|
||||
assert out["observation.images.front"]["names"] == ["height", "width", "channels"]
|
||||
assert "info" not in out["observation.images.front"]
|
||||
|
||||
|
||||
def test_hw_to_dataset_features_routes_2d_shape_to_depth():
|
||||
hw = {"front_depth": (480, 640)}
|
||||
out = hw_to_dataset_features(hw, OBS_STR, use_video=True)
|
||||
|
||||
assert "observation.depth.front" in out, out
|
||||
feat = out["observation.depth.front"]
|
||||
assert feat["dtype"] == "video"
|
||||
assert feat["shape"] == (480, 640)
|
||||
assert feat["names"] == ["height", "width"]
|
||||
assert feat["info"] == {"video.is_depth_map": True}
|
||||
|
||||
|
||||
def test_hw_to_dataset_features_handles_paired_color_and_depth():
|
||||
"""A camera with use_depth=True is expected to emit both keys."""
|
||||
hw = {"front": (480, 640, 3), "front_depth": (480, 640)}
|
||||
out = hw_to_dataset_features(hw, OBS_STR, use_video=True)
|
||||
|
||||
assert set(out) == {"observation.images.front", "observation.depth.front"}
|
||||
assert out["observation.images.front"]["shape"] == (480, 640, 3)
|
||||
assert out["observation.depth.front"]["shape"] == (480, 640)
|
||||
|
||||
|
||||
def test_hw_to_dataset_features_keeps_bare_2d_key_when_no_suffix():
|
||||
"""If the producer didn't use a "_depth" suffix, the bare name flows through."""
|
||||
hw = {"top": (240, 320)}
|
||||
out = hw_to_dataset_features(hw, OBS_STR, use_video=True)
|
||||
|
||||
assert "observation.depth.top" in out
|
||||
|
||||
|
||||
def test_build_dataset_frame_routes_depth_values():
|
||||
ds_features = hw_to_dataset_features(
|
||||
{"front": (4, 6, 3), "front_depth": (4, 6)},
|
||||
OBS_STR,
|
||||
use_video=True,
|
||||
)
|
||||
rgb = np.zeros((4, 6, 3), dtype=np.uint8)
|
||||
depth = np.full((4, 6), 0.5, dtype=np.float32)
|
||||
values = {"front": rgb, "front_depth": depth}
|
||||
|
||||
frame = build_dataset_frame(ds_features, values, OBS_STR)
|
||||
assert frame["observation.images.front"] is rgb
|
||||
assert frame["observation.depth.front"] is depth
|
||||
@@ -43,18 +43,32 @@ def mock_rerun(monkeypatch):
|
||||
def __init__(self, arr):
|
||||
self.arr = arr
|
||||
|
||||
class DummyDepthImage:
|
||||
def __init__(self, arr, meter=None, colormap=None, **kwargs):
|
||||
self.arr = arr
|
||||
self.meter = meter
|
||||
self.colormap = colormap
|
||||
self.kwargs = kwargs
|
||||
|
||||
def dummy_log(key, obj=None, **kwargs):
|
||||
# Accept either positional `obj` or keyword `entity` and record remaining kwargs.
|
||||
if obj is None and "entity" in kwargs:
|
||||
obj = kwargs.pop("entity")
|
||||
calls.append((key, obj, kwargs))
|
||||
|
||||
class _Colormap:
|
||||
Viridis = "viridis"
|
||||
|
||||
dummy_components = SimpleNamespace(Colormap=_Colormap)
|
||||
|
||||
dummy_rr = SimpleNamespace(
|
||||
__name__="rerun",
|
||||
__package__="rerun",
|
||||
__spec__=SimpleNamespace(name="rerun", submodule_search_locations=None),
|
||||
Scalars=DummyScalar,
|
||||
Image=DummyImage,
|
||||
DepthImage=DummyDepthImage,
|
||||
components=dummy_components,
|
||||
log=dummy_log,
|
||||
init=lambda *a, **k: None,
|
||||
spawn=lambda *a, **k: None,
|
||||
@@ -232,3 +246,122 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
|
||||
a = _obj_for(calls, "action.a")
|
||||
assert type(a).__name__ == "DummyScalar"
|
||||
assert a.value == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_log_rerun_data_routes_depth_to_depth_image(mock_rerun):
|
||||
"""A 2D float depth obs whose key matches a depth feature must use ``rr.DepthImage``.
|
||||
|
||||
Without ``video.depth_min``/``video.depth_max`` in the feature info, the
|
||||
visualizer should not pass a ``depth_range`` (Rerun then auto-normalizes).
|
||||
"""
|
||||
vu, calls = mock_rerun
|
||||
|
||||
features = {
|
||||
"observation.images.front": {
|
||||
"dtype": "video",
|
||||
"shape": (480, 640, 3),
|
||||
"info": {"video.is_depth_map": False},
|
||||
},
|
||||
"observation.depth.front": {
|
||||
"dtype": "video",
|
||||
"shape": (480, 640),
|
||||
"info": {"video.is_depth_map": True},
|
||||
},
|
||||
}
|
||||
obs = {
|
||||
"front": np.zeros((10, 12, 3), dtype=np.uint8),
|
||||
"front_depth": np.full((10, 12), 0.7, dtype=np.float32),
|
||||
}
|
||||
|
||||
vu.log_rerun_data(observation=obs, features=features)
|
||||
|
||||
rgb = _obj_for(calls, "observation.front")
|
||||
assert type(rgb).__name__ == "DummyImage"
|
||||
|
||||
depth = _obj_for(calls, "observation.front_depth")
|
||||
assert type(depth).__name__ == "DummyDepthImage"
|
||||
assert depth.arr.shape == (10, 12)
|
||||
assert depth.meter == pytest.approx(1.0)
|
||||
assert depth.colormap == "viridis"
|
||||
# No range available -> Rerun should auto-normalize; we must not pass `depth_range`.
|
||||
assert "depth_range" not in depth.kwargs
|
||||
assert _kwargs_for(calls, "observation.front_depth").get("static", False) is True
|
||||
|
||||
|
||||
def test_log_rerun_data_depth_range_anchored_from_info(mock_rerun):
|
||||
"""When ``video.depth_min``/``depth_max`` are present, ``depth_range`` is forwarded."""
|
||||
vu, calls = mock_rerun
|
||||
|
||||
features = {
|
||||
"observation.depth.front": {
|
||||
"dtype": "video",
|
||||
"shape": (480, 640),
|
||||
"info": {
|
||||
"video.is_depth_map": True,
|
||||
"video.depth_min": 0.05,
|
||||
"video.depth_max": 4.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
obs = {"front_depth": np.full((10, 12), 0.5, dtype=np.float32)}
|
||||
|
||||
vu.log_rerun_data(observation=obs, features=features)
|
||||
|
||||
depth = _obj_for(calls, "observation.front_depth")
|
||||
assert type(depth).__name__ == "DummyDepthImage"
|
||||
assert depth.kwargs.get("depth_range") == (0.05, 4.0)
|
||||
|
||||
|
||||
def test_log_rerun_data_depth_range_ignored_when_invalid(mock_rerun):
|
||||
"""A degenerate range (max <= min, or non-numeric) must be discarded silently."""
|
||||
vu, calls = mock_rerun
|
||||
|
||||
features = {
|
||||
"observation.depth.front": {
|
||||
"dtype": "video",
|
||||
"shape": (480, 640),
|
||||
"info": {
|
||||
"video.is_depth_map": True,
|
||||
"video.depth_min": 1.0,
|
||||
"video.depth_max": 1.0, # degenerate
|
||||
},
|
||||
},
|
||||
}
|
||||
obs = {"front_depth": np.full((10, 12), 0.5, dtype=np.float32)}
|
||||
|
||||
vu.log_rerun_data(observation=obs, features=features)
|
||||
|
||||
depth = _obj_for(calls, "observation.front_depth")
|
||||
assert type(depth).__name__ == "DummyDepthImage"
|
||||
assert "depth_range" not in depth.kwargs
|
||||
|
||||
|
||||
def test_log_rerun_data_depth_skips_compression(mock_rerun):
|
||||
"""Depth frames must never be JPEG-compressed even when ``compress_images=True``."""
|
||||
vu, calls = mock_rerun
|
||||
|
||||
features = {
|
||||
"observation.depth.front": {
|
||||
"dtype": "video",
|
||||
"shape": (8, 8),
|
||||
"info": {"video.is_depth_map": True},
|
||||
},
|
||||
}
|
||||
obs = {"front_depth": np.full((8, 8), 0.5, dtype=np.float32)}
|
||||
|
||||
vu.log_rerun_data(observation=obs, features=features, compress_images=True)
|
||||
|
||||
depth = _obj_for(calls, "observation.front_depth")
|
||||
assert type(depth).__name__ == "DummyDepthImage"
|
||||
|
||||
|
||||
def test_log_rerun_data_no_features_falls_back_to_image(mock_rerun):
|
||||
"""Without ``features``, a 2D array still goes through the generic Image path (no depth detection)."""
|
||||
vu, calls = mock_rerun
|
||||
|
||||
obs = {"front_depth": np.zeros((8, 8), dtype=np.float32)}
|
||||
|
||||
vu.log_rerun_data(observation=obs)
|
||||
|
||||
obj = _obj_for(calls, "observation.front_depth")
|
||||
assert type(obj).__name__ == "DummyImage"
|
||||
|
||||
Reference in New Issue
Block a user