mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
try fix 8
This commit is contained in:
@@ -1044,39 +1044,71 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Ensure dataset is loaded when we actually need to read from it
|
||||
self._ensure_hf_dataset_loaded()
|
||||
|
||||
# Get the single, current-timestep item
|
||||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
# 1. Get query indices if deltas are needed
|
||||
query_indices = None
|
||||
padding = {}
|
||||
if self.delta_indices is not None:
|
||||
# 1. Get indices for all deltas
|
||||
query_indices, padding = self._get_query_indices(idx, ep_idx)
|
||||
# We need the episode index *first* to get boundaries.
|
||||
# This is a small read for just one item.
|
||||
ep_idx_only = self.hf_dataset[idx : idx + 1]["episode_index"][0].item()
|
||||
query_indices, padding = self._get_query_indices(idx, ep_idx_only)
|
||||
|
||||
# 2. Query non-image, non-video features
|
||||
query_result = self._query_hf_dataset(query_indices)
|
||||
item = {**item, **padding}
|
||||
for key, val in query_result.items():
|
||||
item[key] = val
|
||||
# 2. Fetch all data (including images)
|
||||
if query_indices is not None:
|
||||
# --- Delta path ---
|
||||
# Fetch all keys (state, action, AND images) for all deltas
|
||||
item_batch = self.hf_dataset[query_indices["index"]]
|
||||
|
||||
# 3. Query image features (which are not in _query_hf_dataset)
|
||||
for key in self.meta.image_keys:
|
||||
if key in query_indices:
|
||||
# hf_dataset[query_indices[key]][key] returns a LIST of PIL.Image objects
|
||||
item[key] = torch.stack(self.hf_dataset[query_indices[key]][key])
|
||||
# hf_transform_to_torch (item-level) has already run,
|
||||
# so all values are tensors. We stack them.
|
||||
item = {}
|
||||
for key in item_batch:
|
||||
item[key] = torch.stack(item_batch[key])
|
||||
|
||||
item.update(padding)
|
||||
|
||||
# Use the "current" item's index/timestamp/ep_idx
|
||||
# (assuming 'index' is the key for the main query)
|
||||
current_idx_in_batch = query_indices["index"].index(idx)
|
||||
item["index"] = item["index"][current_idx_in_batch]
|
||||
item["timestamp"] = item["timestamp"][current_idx_in_batch]
|
||||
item["episode_index"] = item["episode_index"][current_idx_in_batch]
|
||||
item["task_index"] = item["task_index"][current_idx_in_batch]
|
||||
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
else:
|
||||
# --- Single-frame path ---
|
||||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
# 3. Handle videos (which are always separate)
|
||||
if len(self.meta.video_keys) > 0:
|
||||
current_ts = item["timestamp"].item()
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
current_ts = (
|
||||
item["timestamp"].item()
|
||||
if query_indices is None
|
||||
else item["timestamp"][current_idx_in_batch].item()
|
||||
)
|
||||
|
||||
video_query_indices = query_indices
|
||||
if video_query_indices is None:
|
||||
# If no deltas, create a dummy query_indices for _get_query_timestamps
|
||||
video_query_indices = {key: [idx] for key in self.meta.video_keys}
|
||||
|
||||
query_timestamps = self._get_query_timestamps(current_ts, video_query_indices)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
|
||||
# video_frames are already stacked tensors (B, C, H, W) or (C, H, W)
|
||||
item = {**video_frames, **item}
|
||||
|
||||
# 4. Apply image transforms
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.meta.camera_keys
|
||||
for cam in image_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
if cam in item: # videos or images
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
|
||||
# Add task as a string
|
||||
# 5. Add task string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks.iloc[task_idx].name
|
||||
return item
|
||||
|
||||
Reference in New Issue
Block a user