mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
Migrate PR to new folder structure introduce on 1417
This commit is contained in:
@@ -5,9 +5,9 @@ from pathlib import Path
|
||||
import pandas as pd
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import (
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
@@ -17,6 +17,7 @@ from lerobot.common.datasets.utils import (
|
||||
concat_video_files,
|
||||
get_parquet_file_size_in_mb,
|
||||
get_video_size_in_mb,
|
||||
to_parquet_with_hf_images,
|
||||
update_chunk_file_indices,
|
||||
write_info,
|
||||
write_stats,
|
||||
@@ -98,11 +99,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
||||
]
|
||||
)
|
||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||
<<<<<<< HEAD:src/lerobot/datasets/aggregate.py
|
||||
video_keys = [k for k, v in features.items() if v["dtype"] == "video"]
|
||||
=======
|
||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||
>>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py
|
||||
|
||||
# Initialize output dataset metadata
|
||||
dst_meta = LeRobotDatasetMetadata.create(
|
||||
@@ -125,16 +122,14 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
||||
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
|
||||
}
|
||||
|
||||
dst_meta.episodes = {}
|
||||
|
||||
# Process each dataset
|
||||
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx)
|
||||
data_idx = aggregate_data(src_meta, dst_meta, data_idx)
|
||||
<<<<<<< HEAD:src/lerobot/datasets/aggregate.py
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, video_keys)
|
||||
=======
|
||||
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
||||
>>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py
|
||||
|
||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||
@@ -210,15 +205,9 @@ def aggregate_videos(src_meta, dst_meta, videos_idx):
|
||||
file_idx,
|
||||
)
|
||||
|
||||
<<<<<<< HEAD:src/lerobot/datasets/aggregate.py
|
||||
# Update the video index tracking
|
||||
video_idx["chunk_idx"] = chunk_idx
|
||||
video_idx["file_idx"] = file_idx
|
||||
=======
|
||||
# Update the videos_idx with the final chunk and file indices for this key
|
||||
videos_idx[key]["chunk"] = chunk_idx
|
||||
videos_idx[key]["file"] = file_idx
|
||||
>>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py
|
||||
|
||||
return videos_idx
|
||||
|
||||
@@ -237,22 +226,15 @@ def aggregate_data(src_meta, dst_meta, data_idx):
|
||||
df = pd.read_parquet(src_path)
|
||||
df = update_data_df(df, src_meta, dst_meta)
|
||||
|
||||
dst_path = aggr_root / DEFAULT_DATA_PATH.format(
|
||||
chunk_index=data_idx["chunk"], file_index=data_idx["file"]
|
||||
)
|
||||
data_idx = write_parquet_safely(
|
||||
data_idx = append_or_create_parquet_file(
|
||||
df,
|
||||
src_path,
|
||||
dst_path,
|
||||
data_idx,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_PATH,
|
||||
<<<<<<< HEAD:src/lerobot/datasets/aggregate.py
|
||||
=======
|
||||
contains_images=len(dst_meta.image_keys) > 0,
|
||||
aggr_root=dst_meta.root,
|
||||
>>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py
|
||||
)
|
||||
|
||||
return data_idx
|
||||
@@ -282,17 +264,9 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
for k in videos_idx:
|
||||
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
||||
|
||||
<<<<<<< HEAD:src/lerobot/datasets/aggregate.py
|
||||
dst_path = dst_meta.root / DEFAULT_EPISODES_PATH.format(
|
||||
chunk_index=meta_idx["chunk"], file_index=meta_idx["file"]
|
||||
)
|
||||
write_parquet_safely(
|
||||
=======
|
||||
meta_idx = append_or_create_parquet_file(
|
||||
>>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py
|
||||
df,
|
||||
src_path,
|
||||
dst_path,
|
||||
meta_idx,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
@@ -304,19 +278,15 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
return meta_idx
|
||||
|
||||
|
||||
def write_parquet_safely(
|
||||
def append_or_create_parquet_file(
|
||||
df: pd.DataFrame,
|
||||
src_path: Path,
|
||||
dst_path: Path,
|
||||
idx: dict[str, int],
|
||||
max_mb: float,
|
||||
chunk_size: int,
|
||||
default_path: str,
|
||||
<<<<<<< HEAD:src/lerobot/datasets/aggregate.py
|
||||
=======
|
||||
contains_images: bool = False,
|
||||
aggr_root: Path = None,
|
||||
>>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py
|
||||
):
|
||||
"""
|
||||
Safely appends or creates a Parquet file at dst_path based on size constraints.
|
||||
@@ -324,7 +294,6 @@ def write_parquet_safely(
|
||||
Parameters:
|
||||
df (pd.DataFrame): Data to write.
|
||||
src_path (Path): Path to source file (used to get size).
|
||||
dst_path (Path): Target path for writing.
|
||||
idx (dict): Dictionary containing 'chunk' and 'file' indices.
|
||||
max_mb (float): Maximum allowed file size in MB.
|
||||
chunk_size (int): Maximum number of files per chunk.
|
||||
@@ -333,11 +302,8 @@ def write_parquet_safely(
|
||||
Returns:
|
||||
dict: Updated index dictionary.
|
||||
"""
|
||||
<<<<<<< HEAD:src/lerobot/datasets/aggregate.py
|
||||
=======
|
||||
# Initial destination path - use the correct default_path parameter
|
||||
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||
>>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py
|
||||
|
||||
# If destination file doesn't exist, just write the new one
|
||||
if not dst_path.exists():
|
||||
@@ -357,14 +323,6 @@ def write_parquet_safely(
|
||||
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
||||
new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
<<<<<<< HEAD:src/lerobot/datasets/aggregate.py
|
||||
df.to_parquet(new_path)
|
||||
else:
|
||||
# Append to existing file
|
||||
existing_df = pd.read_parquet(dst_path)
|
||||
combined_df = pd.concat([existing_df, df], ignore_index=True)
|
||||
combined_df.to_parquet(dst_path)
|
||||
=======
|
||||
final_df = df
|
||||
target_path = new_path
|
||||
else:
|
||||
@@ -377,7 +335,6 @@ def write_parquet_safely(
|
||||
to_parquet_with_hf_images(final_df, target_path)
|
||||
else:
|
||||
final_df.to_parquet(target_path)
|
||||
>>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py
|
||||
|
||||
return idx
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,954 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import packaging.version
|
||||
import pandas
|
||||
import pandas as pd
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from datasets import Dataset, concatenate_datasets
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.datasets.backward_compatibility import (
|
||||
V21_MESSAGE,
|
||||
BackwardCompatibilityError,
|
||||
ForwardCompatibilityError,
|
||||
)
|
||||
from lerobot.robots import Robot
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 500 # Max size per file
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
|
||||
EPISODES_DIR = "meta/episodes"
|
||||
DATA_DIR = "data"
|
||||
VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
# Metadata will go there
|
||||
---
|
||||
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
## {}
|
||||
|
||||
"""
|
||||
|
||||
DEFAULT_FEATURES = {
|
||||
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
|
||||
def get_parquet_file_size_in_mb(parquet_path):
|
||||
metadata = pq.read_metadata(parquet_path)
|
||||
total_uncompressed_size = 0
|
||||
for row_group in range(metadata.num_row_groups):
|
||||
rg_metadata = metadata.row_group(row_group)
|
||||
for column in range(rg_metadata.num_columns):
|
||||
col_metadata = rg_metadata.column(column)
|
||||
total_uncompressed_size += col_metadata.total_uncompressed_size
|
||||
return total_uncompressed_size / (1024**2)
|
||||
|
||||
|
||||
def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
|
||||
return hf_ds.data.nbytes / (1024**2)
|
||||
|
||||
|
||||
def get_pd_dataframe_size_in_mb(df: pandas.DataFrame) -> int:
|
||||
# TODO(rcadene): unused?
|
||||
memory_usage_bytes = df.memory_usage(deep=True).sum()
|
||||
return memory_usage_bytes / (1024**2)
|
||||
|
||||
|
||||
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int):
|
||||
if file_idx == chunks_size - 1:
|
||||
file_idx = 0
|
||||
chunk_idx += 1
|
||||
else:
|
||||
file_idx += 1
|
||||
return chunk_idx, file_idx
|
||||
|
||||
|
||||
def load_nested_dataset(pq_dir: Path) -> Dataset:
|
||||
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
|
||||
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
|
||||
Concatenate all pyarrow references to return HF Dataset format
|
||||
"""
|
||||
paths = sorted(pq_dir.glob("*/*.parquet"))
|
||||
if len(paths) == 0:
|
||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||
|
||||
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
|
||||
datasets = [Dataset.from_parquet(str(path)) for path in paths]
|
||||
return concatenate_datasets(datasets)
|
||||
|
||||
|
||||
def get_parquet_num_frames(parquet_path):
|
||||
metadata = pq.read_metadata(parquet_path)
|
||||
return metadata.num_rows
|
||||
|
||||
|
||||
def get_video_size_in_mb(mp4_path: Path):
|
||||
file_size_bytes = mp4_path.stat().st_size
|
||||
file_size_mb = file_size_bytes / (1024**2)
|
||||
return file_size_mb
|
||||
|
||||
|
||||
def concat_video_files(paths_to_cat: list[Path], root: Path, video_key: str, chunk_idx: int, file_idx: int):
|
||||
# TODO(rcadene): move to video_utils.py
|
||||
# TODO(rcadene): add docstring
|
||||
tmp_dir = Path(tempfile.mkdtemp(dir=root))
|
||||
# Create a text file with the list of files to concatenate
|
||||
path_concat_video_files = tmp_dir / "concat_video_files.txt"
|
||||
with open(path_concat_video_files, "w") as f:
|
||||
for ep_path in paths_to_cat:
|
||||
f.write(f"file '{str(ep_path)}'\n")
|
||||
|
||||
path_tmp_output = tmp_dir / "tmp_output.mp4"
|
||||
command = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-f",
|
||||
"concat",
|
||||
"-safe",
|
||||
"0",
|
||||
"-i",
|
||||
str(path_concat_video_files),
|
||||
"-c",
|
||||
"copy",
|
||||
str(path_tmp_output),
|
||||
]
|
||||
subprocess.run(command, check=True)
|
||||
|
||||
output_path = root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(path_tmp_output), str(output_path))
|
||||
shutil.rmtree(str(tmp_dir))
|
||||
|
||||
|
||||
def get_video_duration_in_s(mp4_file: Path):
|
||||
# TODO(rcadene): move to video_utils.py
|
||||
command = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-show_entries",
|
||||
"format=duration",
|
||||
"-of",
|
||||
"default=noprint_wrappers=1:nokey=1",
|
||||
str(mp4_file),
|
||||
]
|
||||
result = subprocess.run(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
return float(result.stdout)
|
||||
|
||||
|
||||
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
|
||||
For example:
|
||||
```
|
||||
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
|
||||
>>> print(flatten_dict(dct))
|
||||
{"a/b": 1, "a/c/d": 2, "e": 3}
|
||||
"""
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
return dict(items)
|
||||
|
||||
|
||||
def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
outdict = {}
|
||||
for key, value in d.items():
|
||||
parts = key.split(sep)
|
||||
d = outdict
|
||||
for part in parts[:-1]:
|
||||
if part not in d:
|
||||
d[part] = {}
|
||||
d = d[part]
|
||||
d[parts[-1]] = value
|
||||
return outdict
|
||||
|
||||
|
||||
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
serialized_dict = {}
|
||||
for key, value in flatten_dict(stats).items():
|
||||
if isinstance(value, (torch.Tensor, np.ndarray)):
|
||||
serialized_dict[key] = value.tolist()
|
||||
elif isinstance(value, list) and isinstance(value[0], (int, float, list)):
|
||||
serialized_dict[key] = value
|
||||
elif isinstance(value, np.generic):
|
||||
serialized_dict[key] = value.item()
|
||||
elif isinstance(value, (int, float)):
|
||||
serialized_dict[key] = value
|
||||
else:
|
||||
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
|
||||
return unflatten_dict(serialized_dict)
|
||||
|
||||
|
||||
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
# Embed image bytes into the table before saving to parquet
|
||||
format = dataset.format
|
||||
dataset = dataset.with_format("arrow")
|
||||
dataset = dataset.map(embed_table_storage, batched=False)
|
||||
dataset = dataset.with_format(**format)
|
||||
return dataset
|
||||
|
||||
|
||||
def load_json(fpath: Path) -> Any:
|
||||
with open(fpath) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def write_info(info: dict, local_dir: Path):
|
||||
write_json(info, local_dir / INFO_PATH)
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
info = load_json(local_dir / INFO_PATH)
|
||||
for ft in info["features"].values():
|
||||
ft["shape"] = tuple(ft["shape"])
|
||||
return info
|
||||
|
||||
|
||||
def write_stats(stats: dict, local_dir: Path):
|
||||
serialized_stats = serialize_dict(stats)
|
||||
write_json(serialized_stats, local_dir / STATS_PATH)
|
||||
|
||||
|
||||
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
|
||||
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
||||
if not (local_dir / STATS_PATH).exists():
|
||||
return None
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
return cast_stats_to_numpy(stats)
|
||||
|
||||
|
||||
def write_hf_dataset(hf_dataset: Dataset, local_dir: Path):
|
||||
if get_hf_dataset_size_in_mb(hf_dataset) > DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||
raise NotImplementedError("Contact a maintainer.")
|
||||
|
||||
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
hf_dataset.to_parquet(path)
|
||||
|
||||
|
||||
def write_tasks(tasks: pandas.DataFrame, local_dir: Path):
|
||||
path = local_dir / DEFAULT_TASKS_PATH
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tasks.to_parquet(path)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path):
|
||||
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
||||
return tasks
|
||||
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path):
|
||||
if get_hf_dataset_size_in_mb(episodes) > DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||
raise NotImplementedError("Contact a maintainer.")
|
||||
|
||||
fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
episodes.to_parquet(fpath)
|
||||
|
||||
|
||||
def load_episodes(local_dir: Path) -> datasets.Dataset:
|
||||
episodes = load_nested_dataset(local_dir / EPISODES_DIR)
|
||||
# Select episode features/columns containing references to episode data and videos
|
||||
# (e.g. tasks, dataset_from_index, dataset_to_index, data/chunk_index, data/file_index, etc.)
|
||||
# This is to speedup access to these data, instead of having to load episode stats.
|
||||
episodes = episodes.select_columns([key for key in episodes.features if not key.startswith("stats/")])
|
||||
return episodes
|
||||
|
||||
|
||||
def backward_compatible_episodes_stats(
|
||||
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
|
||||
) -> dict[str, dict[str, np.ndarray]]:
|
||||
return dict.fromkeys(episodes, stats)
|
||||
|
||||
|
||||
def load_image_as_numpy(
|
||||
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
|
||||
) -> np.ndarray:
|
||||
img = PILImage.open(fpath).convert("RGB")
|
||||
img_array = np.array(img, dtype=dtype)
|
||||
if channel_first: # (H, W, C) -> (C, H, W)
|
||||
img_array = np.transpose(img_array, (2, 0, 1))
|
||||
if np.issubdtype(dtype, np.floating):
|
||||
img_array /= 255.0
|
||||
return img_array
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
a channel last representation (h w c) of uint8 type, to a torch image representation
|
||||
with channel first (c h w) of float32 type in range [0,1].
|
||||
"""
|
||||
for key in items_dict:
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
return items_dict
|
||||
|
||||
|
||||
def is_valid_version(version: str) -> bool:
|
||||
try:
|
||||
packaging.version.parse(version)
|
||||
return True
|
||||
except packaging.version.InvalidVersion:
|
||||
return False
|
||||
|
||||
|
||||
def check_version_compatibility(
|
||||
repo_id: str,
|
||||
version_to_check: str | packaging.version.Version,
|
||||
current_version: str | packaging.version.Version,
|
||||
enforce_breaking_major: bool = True,
|
||||
) -> None:
|
||||
v_check = (
|
||||
packaging.version.parse(version_to_check)
|
||||
if not isinstance(version_to_check, packaging.version.Version)
|
||||
else version_to_check
|
||||
)
|
||||
v_current = (
|
||||
packaging.version.parse(current_version)
|
||||
if not isinstance(current_version, packaging.version.Version)
|
||||
else current_version
|
||||
)
|
||||
if v_check.major < v_current.major and enforce_breaking_major:
|
||||
raise BackwardCompatibilityError(repo_id, v_check)
|
||||
elif v_check.minor < v_current.minor:
|
||||
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
|
||||
|
||||
|
||||
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
||||
"""Returns available valid versions (branches and tags) on given repo."""
|
||||
api = HfApi()
|
||||
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
|
||||
repo_versions = []
|
||||
for ref in repo_refs:
|
||||
with contextlib.suppress(packaging.version.InvalidVersion):
|
||||
repo_versions.append(packaging.version.parse(ref))
|
||||
|
||||
return repo_versions
|
||||
|
||||
|
||||
def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str:
|
||||
"""
|
||||
Returns the version if available on repo or the latest compatible one.
|
||||
Otherwise, will throw a `CompatibilityError`.
|
||||
"""
|
||||
target_version = (
|
||||
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
|
||||
)
|
||||
hub_versions = get_repo_versions(repo_id)
|
||||
|
||||
if not hub_versions:
|
||||
raise RevisionNotFoundError(
|
||||
f"""Your dataset must be tagged with a codebase version.
|
||||
Assuming _version_ is the codebase_version value in the info.json, you can run this:
|
||||
```python
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset")
|
||||
```
|
||||
"""
|
||||
)
|
||||
|
||||
if target_version in hub_versions:
|
||||
return f"v{target_version}"
|
||||
|
||||
compatibles = [
|
||||
v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
|
||||
]
|
||||
if compatibles:
|
||||
return_version = max(compatibles)
|
||||
if return_version < target_version:
|
||||
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
|
||||
return f"v{return_version}"
|
||||
|
||||
lower_major = [v for v in hub_versions if v.major < target_version.major]
|
||||
if lower_major:
|
||||
raise BackwardCompatibilityError(repo_id, max(lower_major))
|
||||
|
||||
upper_versions = [v for v in hub_versions if v > target_version]
|
||||
assert len(upper_versions) > 0
|
||||
raise ForwardCompatibilityError(repo_id, min(upper_versions))
|
||||
|
||||
|
||||
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
elif ft["shape"] == (1,):
|
||||
hf_features[key] = datasets.Value(dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 1:
|
||||
hf_features[key] = datasets.Sequence(
|
||||
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
||||
)
|
||||
elif len(ft["shape"]) == 2:
|
||||
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 3:
|
||||
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 4:
|
||||
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 5:
|
||||
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
|
||||
else:
|
||||
raise ValueError(f"Corresponding feature is not valid: {ft}")
|
||||
|
||||
return datasets.Features(hf_features)
|
||||
|
||||
|
||||
def _validate_feature_names(features: dict[str, dict]) -> None:
|
||||
invalid_features = {name: ft for name, ft in features.items() if "/" in name}
|
||||
if invalid_features:
|
||||
raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
|
||||
|
||||
|
||||
def hw_to_dataset_features(
|
||||
hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
|
||||
) -> dict[str, dict]:
|
||||
features = {}
|
||||
joint_fts = {key: ftype for key, ftype in hw_features.items() if ftype is float}
|
||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
||||
|
||||
if joint_fts and prefix == "action":
|
||||
features[prefix] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(joint_fts),),
|
||||
"names": list(joint_fts),
|
||||
}
|
||||
|
||||
if joint_fts and prefix == "observation":
|
||||
features[f"{prefix}.state"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(joint_fts),),
|
||||
"names": list(joint_fts),
|
||||
}
|
||||
|
||||
for key, shape in cam_fts.items():
|
||||
features[f"{prefix}.images.{key}"] = {
|
||||
"dtype": "video" if use_video else "image",
|
||||
"shape": shape,
|
||||
"names": ["height", "width", "channels"],
|
||||
}
|
||||
|
||||
_validate_feature_names(features)
|
||||
return features
|
||||
|
||||
|
||||
def build_dataset_frame(
|
||||
ds_features: dict[str, dict], values: dict[str, Any], prefix: str
|
||||
) -> dict[str, np.ndarray]:
|
||||
frame = {}
|
||||
for key, ft in ds_features.items():
|
||||
if key in DEFAULT_FEATURES or not key.startswith(prefix):
|
||||
continue
|
||||
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
|
||||
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
||||
# TODO(rcadene): add fps for each feature
|
||||
camera_ft = {}
|
||||
if robot.cameras:
|
||||
camera_ft = {
|
||||
key: {"dtype": "video" if use_videos else "image", **ft}
|
||||
for key, ft in robot.camera_features.items()
|
||||
}
|
||||
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
|
||||
|
||||
|
||||
def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
# TODO(aliberts): Implement "type" in dataset features and simplify this
|
||||
policy_features = {}
|
||||
for key, ft in features.items():
|
||||
shape = ft["shape"]
|
||||
if ft["dtype"] in ["image", "video"]:
|
||||
type = FeatureType.VISUAL
|
||||
if len(shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
||||
|
||||
names = ft["names"]
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif key == "observation.environment_state":
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith("observation"):
|
||||
type = FeatureType.STATE
|
||||
elif key.startswith("action"):
|
||||
type = FeatureType.ACTION
|
||||
else:
|
||||
continue
|
||||
|
||||
policy_features[key] = PolicyFeature(
|
||||
type=type,
|
||||
shape=shape,
|
||||
)
|
||||
|
||||
return policy_features
|
||||
|
||||
|
||||
def create_empty_dataset_info(
|
||||
codebase_version: str,
|
||||
fps: int,
|
||||
features: dict,
|
||||
use_videos: bool,
|
||||
robot_type: str | None = None,
|
||||
) -> dict:
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": 0,
|
||||
"total_frames": 0,
|
||||
"total_tasks": 0,
|
||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||
"data_files_size_in_mb": DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
"video_files_size_in_mb": DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": DEFAULT_DATA_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
|
||||
def check_timestamps_sync(
|
||||
timestamps: np.ndarray,
|
||||
episode_indices: np.ndarray,
|
||||
episode_data_index: dict[str, np.ndarray],
|
||||
fps: int,
|
||||
tolerance_s: float,
|
||||
raise_value_error: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
|
||||
to account for possible numerical error.
|
||||
|
||||
Args:
|
||||
timestamps (np.ndarray): Array of timestamps in seconds.
|
||||
episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
|
||||
episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
|
||||
which identifies indices for the end of each episode.
|
||||
fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
|
||||
tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
|
||||
raise_value_error (bool): Whether to raise a ValueError if the check fails.
|
||||
|
||||
Returns:
|
||||
bool: True if all checked timestamp differences lie within tolerance, False otherwise.
|
||||
|
||||
Raises:
|
||||
ValueError: If the check fails and `raise_value_error` is True.
|
||||
"""
|
||||
if timestamps.shape != episode_indices.shape:
|
||||
raise ValueError(
|
||||
"timestamps and episode_indices should have the same shape. "
|
||||
f"Found {timestamps.shape=} and {episode_indices.shape=}."
|
||||
)
|
||||
|
||||
# Consecutive differences
|
||||
diffs = np.diff(timestamps)
|
||||
within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s
|
||||
|
||||
# Mask to ignore differences at the boundaries between episodes
|
||||
mask = np.ones(len(diffs), dtype=bool)
|
||||
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
|
||||
mask[ignored_diffs] = False
|
||||
filtered_within_tolerance = within_tolerance[mask]
|
||||
|
||||
# Check if all remaining diffs are within tolerance
|
||||
if not np.all(filtered_within_tolerance):
|
||||
# Track original indices before masking
|
||||
original_indices = np.arange(len(diffs))
|
||||
filtered_indices = original_indices[mask]
|
||||
outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0]
|
||||
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
||||
|
||||
outside_tolerances = []
|
||||
for idx in outside_tolerance_indices:
|
||||
entry = {
|
||||
"timestamps": [timestamps[idx], timestamps[idx + 1]],
|
||||
"diff": diffs[idx],
|
||||
"episode_index": episode_indices[idx].item()
|
||||
if hasattr(episode_indices[idx], "item")
|
||||
else episode_indices[idx],
|
||||
}
|
||||
outside_tolerances.append(entry)
|
||||
|
||||
if raise_value_error:
|
||||
raise ValueError(
|
||||
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
|
||||
This might be due to synchronization issues during data collection.
|
||||
\n{pformat(outside_tolerances)}"""
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def check_delta_timestamps(
|
||||
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||
) -> bool:
|
||||
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
|
||||
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
|
||||
actual timestamps from the dataset.
|
||||
"""
|
||||
outside_tolerance = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
|
||||
if not all(within_tolerance):
|
||||
outside_tolerance[key] = [
|
||||
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
|
||||
]
|
||||
|
||||
if len(outside_tolerance) > 0:
|
||||
if raise_value_error:
|
||||
raise ValueError(
|
||||
f"""
|
||||
The following delta_timestamps are found outside of tolerance range.
|
||||
Please make sure they are multiples of 1/{fps} +/- tolerance and adjust
|
||||
their values accordingly.
|
||||
\n{pformat(outside_tolerance)}
|
||||
"""
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
|
||||
delta_indices = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
delta_indices[key] = [round(d * fps) for d in delta_ts]
|
||||
|
||||
return delta_indices
|
||||
|
||||
|
||||
def cycle(iterable):
|
||||
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
|
||||
|
||||
See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe.
|
||||
"""
|
||||
iterator = iter(iterable)
|
||||
while True:
|
||||
try:
|
||||
yield next(iterator)
|
||||
except StopIteration:
|
||||
iterator = iter(iterable)
|
||||
|
||||
|
||||
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
|
||||
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
|
||||
exists before creating it.
|
||||
"""
|
||||
api = HfApi()
|
||||
|
||||
branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches
|
||||
refs = [branch.ref for branch in branches]
|
||||
ref = f"refs/heads/{branch}"
|
||||
if ref in refs:
|
||||
api.delete_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
|
||||
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
|
||||
|
||||
def create_lerobot_dataset_card(
|
||||
tags: list | None = None,
|
||||
dataset_info: dict | None = None,
|
||||
**kwargs,
|
||||
) -> DatasetCard:
|
||||
"""
|
||||
Keyword arguments will be used to replace values in ./lerobot/datasets/card_template.md.
|
||||
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
|
||||
"""
|
||||
card_tags = ["LeRobot"]
|
||||
|
||||
if tags:
|
||||
card_tags += tags
|
||||
if dataset_info:
|
||||
dataset_structure = "[meta/info.json](meta/info.json):\n"
|
||||
dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n"
|
||||
kwargs = {**kwargs, "dataset_structure": dataset_structure}
|
||||
card_data = DatasetCardData(
|
||||
license=kwargs.get("license"),
|
||||
tags=card_tags,
|
||||
task_categories=["robotics"],
|
||||
configs=[
|
||||
{
|
||||
"config_name": "default",
|
||||
"data_files": "data/*/*.parquet",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
card_template = (importlib.resources.files("lerobot.datasets") / "card_template.md").read_text()
|
||||
|
||||
return DatasetCard.from_template(
|
||||
card_data=card_data,
|
||||
template_str=card_template,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class IterableNamespace(SimpleNamespace):
|
||||
"""
|
||||
A namespace object that supports both dictionary-like iteration and dot notation access.
|
||||
Automatically converts nested dictionaries into IterableNamespaces.
|
||||
|
||||
This class extends SimpleNamespace to provide:
|
||||
- Dictionary-style iteration over keys
|
||||
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
|
||||
- Dictionary-like methods: items(), keys(), values()
|
||||
- Recursive conversion of nested dictionaries
|
||||
|
||||
Args:
|
||||
dictionary: Optional dictionary to initialize the namespace
|
||||
**kwargs: Additional keyword arguments passed to SimpleNamespace
|
||||
|
||||
Examples:
|
||||
>>> data = {"name": "Alice", "details": {"age": 25}}
|
||||
>>> ns = IterableNamespace(data)
|
||||
>>> ns.name
|
||||
'Alice'
|
||||
>>> ns.details.age
|
||||
25
|
||||
>>> list(ns.keys())
|
||||
['name', 'details']
|
||||
>>> for key, value in ns.items():
|
||||
... print(f"{key}: {value}")
|
||||
name: Alice
|
||||
details: IterableNamespace(age=25)
|
||||
"""
|
||||
|
||||
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if dictionary is not None:
|
||||
for key, value in dictionary.items():
|
||||
if isinstance(value, dict):
|
||||
setattr(self, key, IterableNamespace(value))
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(vars(self))
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return vars(self)[key]
|
||||
|
||||
def items(self):
|
||||
return vars(self).items()
|
||||
|
||||
def values(self):
|
||||
return vars(self).values()
|
||||
|
||||
def keys(self):
|
||||
return vars(self).keys()
|
||||
|
||||
|
||||
def validate_frame(frame: dict, features: dict):
|
||||
expected_features = set(features) - set(DEFAULT_FEATURES)
|
||||
actual_features = set(frame)
|
||||
|
||||
# task is a special required field that's not part of regular features
|
||||
if "task" not in actual_features:
|
||||
raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n")
|
||||
|
||||
# Remove task from actual_features for regular feature validation
|
||||
actual_features_for_validation = actual_features - {"task"}
|
||||
|
||||
error_message = validate_features_presence(actual_features_for_validation, expected_features)
|
||||
|
||||
common_features = actual_features_for_validation & expected_features
|
||||
for name in common_features:
|
||||
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
||||
|
||||
if error_message:
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
def validate_features_presence(actual_features: set[str], expected_features: set[str]):
|
||||
error_message = ""
|
||||
missing_features = expected_features - actual_features
|
||||
extra_features = actual_features - expected_features
|
||||
|
||||
if missing_features or extra_features:
|
||||
error_message += "Feature mismatch in `frame` dictionary:\n"
|
||||
if missing_features:
|
||||
error_message += f"Missing features: {missing_features}\n"
|
||||
if extra_features:
|
||||
error_message += f"Extra features: {extra_features}\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
|
||||
expected_dtype = feature["dtype"]
|
||||
expected_shape = feature["shape"]
|
||||
if is_valid_numpy_dtype_string(expected_dtype):
|
||||
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||
elif expected_dtype in ["image", "video"]:
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
|
||||
def validate_feature_numpy_array(
|
||||
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
|
||||
):
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_dtype = value.dtype
|
||||
actual_shape = value.shape
|
||||
|
||||
if actual_dtype != np.dtype(expected_dtype):
|
||||
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
|
||||
|
||||
if actual_shape != expected_shape:
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
|
||||
else:
|
||||
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c, h, w = expected_shape
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
elif isinstance(value, PILImage.Image):
|
||||
pass
|
||||
else:
|
||||
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_string(name: str, value: str):
|
||||
if not isinstance(value, str):
|
||||
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
||||
return ""
|
||||
|
||||
|
||||
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
|
||||
if "size" not in episode_buffer:
|
||||
raise ValueError("size key not found in episode_buffer")
|
||||
|
||||
if "task" not in episode_buffer:
|
||||
raise ValueError("task key not found in episode_buffer")
|
||||
|
||||
if episode_buffer["episode_index"] != total_episodes:
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError(
|
||||
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
||||
"match the total number of episodes already in the dataset. This is not supported for now."
|
||||
)
|
||||
|
||||
if episode_buffer["size"] == 0:
|
||||
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
|
||||
|
||||
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
|
||||
if not buffer_keys == set(features):
|
||||
raise ValueError(
|
||||
f"Features from `episode_buffer` don't match the ones in `features`."
|
||||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||
)
|
||||
|
||||
|
||||
def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path):
|
||||
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
||||
This way, it can be loaded by HF dataset and correctly formatted images are returned.
|
||||
"""
|
||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
||||
2.1. It will:
|
||||
|
||||
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||
- Check consistency between these new stats and the old ones.
|
||||
- Remove the deprecated `stats.json`.
|
||||
- Update codebase_version in `info.json`.
|
||||
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \
|
||||
--repo-id=aliberts/koch_tutorial
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.datasets.utils import LEGACY_EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||
|
||||
V20 = "v2.0"
|
||||
V21 = "v2.1"
|
||||
|
||||
|
||||
class SuppressWarnings:
|
||||
def __enter__(self):
|
||||
self.previous_level = logging.getLogger().getEffectiveLevel()
|
||||
logging.getLogger().setLevel(logging.ERROR)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
logging.getLogger().setLevel(self.previous_level)
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
branch: str | None = None,
|
||||
num_workers: int = 4,
|
||||
):
|
||||
with SuppressWarnings():
|
||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||
|
||||
if (dataset.root / LEGACY_EPISODES_STATS_PATH).is_file():
|
||||
(dataset.root / LEGACY_EPISODES_STATS_PATH).unlink()
|
||||
|
||||
convert_stats(dataset, num_workers=num_workers)
|
||||
ref_stats = load_stats(dataset.root)
|
||||
check_aggregate_stats(dataset, ref_stats)
|
||||
|
||||
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
|
||||
write_info(dataset.meta.info, dataset.root)
|
||||
|
||||
dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/")
|
||||
|
||||
# delete old stats.json file
|
||||
if (dataset.root / STATS_PATH).is_file:
|
||||
(dataset.root / STATS_PATH).unlink()
|
||||
|
||||
hub_api = HfApi()
|
||||
if hub_api.file_exists(
|
||||
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repo branch to push your dataset. Defaults to the main branch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of workers for parallelizing stats compute. Defaults to 4.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_dataset(**vars(args))
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import legacy_write_episode_stats
|
||||
|
||||
|
||||
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||
ep_len = dataset.meta.episodes[episode_index]["length"]
|
||||
sampled_indices = sample_indices(ep_len)
|
||||
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
|
||||
video_frames = dataset._query_videos(query_timestamps, episode_index)
|
||||
return video_frames[ft_key].numpy()
|
||||
|
||||
|
||||
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
|
||||
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
|
||||
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
|
||||
|
||||
ep_stats = {}
|
||||
for key, ft in dataset.features.items():
|
||||
if ft["dtype"] == "video":
|
||||
# We sample only for videos
|
||||
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
|
||||
else:
|
||||
ep_ft_data = np.array(ep_data[key])
|
||||
|
||||
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
||||
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
||||
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
||||
|
||||
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
dataset.meta.episodes_stats[ep_idx] = ep_stats
|
||||
|
||||
|
||||
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
||||
assert dataset.episodes is None
|
||||
print("Computing episodes stats")
|
||||
total_episodes = dataset.meta.total_episodes
|
||||
if num_workers > 0:
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = {
|
||||
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
|
||||
for ep_idx in range(total_episodes)
|
||||
}
|
||||
for future in tqdm(as_completed(futures), total=total_episodes):
|
||||
future.result()
|
||||
else:
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
convert_episode_stats(dataset, ep_idx)
|
||||
|
||||
for ep_idx in tqdm(range(total_episodes)):
|
||||
legacy_write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
||||
|
||||
|
||||
def check_aggregate_stats(
|
||||
dataset: LeRobotDataset,
|
||||
reference_stats: dict[str, dict[str, np.ndarray]],
|
||||
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
|
||||
default_rtol_atol: tuple[float] = (5e-6, 6e-5),
|
||||
):
|
||||
"""Verifies that the aggregated stats from episodes_stats are close to reference stats."""
|
||||
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
|
||||
for key, ft in dataset.features.items():
|
||||
# These values might need some fine-tuning
|
||||
if ft["dtype"] == "video":
|
||||
# to account for image sub-sampling
|
||||
rtol, atol = video_rtol_atol
|
||||
else:
|
||||
rtol, atol = default_rtol_atol
|
||||
|
||||
for stat, val in agg_stats[key].items():
|
||||
if key in reference_stats and stat in reference_stats[key]:
|
||||
err_msg = f"feature='{key}' stats='{stat}'"
|
||||
np.testing.assert_allclose(
|
||||
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ This script will help you convert any LeRobot dataset already pushed to the hub
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||
python lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||
--repo-id=lerobot/pusht
|
||||
```
|
||||
|
||||
@@ -30,10 +30,10 @@ from datasets import Dataset, Features, Image
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from requests import HTTPError
|
||||
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.utils import (
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
|
||||
@@ -0,0 +1,433 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_safe_default_codec():
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
return "torchcodec"
|
||||
else:
|
||||
logging.warning(
|
||||
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
||||
)
|
||||
return "pyav"
|
||||
|
||||
|
||||
def decode_video_frames(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decodes video frames using the specified backend.
|
||||
|
||||
Args:
|
||||
video_path (Path): Path to the video file.
|
||||
timestamps (list[float]): List of timestamps to extract frames.
|
||||
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Decoded frames.
|
||||
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend is None:
|
||||
backend = get_safe_default_codec()
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||
elif backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
|
||||
def decode_video_frames_torchvision(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str = "pyav",
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated to the requested timestamps of a video
|
||||
|
||||
The backend can be either "pyav" (default) or "video_reader".
|
||||
"video_reader" requires installing torchvision from source, see:
|
||||
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
|
||||
(note that you need to compile against ffmpeg<4.3)
|
||||
|
||||
While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup.
|
||||
For more info on video decoding, see `benchmark/video/README.md`
|
||||
|
||||
See torchvision doc for more info on these two backends:
|
||||
https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
|
||||
|
||||
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
||||
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
||||
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||
"""
|
||||
video_path = str(video_path)
|
||||
|
||||
# set backend
|
||||
keyframes_only = False
|
||||
torchvision.set_video_backend(backend)
|
||||
if backend == "pyav":
|
||||
keyframes_only = True # pyav doesnt support accuracte seek
|
||||
|
||||
# set a video stream reader
|
||||
# TODO(rcadene): also load audio stream at the same time
|
||||
reader = torchvision.io.VideoReader(video_path, "video")
|
||||
|
||||
# set the first and last requested timestamps
|
||||
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
|
||||
# access closest key frame of the first requested frame
|
||||
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
||||
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
|
||||
reader.seek(first_ts, keyframes_only=keyframes_only)
|
||||
|
||||
# load all frames until last requested frame
|
||||
loaded_frames = []
|
||||
loaded_ts = []
|
||||
for frame in reader:
|
||||
current_ts = frame["pts"]
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
loaded_frames.append(frame["data"])
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
break
|
||||
|
||||
if backend == "pyav":
|
||||
reader.container.close()
|
||||
|
||||
reader = None
|
||||
|
||||
query_ts = torch.tensor(timestamps)
|
||||
loaded_ts = torch.tensor(loaded_ts)
|
||||
|
||||
# compute distances between each query timestamp and timestamps of all loaded frames
|
||||
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
# TODO(rcadene): remove torch.stack
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
closest_ts = loaded_ts[argmin_]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"{closest_ts=}")
|
||||
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
|
||||
assert len(timestamps) == len(closest_frames)
|
||||
return closest_frames
|
||||
|
||||
|
||||
def decode_video_frames_torchcodec(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
device: str = "cpu",
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
||||
|
||||
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
|
||||
|
||||
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
||||
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
||||
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||
"""
|
||||
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
else:
|
||||
raise ImportError("torchcodec is required but not available.")
|
||||
|
||||
# initialize video decoder
|
||||
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
|
||||
loaded_frames = []
|
||||
loaded_ts = []
|
||||
# get metadata for frame information
|
||||
metadata = decoder.metadata
|
||||
average_fps = metadata.average_fps
|
||||
|
||||
# convert timestamps to frame indices
|
||||
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
||||
|
||||
# retrieve frames based on indices
|
||||
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
||||
|
||||
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
|
||||
loaded_frames.append(frame)
|
||||
loaded_ts.append(pts.item())
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"Frame loaded at timestamp={pts:.4f}")
|
||||
|
||||
query_ts = torch.tensor(timestamps)
|
||||
loaded_ts = torch.tensor(loaded_ts)
|
||||
|
||||
# compute distances between each query timestamp and loaded timestamps
|
||||
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
closest_ts = loaded_ts[argmin_]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"{closest_ts=}")
|
||||
|
||||
# convert to float32 in [0,1] range (channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
|
||||
assert len(timestamps) == len(closest_frames)
|
||||
return closest_frames
|
||||
|
||||
|
||||
def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
fast_decode: int = 0,
|
||||
log_level: str | None = "quiet",
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ffmpeg_args = OrderedDict(
|
||||
[
|
||||
("-f", "image2"),
|
||||
("-r", str(fps)),
|
||||
("-i", str(imgs_dir / "frame-%06d.png")),
|
||||
("-vcodec", vcodec),
|
||||
("-pix_fmt", pix_fmt),
|
||||
]
|
||||
)
|
||||
|
||||
if g is not None:
|
||||
ffmpeg_args["-g"] = str(g)
|
||||
|
||||
if crf is not None:
|
||||
ffmpeg_args["-crf"] = str(crf)
|
||||
|
||||
if fast_decode:
|
||||
key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune"
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
ffmpeg_args[key] = value
|
||||
|
||||
if log_level is not None:
|
||||
ffmpeg_args["-loglevel"] = str(log_level)
|
||||
|
||||
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
|
||||
if overwrite:
|
||||
ffmpeg_args.append("-y")
|
||||
|
||||
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
|
||||
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
||||
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
||||
|
||||
if not video_path.exists():
|
||||
raise OSError(
|
||||
f"Video encoding did not work. File not found: {video_path}. "
|
||||
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrame:
|
||||
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
|
||||
"""
|
||||
Provides a type for a dataset containing video frames.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
|
||||
features = {"image": VideoFrame()}
|
||||
Dataset.from_dict(data_dict, features=Features(features))
|
||||
```
|
||||
"""
|
||||
|
||||
pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()})
|
||||
_type: str = field(default="VideoFrame", init=False, repr=False)
|
||||
|
||||
def __call__(self):
|
||||
return self.pa_type
|
||||
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
"'register_feature' is experimental and might be subject to breaking changes in the future.",
|
||||
category=UserWarning,
|
||||
)
|
||||
# to make VideoFrame available in HuggingFace `datasets`
|
||||
register_feature(VideoFrame, "VideoFrame")
|
||||
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
ffprobe_audio_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"a:0",
|
||||
"-show_entries",
|
||||
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
||||
if audio_stream_info is None:
|
||||
return {"has_audio": False}
|
||||
|
||||
# Return the information, defaulting to None if no audio stream is present
|
||||
return {
|
||||
"has_audio": True,
|
||||
"audio.channels": audio_stream_info.get("channels", None),
|
||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||
if audio_stream_info.get("sample_rate")
|
||||
else None,
|
||||
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
||||
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||
}
|
||||
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
ffprobe_video_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"v:0",
|
||||
"-show_entries",
|
||||
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
video_stream_info = info["streams"][0]
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
r_frame_rate = video_stream_info["r_frame_rate"]
|
||||
num, denom = map(int, r_frame_rate.split("/"))
|
||||
fps = num / denom
|
||||
|
||||
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
|
||||
|
||||
video_info = {
|
||||
"video.fps": fps,
|
||||
"video.height": video_stream_info["height"],
|
||||
"video.width": video_stream_info["width"],
|
||||
"video.channels": pixel_channels,
|
||||
"video.codec": video_stream_info["codec_name"],
|
||||
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||
"video.is_depth_map": False,
|
||||
**get_audio_info(video_path),
|
||||
}
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||
return 1
|
||||
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
||||
return 4
|
||||
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
||||
return 3
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_image_pixel_channels(image: Image):
|
||||
if image.mode == "L":
|
||||
return 1 # Grayscale
|
||||
elif image.mode == "LA":
|
||||
return 2 # Grayscale + Alpha
|
||||
elif image.mode == "RGB":
|
||||
return 3 # RGB
|
||||
elif image.mode == "RGBA":
|
||||
return 4 # RGBA
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
@@ -0,0 +1,384 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import platform
|
||||
import select
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from copy import copy, deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from statistics import mean
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def none_or_int(value):
|
||||
if value == "None":
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
|
||||
def inside_slurm():
|
||||
"""Check whether the python process was launched through slurm"""
|
||||
# TODO(rcadene): return False for interactive mode `--pty bash`
|
||||
return "SLURM_JOB_ID" in os.environ
|
||||
|
||||
|
||||
def auto_select_torch_device() -> torch.device:
|
||||
"""Tries to select automatically a torch device."""
|
||||
if torch.cuda.is_available():
|
||||
logging.info("Cuda backend detected, using cuda.")
|
||||
return torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
logging.info("Metal backend detected, using cuda.")
|
||||
return torch.device("mps")
|
||||
else:
|
||||
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
|
||||
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||
try_device = str(try_device)
|
||||
match try_device:
|
||||
case "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
device = torch.device("cuda")
|
||||
case "mps":
|
||||
assert torch.backends.mps.is_available()
|
||||
device = torch.device("mps")
|
||||
case "cpu":
|
||||
device = torch.device("cpu")
|
||||
if log:
|
||||
logging.warning("Using CPU, this will be slow.")
|
||||
case _:
|
||||
device = torch.device(try_device)
|
||||
if log:
|
||||
logging.warning(f"Using custom {try_device} device.")
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
|
||||
"""
|
||||
mps is currently not compatible with float64
|
||||
"""
|
||||
if isinstance(device, torch.device):
|
||||
device = device.type
|
||||
if device == "mps" and dtype == torch.float64:
|
||||
return torch.float32
|
||||
else:
|
||||
return dtype
|
||||
|
||||
|
||||
def is_torch_device_available(try_device: str) -> bool:
|
||||
try_device = str(try_device) # Ensure try_device is a string
|
||||
if try_device == "cuda":
|
||||
return torch.cuda.is_available()
|
||||
elif try_device == "mps":
|
||||
return torch.backends.mps.is_available()
|
||||
elif try_device == "cpu":
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
|
||||
|
||||
|
||||
def is_amp_available(device: str):
|
||||
if device in ["cuda", "cpu"]:
|
||||
return True
|
||||
elif device == "mps":
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f"Unknown device '{device}.")
|
||||
|
||||
|
||||
def init_logging(log_file: Path | None = None, display_pid: bool = False):
|
||||
def custom_format(record):
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
fnameline = f"{record.pathname}:{record.lineno}"
|
||||
|
||||
# NOTE: Display PID is useful for multi-process logging.
|
||||
if display_pid:
|
||||
pid_str = f"[PID: {os.getpid()}]"
|
||||
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
else:
|
||||
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
||||
return message
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
formatter = logging.Formatter()
|
||||
formatter.format = custom_format
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(console_handler)
|
||||
|
||||
if log_file is not None:
|
||||
# Additionally write logs to file
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_handler.setFormatter(formatter)
|
||||
logging.getLogger().addHandler(file_handler)
|
||||
|
||||
|
||||
def format_big_number(num, precision=0):
|
||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||
divisor = 1000.0
|
||||
|
||||
for suffix in suffixes:
|
||||
if abs(num) < divisor:
|
||||
return f"{num:.{precision}f}{suffix}"
|
||||
num /= divisor
|
||||
|
||||
return num
|
||||
|
||||
|
||||
def _relative_path_between(path1: Path, path2: Path) -> Path:
|
||||
"""Returns path1 relative to path2."""
|
||||
path1 = path1.absolute()
|
||||
path2 = path2.absolute()
|
||||
try:
|
||||
return path1.relative_to(path2)
|
||||
except ValueError: # most likely because path1 is not a subpath of path2
|
||||
common_parts = Path(osp.commonpath([path1, path2])).parts
|
||||
return Path(
|
||||
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
|
||||
)
|
||||
|
||||
|
||||
def print_cuda_memory_usage():
|
||||
"""Use this function to locate and debug memory leak."""
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
# Also clear the cache if you want to fully release the memory
|
||||
torch.cuda.empty_cache()
|
||||
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
|
||||
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
|
||||
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
|
||||
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
|
||||
|
||||
|
||||
def capture_timestamp_utc():
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def say(text: str, blocking: bool = False):
|
||||
system = platform.system()
|
||||
|
||||
if system == "Darwin":
|
||||
cmd = ["say", text]
|
||||
|
||||
elif system == "Linux":
|
||||
cmd = ["spd-say", text]
|
||||
if blocking:
|
||||
cmd.append("--wait")
|
||||
|
||||
elif system == "Windows":
|
||||
cmd = [
|
||||
"PowerShell",
|
||||
"-Command",
|
||||
"Add-Type -AssemblyName System.Speech; "
|
||||
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')",
|
||||
]
|
||||
|
||||
else:
|
||||
raise RuntimeError("Unsupported operating system for text-to-speech.")
|
||||
|
||||
if blocking:
|
||||
subprocess.run(cmd, check=True)
|
||||
else:
|
||||
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
|
||||
|
||||
|
||||
def log_say(text: str, play_sounds: bool = True, blocking: bool = False):
|
||||
logging.info(text)
|
||||
|
||||
if play_sounds:
|
||||
say(text, blocking)
|
||||
|
||||
|
||||
def get_channel_first_image_shape(image_shape: tuple) -> tuple:
|
||||
shape = copy(image_shape)
|
||||
if shape[2] < shape[0] and shape[2] < shape[1]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif not (shape[0] < shape[1] and shape[0] < shape[2]):
|
||||
raise ValueError(image_shape)
|
||||
|
||||
return shape
|
||||
|
||||
|
||||
def has_method(cls: object, method_name: str) -> bool:
|
||||
return hasattr(cls, method_name) and callable(getattr(cls, method_name))
|
||||
|
||||
|
||||
def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
|
||||
"""
|
||||
Return True if a given string can be converted to a numpy dtype.
|
||||
"""
|
||||
try:
|
||||
# Attempt to convert the string to a numpy dtype
|
||||
np.dtype(dtype_str)
|
||||
return True
|
||||
except TypeError:
|
||||
# If a TypeError is raised, the string is not a valid dtype
|
||||
return False
|
||||
|
||||
|
||||
def enter_pressed() -> bool:
|
||||
if platform.system() == "Windows":
|
||||
import msvcrt
|
||||
|
||||
if msvcrt.kbhit():
|
||||
key = msvcrt.getch()
|
||||
return key in (b"\r", b"\n") # enter key
|
||||
return False
|
||||
else:
|
||||
return select.select([sys.stdin], [], [], 0)[0] and sys.stdin.readline().strip() == ""
|
||||
|
||||
|
||||
def move_cursor_up(lines):
|
||||
"""Move the cursor up by a specified number of lines."""
|
||||
print(f"\033[{lines}A", end="")
|
||||
|
||||
|
||||
def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float):
|
||||
days = int(elapsed_time_s // (24 * 3600))
|
||||
elapsed_time_s %= 24 * 3600
|
||||
hours = int(elapsed_time_s // 3600)
|
||||
elapsed_time_s %= 3600
|
||||
minutes = int(elapsed_time_s // 60)
|
||||
seconds = elapsed_time_s % 60
|
||||
return days, hours, minutes, seconds
|
||||
|
||||
|
||||
class TimerManager:
|
||||
"""
|
||||
Lightweight utility to measure elapsed time.
|
||||
|
||||
Examples
|
||||
--------
|
||||
```python
|
||||
# Example 1: Using context manager
|
||||
timer = TimerManager("Policy", log=False)
|
||||
for _ in range(3):
|
||||
with timer:
|
||||
time.sleep(0.01)
|
||||
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
|
||||
```
|
||||
|
||||
```python
|
||||
# Example 2: Using start/stop methods
|
||||
timer = TimerManager("Policy", log=False)
|
||||
timer.start()
|
||||
time.sleep(0.01)
|
||||
timer.stop()
|
||||
print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
label: str = "Elapsed-time",
|
||||
log: bool = True,
|
||||
logger: logging.Logger | None = None,
|
||||
):
|
||||
self.label = label
|
||||
self.log = log
|
||||
self.logger = logger
|
||||
self._start: float | None = None
|
||||
self._history: list[float] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self.start()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
self._start = time.perf_counter()
|
||||
return self
|
||||
|
||||
def stop(self) -> float:
|
||||
if self._start is None:
|
||||
raise RuntimeError("Timer was never started.")
|
||||
elapsed = time.perf_counter() - self._start
|
||||
self._history.append(elapsed)
|
||||
self._start = None
|
||||
if self.log:
|
||||
if self.logger is not None:
|
||||
self.logger.info(f"{self.label}: {elapsed:.6f} s")
|
||||
else:
|
||||
logging.info(f"{self.label}: {elapsed:.6f} s")
|
||||
return elapsed
|
||||
|
||||
def reset(self):
|
||||
self._history.clear()
|
||||
|
||||
@property
|
||||
def last(self) -> float:
|
||||
return self._history[-1] if self._history else 0.0
|
||||
|
||||
@property
|
||||
def avg(self) -> float:
|
||||
return mean(self._history) if self._history else 0.0
|
||||
|
||||
@property
|
||||
def total(self) -> float:
|
||||
return sum(self._history)
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
return len(self._history)
|
||||
|
||||
@property
|
||||
def history(self) -> list[float]:
|
||||
return deepcopy(self._history)
|
||||
|
||||
@property
|
||||
def fps_history(self) -> list[float]:
|
||||
return [1.0 / t for t in self._history]
|
||||
|
||||
@property
|
||||
def fps_last(self) -> float:
|
||||
return 0.0 if self.last == 0 else 1.0 / self.last
|
||||
|
||||
@property
|
||||
def fps_avg(self) -> float:
|
||||
return 0.0 if self.avg == 0 else 1.0 / self.avg
|
||||
|
||||
def percentile(self, p: float) -> float:
|
||||
"""
|
||||
Return the p-th percentile of recorded times.
|
||||
"""
|
||||
if not self._history:
|
||||
return 0.0
|
||||
return float(np.percentile(self._history, p))
|
||||
|
||||
def fps_percentile(self, p: float) -> float:
|
||||
"""
|
||||
FPS corresponding to the p-th percentile time.
|
||||
"""
|
||||
val = self.percentile(p)
|
||||
return 0.0 if val == 0 else 1.0 / val
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
|
||||
@@ -26,19 +26,19 @@ from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.image_writer import image_array_to_pil_image
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image
|
||||
from lerobot.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
from lerobot.datasets.utils import (
|
||||
create_branch,
|
||||
hw_to_dataset_features,
|
||||
)
|
||||
from lerobot.common.envs.factory import make_env_config
|
||||
from lerobot.common.policies.factory import make_policy_config
|
||||
from lerobot.common.robots import make_robot_from_config
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.robots import make_robot_from_config
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
@@ -107,7 +107,7 @@ def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
|
||||
# and test the small resulting function that validates the features
|
||||
def test_dataset_feature_with_forward_slash_raises_error():
|
||||
# make sure dir does not exist
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
|
||||
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
|
||||
# make sure does not exist
|
||||
|
||||
@@ -21,8 +21,8 @@ import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.utils import (
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.utils import (
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
hf_transform_to_torch,
|
||||
|
||||
2
tests/fixtures/files.py
vendored
2
tests/fixtures/files.py
vendored
@@ -19,7 +19,7 @@ import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
from lerobot.datasets.utils import (
|
||||
write_episodes,
|
||||
write_hf_dataset,
|
||||
write_info,
|
||||
|
||||
2
tests/fixtures/hub.py
vendored
2
tests/fixtures/hub.py
vendored
@@ -18,7 +18,7 @@ import pandas as pd
|
||||
import pytest
|
||||
from huggingface_hub.utils import filter_repo_objects
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
|
||||
@@ -23,20 +23,20 @@ import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot import available_policies
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle, dataset_to_policy_features
|
||||
from lerobot.common.envs.factory import make_env, make_env_config
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler
|
||||
from lerobot.common.policies.factory import (
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.utils import cycle, dataset_to_policy_features
|
||||
from lerobot.envs.factory import make_env, make_env_config
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.act.modeling_act import ACTTemporalEnsembler
|
||||
from lerobot.policies.factory import (
|
||||
get_policy_class,
|
||||
make_policy,
|
||||
make_policy_config,
|
||||
)
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.utils.random_utils import seeded_context
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
Reference in New Issue
Block a user