diff --git a/src/lerobot/datasets/image_writer.py b/src/lerobot/datasets/image_writer.py index ee10df6..23bc2ef 100644 --- a/src/lerobot/datasets/image_writer.py +++ b/src/lerobot/datasets/image_writer.py @@ -110,8 +110,8 @@ def worker_thread_loop(queue: queue.Queue): if item is None: queue.task_done() break - image_array, fpath = item - write_image(image_array, fpath) + image_array, fpath, compress_level = item + write_image(image_array, fpath, compress_level) queue.task_done() @@ -169,11 +169,13 @@ class AsyncImageWriter: p.start() self.processes.append(p) - def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path): + def save_image( + self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1 + ): if isinstance(image, torch.Tensor): # Convert tensor to numpy array to minimize main process time image = image.cpu().numpy() - self.queue.put((image, fpath)) + self.queue.put((image, fpath, compress_level)) def wait_until_done(self): self.queue.join() diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index fbfb82c..5ea1884 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -1092,13 +1092,15 @@ class LeRobotDataset(torch.utils.data.Dataset): def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path: return self._get_image_file_path(episode_index, image_key, frame_index=0).parent - def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None: + def _save_image( + self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1 + ) -> None: if self.image_writer is None: if isinstance(image, torch.Tensor): image = image.cpu().numpy() - write_image(image, fpath) + write_image(image, fpath, compress_level=compress_level) else: - self.image_writer.save_image(image=image, fpath=fpath) + self.image_writer.save_image(image=image, fpath=fpath, compress_level=compress_level) def add_frame(self, frame: dict) -> None: """ @@ -1136,7 +1138,8 @@ class LeRobotDataset(torch.utils.data.Dataset): ) if frame_index == 0: img_path.parent.mkdir(parents=True, exist_ok=True) - self._save_image(frame[key], img_path) + compress_level = 1 if self.features[key]["dtype"] == "video" else 6 + self._save_image(frame[key], img_path, compress_level) self.episode_buffer[key].append(str(img_path)) else: self.episode_buffer[key].append(frame[key])