diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index d9cc28b30..420117996 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -1498,7 +1498,7 @@ class LeRobotDataset(torch.utils.data.Dataset): episode_index = self.episode_buffer["episode_index"] if isinstance(episode_index, np.ndarray): episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0] - for cam_key in self.meta.camera_keys: + for cam_key in self.meta.image_keys: img_dir = self._get_image_file_dir(episode_index, cam_key) if img_dir.is_dir(): shutil.rmtree(img_dir) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 4c91c55c0..157a14851 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -352,6 +352,65 @@ def test_image_array_to_pil_image_wrong_range_float_0_255(): image_array_to_pil_image(image) +def test_tmp_image_deletion(tmp_path, empty_lerobot_dataset_factory): + """Verify temporary image directories are removed for image features after saving episode.""" + # Image feature: images should be deleted after saving episode + image_key = "image" + features_image = { + image_key: {"dtype": "image", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]} + } + ds_img = empty_lerobot_dataset_factory(root=tmp_path / "img", features=features_image) + ds_img.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) + ds_img.save_episode() + img_dir = ds_img._get_image_file_dir(0, image_key) + assert not img_dir.exists(), "Temporary image directory should be removed for image features" + + +def test_tmp_video_deletion(tmp_path, empty_lerobot_dataset_factory): + """Verify temporary image directories are removed for video encoding when `batch_encoding_size == 1`.""" + # Video feature: when batch_encoding_size == 1 temporary images should be deleted + vid_key = "video" + features_video = { + vid_key: {"dtype": "video", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]} + } + + ds_vid = empty_lerobot_dataset_factory(root=tmp_path / "vid", features=features_video) + ds_vid.batch_encoding_size = 1 + ds_vid.add_frame({vid_key: np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) + ds_vid.save_episode() + vid_img_dir = ds_vid._get_image_file_dir(0, vid_key) + assert not vid_img_dir.exists(), ( + "Temporary image directory should be removed when batch_encoding_size == 1" + ) + + +def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory): + """Verify temporary image directories are removed appropriately when both image and video features are present.""" + image_key = "image" + vid_key = "video" + features_mixed = { + image_key: {"dtype": "image", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]}, + vid_key: {"dtype": "video", "shape": DUMMY_HWC, "names": ["height", "width", "channels"]}, + } + ds_mixed = empty_lerobot_dataset_factory( + root=tmp_path / "mixed", features=features_mixed, batch_encoding_size=2 + ) + ds_mixed.add_frame( + { + "image": np.random.rand(*DUMMY_CHW), + "video": np.random.rand(*DUMMY_HWC), + "task": "Dummy task", + } + ) + ds_mixed.save_episode() + img_dir = ds_mixed._get_image_file_dir(0, image_key) + vid_img_dir = ds_mixed._get_image_file_dir(0, vid_key) + assert not img_dir.exists(), "Temporary image directory should be removed for image features" + assert vid_img_dir.exists(), ( + "Temporary image directory should not be removed for video features when batch_encoding_size == 2" + ) + + # TODO(aliberts): # - [ ] test various attributes & state from init and create # - [ ] test init with episodes and check num_frames