feat(depth): wire DatasetReader to decode_depth_frames

This commit is contained in:
CarolinePascal
2026-05-19 23:46:28 +02:00
parent d39698da0f
commit e51d45dd2c
2 changed files with 33 additions and 4 deletions

View File

@@ -23,6 +23,7 @@ import datasets
import torch
from .dataset_metadata import LeRobotDatasetMetadata
from .depth_utils import dequantize_depth
from .feature_utils import (
check_delta_timestamps,
get_delta_indices,
@@ -86,6 +87,18 @@ class DatasetReader:
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
if self._meta.depth_keys:
# TODO(CarolinePascal): make this decent, this is awful.
self._dequantize_depth_configs = {
vid_key: {
"depth_min": self._meta.features[vid_key]["info"]["video.depth_min"],
"depth_max": self._meta.features[vid_key]["info"]["video.depth_max"],
"shift": self._meta.features[vid_key]["info"]["video.shift"],
"use_log": self._meta.features[vid_key]["info"]["video.use_log"],
}
for vid_key in self._meta.depth_keys
}
def try_load(self) -> bool:
"""Attempt to load from local cache. Returns True if data is sufficient."""
try:
@@ -247,7 +260,12 @@ class DatasetReader:
self._tolerance_s,
self._video_backend,
return_uint8=self._return_uint8,
is_depth=vid_key in self._meta.depth_keys,
)
if vid_key in self._meta.depth_keys:
frames = dequantize_depth(
frames, **self._dequantize_depth_configs[vid_key], output_tensor=True
)
return vid_key, frames.squeeze(0)
items = list(query_timestamps.items())