mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
* fix bug in `augment_dataset_quantile_stats.py` that was not detecting the image features because we were looping over hf_dataset. Now we loop over the dataset itself * Update src/lerobot/datasets/v30/augment_dataset_quantile_stats.py Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co> --------- Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
261 lines
8.6 KiB
Python
261 lines
8.6 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
This script augments existing LeRobot datasets with quantile statistics.
|
|
|
|
Most datasets created before the quantile feature was added do not contain
|
|
quantile statistics (q01, q10, q50, q90, q99) in their metadata. This script:
|
|
|
|
1. Loads an existing LeRobot dataset in v3.0 format
|
|
2. Checks if it already contains quantile statistics
|
|
3. If missing, computes quantile statistics for all features
|
|
4. Updates the dataset metadata with the new quantile statistics
|
|
|
|
Usage:
|
|
|
|
```bash
|
|
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
|
--repo-id=lerobot/pusht \
|
|
```
|
|
"""
|
|
|
|
import argparse
|
|
import concurrent.futures
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
from huggingface_hub import HfApi
|
|
from requests import HTTPError
|
|
from tqdm import tqdm
|
|
|
|
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
|
|
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
|
from lerobot.datasets.utils import write_stats
|
|
from lerobot.utils.utils import init_logging
|
|
|
|
|
|
def has_quantile_stats(stats: dict[str, dict] | None, quantile_list_keys: list[str] | None = None) -> bool:
|
|
"""Check if dataset statistics already contain quantile information.
|
|
|
|
Args:
|
|
stats: Dataset statistics dictionary
|
|
|
|
Returns:
|
|
True if quantile statistics are present, False otherwise
|
|
"""
|
|
if quantile_list_keys is None:
|
|
quantile_list_keys = [f"q{int(q * 100):02d}" for q in DEFAULT_QUANTILES]
|
|
|
|
if stats is None:
|
|
return False
|
|
|
|
for feature_stats in stats.values():
|
|
if any(q_key in feature_stats for q_key in quantile_list_keys):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def process_single_episode(dataset: LeRobotDataset, episode_idx: int) -> dict:
|
|
"""Process a single episode and return its statistics.
|
|
|
|
Args:
|
|
dataset: The LeRobot dataset
|
|
episode_idx: Index of the episode to process
|
|
|
|
Returns:
|
|
Dictionary containing episode statistics
|
|
"""
|
|
logging.info(f"Computing stats for episode {episode_idx}")
|
|
|
|
start_idx = dataset.meta.episodes[episode_idx]["dataset_from_index"]
|
|
end_idx = dataset.meta.episodes[episode_idx]["dataset_to_index"]
|
|
|
|
collected_data: dict[str, list] = {}
|
|
for idx in range(start_idx, end_idx):
|
|
item = dataset[idx]
|
|
for key, value in item.items():
|
|
if key not in dataset.features:
|
|
continue
|
|
|
|
if key not in collected_data:
|
|
collected_data[key] = []
|
|
collected_data[key].append(value)
|
|
|
|
ep_stats = {}
|
|
for key, data_list in collected_data.items():
|
|
if dataset.features[key]["dtype"] == "string":
|
|
continue
|
|
|
|
data = torch.stack(data_list).cpu().numpy()
|
|
if dataset.features[key]["dtype"] in ["image", "video"]:
|
|
if data.dtype == np.uint8:
|
|
data = data.astype(np.float32) / 255.0
|
|
|
|
axes_to_reduce = (0, 2, 3)
|
|
keepdims = True
|
|
else:
|
|
axes_to_reduce = 0
|
|
keepdims = data.ndim == 1
|
|
|
|
ep_stats[key] = get_feature_stats(
|
|
data, axis=axes_to_reduce, keepdims=keepdims, quantile_list=DEFAULT_QUANTILES
|
|
)
|
|
|
|
if dataset.features[key]["dtype"] in ["image", "video"]:
|
|
ep_stats[key] = {
|
|
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
|
}
|
|
|
|
return ep_stats
|
|
|
|
|
|
def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dict]:
|
|
"""Compute quantile statistics for all episodes in the dataset.
|
|
|
|
Args:
|
|
dataset: The LeRobot dataset to compute statistics for
|
|
|
|
Returns:
|
|
Dictionary containing aggregated statistics with quantiles
|
|
|
|
Note:
|
|
Video decoding operations are not thread-safe, so we process episodes sequentially
|
|
when video keys are present. For datasets without videos, we use parallel processing
|
|
with ThreadPoolExecutor for better performance.
|
|
"""
|
|
logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes")
|
|
|
|
episode_stats_list = []
|
|
has_videos = len(dataset.meta.video_keys) > 0
|
|
|
|
if has_videos:
|
|
logging.info("Dataset contains video keys - using sequential processing for thread safety")
|
|
for episode_idx in tqdm(range(dataset.num_episodes), desc="Processing episodes"):
|
|
ep_stats = process_single_episode(dataset, episode_idx)
|
|
episode_stats_list.append(ep_stats)
|
|
else:
|
|
logging.info("Dataset has no video keys - using parallel processing for better performance")
|
|
max_workers = min(dataset.num_episodes, 16)
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
future_to_episode = {
|
|
executor.submit(process_single_episode, dataset, episode_idx): episode_idx
|
|
for episode_idx in range(dataset.num_episodes)
|
|
}
|
|
|
|
episode_results = {}
|
|
with tqdm(total=dataset.num_episodes, desc="Processing episodes") as pbar:
|
|
for future in concurrent.futures.as_completed(future_to_episode):
|
|
episode_idx = future_to_episode[future]
|
|
ep_stats = future.result()
|
|
episode_results[episode_idx] = ep_stats
|
|
pbar.update(1)
|
|
|
|
for episode_idx in range(dataset.num_episodes):
|
|
if episode_idx in episode_results:
|
|
episode_stats_list.append(episode_results[episode_idx])
|
|
|
|
if not episode_stats_list:
|
|
raise ValueError("No episode data found for computing statistics")
|
|
|
|
logging.info(f"Aggregating statistics from {len(episode_stats_list)} episodes")
|
|
return aggregate_stats(episode_stats_list)
|
|
|
|
|
|
def augment_dataset_with_quantile_stats(
|
|
repo_id: str,
|
|
root: str | Path | None = None,
|
|
overwrite: bool = False,
|
|
) -> None:
|
|
"""Augment a dataset with quantile statistics if they are missing.
|
|
|
|
Args:
|
|
repo_id: Repository ID of the dataset
|
|
root: Local root directory for the dataset
|
|
overwrite: Overwrite existing quantile statistics if they already exist
|
|
"""
|
|
logging.info(f"Loading dataset: {repo_id}")
|
|
dataset = LeRobotDataset(
|
|
repo_id=repo_id,
|
|
root=root,
|
|
)
|
|
|
|
if not overwrite and has_quantile_stats(dataset.meta.stats):
|
|
logging.info("Dataset already contains quantile statistics. No action needed.")
|
|
return
|
|
|
|
logging.info("Dataset does not contain quantile statistics. Computing them now...")
|
|
|
|
new_stats = compute_quantile_stats_for_dataset(dataset)
|
|
|
|
logging.info("Updating dataset metadata with new quantile statistics")
|
|
dataset.meta.stats = new_stats
|
|
|
|
write_stats(new_stats, dataset.meta.root)
|
|
|
|
logging.info("Successfully updated dataset with quantile statistics")
|
|
dataset.push_to_hub()
|
|
|
|
hub_api = HfApi()
|
|
try:
|
|
hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
|
except HTTPError as e:
|
|
logging.info(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
|
|
pass
|
|
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=None, repo_type="dataset")
|
|
|
|
|
|
def main():
|
|
"""Main function to run the augmentation script."""
|
|
parser = argparse.ArgumentParser(description="Augment LeRobot dataset with quantile statistics")
|
|
|
|
parser.add_argument(
|
|
"--repo-id",
|
|
type=str,
|
|
required=True,
|
|
help="Repository ID of the dataset (e.g., 'lerobot/pusht')",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--root",
|
|
type=str,
|
|
help="Local root directory for the dataset",
|
|
)
|
|
parser.add_argument(
|
|
"--overwrite",
|
|
action="store_true",
|
|
help="Overwrite existing quantile statistics if they already exist",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
root = Path(args.root) if args.root else None
|
|
|
|
init_logging()
|
|
|
|
augment_dataset_with_quantile_stats(
|
|
repo_id=args.repo_id,
|
|
root=root,
|
|
overwrite=args.overwrite,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|