From a4aa31647025b0e4e5c338ad12df78d0d5ca051f Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 7 Nov 2025 21:54:44 +0100 Subject: [PATCH] fix(dataset): fix data access bottleneck for faster training (#2408) --- src/lerobot/datasets/lerobot_dataset.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index c8bc504..48608a8 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -940,11 +940,26 @@ class LeRobotDataset(torch.utils.data.Dataset): return query_timestamps def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: - return { - key: torch.stack(self.hf_dataset[q_idx][key]) - for key, q_idx in query_indices.items() - if key not in self.meta.video_keys - } + """ + Query dataset for indices across keys, skipping video keys. + + Tries column-first [key][indices] for speed, falls back to row-first. + + Args: + query_indices: Dict mapping keys to index lists to retrieve + + Returns: + Dict with stacked tensors of queried data (video keys excluded) + """ + result: dict = {} + for key, q_idx in query_indices.items(): + if key in self.meta.video_keys: + continue + try: + result[key] = torch.stack(self.hf_dataset[key][q_idx]) + except (KeyError, TypeError, IndexError): + result[key] = torch.stack(self.hf_dataset[q_idx][key]) + return result def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function