mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
feat(depth maps writer): adding support for raw depth maps recording with image writer
This commit is contained in:
@@ -46,6 +46,7 @@ from .io_utils import (
|
||||
write_info,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_DEPTH_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
update_chunk_file_indices,
|
||||
@@ -57,6 +58,7 @@ from .video_utils import (
|
||||
concatenate_video_files,
|
||||
encode_video_frames,
|
||||
get_video_duration_in_s,
|
||||
is_depth_feature,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -149,8 +151,16 @@ class DatasetWriter:
|
||||
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
|
||||
return ep_buffer
|
||||
|
||||
def _is_depth_image_key(self, image_key: str) -> bool:
|
||||
"""Whether *image_key* is a depth feature stored as per-frame images."""
|
||||
ft = self._meta.features.get(image_key)
|
||||
if ft is None or ft.get("dtype") != "image":
|
||||
return False
|
||||
return is_depth_feature(ft.get("info") or {})
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
path_template = DEFAULT_DEPTH_PATH if self._is_depth_image_key(image_key) else DEFAULT_IMAGE_PATH
|
||||
fpath = path_template.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self._root / fpath
|
||||
|
||||
@@ -41,15 +41,56 @@ def safe_stop_image_writer(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim != 3:
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||
# Single-channel dtypes that PIL natively maps to the matching mode
|
||||
# (``uint8`` → ``L``, ``uint16`` → ``I;16``, ``float32`` → ``F``).
|
||||
GRAYSCALE_DTYPES: tuple[np.dtype, ...] = (
|
||||
np.dtype("uint8"),
|
||||
np.dtype("uint16"),
|
||||
np.dtype("float32"),
|
||||
)
|
||||
|
||||
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
"""Convert a NumPy array to a PIL Image, preserving precision for grayscale.
|
||||
|
||||
Behaviour by shape:
|
||||
|
||||
- ``(H, W)`` or ``(1, H, W)`` / ``(H, W, 1)``: single-channel grayscale.
|
||||
The native dtype is preserved using the matching PIL mode
|
||||
(``L`` / ``I;16`` / ``F``). This is the path used for raw depth maps (no rescaling, clamping, or downcasting)
|
||||
- ``(3, H, W)`` / ``(H, W, 3)``: RGB. Channels-first inputs are transposed
|
||||
to channels-last. Float inputs in ``[0, 1]`` are scaled to ``uint8``
|
||||
(existing behaviour, gated by ``range_check``).
|
||||
|
||||
Other shapes / channel counts raise ``NotImplementedError`` or
|
||||
``ValueError``.
|
||||
"""
|
||||
if image_array.ndim not in (2, 3):
|
||||
raise ValueError(
|
||||
f"The array has {image_array.ndim} dimensions, but 2 or 3 is expected for an image."
|
||||
)
|
||||
|
||||
# Squeeze 3D single-channel inputs to 2D so depth maps work whether the
|
||||
# caller emits (H, W), (1, H, W), or (H, W, 1).
|
||||
if image_array.ndim == 3:
|
||||
if image_array.shape[0] == 1:
|
||||
image_array = image_array[0]
|
||||
elif image_array.shape[-1] == 1:
|
||||
image_array = image_array[..., 0]
|
||||
|
||||
if image_array.ndim == 2:
|
||||
if image_array.dtype not in GRAYSCALE_DTYPES:
|
||||
raise ValueError(
|
||||
f"Unsupported single-channel image dtype: {image_array.dtype}. "
|
||||
f"Supported dtypes: {sorted(str(d) for d in GRAYSCALE_DTYPES)}."
|
||||
)
|
||||
|
||||
return PIL.Image.fromarray(np.ascontiguousarray(image_array))
|
||||
|
||||
# 3D path: must be RGB (3 channels), channels-first or channels-last.
|
||||
if image_array.shape[0] == 3:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
image_array = image_array.transpose(1, 2, 0)
|
||||
|
||||
elif image_array.shape[-1] != 3:
|
||||
raise NotImplementedError(
|
||||
f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
|
||||
@@ -71,13 +112,28 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
|
||||
return PIL.Image.fromarray(image_array)
|
||||
|
||||
|
||||
def save_kwargs_for_path(fpath: Path, compress_level: int) -> dict:
|
||||
"""Pick the right format-specific kwargs for :meth:`PIL.Image.Image.save`.
|
||||
|
||||
PNG uses ``compress_level`` (0–9, zlib). TIFF uses ``compression`` (raw) for lossless raw depth maps.
|
||||
"""
|
||||
suffix = Path(fpath).suffix.lower()
|
||||
if suffix == ".png":
|
||||
return {"compress_level": compress_level}
|
||||
if suffix in (".tif", ".tiff"):
|
||||
return {"compression": "raw"}
|
||||
return {}
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):
|
||||
"""
|
||||
Saves a NumPy array or PIL Image to a file.
|
||||
|
||||
This function handles both NumPy arrays and PIL Image objects, converting
|
||||
the former to a PIL Image before saving. It includes error handling for
|
||||
the save operation.
|
||||
the save operation. The output format is inferred from the *fpath*
|
||||
extension: ``.png`` → PNG with ``compress_level``, ``.tiff`` / ``.tif``
|
||||
→ lossless raw depth maps (TIFF).
|
||||
|
||||
Args:
|
||||
image (np.ndarray | PIL.Image.Image): The image data to save.
|
||||
@@ -101,7 +157,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
|
||||
img = image
|
||||
else:
|
||||
raise TypeError(f"Unsupported image type: {type(image)}")
|
||||
img.save(fpath, compress_level=compress_level)
|
||||
img.save(fpath, **save_kwargs_for_path(Path(fpath), compress_level))
|
||||
except Exception as e:
|
||||
logger.error("Error writing image %s: %s", fpath, e)
|
||||
|
||||
|
||||
@@ -93,6 +93,10 @@ DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||
# Depth maps live alongside images on disk but use TIFF instead of PNG: PNG
|
||||
# cannot natively round-trip float32, and several common loaders silently
|
||||
# downcast 16-bit grayscale.
|
||||
DEFAULT_DEPTH_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.tiff"
|
||||
|
||||
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
|
||||
@@ -93,9 +93,32 @@ def test_image_array_to_pil_image_pytorch_format(img_array_factory):
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_single_channel(img_array_factory):
|
||||
# Single-channel inputs are routed to grayscale mode for raw depth maps.
|
||||
img_array = img_array_factory(channels=1)
|
||||
with pytest.raises(NotImplementedError):
|
||||
image_array_to_pil_image(img_array)
|
||||
result_image = image_array_to_pil_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "L"
|
||||
assert np.array_equal(np.array(result_image), img_array.squeeze(-1))
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_single_channel_uint16(img_array_factory):
|
||||
img_array = img_array_factory(channels=1, dtype=np.uint16)
|
||||
result_image = image_array_to_pil_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "I;16"
|
||||
# Bit-perfect: no rescaling, no clipping.
|
||||
assert np.array_equal(np.array(result_image), img_array.squeeze(-1))
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_single_channel_float32(img_array_factory):
|
||||
img_array = img_array_factory(channels=1, dtype=np.float32)
|
||||
result_image = image_array_to_pil_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "F"
|
||||
assert np.array_equal(np.array(result_image), img_array.squeeze(-1))
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_4_channels(img_array_factory):
|
||||
@@ -141,6 +164,28 @@ def test_write_image_image(tmp_path, img_factory):
|
||||
assert np.array_equal(image_pil, saved_image)
|
||||
|
||||
|
||||
def test_write_image_tiff_uint16_bitperfect(tmp_path):
|
||||
"""16-bit grayscale TIFF round-trips bit-perfectly (raw depth maps)."""
|
||||
image_array = np.random.randint(0, 65535, size=(32, 48), dtype=np.uint16)
|
||||
fpath = tmp_path / "depth.tiff"
|
||||
write_image(image_array, fpath)
|
||||
assert fpath.exists()
|
||||
saved = np.array(Image.open(fpath))
|
||||
assert saved.dtype == np.uint16
|
||||
assert np.array_equal(saved, image_array)
|
||||
|
||||
|
||||
def test_write_image_tiff_float32_bitperfect(tmp_path):
|
||||
"""Float32 TIFF round-trips bit-perfectly (metric depth in meters)."""
|
||||
image_array = np.random.uniform(0.05, 4.0, size=(32, 48)).astype(np.float32)
|
||||
fpath = tmp_path / "depth.tiff"
|
||||
write_image(image_array, fpath)
|
||||
assert fpath.exists()
|
||||
saved = np.array(Image.open(fpath))
|
||||
assert saved.dtype == np.float32
|
||||
assert np.array_equal(saved, image_array)
|
||||
|
||||
|
||||
def test_write_image_exception(tmp_path):
|
||||
image_array = "invalid data"
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
|
||||
Reference in New Issue
Block a user