Compare commits

..

10 Commits

Author SHA1 Message Date
Steven Palma
bcc13f1d90 try fix 9 2025-11-05 21:52:15 +01:00
Steven Palma
76f25f6afd try fix 8 2025-11-05 21:49:04 +01:00
Steven Palma
ce23681d4b try fix 7 2025-11-05 21:46:09 +01:00
Steven Palma
e195f8d287 try fix 6 2025-11-05 21:42:31 +01:00
Steven Palma
bbcffc4999 try fix 5 2025-11-05 21:34:10 +01:00
Steven Palma
20333abc72 try fix 4 2025-11-05 21:26:52 +01:00
Steven Palma
00a4e6bfb3 try fix 3 2025-11-05 21:09:53 +01:00
Steven Palma
a19bd6e84d try fix 3 2025-11-05 21:08:23 +01:00
Steven Palma
550866a3c5 try fix 2 2025-11-05 20:49:29 +01:00
Steven Palma
3ec4e4ce37 try fix 2025-11-05 20:24:47 +01:00
5 changed files with 190 additions and 288 deletions

View File

@@ -15,10 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
import shutil
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import pandas as pd
@@ -109,7 +107,6 @@ def update_meta_data(
dst_meta,
meta_idx,
data_idx,
data_file_map,
videos_idx,
):
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
@@ -130,25 +127,8 @@ def update_meta_data(
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
# Remap data chunk/file indices per-source-file using the actual destination
# file chosen during data aggregation. A flat offset is incorrect when
# multiple source files are concatenated into a single destination file.
if data_file_map:
new_data_chunk = []
new_data_file = []
for idx in df.index:
src_chunk = int(df.at[idx, "data/chunk_index"]) # original source file location
src_file = int(df.at[idx, "data/file_index"]) # original source file location
dst_chunk, dst_file = data_file_map.get(
(src_chunk, src_file), (src_chunk + data_idx["chunk"], src_file + data_idx["file"])
)
new_data_chunk.append(dst_chunk)
new_data_file.append(dst_file)
df["data/chunk_index"] = new_data_chunk
df["data/file_index"] = new_data_file
else:
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
for key, video_idx in videos_idx.items():
# Store original video file indices before updating
orig_chunk_col = f"videos/{key}/chunk_index"
@@ -186,7 +166,7 @@ def update_meta_data(
return df
def _aggregate_datasets(
def aggregate_datasets(
repo_ids: list[str],
aggr_repo_id: str,
roots: list[Path] | None = None,
@@ -195,24 +175,39 @@ def _aggregate_datasets(
video_files_size_in_mb: float | None = None,
chunk_size: int | None = None,
):
"""Serial aggregation kernel: combines datasets into a destination dataset.
"""Aggregates multiple LeRobot datasets into a single unified dataset.
This function performs a single-process aggregation. It assumes it is the
sole writer for its destination `aggr_root`.
This is the main function that orchestrates the aggregation process by:
1. Loading and validating all source dataset metadata
2. Creating a new destination dataset with unified tasks
3. Aggregating videos, data, and metadata from all source datasets
4. Finalizing the aggregated dataset with proper statistics
Args:
repo_ids: List of repository IDs for the datasets to aggregate.
aggr_repo_id: Repository ID for the aggregated output dataset.
roots: Optional list of root paths for the source datasets.
aggr_root: Optional root path for the aggregated dataset.
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
"""
# Build metadata objects, supporting a per-dataset "root" that may be None.
# When root is provided we load from the local filesystem, otherwise from Hub cache.
if roots is None:
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
else:
all_metadata = [
(
LeRobotDatasetMetadata(repo_id, root=root)
if root is not None
else LeRobotDatasetMetadata(repo_id)
)
for repo_id, root in zip(repo_ids, roots, strict=False)
logging.info("Start aggregate_datasets")
if data_files_size_in_mb is None:
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
if video_files_size_in_mb is None:
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
if chunk_size is None:
chunk_size = DEFAULT_CHUNK_SIZE
all_metadata = (
[LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
if roots is None
else [
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
]
)
fps, robot_type, features = validate_all_metadata(all_metadata)
video_keys = [key for key in features if features[key]["dtype"] == "video"]
@@ -242,11 +237,9 @@ def _aggregate_datasets(
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
data_idx, data_file_map = aggregate_data(
src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size
)
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, data_file_map, videos_idx)
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
dst_meta.info["total_episodes"] += src_meta.total_episodes
dst_meta.info["total_frames"] += src_meta.total_frames
@@ -255,168 +248,6 @@ def _aggregate_datasets(
logging.info("Aggregation complete.")
def aggregate_datasets(
repo_ids: list[str],
aggr_repo_id: str,
roots: list[Path] | None = None,
aggr_root: Path | None = None,
data_files_size_in_mb: float | None = None,
video_files_size_in_mb: float | None = None,
chunk_size: int | None = None,
num_workers: int | None = None,
tmp_root: Path | None = None,
):
"""Aggregates multiple LeRobot datasets into a single unified dataset.
This is the main function that orchestrates the aggregation process by:
1. Loading and validating all source dataset metadata
2. Creating a new destination dataset with unified tasks
3. Aggregating videos, data, and metadata from all source datasets
4. Finalizing the aggregated dataset with proper statistics
Args:
repo_ids: List of repository IDs for the datasets to aggregate.
aggr_repo_id: Repository ID for the aggregated output dataset.
roots: Optional list of root paths for the source datasets.
aggr_root: Optional root path for the aggregated dataset.
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
num_workers: When > 1, performs a tree-based parallel reduction using a thread pool
tmp_root: Optional base directory to store intermediate reduction outputs
"""
logging.info("Start aggregate_datasets")
if data_files_size_in_mb is None:
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
if video_files_size_in_mb is None:
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
if chunk_size is None:
chunk_size = DEFAULT_CHUNK_SIZE
if num_workers is None or num_workers <= 1:
# Run aggregation sequentially
_aggregate_datasets(
repo_ids=repo_ids,
aggr_repo_id=aggr_repo_id,
aggr_root=aggr_root,
roots=roots,
data_files_size_in_mb=data_files_size_in_mb,
video_files_size_in_mb=video_files_size_in_mb,
chunk_size=chunk_size,
)
# Uses a parallel fan-out/fan-in strategy when num_workers is provided
elif num_workers > 1:
# Validate across all metadata early to fail fast
all_metadata_for_validation = (
[LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
if roots is None
else [
LeRobotDatasetMetadata(repo_id, root=root)
for repo_id, root in zip(repo_ids, roots, strict=False)
]
)
validate_all_metadata(all_metadata_for_validation)
# Clamp workers to a sensible upper bound (pairs per round)
num_workers = min(num_workers, max(1, len(repo_ids) // 2))
# Choose a base temporary root for intermediate merge results
if tmp_root is not None:
base_tmp_root = tmp_root
elif aggr_root is not None:
base_tmp_root = aggr_root.parent / f".{aggr_repo_id}__tmp"
else:
base_tmp_root = Path.cwd() / f".{aggr_repo_id}__tmp"
base_tmp_root.mkdir(parents=True, exist_ok=True)
current_repo_ids: list[str] = list(repo_ids)
# Always maintain a roots list aligned with repo_ids. Use None for Hub-backed inputs.
current_roots: list[Path | None] = list(roots) if roots is not None else [None] * len(repo_ids)
try:
level = 0
while len(current_repo_ids) > 1:
next_repo_ids: list[str] = []
next_roots: list[Path | None] = []
futures = []
with ThreadPoolExecutor(max_workers=num_workers) as executor:
group_index = 0
i = 0
while i < len(current_repo_ids):
group_repo_ids = current_repo_ids[i : i + 2]
group_roots = current_roots[i : i + 2]
if len(group_repo_ids) == 1:
# Carry over singleton to next level
next_repo_ids.append(group_repo_ids[0])
next_roots.append(group_roots[0])
i += 1
continue
out_repo_id = f"{aggr_repo_id}__reduce_l{level}_g{group_index}"
out_root = base_tmp_root / f"reduce_l{level}_g{group_index}"
futures.append(
executor.submit(
_aggregate_datasets,
group_repo_ids,
out_repo_id,
group_roots,
out_root,
data_files_size_in_mb,
video_files_size_in_mb,
chunk_size,
)
)
next_repo_ids.append(out_repo_id)
next_roots.append(out_root)
i += 2
group_index += 1
for f in as_completed(futures):
# Bubble up any exception raised inside tasks
f.result()
# Cleanup previous level temporary outputs that won't be used again
base_resolved = base_tmp_root.resolve()
keep_set = {nr.resolve() for nr in next_roots if nr is not None}
for prev_root in current_roots:
if prev_root is None:
continue
# Suppress per-iteration to keep cleaning other roots even if one fails
with contextlib.suppress(Exception):
pr = prev_root.resolve()
if pr not in keep_set and base_resolved in pr.parents:
shutil.rmtree(prev_root, ignore_errors=True)
current_repo_ids = next_repo_ids
current_roots = next_roots # aligned list of Path|None after first level
level += 1
# Final copy/aggregation into the desired output
_aggregate_datasets(
repo_ids=current_repo_ids,
aggr_repo_id=aggr_repo_id,
roots=current_roots,
aggr_root=aggr_root,
data_files_size_in_mb=data_files_size_in_mb,
video_files_size_in_mb=video_files_size_in_mb,
chunk_size=chunk_size,
)
finally:
# Remove all temporary reduction artifacts
with contextlib.suppress(Exception):
shutil.rmtree(base_tmp_root, ignore_errors=True)
logging.info("Aggregation complete.")
return
def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size):
"""Aggregates video chunks from a source dataset into the destination dataset.
@@ -535,9 +366,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
unique_chunk_file_ids = sorted(unique_chunk_file_ids)
# Map source (chunk,file) -> destination (chunk,file) actually used during write
src_to_dst_file: dict[tuple[int, int], tuple[int, int]] = {}
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
chunk_index=src_chunk_idx, file_index=src_file_idx
@@ -545,7 +373,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
df = pd.read_parquet(src_path)
df = update_data_df(df, src_meta, dst_meta)
data_idx, used_chunk, used_file = append_or_create_parquet_file(
data_idx = append_or_create_parquet_file(
df,
src_path,
data_idx,
@@ -555,12 +383,11 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
contains_images=len(dst_meta.image_keys) > 0,
aggr_root=dst_meta.root,
)
src_to_dst_file[(src_chunk_idx, src_file_idx)] = (used_chunk, used_file)
return data_idx, src_to_dst_file
return data_idx
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, data_file_map, videos_idx):
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
"""Aggregates metadata from a source dataset into the destination dataset.
Reads source metadata files, updates all indices and timestamps,
@@ -594,11 +421,10 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, data_file_map, vi
dst_meta,
meta_idx,
data_idx,
data_file_map,
videos_idx,
)
meta_idx, _m_used_chunk, _m_used_file = append_or_create_parquet_file(
meta_idx = append_or_create_parquet_file(
df,
src_path,
meta_idx,
@@ -652,7 +478,7 @@ def append_or_create_parquet_file(
to_parquet_with_hf_images(df, dst_path)
else:
df.to_parquet(dst_path)
return idx, idx["chunk"], idx["file"]
return idx
src_size = get_parquet_file_size_in_mb(src_path)
dst_size = get_parquet_file_size_in_mb(dst_path)
@@ -663,19 +489,17 @@ def append_or_create_parquet_file(
new_path.parent.mkdir(parents=True, exist_ok=True)
final_df = df
target_path = new_path
used_chunk, used_file = idx["chunk"], idx["file"]
else:
existing_df = pd.read_parquet(dst_path)
final_df = pd.concat([existing_df, df], ignore_index=True)
target_path = dst_path
used_chunk, used_file = idx["chunk"], idx["file"]
if contains_images:
to_parquet_with_hf_images(final_df, target_path)
else:
final_df.to_parquet(target_path)
return idx, used_chunk, used_file
return idx
def finalize_aggregation(aggr_meta, all_metadata):

View File

@@ -234,7 +234,6 @@ def merge_datasets(
datasets: list[LeRobotDataset],
output_repo_id: str,
output_dir: str | Path | None = None,
num_workers: int | None = None,
) -> LeRobotDataset:
"""Merge multiple LeRobotDatasets into a single dataset.
@@ -258,7 +257,6 @@ def merge_datasets(
aggr_repo_id=output_repo_id,
roots=roots,
aggr_root=output_dir,
num_workers=num_workers,
)
merged_dataset = LeRobotDataset(
@@ -331,7 +329,7 @@ def modify_features(
if repo_id is None:
repo_id = f"{dataset.repo_id}_modified"
output_dir = Path(output_dir, exists_ok=True) if output_dir is not None else HF_LEROBOT_HOME / repo_id
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
new_features = dataset.meta.features.copy()

View File

@@ -19,6 +19,7 @@ import shutil
import tempfile
from collections.abc import Callable
from pathlib import Path
from typing import Any
import datasets
import numpy as np
@@ -31,6 +32,8 @@ import torch
import torch.utils
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage
from torchvision import transforms
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
@@ -50,11 +53,9 @@ from lerobot.datasets.utils import (
get_file_size_in_mb,
get_hf_features_from_features,
get_safe_version,
hf_transform_to_torch,
is_valid_version,
load_episodes,
load_info,
load_nested_dataset,
load_stats,
load_tasks,
update_chunk_file_indices,
@@ -79,6 +80,51 @@ from lerobot.utils.constants import HF_LEROBOT_HOME
CODEBASE_VERSION = "v3.0"
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
"""
Converts a batch from a Hugging Face dataset to torch tensors.
"""
# Create a single ToTensor transform instance to reuse
to_tensor = transforms.ToTensor()
for key in items_dict:
items_list = items_dict[key]
# Check if the list is non-empty
if not items_list:
continue
first_item = items_list[0]
if isinstance(first_item, PILImage.Image):
# This is the (slow) CPU-bound part.
# We convert every image in the batch list to a tensor.
items_dict[key] = [to_tensor(img) for img in items_list]
elif isinstance(first_item, (str, bytes)):
# List of strings (e.g., 'task'), do nothing
pass
elif first_item is None:
# List of Nones, do nothing
pass
else:
# List of other things (int, float, list, np.ndarray)
try:
# Convert each item in the list to a tensor
items_dict[key] = [torch.tensor(item) for item in items_list]
except Exception as e:
# This catch is what was missing from the original v3.0 code
print(
f"Error converting batch['{key}'] to tensor. First item: {first_item}, Type: {type(first_item)}"
)
raise e
return items_dict
class LeRobotDatasetMetadata:
def __init__(
self,
@@ -693,6 +739,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
)
# Pre-load episodes metadata into memory to avoid file I/O in __getitem__
self.episodes_metadata_list = [ep for ep in self.meta.episodes]
# Track dataset state for efficient incremental writing
self._lazy_loading = False
self._recorded_frames = self.meta.total_frames
@@ -829,8 +878,36 @@ class LeRobotDataset(torch.utils.data.Dataset):
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
features = get_hf_features_from_features(self.features)
hf_dataset = load_nested_dataset(self.root / "data", features=features)
if self.episodes is not None:
# Path for episode-specific loading (e.g., visualization)
fpaths = set()
for ep_idx in self.episodes:
ep_meta = self.episodes_metadata_list[ep_idx]
chunk_idx = ep_meta["data/chunk_index"]
file_idx = ep_meta["data/file_index"]
fpath_str = self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
fpaths.add(str(self.root / fpath_str))
data_files = sorted(list(fpaths))
hf_dataset = datasets.load_dataset(
"parquet", data_files=data_files, features=features, split="train"
)
requested_episodes_set = set(self.episodes)
hf_dataset = hf_dataset.filter(
lambda x: x["episode_index"] in requested_episodes_set, batched=True, batch_size=1000
)
else:
# THIS IS THE FAST PATH FOR TRAINING (self.episodes is None)
# Use `data_dir` to trigger the v2.1-style efficient cache.
data_dir = str(self.root / "data")
hf_dataset = datasets.load_dataset("parquet", data_dir=data_dir, features=features, split="train")
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
@@ -909,7 +986,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
return get_hf_features_from_features(self.features)
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
ep = self.meta.episodes[ep_idx]
ep = self.episodes_metadata_list[ep_idx]
ep_start = ep["dataset_from_index"]
ep_end = ep["dataset_to_index"]
query_indices = {
@@ -940,26 +1017,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
return query_timestamps
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
"""
Query dataset for indices across keys, skipping video keys.
Tries column-first [key][indices] for speed, falls back to row-first.
Args:
query_indices: Dict mapping keys to index lists to retrieve
Returns:
Dict with stacked tensors of queried data (video keys excluded)
"""
result: dict = {}
for key, q_idx in query_indices.items():
if key in self.meta.video_keys:
continue
try:
result[key] = torch.stack(self.hf_dataset[key][q_idx])
except (KeyError, TypeError, IndexError):
result[key] = torch.stack(self.hf_dataset[q_idx][key])
return result
return {
key: torch.stack(self.hf_dataset[q_idx][key])
for key, q_idx in query_indices.items()
if key not in self.meta.video_keys
}
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
@@ -967,7 +1029,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
the main process and a subprocess fails to access it.
"""
ep = self.meta.episodes[ep_idx]
ep = self.episodes_metadata_list[ep_idx]
item = {}
for vid_key, query_ts in query_timestamps.items():
# Episodes are stored sequentially on a single mp4 to reduce the number of files.
@@ -998,29 +1060,72 @@ class LeRobotDataset(torch.utils.data.Dataset):
def __getitem__(self, idx) -> dict:
# Ensure dataset is loaded when we actually need to read from it
self._ensure_hf_dataset_loaded()
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()
# 1. Get query indices if deltas are needed
query_indices = None
padding = {}
if self.delta_indices is not None:
query_indices, padding = self._get_query_indices(idx, ep_idx)
query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding}
for key, val in query_result.items():
item[key] = val
# We need the episode index *first* to get boundaries.
# This is a small read for just one item.
ep_idx_only = self.hf_dataset[idx : idx + 1]["episode_index"][0].item()
query_indices, padding = self._get_query_indices(idx, ep_idx_only)
# 2. Fetch all data (including images)
if query_indices is not None:
# --- Delta path ---
# Fetch all keys (state, action, AND images) for all deltas
item_batch = self.hf_dataset[query_indices["index"]]
# hf_transform_to_torch (item-level) has already run,
# so all values are tensors. We stack them.
item = {}
for key in item_batch:
item[key] = torch.stack(item_batch[key])
item.update(padding)
# Use the "current" item's index/timestamp/ep_idx
# (assuming 'index' is the key for the main query)
current_idx_in_batch = query_indices["index"].index(idx)
item["index"] = item["index"][current_idx_in_batch]
item["timestamp"] = item["timestamp"][current_idx_in_batch]
item["episode_index"] = item["episode_index"][current_idx_in_batch]
item["task_index"] = item["task_index"][current_idx_in_batch]
ep_idx = item["episode_index"].item()
else:
# --- Single-frame path ---
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()
# 3. Handle videos (which are always separate)
if len(self.meta.video_keys) > 0:
current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
current_ts = (
item["timestamp"].item()
if query_indices is None
else item["timestamp"][current_idx_in_batch].item()
)
video_query_indices = query_indices
if video_query_indices is None:
# If no deltas, create a dummy query_indices for _get_query_timestamps
video_query_indices = {key: [idx] for key in self.meta.video_keys}
query_timestamps = self._get_query_timestamps(current_ts, video_query_indices)
video_frames = self._query_videos(query_timestamps, ep_idx)
# video_frames are already stacked tensors (B, C, H, W) or (C, H, W)
item = {**video_frames, **item}
# 4. Apply image transforms
if self.image_transforms is not None:
image_keys = self.meta.camera_keys
for cam in image_keys:
item[cam] = self.image_transforms(item[cam])
if cam in item: # videos or images
item[cam] = self.image_transforms(item[cam])
# Add task as a string
# 5. Add task string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks.iloc[task_idx].name
return item

View File

@@ -35,7 +35,6 @@ from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage
from torchvision import transforms
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.backward_compatibility import (
@@ -116,10 +115,15 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None)
if len(paths) == 0:
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
# Convert Path objects to a list of strings
file_paths = [str(path) for path in paths]
# Use datasets.load_dataset to force creation of an efficient cache
# This pre-decodes the images and avoids the on-the-fly bottleneck.
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
with SuppressProgressBars():
datasets = Dataset.from_parquet([str(path) for path in paths], features=features)
return datasets
dataset = datasets.load_dataset("parquet", data_files=file_paths, features=features, split="train")
return dataset
def get_parquet_num_frames(parquet_path: str | Path) -> int:
@@ -394,33 +398,6 @@ def load_image_as_numpy(
return img_array
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
"""Convert a batch from a Hugging Face dataset to torch tensors.
This transform function converts items from Hugging Face dataset format (pyarrow)
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
types are converted to torch.tensor.
Args:
items_dict (dict): A dictionary representing a batch of data from a
Hugging Face dataset.
Returns:
dict: The batch with items converted to torch tensors.
"""
for key in items_dict:
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif first_item is None:
pass
else:
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
return items_dict
def is_valid_version(version: str) -> bool:
"""Check if a string is a valid PEP 440 version.

View File

@@ -103,7 +103,6 @@ class SplitConfig:
class MergeConfig:
type: str = "merge"
repo_ids: list[str] | None = None
num_workers: int | None = None
@dataclass
@@ -216,7 +215,6 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
datasets,
output_repo_id=cfg.repo_id,
output_dir=output_dir,
num_workers=cfg.operation.num_workers,
)
logging.info(f"Merged dataset saved to {output_dir}")