Compare commits

..

3 Commits

Author SHA1 Message Date
Michel Aractingi
8008cb357d remove bad typing 2025-11-06 09:13:26 +01:00
Michel Aractingi
ca5a4a7ae5 add typing hints 2025-11-06 09:12:09 +01:00
Michel Aractingi
b5dcd70d2c add embed images in conversion to v3 script; add parquet writer in conversion script 2025-11-05 23:41:38 +01:00
5 changed files with 73 additions and 259 deletions

View File

@@ -15,10 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
import shutil
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import pandas as pd
@@ -109,7 +107,6 @@ def update_meta_data(
dst_meta,
meta_idx,
data_idx,
data_file_map,
videos_idx,
):
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
@@ -130,25 +127,8 @@ def update_meta_data(
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
# Remap data chunk/file indices per-source-file using the actual destination
# file chosen during data aggregation. A flat offset is incorrect when
# multiple source files are concatenated into a single destination file.
if data_file_map:
new_data_chunk = []
new_data_file = []
for idx in df.index:
src_chunk = int(df.at[idx, "data/chunk_index"]) # original source file location
src_file = int(df.at[idx, "data/file_index"]) # original source file location
dst_chunk, dst_file = data_file_map.get(
(src_chunk, src_file), (src_chunk + data_idx["chunk"], src_file + data_idx["file"])
)
new_data_chunk.append(dst_chunk)
new_data_file.append(dst_file)
df["data/chunk_index"] = new_data_chunk
df["data/file_index"] = new_data_file
else:
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
for key, video_idx in videos_idx.items():
# Store original video file indices before updating
orig_chunk_col = f"videos/{key}/chunk_index"
@@ -186,7 +166,7 @@ def update_meta_data(
return df
def _aggregate_datasets(
def aggregate_datasets(
repo_ids: list[str],
aggr_repo_id: str,
roots: list[Path] | None = None,
@@ -195,24 +175,39 @@ def _aggregate_datasets(
video_files_size_in_mb: float | None = None,
chunk_size: int | None = None,
):
"""Serial aggregation kernel: combines datasets into a destination dataset.
"""Aggregates multiple LeRobot datasets into a single unified dataset.
This function performs a single-process aggregation. It assumes it is the
sole writer for its destination `aggr_root`.
This is the main function that orchestrates the aggregation process by:
1. Loading and validating all source dataset metadata
2. Creating a new destination dataset with unified tasks
3. Aggregating videos, data, and metadata from all source datasets
4. Finalizing the aggregated dataset with proper statistics
Args:
repo_ids: List of repository IDs for the datasets to aggregate.
aggr_repo_id: Repository ID for the aggregated output dataset.
roots: Optional list of root paths for the source datasets.
aggr_root: Optional root path for the aggregated dataset.
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
"""
# Build metadata objects, supporting a per-dataset "root" that may be None.
# When root is provided we load from the local filesystem, otherwise from Hub cache.
if roots is None:
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
else:
all_metadata = [
(
LeRobotDatasetMetadata(repo_id, root=root)
if root is not None
else LeRobotDatasetMetadata(repo_id)
)
for repo_id, root in zip(repo_ids, roots, strict=False)
logging.info("Start aggregate_datasets")
if data_files_size_in_mb is None:
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
if video_files_size_in_mb is None:
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
if chunk_size is None:
chunk_size = DEFAULT_CHUNK_SIZE
all_metadata = (
[LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
if roots is None
else [
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
]
)
fps, robot_type, features = validate_all_metadata(all_metadata)
video_keys = [key for key in features if features[key]["dtype"] == "video"]
@@ -242,11 +237,9 @@ def _aggregate_datasets(
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
data_idx, data_file_map = aggregate_data(
src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size
)
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, data_file_map, videos_idx)
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
dst_meta.info["total_episodes"] += src_meta.total_episodes
dst_meta.info["total_frames"] += src_meta.total_frames
@@ -255,168 +248,6 @@ def _aggregate_datasets(
logging.info("Aggregation complete.")
def aggregate_datasets(
repo_ids: list[str],
aggr_repo_id: str,
roots: list[Path] | None = None,
aggr_root: Path | None = None,
data_files_size_in_mb: float | None = None,
video_files_size_in_mb: float | None = None,
chunk_size: int | None = None,
num_workers: int | None = None,
tmp_root: Path | None = None,
):
"""Aggregates multiple LeRobot datasets into a single unified dataset.
This is the main function that orchestrates the aggregation process by:
1. Loading and validating all source dataset metadata
2. Creating a new destination dataset with unified tasks
3. Aggregating videos, data, and metadata from all source datasets
4. Finalizing the aggregated dataset with proper statistics
Args:
repo_ids: List of repository IDs for the datasets to aggregate.
aggr_repo_id: Repository ID for the aggregated output dataset.
roots: Optional list of root paths for the source datasets.
aggr_root: Optional root path for the aggregated dataset.
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
num_workers: When > 1, performs a tree-based parallel reduction using a thread pool
tmp_root: Optional base directory to store intermediate reduction outputs
"""
logging.info("Start aggregate_datasets")
if data_files_size_in_mb is None:
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
if video_files_size_in_mb is None:
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
if chunk_size is None:
chunk_size = DEFAULT_CHUNK_SIZE
if num_workers is None or num_workers <= 1:
# Run aggregation sequentially
_aggregate_datasets(
repo_ids=repo_ids,
aggr_repo_id=aggr_repo_id,
aggr_root=aggr_root,
roots=roots,
data_files_size_in_mb=data_files_size_in_mb,
video_files_size_in_mb=video_files_size_in_mb,
chunk_size=chunk_size,
)
# Uses a parallel fan-out/fan-in strategy when num_workers is provided
elif num_workers > 1:
# Validate across all metadata early to fail fast
all_metadata_for_validation = (
[LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
if roots is None
else [
LeRobotDatasetMetadata(repo_id, root=root)
for repo_id, root in zip(repo_ids, roots, strict=False)
]
)
validate_all_metadata(all_metadata_for_validation)
# Clamp workers to a sensible upper bound (pairs per round)
num_workers = min(num_workers, max(1, len(repo_ids) // 2))
# Choose a base temporary root for intermediate merge results
if tmp_root is not None:
base_tmp_root = tmp_root
elif aggr_root is not None:
base_tmp_root = aggr_root.parent / f".{aggr_repo_id}__tmp"
else:
base_tmp_root = Path.cwd() / f".{aggr_repo_id}__tmp"
base_tmp_root.mkdir(parents=True, exist_ok=True)
current_repo_ids: list[str] = list(repo_ids)
# Always maintain a roots list aligned with repo_ids. Use None for Hub-backed inputs.
current_roots: list[Path | None] = list(roots) if roots is not None else [None] * len(repo_ids)
try:
level = 0
while len(current_repo_ids) > 1:
next_repo_ids: list[str] = []
next_roots: list[Path | None] = []
futures = []
with ThreadPoolExecutor(max_workers=num_workers) as executor:
group_index = 0
i = 0
while i < len(current_repo_ids):
group_repo_ids = current_repo_ids[i : i + 2]
group_roots = current_roots[i : i + 2]
if len(group_repo_ids) == 1:
# Carry over singleton to next level
next_repo_ids.append(group_repo_ids[0])
next_roots.append(group_roots[0])
i += 1
continue
out_repo_id = f"{aggr_repo_id}__reduce_l{level}_g{group_index}"
out_root = base_tmp_root / f"reduce_l{level}_g{group_index}"
futures.append(
executor.submit(
_aggregate_datasets,
group_repo_ids,
out_repo_id,
group_roots,
out_root,
data_files_size_in_mb,
video_files_size_in_mb,
chunk_size,
)
)
next_repo_ids.append(out_repo_id)
next_roots.append(out_root)
i += 2
group_index += 1
for f in as_completed(futures):
# Bubble up any exception raised inside tasks
f.result()
# Cleanup previous level temporary outputs that won't be used again
base_resolved = base_tmp_root.resolve()
keep_set = {nr.resolve() for nr in next_roots if nr is not None}
for prev_root in current_roots:
if prev_root is None:
continue
# Suppress per-iteration to keep cleaning other roots even if one fails
with contextlib.suppress(Exception):
pr = prev_root.resolve()
if pr not in keep_set and base_resolved in pr.parents:
shutil.rmtree(prev_root, ignore_errors=True)
current_repo_ids = next_repo_ids
current_roots = next_roots # aligned list of Path|None after first level
level += 1
# Final copy/aggregation into the desired output
_aggregate_datasets(
repo_ids=current_repo_ids,
aggr_repo_id=aggr_repo_id,
roots=current_roots,
aggr_root=aggr_root,
data_files_size_in_mb=data_files_size_in_mb,
video_files_size_in_mb=video_files_size_in_mb,
chunk_size=chunk_size,
)
finally:
# Remove all temporary reduction artifacts
with contextlib.suppress(Exception):
shutil.rmtree(base_tmp_root, ignore_errors=True)
logging.info("Aggregation complete.")
return
def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size):
"""Aggregates video chunks from a source dataset into the destination dataset.
@@ -535,9 +366,6 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
unique_chunk_file_ids = sorted(unique_chunk_file_ids)
# Map source (chunk,file) -> destination (chunk,file) actually used during write
src_to_dst_file: dict[tuple[int, int], tuple[int, int]] = {}
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
chunk_index=src_chunk_idx, file_index=src_file_idx
@@ -545,7 +373,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
df = pd.read_parquet(src_path)
df = update_data_df(df, src_meta, dst_meta)
data_idx, used_chunk, used_file = append_or_create_parquet_file(
data_idx = append_or_create_parquet_file(
df,
src_path,
data_idx,
@@ -555,12 +383,11 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
contains_images=len(dst_meta.image_keys) > 0,
aggr_root=dst_meta.root,
)
src_to_dst_file[(src_chunk_idx, src_file_idx)] = (used_chunk, used_file)
return data_idx, src_to_dst_file
return data_idx
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, data_file_map, videos_idx):
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
"""Aggregates metadata from a source dataset into the destination dataset.
Reads source metadata files, updates all indices and timestamps,
@@ -594,11 +421,10 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, data_file_map, vi
dst_meta,
meta_idx,
data_idx,
data_file_map,
videos_idx,
)
meta_idx, _m_used_chunk, _m_used_file = append_or_create_parquet_file(
meta_idx = append_or_create_parquet_file(
df,
src_path,
meta_idx,
@@ -652,7 +478,7 @@ def append_or_create_parquet_file(
to_parquet_with_hf_images(df, dst_path)
else:
df.to_parquet(dst_path)
return idx, idx["chunk"], idx["file"]
return idx
src_size = get_parquet_file_size_in_mb(src_path)
dst_size = get_parquet_file_size_in_mb(dst_path)
@@ -663,19 +489,17 @@ def append_or_create_parquet_file(
new_path.parent.mkdir(parents=True, exist_ok=True)
final_df = df
target_path = new_path
used_chunk, used_file = idx["chunk"], idx["file"]
else:
existing_df = pd.read_parquet(dst_path)
final_df = pd.concat([existing_df, df], ignore_index=True)
target_path = dst_path
used_chunk, used_file = idx["chunk"], idx["file"]
if contains_images:
to_parquet_with_hf_images(final_df, target_path)
else:
final_df.to_parquet(target_path)
return idx, used_chunk, used_file
return idx
def finalize_aggregation(aggr_meta, all_metadata):

View File

@@ -234,7 +234,6 @@ def merge_datasets(
datasets: list[LeRobotDataset],
output_repo_id: str,
output_dir: str | Path | None = None,
num_workers: int | None = None,
) -> LeRobotDataset:
"""Merge multiple LeRobotDatasets into a single dataset.
@@ -258,7 +257,6 @@ def merge_datasets(
aggr_repo_id=output_repo_id,
roots=roots,
aggr_root=output_dir,
num_workers=num_workers,
)
merged_dataset = LeRobotDataset(
@@ -331,7 +329,7 @@ def modify_features(
if repo_id is None:
repo_id = f"{dataset.repo_id}_modified"
output_dir = Path(output_dir, exists_ok=True) if output_dir is not None else HF_LEROBOT_HOME / repo_id
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
new_features = dataset.meta.features.copy()

View File

@@ -940,26 +940,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
return query_timestamps
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
"""
Query dataset for indices across keys, skipping video keys.
Tries column-first [key][indices] for speed, falls back to row-first.
Args:
query_indices: Dict mapping keys to index lists to retrieve
Returns:
Dict with stacked tensors of queried data (video keys excluded)
"""
result: dict = {}
for key, q_idx in query_indices.items():
if key in self.meta.video_keys:
continue
try:
result[key] = torch.stack(self.hf_dataset[key][q_idx])
except (KeyError, TypeError, IndexError):
result[key] = torch.stack(self.hf_dataset[q_idx][key])
return result
return {
key: torch.stack(self.hf_dataset[q_idx][key])
for key, q_idx in query_indices.items()
if key not in self.meta.video_keys
}
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function

View File

@@ -50,9 +50,9 @@ from typing import Any
import jsonlines
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import tqdm
from datasets import Dataset, Features, Image
from datasets import Dataset, concatenate_datasets
from huggingface_hub import HfApi, snapshot_download
from requests import HTTPError
@@ -68,6 +68,7 @@ from lerobot.datasets.utils import (
LEGACY_EPISODES_STATS_PATH,
LEGACY_TASKS_PATH,
cast_stats_to_numpy,
embed_images,
flatten_dict,
get_file_size_in_mb,
get_parquet_file_size_in_mb,
@@ -174,25 +175,33 @@ def convert_tasks(root, new_root):
write_tasks(df_tasks, new_root)
def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys):
# TODO(rcadene): to save RAM use Dataset.from_parquet(file) and concatenate_datasets
dataframes = [pd.read_parquet(file) for file in paths_to_cat]
# Concatenate all DataFrames along rows
concatenated_df = pd.concat(dataframes, ignore_index=True)
def concat_data_files(
paths_to_cat: list[Path], new_root: Path, chunk_idx: int, file_idx: int, image_keys: list[str]
):
"""Concatenate multiple parquet data files into a single file.
Args:
paths_to_cat: List of parquet file paths to concatenate
new_root: Root directory for the new dataset
chunk_idx: Chunk index for the output file
file_idx: File index within the chunk
image_keys: List of feature keys that contain images
"""
datasets_list: list[Dataset] = [Dataset.from_parquet(str(file)) for file in paths_to_cat]
concatenated_ds: Dataset = concatenate_datasets(datasets_list)
if len(image_keys) > 0:
logging.debug(f"Embedding {len(image_keys)} image features for optimal training performance")
concatenated_ds = embed_images(concatenated_ds)
path = new_root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
if len(image_keys) > 0:
schema = pa.Schema.from_pandas(concatenated_df)
features = Features.from_arrow_schema(schema)
for key in image_keys:
features[key] = Image()
schema = features.arrow_schema
else:
schema = None
concatenated_df.to_parquet(path, index=False, schema=schema)
table = concatenated_ds.with_format("arrow")[:]
writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
writer.write_table(table)
writer.close()
def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):

View File

@@ -103,7 +103,6 @@ class SplitConfig:
class MergeConfig:
type: str = "merge"
repo_ids: list[str] | None = None
num_workers: int | None = None
@dataclass
@@ -216,7 +215,6 @@ def handle_merge(cfg: EditDatasetConfig) -> None:
datasets,
output_repo_id=cfg.repo_id,
output_dir=output_dir,
num_workers=cfg.operation.num_workers,
)
logging.info(f"Merged dataset saved to {output_dir}")