fix(dataset): use revision-safe Hub cache for downloaded datasets (#3233)

* refactor(dataset): enhance dataset root directory handling and introduce hub cache support

- Updated DatasetConfig and LeRobotDatasetMetadata to clarify root directory behavior and introduce a dedicated hub cache for downloads.
- Refactored LeRobotDataset and StreamingLeRobotDataset to utilize the new hub cache and improve directory management.
- Added tests to ensure correct behavior when using the hub cache and handling different revisions without a specified root directory.

* refactor(dataset): improve root directory handling in LeRobotDataset

- Updated LeRobotDataset to store the requested root path separately from the actual root path.
- Adjusted metadata loading to use the requested root, enhancing clarity and consistency in directory management.

* refactor(dataset): minor improvements for hub cache support

* chore(datasets): guard in resume + assertion test

---------

Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com>
Co-authored-by: mickaelChen <mickael.chen.levinson@gmail.com>
This commit is contained in:
Steven Palma
2026-03-27 22:21:55 +01:00
committed by GitHub
parent 975d89b38d
commit 4e45acca52
8 changed files with 440 additions and 40 deletions

View File

@@ -68,7 +68,7 @@ class DatasetReader:
visual features.
"""
self._meta = meta
self._root = root
self.root = root
self.episodes = episodes
self._tolerance_s = tolerance_s
self._video_backend = video_backend
@@ -125,7 +125,7 @@ class DatasetReader:
def _load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
features = get_hf_features_from_features(self._meta.features)
hf_dataset = load_nested_dataset(self._root / "data", features=features, episodes=self.episodes)
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
@@ -150,7 +150,7 @@ class DatasetReader:
if len(self._meta.video_keys) > 0:
for ep_idx in requested_episodes:
for vid_key in self._meta.video_keys:
video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key)
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
if not video_path.exists():
return False
@@ -240,7 +240,7 @@ class DatasetReader:
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
video_path = self._root / self._meta.get_video_file_path(ep_idx, vid_key)
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames(video_path, shifted_query_ts, self._tolerance_s, self._video_backend)
item[vid_key] = frames.squeeze(0)