mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
Compare commits
10 Commits
feat/fraca
...
fix/datase
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bcc13f1d90 | ||
|
|
76f25f6afd | ||
|
|
ce23681d4b | ||
|
|
e195f8d287 | ||
|
|
bbcffc4999 | ||
|
|
20333abc72 | ||
|
|
00a4e6bfb3 | ||
|
|
a19bd6e84d | ||
|
|
550866a3c5 | ||
|
|
3ec4e4ce37 |
@@ -15,10 +15,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import shutil
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
@@ -109,7 +107,6 @@ def update_meta_data(
|
||||
dst_meta,
|
||||
meta_idx,
|
||||
data_idx,
|
||||
data_file_map,
|
||||
videos_idx,
|
||||
):
|
||||
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
|
||||
@@ -130,25 +127,8 @@ def update_meta_data(
|
||||
|
||||
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
||||
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
|
||||
# Remap data chunk/file indices per-source-file using the actual destination
|
||||
# file chosen during data aggregation. A flat offset is incorrect when
|
||||
# multiple source files are concatenated into a single destination file.
|
||||
if data_file_map:
|
||||
new_data_chunk = []
|
||||
new_data_file = []
|
||||
for idx in df.index:
|
||||
src_chunk = int(df.at[idx, "data/chunk_index"]) # original source file location
|
||||
src_file = int(df.at[idx, "data/file_index"]) # original source file location
|
||||
dst_chunk, dst_file = data_file_map.get(
|
||||
(src_chunk, src_file), (src_chunk + data_idx["chunk"], src_file + data_idx["file"])
|
||||
)
|
||||
new_data_chunk.append(dst_chunk)
|
||||
new_data_file.append(dst_file)
|
||||
df["data/chunk_index"] = new_data_chunk
|
||||
df["data/file_index"] = new_data_file
|
||||
else:
|
||||
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
||||
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
||||
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
||||
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
||||
for key, video_idx in videos_idx.items():
|
||||
# Store original video file indices before updating
|
||||
orig_chunk_col = f"videos/{key}/chunk_index"
|
||||
@@ -186,7 +166,7 @@ def update_meta_data(
|
||||
return df
|
||||
|
||||
|
||||
def _aggregate_datasets(
|
||||
def aggregate_datasets(
|
||||
repo_ids: list[str],
|
||||
aggr_repo_id: str,
|
||||
roots: list[Path] | None = None,
|
||||
@@ -195,24 +175,39 @@ def _aggregate_datasets(
|
||||
video_files_size_in_mb: float | None = None,
|
||||
chunk_size: int | None = None,
|
||||
):
|
||||
"""Serial aggregation kernel: combines datasets into a destination dataset.
|
||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||
|
||||
This function performs a single-process aggregation. It assumes it is the
|
||||
sole writer for its destination `aggr_root`.
|
||||
This is the main function that orchestrates the aggregation process by:
|
||||
1. Loading and validating all source dataset metadata
|
||||
2. Creating a new destination dataset with unified tasks
|
||||
3. Aggregating videos, data, and metadata from all source datasets
|
||||
4. Finalizing the aggregated dataset with proper statistics
|
||||
|
||||
Args:
|
||||
repo_ids: List of repository IDs for the datasets to aggregate.
|
||||
aggr_repo_id: Repository ID for the aggregated output dataset.
|
||||
roots: Optional list of root paths for the source datasets.
|
||||
aggr_root: Optional root path for the aggregated dataset.
|
||||
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
"""
|
||||
# Build metadata objects, supporting a per-dataset "root" that may be None.
|
||||
# When root is provided we load from the local filesystem, otherwise from Hub cache.
|
||||
if roots is None:
|
||||
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
|
||||
else:
|
||||
all_metadata = [
|
||||
(
|
||||
LeRobotDatasetMetadata(repo_id, root=root)
|
||||
if root is not None
|
||||
else LeRobotDatasetMetadata(repo_id)
|
||||
)
|
||||
for repo_id, root in zip(repo_ids, roots, strict=False)
|
||||
logging.info("Start aggregate_datasets")
|
||||
|
||||
if data_files_size_in_mb is None:
|
||||
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||
if video_files_size_in_mb is None:
|
||||
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
if chunk_size is None:
|
||||
chunk_size = DEFAULT_CHUNK_SIZE
|
||||
|
||||
all_metadata = (
|
||||
[LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
|
||||
if roots is None
|
||||
else [
|
||||
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
|
||||
]
|
||||
)
|
||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||
|
||||
@@ -242,11 +237,9 @@ def _aggregate_datasets(
|
||||
|
||||
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
|
||||
data_idx, data_file_map = aggregate_data(
|
||||
src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size
|
||||
)
|
||||
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
|
||||
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, data_file_map, videos_idx)
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
||||
|
||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||
@@ -255,168 +248,6 @@ def _aggregate_datasets(
|
||||
logging.info("Aggregation complete.")
|
||||
|
||||
|
||||
def aggregate_datasets(
|
||||
repo_ids: list[str],
|
||||
aggr_repo_id: str,
|
||||
roots: list[Path] | None = None,
|
||||
aggr_root: Path | None = None,
|
||||
data_files_size_in_mb: float | None = None,
|
||||
video_files_size_in_mb: float | None = None,
|
||||
chunk_size: int | None = None,
|
||||
num_workers: int | None = None,
|
||||
tmp_root: Path | None = None,
|
||||
):
|
||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||
|
||||
This is the main function that orchestrates the aggregation process by:
|
||||
1. Loading and validating all source dataset metadata
|
||||
2. Creating a new destination dataset with unified tasks
|
||||
3. Aggregating videos, data, and metadata from all source datasets
|
||||
4. Finalizing the aggregated dataset with proper statistics
|
||||
|
||||
Args:
|
||||
repo_ids: List of repository IDs for the datasets to aggregate.
|
||||
aggr_repo_id: Repository ID for the aggregated output dataset.
|
||||
roots: Optional list of root paths for the source datasets.
|
||||
aggr_root: Optional root path for the aggregated dataset.
|
||||
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
num_workers: When > 1, performs a tree-based parallel reduction using a thread pool
|
||||
tmp_root: Optional base directory to store intermediate reduction outputs
|
||||
"""
|
||||
logging.info("Start aggregate_datasets")
|
||||
|
||||
if data_files_size_in_mb is None:
|
||||
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||
if video_files_size_in_mb is None:
|
||||
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
if chunk_size is None:
|
||||
chunk_size = DEFAULT_CHUNK_SIZE
|
||||
|
||||
if num_workers is None or num_workers <= 1:
|
||||
# Run aggregation sequentially
|
||||
_aggregate_datasets(
|
||||
repo_ids=repo_ids,
|
||||
aggr_repo_id=aggr_repo_id,
|
||||
aggr_root=aggr_root,
|
||||
roots=roots,
|
||||
data_files_size_in_mb=data_files_size_in_mb,
|
||||
video_files_size_in_mb=video_files_size_in_mb,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# Uses a parallel fan-out/fan-in strategy when num_workers is provided
|
||||
elif num_workers > 1:
|
||||
# Validate across all metadata early to fail fast
|
||||
all_metadata_for_validation = (
|
||||
[LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
|
||||
if roots is None
|
||||
else [
|
||||
LeRobotDatasetMetadata(repo_id, root=root)
|
||||
for repo_id, root in zip(repo_ids, roots, strict=False)
|
||||
]
|
||||
)
|
||||
validate_all_metadata(all_metadata_for_validation)
|
||||
|
||||
# Clamp workers to a sensible upper bound (pairs per round)
|
||||
num_workers = min(num_workers, max(1, len(repo_ids) // 2))
|
||||
|
||||
# Choose a base temporary root for intermediate merge results
|
||||
if tmp_root is not None:
|
||||
base_tmp_root = tmp_root
|
||||
elif aggr_root is not None:
|
||||
base_tmp_root = aggr_root.parent / f".{aggr_repo_id}__tmp"
|
||||
else:
|
||||
base_tmp_root = Path.cwd() / f".{aggr_repo_id}__tmp"
|
||||
base_tmp_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
current_repo_ids: list[str] = list(repo_ids)
|
||||
# Always maintain a roots list aligned with repo_ids. Use None for Hub-backed inputs.
|
||||
current_roots: list[Path | None] = list(roots) if roots is not None else [None] * len(repo_ids)
|
||||
|
||||
try:
|
||||
level = 0
|
||||
while len(current_repo_ids) > 1:
|
||||
next_repo_ids: list[str] = []
|
||||
next_roots: list[Path | None] = []
|
||||
futures = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
group_index = 0
|
||||
i = 0
|
||||
while i < len(current_repo_ids):
|
||||
group_repo_ids = current_repo_ids[i : i + 2]
|
||||
group_roots = current_roots[i : i + 2]
|
||||
|
||||
if len(group_repo_ids) == 1:
|
||||
# Carry over singleton to next level
|
||||
next_repo_ids.append(group_repo_ids[0])
|
||||
next_roots.append(group_roots[0])
|
||||
i += 1
|
||||
continue
|
||||
|
||||
out_repo_id = f"{aggr_repo_id}__reduce_l{level}_g{group_index}"
|
||||
out_root = base_tmp_root / f"reduce_l{level}_g{group_index}"
|
||||
|
||||
futures.append(
|
||||
executor.submit(
|
||||
_aggregate_datasets,
|
||||
group_repo_ids,
|
||||
out_repo_id,
|
||||
group_roots,
|
||||
out_root,
|
||||
data_files_size_in_mb,
|
||||
video_files_size_in_mb,
|
||||
chunk_size,
|
||||
)
|
||||
)
|
||||
|
||||
next_repo_ids.append(out_repo_id)
|
||||
next_roots.append(out_root)
|
||||
|
||||
i += 2
|
||||
group_index += 1
|
||||
|
||||
for f in as_completed(futures):
|
||||
# Bubble up any exception raised inside tasks
|
||||
f.result()
|
||||
|
||||
# Cleanup previous level temporary outputs that won't be used again
|
||||
base_resolved = base_tmp_root.resolve()
|
||||
keep_set = {nr.resolve() for nr in next_roots if nr is not None}
|
||||
for prev_root in current_roots:
|
||||
if prev_root is None:
|
||||
continue
|
||||
# Suppress per-iteration to keep cleaning other roots even if one fails
|
||||
with contextlib.suppress(Exception):
|
||||
pr = prev_root.resolve()
|
||||
if pr not in keep_set and base_resolved in pr.parents:
|
||||
shutil.rmtree(prev_root, ignore_errors=True)
|
||||
|
||||
current_repo_ids = next_repo_ids
|
||||
current_roots = next_roots # aligned list of Path|None after first level
|
||||
level += 1
|
||||
|
||||
# Final copy/aggregation into the desired output
|
||||
_aggregate_datasets(
|
||||
repo_ids=current_repo_ids,
|
||||
aggr_repo_id=aggr_repo_id,
|
||||
roots=current_roots,
|
||||
aggr_root=aggr_root,
|
||||
data_files_size_in_mb=data_files_size_in_mb,
|
||||
video_files_size_in_mb=video_files_size_in_mb,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
finally:
|
||||
# Remove all temporary reduction artifacts
|
||||
with contextlib.suppress(Exception):
|
||||
shutil.rmtree(base_tmp_root, ignore_errors=True)
|
||||
|
||||
logging.info("Aggregation complete.")
|
||||
return
|
||||
|
||||
|
||||
def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size):
|
||||
"""Aggregates video chunks from a source dataset into the destination dataset.
|
||||
|
||||
@@ -535,9 +366,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
|
||||
unique_chunk_file_ids = sorted(unique_chunk_file_ids)
|
||||
|
||||
# Map source (chunk,file) -> destination (chunk,file) actually used during write
|
||||
src_to_dst_file: dict[tuple[int, int], tuple[int, int]] = {}
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
|
||||
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
||||
chunk_index=src_chunk_idx, file_index=src_file_idx
|
||||
@@ -545,7 +373,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
df = pd.read_parquet(src_path)
|
||||
df = update_data_df(df, src_meta, dst_meta)
|
||||
|
||||
data_idx, used_chunk, used_file = append_or_create_parquet_file(
|
||||
data_idx = append_or_create_parquet_file(
|
||||
df,
|
||||
src_path,
|
||||
data_idx,
|
||||
@@ -555,12 +383,11 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
contains_images=len(dst_meta.image_keys) > 0,
|
||||
aggr_root=dst_meta.root,
|
||||
)
|
||||
src_to_dst_file[(src_chunk_idx, src_file_idx)] = (used_chunk, used_file)
|
||||
|
||||
return data_idx, src_to_dst_file
|
||||
return data_idx
|
||||
|
||||
|
||||
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, data_file_map, videos_idx):
|
||||
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
"""Aggregates metadata from a source dataset into the destination dataset.
|
||||
|
||||
Reads source metadata files, updates all indices and timestamps,
|
||||
@@ -594,11 +421,10 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, data_file_map, vi
|
||||
dst_meta,
|
||||
meta_idx,
|
||||
data_idx,
|
||||
data_file_map,
|
||||
videos_idx,
|
||||
)
|
||||
|
||||
meta_idx, _m_used_chunk, _m_used_file = append_or_create_parquet_file(
|
||||
meta_idx = append_or_create_parquet_file(
|
||||
df,
|
||||
src_path,
|
||||
meta_idx,
|
||||
@@ -652,7 +478,7 @@ def append_or_create_parquet_file(
|
||||
to_parquet_with_hf_images(df, dst_path)
|
||||
else:
|
||||
df.to_parquet(dst_path)
|
||||
return idx, idx["chunk"], idx["file"]
|
||||
return idx
|
||||
|
||||
src_size = get_parquet_file_size_in_mb(src_path)
|
||||
dst_size = get_parquet_file_size_in_mb(dst_path)
|
||||
@@ -663,19 +489,17 @@ def append_or_create_parquet_file(
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
final_df = df
|
||||
target_path = new_path
|
||||
used_chunk, used_file = idx["chunk"], idx["file"]
|
||||
else:
|
||||
existing_df = pd.read_parquet(dst_path)
|
||||
final_df = pd.concat([existing_df, df], ignore_index=True)
|
||||
target_path = dst_path
|
||||
used_chunk, used_file = idx["chunk"], idx["file"]
|
||||
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(final_df, target_path)
|
||||
else:
|
||||
final_df.to_parquet(target_path)
|
||||
|
||||
return idx, used_chunk, used_file
|
||||
return idx
|
||||
|
||||
|
||||
def finalize_aggregation(aggr_meta, all_metadata):
|
||||
|
||||
@@ -234,7 +234,6 @@ def merge_datasets(
|
||||
datasets: list[LeRobotDataset],
|
||||
output_repo_id: str,
|
||||
output_dir: str | Path | None = None,
|
||||
num_workers: int | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Merge multiple LeRobotDatasets into a single dataset.
|
||||
|
||||
@@ -258,7 +257,6 @@ def merge_datasets(
|
||||
aggr_repo_id=output_repo_id,
|
||||
roots=roots,
|
||||
aggr_root=output_dir,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
merged_dataset = LeRobotDataset(
|
||||
@@ -331,7 +329,7 @@ def modify_features(
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_modified"
|
||||
output_dir = Path(output_dir, exists_ok=True) if output_dir is not None else HF_LEROBOT_HOME / repo_id
|
||||
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
new_features = dataset.meta.features.copy()
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import shutil
|
||||
import tempfile
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
@@ -31,6 +32,8 @@ import torch
|
||||
import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
|
||||
@@ -50,11 +53,9 @@ from lerobot.datasets.utils import (
|
||||
get_file_size_in_mb,
|
||||
get_hf_features_from_features,
|
||||
get_safe_version,
|
||||
hf_transform_to_torch,
|
||||
is_valid_version,
|
||||
load_episodes,
|
||||
load_info,
|
||||
load_nested_dataset,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
update_chunk_file_indices,
|
||||
@@ -79,6 +80,51 @@ from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
|
||||
"""
|
||||
Converts a batch from a Hugging Face dataset to torch tensors.
|
||||
"""
|
||||
|
||||
# Create a single ToTensor transform instance to reuse
|
||||
to_tensor = transforms.ToTensor()
|
||||
|
||||
for key in items_dict:
|
||||
items_list = items_dict[key]
|
||||
|
||||
# Check if the list is non-empty
|
||||
if not items_list:
|
||||
continue
|
||||
|
||||
first_item = items_list[0]
|
||||
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
# This is the (slow) CPU-bound part.
|
||||
# We convert every image in the batch list to a tensor.
|
||||
items_dict[key] = [to_tensor(img) for img in items_list]
|
||||
|
||||
elif isinstance(first_item, (str, bytes)):
|
||||
# List of strings (e.g., 'task'), do nothing
|
||||
pass
|
||||
|
||||
elif first_item is None:
|
||||
# List of Nones, do nothing
|
||||
pass
|
||||
|
||||
else:
|
||||
# List of other things (int, float, list, np.ndarray)
|
||||
try:
|
||||
# Convert each item in the list to a tensor
|
||||
items_dict[key] = [torch.tensor(item) for item in items_list]
|
||||
except Exception as e:
|
||||
# This catch is what was missing from the original v3.0 code
|
||||
print(
|
||||
f"Error converting batch['{key}'] to tensor. First item: {first_item}, Type: {type(first_item)}"
|
||||
)
|
||||
raise e
|
||||
|
||||
return items_dict
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -693,6 +739,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
|
||||
# Pre-load episodes metadata into memory to avoid file I/O in __getitem__
|
||||
self.episodes_metadata_list = [ep for ep in self.meta.episodes]
|
||||
|
||||
# Track dataset state for efficient incremental writing
|
||||
self._lazy_loading = False
|
||||
self._recorded_frames = self.meta.total_frames
|
||||
@@ -829,8 +878,36 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
|
||||
features = get_hf_features_from_features(self.features)
|
||||
hf_dataset = load_nested_dataset(self.root / "data", features=features)
|
||||
|
||||
if self.episodes is not None:
|
||||
# Path for episode-specific loading (e.g., visualization)
|
||||
fpaths = set()
|
||||
for ep_idx in self.episodes:
|
||||
ep_meta = self.episodes_metadata_list[ep_idx]
|
||||
chunk_idx = ep_meta["data/chunk_index"]
|
||||
file_idx = ep_meta["data/file_index"]
|
||||
fpath_str = self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
fpaths.add(str(self.root / fpath_str))
|
||||
|
||||
data_files = sorted(list(fpaths))
|
||||
|
||||
hf_dataset = datasets.load_dataset(
|
||||
"parquet", data_files=data_files, features=features, split="train"
|
||||
)
|
||||
|
||||
requested_episodes_set = set(self.episodes)
|
||||
hf_dataset = hf_dataset.filter(
|
||||
lambda x: x["episode_index"] in requested_episodes_set, batched=True, batch_size=1000
|
||||
)
|
||||
|
||||
else:
|
||||
# THIS IS THE FAST PATH FOR TRAINING (self.episodes is None)
|
||||
# Use `data_dir` to trigger the v2.1-style efficient cache.
|
||||
data_dir = str(self.root / "data")
|
||||
hf_dataset = datasets.load_dataset("parquet", data_dir=data_dir, features=features, split="train")
|
||||
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
@@ -909,7 +986,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return get_hf_features_from_features(self.features)
|
||||
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
ep = self.meta.episodes[ep_idx]
|
||||
ep = self.episodes_metadata_list[ep_idx]
|
||||
ep_start = ep["dataset_from_index"]
|
||||
ep_end = ep["dataset_to_index"]
|
||||
query_indices = {
|
||||
@@ -940,26 +1017,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return query_timestamps
|
||||
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
"""
|
||||
Query dataset for indices across keys, skipping video keys.
|
||||
|
||||
Tries column-first [key][indices] for speed, falls back to row-first.
|
||||
|
||||
Args:
|
||||
query_indices: Dict mapping keys to index lists to retrieve
|
||||
|
||||
Returns:
|
||||
Dict with stacked tensors of queried data (video keys excluded)
|
||||
"""
|
||||
result: dict = {}
|
||||
for key, q_idx in query_indices.items():
|
||||
if key in self.meta.video_keys:
|
||||
continue
|
||||
try:
|
||||
result[key] = torch.stack(self.hf_dataset[key][q_idx])
|
||||
except (KeyError, TypeError, IndexError):
|
||||
result[key] = torch.stack(self.hf_dataset[q_idx][key])
|
||||
return result
|
||||
return {
|
||||
key: torch.stack(self.hf_dataset[q_idx][key])
|
||||
for key, q_idx in query_indices.items()
|
||||
if key not in self.meta.video_keys
|
||||
}
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
@@ -967,7 +1029,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||
the main process and a subprocess fails to access it.
|
||||
"""
|
||||
ep = self.meta.episodes[ep_idx]
|
||||
ep = self.episodes_metadata_list[ep_idx]
|
||||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
# Episodes are stored sequentially on a single mp4 to reduce the number of files.
|
||||
@@ -998,29 +1060,72 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def __getitem__(self, idx) -> dict:
|
||||
# Ensure dataset is loaded when we actually need to read from it
|
||||
self._ensure_hf_dataset_loaded()
|
||||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
# 1. Get query indices if deltas are needed
|
||||
query_indices = None
|
||||
padding = {}
|
||||
if self.delta_indices is not None:
|
||||
query_indices, padding = self._get_query_indices(idx, ep_idx)
|
||||
query_result = self._query_hf_dataset(query_indices)
|
||||
item = {**item, **padding}
|
||||
for key, val in query_result.items():
|
||||
item[key] = val
|
||||
# We need the episode index *first* to get boundaries.
|
||||
# This is a small read for just one item.
|
||||
ep_idx_only = self.hf_dataset[idx : idx + 1]["episode_index"][0].item()
|
||||
query_indices, padding = self._get_query_indices(idx, ep_idx_only)
|
||||
|
||||
# 2. Fetch all data (including images)
|
||||
if query_indices is not None:
|
||||
# --- Delta path ---
|
||||
# Fetch all keys (state, action, AND images) for all deltas
|
||||
item_batch = self.hf_dataset[query_indices["index"]]
|
||||
|
||||
# hf_transform_to_torch (item-level) has already run,
|
||||
# so all values are tensors. We stack them.
|
||||
item = {}
|
||||
for key in item_batch:
|
||||
item[key] = torch.stack(item_batch[key])
|
||||
|
||||
item.update(padding)
|
||||
|
||||
# Use the "current" item's index/timestamp/ep_idx
|
||||
# (assuming 'index' is the key for the main query)
|
||||
current_idx_in_batch = query_indices["index"].index(idx)
|
||||
item["index"] = item["index"][current_idx_in_batch]
|
||||
item["timestamp"] = item["timestamp"][current_idx_in_batch]
|
||||
item["episode_index"] = item["episode_index"][current_idx_in_batch]
|
||||
item["task_index"] = item["task_index"][current_idx_in_batch]
|
||||
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
else:
|
||||
# --- Single-frame path ---
|
||||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
# 3. Handle videos (which are always separate)
|
||||
if len(self.meta.video_keys) > 0:
|
||||
current_ts = item["timestamp"].item()
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
current_ts = (
|
||||
item["timestamp"].item()
|
||||
if query_indices is None
|
||||
else item["timestamp"][current_idx_in_batch].item()
|
||||
)
|
||||
|
||||
video_query_indices = query_indices
|
||||
if video_query_indices is None:
|
||||
# If no deltas, create a dummy query_indices for _get_query_timestamps
|
||||
video_query_indices = {key: [idx] for key in self.meta.video_keys}
|
||||
|
||||
query_timestamps = self._get_query_timestamps(current_ts, video_query_indices)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
|
||||
# video_frames are already stacked tensors (B, C, H, W) or (C, H, W)
|
||||
item = {**video_frames, **item}
|
||||
|
||||
# 4. Apply image transforms
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.meta.camera_keys
|
||||
for cam in image_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
if cam in item: # videos or images
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
|
||||
# Add task as a string
|
||||
# 5. Add task string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks.iloc[task_idx].name
|
||||
return item
|
||||
|
||||
@@ -35,7 +35,6 @@ from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.backward_compatibility import (
|
||||
@@ -116,10 +115,15 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None)
|
||||
if len(paths) == 0:
|
||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||
|
||||
# Convert Path objects to a list of strings
|
||||
file_paths = [str(path) for path in paths]
|
||||
|
||||
# Use datasets.load_dataset to force creation of an efficient cache
|
||||
# This pre-decodes the images and avoids the on-the-fly bottleneck.
|
||||
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
|
||||
with SuppressProgressBars():
|
||||
datasets = Dataset.from_parquet([str(path) for path in paths], features=features)
|
||||
return datasets
|
||||
dataset = datasets.load_dataset("parquet", data_files=file_paths, features=features, split="train")
|
||||
return dataset
|
||||
|
||||
|
||||
def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
||||
@@ -394,33 +398,6 @@ def load_image_as_numpy(
|
||||
return img_array
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
|
||||
"""Convert a batch from a Hugging Face dataset to torch tensors.
|
||||
|
||||
This transform function converts items from Hugging Face dataset format (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
|
||||
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
|
||||
types are converted to torch.tensor.
|
||||
|
||||
Args:
|
||||
items_dict (dict): A dictionary representing a batch of data from a
|
||||
Hugging Face dataset.
|
||||
|
||||
Returns:
|
||||
dict: The batch with items converted to torch tensors.
|
||||
"""
|
||||
for key in items_dict:
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
return items_dict
|
||||
|
||||
|
||||
def is_valid_version(version: str) -> bool:
|
||||
"""Check if a string is a valid PEP 440 version.
|
||||
|
||||
|
||||
@@ -103,7 +103,6 @@ class SplitConfig:
|
||||
class MergeConfig:
|
||||
type: str = "merge"
|
||||
repo_ids: list[str] | None = None
|
||||
num_workers: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -216,7 +215,6 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
|
||||
datasets,
|
||||
output_repo_id=cfg.repo_id,
|
||||
output_dir=output_dir,
|
||||
num_workers=cfg.operation.num_workers,
|
||||
)
|
||||
|
||||
logging.info(f"Merged dataset saved to {output_dir}")
|
||||
|
||||
Reference in New Issue
Block a user