try fix 8

This commit is contained in:
Steven Palma
2025-11-05 21:49:04 +01:00
parent ce23681d4b
commit 76f25f6afd

View File

@@ -1044,39 +1044,71 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Ensure dataset is loaded when we actually need to read from it
self._ensure_hf_dataset_loaded()
# Get the single, current-timestep item
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()
# 1. Get query indices if deltas are needed
query_indices = None
padding = {}
if self.delta_indices is not None:
# 1. Get indices for all deltas
query_indices, padding = self._get_query_indices(idx, ep_idx)
# We need the episode index *first* to get boundaries.
# This is a small read for just one item.
ep_idx_only = self.hf_dataset[idx : idx + 1]["episode_index"][0].item()
query_indices, padding = self._get_query_indices(idx, ep_idx_only)
# 2. Query non-image, non-video features
query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding}
for key, val in query_result.items():
item[key] = val
# 2. Fetch all data (including images)
if query_indices is not None:
# --- Delta path ---
# Fetch all keys (state, action, AND images) for all deltas
item_batch = self.hf_dataset[query_indices["index"]]
# 3. Query image features (which are not in _query_hf_dataset)
for key in self.meta.image_keys:
if key in query_indices:
# hf_dataset[query_indices[key]][key] returns a LIST of PIL.Image objects
item[key] = torch.stack(self.hf_dataset[query_indices[key]][key])
# hf_transform_to_torch (item-level) has already run,
# so all values are tensors. We stack them.
item = {}
for key in item_batch:
item[key] = torch.stack(item_batch[key])
item.update(padding)
# Use the "current" item's index/timestamp/ep_idx
# (assuming 'index' is the key for the main query)
current_idx_in_batch = query_indices["index"].index(idx)
item["index"] = item["index"][current_idx_in_batch]
item["timestamp"] = item["timestamp"][current_idx_in_batch]
item["episode_index"] = item["episode_index"][current_idx_in_batch]
item["task_index"] = item["task_index"][current_idx_in_batch]
ep_idx = item["episode_index"].item()
else:
# --- Single-frame path ---
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()
# 3. Handle videos (which are always separate)
if len(self.meta.video_keys) > 0:
current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
current_ts = (
item["timestamp"].item()
if query_indices is None
else item["timestamp"][current_idx_in_batch].item()
)
video_query_indices = query_indices
if video_query_indices is None:
# If no deltas, create a dummy query_indices for _get_query_timestamps
video_query_indices = {key: [idx] for key in self.meta.video_keys}
query_timestamps = self._get_query_timestamps(current_ts, video_query_indices)
video_frames = self._query_videos(query_timestamps, ep_idx)
# video_frames are already stacked tensors (B, C, H, W) or (C, H, W)
item = {**video_frames, **item}
# 4. Apply image transforms
if self.image_transforms is not None:
image_keys = self.meta.camera_keys
for cam in image_keys:
item[cam] = self.image_transforms(item[cam])
if cam in item: # videos or images
item[cam] = self.image_transforms(item[cam])
# Add task as a string
# 5. Add task string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks.iloc[task_idx].name
return item