mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
feat(features): route 2D camera shapes to observation.depth.<key>
hw_to_dataset_features now treats a camera entry whose shape has
length 2 as a single-channel depth feature: it emits the feature as
"{prefix}.depth.<bare>" with names=["height", "width"] and an
info={"video.is_depth_map": True} marker so the depth-encoder branch
in LeRobotDataset is engaged. The "_depth" hardware-side suffix (if
present) is stripped so a paired RGB + depth camera ends up as
"observation.images.<cam>" + "observation.depth.<cam>".
build_dataset_frame mirrors the routing: depth feature keys read
their value from "<bare>_depth" in the raw observation dict, with
fallback to the bare name for producers that already emit
dataset-style keys.
Tests: add tests/utils/test_feature_utils.py covering the routing
of 2D vs 3D camera shapes, the paired RGB+depth case, and the
build_dataset_frame value routing.
Made-with: Cursor
This commit is contained in:
@@ -86,11 +86,24 @@ def hw_to_dataset_features(
|
||||
}
|
||||
|
||||
for key, shape in cam_fts.items():
|
||||
features[f"{prefix}.images.{key}"] = {
|
||||
"dtype": "video" if use_video else "image",
|
||||
"shape": shape,
|
||||
"names": ["height", "width", "channels"],
|
||||
}
|
||||
if len(shape) == 2:
|
||||
# Single-channel feature (e.g. depth map). The hardware-side key is
|
||||
# expected to use a "_depth" suffix to disambiguate from its color
|
||||
# counterpart; we strip it so the dataset feature is published as
|
||||
# ``{prefix}.depth.<bare>`` and aligned with ``observation.images.<bare>``.
|
||||
bare = key.removesuffix("_depth") if key.endswith("_depth") else key
|
||||
features[f"{prefix}.depth.{bare}"] = {
|
||||
"dtype": "video" if use_video else "image",
|
||||
"shape": shape,
|
||||
"names": ["height", "width"],
|
||||
"info": {"video.is_depth_map": True},
|
||||
}
|
||||
else:
|
||||
features[f"{prefix}.images.{key}"] = {
|
||||
"dtype": "video" if use_video else "image",
|
||||
"shape": shape,
|
||||
"names": ["height", "width", "channels"],
|
||||
}
|
||||
|
||||
_validate_feature_names(features)
|
||||
return features
|
||||
@@ -120,7 +133,14 @@ def build_dataset_frame(
|
||||
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
|
||||
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||
if key.startswith(f"{prefix}.depth."):
|
||||
bare = key.removeprefix(f"{prefix}.depth.")
|
||||
# Hardware emits depth values under "<bare>_depth" to disambiguate
|
||||
# from the color stream stored at "<bare>" — fall back to the bare
|
||||
# name when the producer already uses dataset-style keys.
|
||||
frame[key] = values.get(f"{bare}_depth", values.get(bare))
|
||||
else:
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
77
tests/utils/test_feature_utils.py
Normal file
77
tests/utils/test_feature_utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit tests for ``lerobot.utils.feature_utils``."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
|
||||
|
||||
def test_hw_to_dataset_features_routes_3d_shape_to_images():
|
||||
hw = {"front": (480, 640, 3)}
|
||||
out = hw_to_dataset_features(hw, OBS_STR, use_video=True)
|
||||
|
||||
assert "observation.images.front" in out
|
||||
assert out["observation.images.front"]["dtype"] == "video"
|
||||
assert out["observation.images.front"]["shape"] == (480, 640, 3)
|
||||
assert out["observation.images.front"]["names"] == ["height", "width", "channels"]
|
||||
assert "info" not in out["observation.images.front"]
|
||||
|
||||
|
||||
def test_hw_to_dataset_features_routes_2d_shape_to_depth():
|
||||
hw = {"front_depth": (480, 640)}
|
||||
out = hw_to_dataset_features(hw, OBS_STR, use_video=True)
|
||||
|
||||
assert "observation.depth.front" in out, out
|
||||
feat = out["observation.depth.front"]
|
||||
assert feat["dtype"] == "video"
|
||||
assert feat["shape"] == (480, 640)
|
||||
assert feat["names"] == ["height", "width"]
|
||||
assert feat["info"] == {"video.is_depth_map": True}
|
||||
|
||||
|
||||
def test_hw_to_dataset_features_handles_paired_color_and_depth():
|
||||
"""A camera with use_depth=True is expected to emit both keys."""
|
||||
hw = {"front": (480, 640, 3), "front_depth": (480, 640)}
|
||||
out = hw_to_dataset_features(hw, OBS_STR, use_video=True)
|
||||
|
||||
assert set(out) == {"observation.images.front", "observation.depth.front"}
|
||||
assert out["observation.images.front"]["shape"] == (480, 640, 3)
|
||||
assert out["observation.depth.front"]["shape"] == (480, 640)
|
||||
|
||||
|
||||
def test_hw_to_dataset_features_keeps_bare_2d_key_when_no_suffix():
|
||||
"""If the producer didn't use a "_depth" suffix, the bare name flows through."""
|
||||
hw = {"top": (240, 320)}
|
||||
out = hw_to_dataset_features(hw, OBS_STR, use_video=True)
|
||||
|
||||
assert "observation.depth.top" in out
|
||||
|
||||
|
||||
def test_build_dataset_frame_routes_depth_values():
|
||||
ds_features = hw_to_dataset_features(
|
||||
{"front": (4, 6, 3), "front_depth": (4, 6)},
|
||||
OBS_STR,
|
||||
use_video=True,
|
||||
)
|
||||
rgb = np.zeros((4, 6, 3), dtype=np.uint8)
|
||||
depth = np.full((4, 6), 0.5, dtype=np.float32)
|
||||
values = {"front": rgb, "front_depth": depth}
|
||||
|
||||
frame = build_dataset_frame(ds_features, values, OBS_STR)
|
||||
assert frame["observation.images.front"] is rgb
|
||||
assert frame["observation.depth.front"] is depth
|
||||
Reference in New Issue
Block a user