diff --git a/src/lerobot/datasets/dataset_reader.py b/src/lerobot/datasets/dataset_reader.py index 3720a5084..c37c0d9a3 100644 --- a/src/lerobot/datasets/dataset_reader.py +++ b/src/lerobot/datasets/dataset_reader.py @@ -72,6 +72,8 @@ class DatasetReader: self.episodes = episodes self._tolerance_s = tolerance_s self._video_backend = video_backend + if image_transforms is not None and not callable(image_transforms): + raise TypeError("image_transforms must be callable or None.") self._image_transforms = image_transforms self.hf_dataset: datasets.Dataset | None = None @@ -83,6 +85,16 @@ class DatasetReader: check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s) self.delta_indices = get_delta_indices(delta_timestamps, meta.fps) + def set_image_transforms(self, image_transforms: Callable | None) -> None: + """Replace the transform applied to visual observations.""" + if image_transforms is not None and not callable(image_transforms): + raise TypeError("image_transforms must be callable or None.") + self._image_transforms = image_transforms + + def clear_image_transforms(self) -> None: + """Remove the transform applied to visual observations.""" + self._image_transforms = None + def try_load(self) -> bool: """Attempt to load from local cache. Returns True if data is sufficient.""" try: diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 1725046f2..f07e43307 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -194,8 +194,6 @@ class LeRobotDataset(torch.utils.data.Dataset): super().__init__() self.repo_id = repo_id self._requested_root = Path(root) if root else None - self.reader = None - self.set_image_transforms(image_transforms) self.delta_timestamps = delta_timestamps self.episodes = episodes self.tolerance_s = tolerance_s @@ -225,6 +223,7 @@ class LeRobotDataset(torch.utils.data.Dataset): delta_timestamps=delta_timestamps, image_transforms=image_transforms, ) + self.image_transforms = image_transforms # Load actual data if force_cache_sync or not self.reader.try_load(): @@ -480,15 +479,14 @@ class LeRobotDataset(torch.utils.data.Dataset): def set_image_transforms(self, image_transforms: Callable | None) -> None: """Replace the transform applied to visual observations.""" - if image_transforms is not None and not callable(image_transforms): - raise TypeError("image_transforms must be callable or None.") + self._ensure_reader().set_image_transforms(image_transforms) self.image_transforms = image_transforms - if self.reader is not None: - self.reader._image_transforms = image_transforms def clear_image_transforms(self) -> None: """Remove the transform applied to visual observations.""" - self.set_image_transforms(None) + if self.reader is not None: + self.reader.set_image_transforms(None) + self.image_transforms = None # ── Hub methods (stay on facade) ──────────────────────────────────