diff --git a/src/lerobot/datasets/dataset_writer.py b/src/lerobot/datasets/dataset_writer.py index 46c85d8af..fe079541e 100644 --- a/src/lerobot/datasets/dataset_writer.py +++ b/src/lerobot/datasets/dataset_writer.py @@ -46,6 +46,7 @@ from .io_utils import ( write_info, ) from .utils import ( + DEFAULT_DEPTH_PATH, DEFAULT_EPISODES_PATH, DEFAULT_IMAGE_PATH, update_chunk_file_indices, @@ -57,6 +58,7 @@ from .video_utils import ( concatenate_video_files, encode_video_frames, get_video_duration_in_s, + is_depth_feature, ) logger = logging.getLogger(__name__) @@ -149,8 +151,16 @@ class DatasetWriter: ep_buffer[key] = current_ep_idx if key == "episode_index" else [] return ep_buffer + def _is_depth_image_key(self, image_key: str) -> bool: + """Whether *image_key* is a depth feature stored as per-frame images.""" + ft = self._meta.features.get(image_key) + if ft is None or ft.get("dtype") != "image": + return False + return is_depth_feature(ft.get("info") or {}) + def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: - fpath = DEFAULT_IMAGE_PATH.format( + path_template = DEFAULT_DEPTH_PATH if self._is_depth_image_key(image_key) else DEFAULT_IMAGE_PATH + fpath = path_template.format( image_key=image_key, episode_index=episode_index, frame_index=frame_index ) return self._root / fpath diff --git a/src/lerobot/datasets/image_writer.py b/src/lerobot/datasets/image_writer.py index 8fb5804a5..80379d55d 100644 --- a/src/lerobot/datasets/image_writer.py +++ b/src/lerobot/datasets/image_writer.py @@ -41,15 +41,56 @@ def safe_stop_image_writer(func): return wrapper -def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image: - # TODO(aliberts): handle 1 channel and 4 for depth images - if image_array.ndim != 3: - raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.") +# Single-channel dtypes that PIL natively maps to the matching mode +# (``uint8`` → ``L``, ``uint16`` → ``I;16``, ``float32`` → ``F``). +GRAYSCALE_DTYPES: tuple[np.dtype, ...] = ( + np.dtype("uint8"), + np.dtype("uint16"), + np.dtype("float32"), +) + +def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image: + """Convert a NumPy array to a PIL Image, preserving precision for grayscale. + + Behaviour by shape: + + - ``(H, W)`` or ``(1, H, W)`` / ``(H, W, 1)``: single-channel grayscale. + The native dtype is preserved using the matching PIL mode + (``L`` / ``I;16`` / ``F``). This is the path used for raw depth maps (no rescaling, clamping, or downcasting) + - ``(3, H, W)`` / ``(H, W, 3)``: RGB. Channels-first inputs are transposed + to channels-last. Float inputs in ``[0, 1]`` are scaled to ``uint8`` + (existing behaviour, gated by ``range_check``). + + Other shapes / channel counts raise ``NotImplementedError`` or + ``ValueError``. + """ + if image_array.ndim not in (2, 3): + raise ValueError( + f"The array has {image_array.ndim} dimensions, but 2 or 3 is expected for an image." + ) + + # Squeeze 3D single-channel inputs to 2D so depth maps work whether the + # caller emits (H, W), (1, H, W), or (H, W, 1). + if image_array.ndim == 3: + if image_array.shape[0] == 1: + image_array = image_array[0] + elif image_array.shape[-1] == 1: + image_array = image_array[..., 0] + + if image_array.ndim == 2: + if image_array.dtype not in GRAYSCALE_DTYPES: + raise ValueError( + f"Unsupported single-channel image dtype: {image_array.dtype}. " + f"Supported dtypes: {sorted(str(d) for d in GRAYSCALE_DTYPES)}." + ) + + return PIL.Image.fromarray(np.ascontiguousarray(image_array)) + + # 3D path: must be RGB (3 channels), channels-first or channels-last. if image_array.shape[0] == 3: # Transpose from pytorch convention (C, H, W) to (H, W, C) image_array = image_array.transpose(1, 2, 0) - elif image_array.shape[-1] != 3: raise NotImplementedError( f"The image has {image_array.shape[-1]} channels, but 3 is required for now." @@ -71,13 +112,28 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) return PIL.Image.fromarray(image_array) +def save_kwargs_for_path(fpath: Path, compress_level: int) -> dict: + """Pick the right format-specific kwargs for :meth:`PIL.Image.Image.save`. + + PNG uses ``compress_level`` (0–9, zlib). TIFF uses ``compression`` (raw) for lossless raw depth maps. + """ + suffix = Path(fpath).suffix.lower() + if suffix == ".png": + return {"compress_level": compress_level} + if suffix in (".tif", ".tiff"): + return {"compression": "raw"} + return {} + + def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1): """ Saves a NumPy array or PIL Image to a file. This function handles both NumPy arrays and PIL Image objects, converting the former to a PIL Image before saving. It includes error handling for - the save operation. + the save operation. The output format is inferred from the *fpath* + extension: ``.png`` → PNG with ``compress_level``, ``.tiff`` / ``.tif`` + → lossless raw depth maps (TIFF). Args: image (np.ndarray | PIL.Image.Image): The image data to save. @@ -101,7 +157,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level img = image else: raise TypeError(f"Unsupported image type: {type(image)}") - img.save(fpath, compress_level=compress_level) + img.save(fpath, **save_kwargs_for_path(Path(fpath), compress_level)) except Exception as e: logger.error("Error writing image %s: %s", fpath, e) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 715bd2f9b..ac001b49f 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -93,6 +93,10 @@ DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4" DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png" +# Depth maps live alongside images on disk but use TIFF instead of PNG: PNG +# cannot natively round-trip float32, and several common loaders silently +# downcast 16-bit grayscale. +DEFAULT_DEPTH_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.tiff" LEGACY_EPISODES_PATH = "meta/episodes.jsonl" LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl" diff --git a/tests/datasets/test_image_writer.py b/tests/datasets/test_image_writer.py index 916b8f017..613cb8cf6 100644 --- a/tests/datasets/test_image_writer.py +++ b/tests/datasets/test_image_writer.py @@ -93,9 +93,32 @@ def test_image_array_to_pil_image_pytorch_format(img_array_factory): def test_image_array_to_pil_image_single_channel(img_array_factory): + # Single-channel inputs are routed to grayscale mode for raw depth maps. img_array = img_array_factory(channels=1) - with pytest.raises(NotImplementedError): - image_array_to_pil_image(img_array) + result_image = image_array_to_pil_image(img_array) + assert isinstance(result_image, Image.Image) + assert result_image.size == (100, 100) + assert result_image.mode == "L" + assert np.array_equal(np.array(result_image), img_array.squeeze(-1)) + + +def test_image_array_to_pil_image_single_channel_uint16(img_array_factory): + img_array = img_array_factory(channels=1, dtype=np.uint16) + result_image = image_array_to_pil_image(img_array) + assert isinstance(result_image, Image.Image) + assert result_image.size == (100, 100) + assert result_image.mode == "I;16" + # Bit-perfect: no rescaling, no clipping. + assert np.array_equal(np.array(result_image), img_array.squeeze(-1)) + + +def test_image_array_to_pil_image_single_channel_float32(img_array_factory): + img_array = img_array_factory(channels=1, dtype=np.float32) + result_image = image_array_to_pil_image(img_array) + assert isinstance(result_image, Image.Image) + assert result_image.size == (100, 100) + assert result_image.mode == "F" + assert np.array_equal(np.array(result_image), img_array.squeeze(-1)) def test_image_array_to_pil_image_4_channels(img_array_factory): @@ -141,6 +164,28 @@ def test_write_image_image(tmp_path, img_factory): assert np.array_equal(image_pil, saved_image) +def test_write_image_tiff_uint16_bitperfect(tmp_path): + """16-bit grayscale TIFF round-trips bit-perfectly (raw depth maps).""" + image_array = np.random.randint(0, 65535, size=(32, 48), dtype=np.uint16) + fpath = tmp_path / "depth.tiff" + write_image(image_array, fpath) + assert fpath.exists() + saved = np.array(Image.open(fpath)) + assert saved.dtype == np.uint16 + assert np.array_equal(saved, image_array) + + +def test_write_image_tiff_float32_bitperfect(tmp_path): + """Float32 TIFF round-trips bit-perfectly (metric depth in meters).""" + image_array = np.random.uniform(0.05, 4.0, size=(32, 48)).astype(np.float32) + fpath = tmp_path / "depth.tiff" + write_image(image_array, fpath) + assert fpath.exists() + saved = np.array(Image.open(fpath)) + assert saved.dtype == np.float32 + assert np.array_equal(saved, image_array) + + def test_write_image_exception(tmp_path): image_array = "invalid data" fpath = tmp_path / DUMMY_IMAGE