From 0f551df8f4bad4c504e395ea3df74fc5f714016f Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Thu, 20 Nov 2025 14:05:31 +0100 Subject: [PATCH] add `absolute_to_reative_idx` for remapping indicies when a subset of data is loaded (#2490) --- src/lerobot/datasets/lerobot_dataset.py | 26 ++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 29436c4d2..9c94235c9 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -712,6 +712,15 @@ class LeRobotDataset(torch.utils.data.Dataset): self.download(download_videos) self.hf_dataset = self.load_hf_dataset() + # Create mapping from absolute indices to relative indices when only a subset of the episodes are loaded + # Build a mapping: absolute_index -> relative_index_in_filtered_dataset + self._absolute_to_relative_idx = None + if self.episodes is not None: + self._absolute_to_relative_idx = { + abs_idx.item() if isinstance(abs_idx, torch.Tensor) else abs_idx: rel_idx + for rel_idx, abs_idx in enumerate(self.hf_dataset["index"]) + } + # Setup delta_indices if self.delta_timestamps is not None: check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) @@ -930,7 +939,11 @@ class LeRobotDataset(torch.utils.data.Dataset): query_timestamps = {} for key in self.meta.video_keys: if query_indices is not None and key in query_indices: - timestamps = self.hf_dataset[query_indices[key]]["timestamp"] + if self._absolute_to_relative_idx is not None: + relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]] + timestamps = self.hf_dataset[relative_indices]["timestamp"] + else: + timestamps = self.hf_dataset[query_indices[key]]["timestamp"] query_timestamps[key] = torch.stack(timestamps).tolist() else: query_timestamps[key] = [current_ts] @@ -953,10 +966,16 @@ class LeRobotDataset(torch.utils.data.Dataset): for key, q_idx in query_indices.items(): if key in self.meta.video_keys: continue + # Map absolute indices to relative indices if needed + relative_indices = ( + q_idx + if self._absolute_to_relative_idx is None + else [self._absolute_to_relative_idx[idx] for idx in q_idx] + ) try: - result[key] = torch.stack(self.hf_dataset[key][q_idx]) + result[key] = torch.stack(self.hf_dataset[key][relative_indices]) except (KeyError, TypeError, IndexError): - result[key] = torch.stack(self.hf_dataset[q_idx][key]) + result[key] = torch.stack(self.hf_dataset[relative_indices][key]) return result def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: @@ -1496,6 +1515,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.image_transforms = None obj.delta_timestamps = None obj.delta_indices = None + obj._absolute_to_relative_idx = None obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() obj.writer = None obj.latest_episode = None