diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index f348b40eb..2f6c990e4 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -90,7 +90,7 @@ DEFAULT_FEATURES = { } -def get_parquet_file_size_in_mb(parquet_path): +def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float: metadata = pq.read_metadata(parquet_path) total_uncompressed_size = 0 for row_group in range(metadata.num_row_groups): @@ -102,10 +102,10 @@ def get_parquet_file_size_in_mb(parquet_path): def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int: - return hf_ds.data.nbytes / (1024**2) + return hf_ds.data.nbytes // (1024**2) -def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int): +def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]: if file_idx == chunks_size - 1: file_idx = 0 chunk_idx += 1 @@ -132,18 +132,18 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) return concatenate_datasets(datasets) -def get_parquet_num_frames(parquet_path): +def get_parquet_num_frames(parquet_path: str | Path) -> int: metadata = pq.read_metadata(parquet_path) return metadata.num_rows -def get_video_size_in_mb(mp4_path: Path): +def get_video_size_in_mb(mp4_path: Path) -> float: file_size_bytes = mp4_path.stat().st_size file_size_mb = file_size_bytes / (1024**2) return file_size_mb -def get_video_duration_in_s(mp4_file: Path): +def get_video_duration_in_s(mp4_file: Path) -> float: # TODO(rcadene): move to video_utils.py command = [ "ffprobe", @@ -171,6 +171,7 @@ def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}` >>> print(flatten_dict(dct)) {"a/b": 1, "a/c/d": 2, "e": 3} + ``` """ items = [] for k, v in d.items(): @@ -231,7 +232,7 @@ def write_json(data: dict, fpath: Path) -> None: json.dump(data, f, indent=4, ensure_ascii=False) -def write_info(info: dict, local_dir: Path): +def write_info(info: dict, local_dir: Path) -> None: write_json(info, local_dir / INFO_PATH) @@ -242,35 +243,35 @@ def load_info(local_dir: Path) -> dict: return info -def write_stats(stats: dict, local_dir: Path): +def write_stats(stats: dict, local_dir: Path) -> None: serialized_stats = serialize_dict(stats) write_json(serialized_stats, local_dir / STATS_PATH) -def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]: +def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]: stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} return unflatten_dict(stats) -def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]: +def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]] | None: if not (local_dir / STATS_PATH).exists(): return None stats = load_json(local_dir / STATS_PATH) return cast_stats_to_numpy(stats) -def write_tasks(tasks: pandas.DataFrame, local_dir: Path): +def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None: path = local_dir / DEFAULT_TASKS_PATH path.parent.mkdir(parents=True, exist_ok=True) tasks.to_parquet(path) -def load_tasks(local_dir: Path): +def load_tasks(local_dir: Path) -> pandas.DataFrame: tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH) return tasks -def write_episodes(episodes: Dataset, local_dir: Path): +def write_episodes(episodes: Dataset, local_dir: Path) -> None: if get_hf_dataset_size_in_mb(episodes) > DEFAULT_DATA_FILE_SIZE_IN_MB: raise NotImplementedError("Contact a maintainer.") @@ -290,7 +291,7 @@ def load_episodes(local_dir: Path) -> datasets.Dataset: def backward_compatible_episodes_stats( stats: dict[str, dict[str, np.ndarray]], episodes: list[int] -) -> dict[str, dict[str, np.ndarray]]: +) -> dict[int, dict[str, dict[str, np.ndarray]]]: return dict.fromkeys(episodes, stats) @@ -306,7 +307,7 @@ def load_image_as_numpy( return img_array -def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): +def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]: """Get a transform function that convert items from Hugging Face dataset (pyarrow) to torch tensors. Importantly, images are converted from PIL, which corresponds to a channel last representation (h w c) of uint8 type, to a torch image representation @@ -595,7 +596,7 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic return delta_indices -def cycle(iterable): +def cycle(iterable: Any) -> Iterator[Any]: """The equivalent of itertools.cycle, but safe for Pytorch dataloaders. See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe. @@ -608,7 +609,7 @@ def cycle(iterable): iterator = iter(iterable) -def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None: +def create_branch(repo_id: str, *, branch: str, repo_type: str | None = None) -> None: """Create a branch on a existing Hugging Face repo. Delete the branch if it already exists before creating it. """ @@ -716,7 +717,7 @@ class IterableNamespace(SimpleNamespace): return vars(self).keys() -def validate_frame(frame: dict, features: dict): +def validate_frame(frame: dict, features: dict) -> None: expected_features = set(features) - set(DEFAULT_FEATURES) actual_features = set(frame) @@ -737,7 +738,7 @@ def validate_frame(frame: dict, features: dict): raise ValueError(error_message) -def validate_features_presence(actual_features: set[str], expected_features: set[str]): +def validate_features_presence(actual_features: set[str], expected_features: set[str]) -> str: error_message = "" missing_features = expected_features - actual_features extra_features = actual_features - expected_features @@ -752,7 +753,9 @@ def validate_features_presence(actual_features: set[str], expected_features: set return error_message -def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str): +def validate_feature_dtype_and_shape( + name: str, feature: dict, value: np.ndarray | PILImage.Image | str +) -> str: expected_dtype = feature["dtype"] expected_shape = feature["shape"] if is_valid_numpy_dtype_string(expected_dtype): @@ -767,7 +770,7 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray def validate_feature_numpy_array( name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray -): +) -> str: error_message = "" if isinstance(value, np.ndarray): actual_dtype = value.dtype @@ -784,7 +787,9 @@ def validate_feature_numpy_array( return error_message -def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image): +def validate_feature_image_or_video( + name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image +) -> str: # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. error_message = "" if isinstance(value, np.ndarray): @@ -800,13 +805,13 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value: return error_message -def validate_feature_string(name: str, value: str): +def validate_feature_string(name: str, value: str) -> str: if not isinstance(value, str): return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" return "" -def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict): +def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None: if "size" not in episode_buffer: raise ValueError("size key not found in episode_buffer") @@ -832,7 +837,7 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: ) -def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path): +def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None: """This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset. This way, it can be loaded by HF dataset and correctly formatted images are returned. """