From 76f25f6afd57da3aea670118853dacb79c8f0c8c Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 5 Nov 2025 21:49:04 +0100 Subject: [PATCH] try fix 8 --- src/lerobot/datasets/lerobot_dataset.py | 72 ++++++++++++++++++------- 1 file changed, 52 insertions(+), 20 deletions(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 16127056a..72c277b67 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -1044,39 +1044,71 @@ class LeRobotDataset(torch.utils.data.Dataset): # Ensure dataset is loaded when we actually need to read from it self._ensure_hf_dataset_loaded() - # Get the single, current-timestep item - item = self.hf_dataset[idx] - ep_idx = item["episode_index"].item() - + # 1. Get query indices if deltas are needed query_indices = None + padding = {} if self.delta_indices is not None: - # 1. Get indices for all deltas - query_indices, padding = self._get_query_indices(idx, ep_idx) + # We need the episode index *first* to get boundaries. + # This is a small read for just one item. + ep_idx_only = self.hf_dataset[idx : idx + 1]["episode_index"][0].item() + query_indices, padding = self._get_query_indices(idx, ep_idx_only) - # 2. Query non-image, non-video features - query_result = self._query_hf_dataset(query_indices) - item = {**item, **padding} - for key, val in query_result.items(): - item[key] = val + # 2. Fetch all data (including images) + if query_indices is not None: + # --- Delta path --- + # Fetch all keys (state, action, AND images) for all deltas + item_batch = self.hf_dataset[query_indices["index"]] - # 3. Query image features (which are not in _query_hf_dataset) - for key in self.meta.image_keys: - if key in query_indices: - # hf_dataset[query_indices[key]][key] returns a LIST of PIL.Image objects - item[key] = torch.stack(self.hf_dataset[query_indices[key]][key]) + # hf_transform_to_torch (item-level) has already run, + # so all values are tensors. We stack them. + item = {} + for key in item_batch: + item[key] = torch.stack(item_batch[key]) + item.update(padding) + + # Use the "current" item's index/timestamp/ep_idx + # (assuming 'index' is the key for the main query) + current_idx_in_batch = query_indices["index"].index(idx) + item["index"] = item["index"][current_idx_in_batch] + item["timestamp"] = item["timestamp"][current_idx_in_batch] + item["episode_index"] = item["episode_index"][current_idx_in_batch] + item["task_index"] = item["task_index"][current_idx_in_batch] + + ep_idx = item["episode_index"].item() + + else: + # --- Single-frame path --- + item = self.hf_dataset[idx] + ep_idx = item["episode_index"].item() + + # 3. Handle videos (which are always separate) if len(self.meta.video_keys) > 0: - current_ts = item["timestamp"].item() - query_timestamps = self._get_query_timestamps(current_ts, query_indices) + current_ts = ( + item["timestamp"].item() + if query_indices is None + else item["timestamp"][current_idx_in_batch].item() + ) + + video_query_indices = query_indices + if video_query_indices is None: + # If no deltas, create a dummy query_indices for _get_query_timestamps + video_query_indices = {key: [idx] for key in self.meta.video_keys} + + query_timestamps = self._get_query_timestamps(current_ts, video_query_indices) video_frames = self._query_videos(query_timestamps, ep_idx) + + # video_frames are already stacked tensors (B, C, H, W) or (C, H, W) item = {**video_frames, **item} + # 4. Apply image transforms if self.image_transforms is not None: image_keys = self.meta.camera_keys for cam in image_keys: - item[cam] = self.image_transforms(item[cam]) + if cam in item: # videos or images + item[cam] = self.image_transforms(item[cam]) - # Add task as a string + # 5. Add task string task_idx = item["task_index"].item() item["task"] = self.meta.tasks.iloc[task_idx].name return item