added the file and video max size as arguments

This commit is contained in:
Michel Aractingi
2025-09-02 15:41:22 +02:00
parent 4062d0564a
commit 2df4e25558

View File

@@ -172,7 +172,7 @@ def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys):
concatenated_df.to_parquet(path, index=False, schema=schema)
def convert_data(root, new_root):
def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
data_dir = root / "data"
ep_paths = sorted(data_dir.glob("*/*.parquet"))
@@ -200,7 +200,7 @@ def convert_data(root, new_root):
episodes_metadata.append(ep_metadata)
ep_idx += 1
if size_in_mb < DEFAULT_DATA_FILE_SIZE_IN_MB:
if size_in_mb < data_file_size_in_mb:
paths_to_cat.append(ep_path)
continue
@@ -234,7 +234,7 @@ def get_image_keys(root):
return image_keys
def convert_videos(root: Path, new_root: Path):
def convert_videos(root: Path, new_root: Path, video_file_size_in_mb: int):
video_keys = get_video_keys(root)
if len(video_keys) == 0:
return None
@@ -243,7 +243,7 @@ def convert_videos(root: Path, new_root: Path):
eps_metadata_per_cam = []
for camera in video_keys:
eps_metadata = convert_videos_of_camera(root, new_root, camera)
eps_metadata = convert_videos_of_camera(root, new_root, camera, video_file_size_in_mb)
eps_metadata_per_cam.append(eps_metadata)
num_eps_per_cam = [len(eps_cam_map) for eps_cam_map in eps_metadata_per_cam]
@@ -268,7 +268,7 @@ def convert_videos(root: Path, new_root: Path):
return episods_metadata
def convert_videos_of_camera(root: Path, new_root: Path, video_key):
def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_file_size_in_mb: int):
# Access old paths to mp4
videos_dir = root / "videos"
ep_paths = sorted(videos_dir.glob(f"*/{video_key}/*.mp4"))
@@ -285,7 +285,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key):
ep_duration_in_s = get_video_duration_in_s(ep_path)
# Check if adding this episode would exceed the limit
if size_in_mb + ep_size_in_mb >= DEFAULT_VIDEO_FILE_SIZE_IN_MB and len(paths_to_cat) > 0:
if size_in_mb + ep_size_in_mb >= video_file_size_in_mb and len(paths_to_cat) > 0:
# Size limit would be exceeded, save current accumulation WITHOUT this episode
concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx)
@@ -386,13 +386,13 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_
write_stats(stats, new_root)
def convert_info(root, new_root):
def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb):
info = load_info(root)
info["codebase_version"] = "v3.0"
del info["total_chunks"]
del info["total_videos"]
info["data_files_size_in_mb"] = DEFAULT_DATA_FILE_SIZE_IN_MB
info["video_files_size_in_mb"] = DEFAULT_VIDEO_FILE_SIZE_IN_MB
info["data_files_size_in_mb"] = data_file_size_in_mb
info["video_files_size_in_mb"] = video_file_size_in_mb
info["data_path"] = DEFAULT_DATA_PATH
info["video_path"] = DEFAULT_VIDEO_PATH
info["fps"] = float(info["fps"])
@@ -407,12 +407,18 @@ def convert_info(root, new_root):
def convert_dataset(
repo_id: str,
branch: str | None = None,
num_workers: int = 4,
data_file_size_in_mb: int | None = None,
video_file_size_in_mb: int | None = None,
):
root = HF_LEROBOT_HOME / repo_id
old_root = HF_LEROBOT_HOME / f"{repo_id}_old"
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"
if data_file_size_in_mb is None:
data_file_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
if video_file_size_in_mb is None:
video_file_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
if old_root.is_dir() and root.is_dir():
shutil.rmtree(str(root))
shutil.move(str(old_root), str(root))
@@ -427,10 +433,10 @@ def convert_dataset(
local_dir=root,
)
convert_info(root, new_root)
convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb)
convert_tasks(root, new_root)
episodes_metadata = convert_data(root, new_root)
episodes_videos_metadata = convert_videos(root, new_root)
episodes_metadata = convert_data(root, new_root, data_file_size_in_mb)
episodes_videos_metadata = convert_videos(root, new_root, video_file_size_in_mb)
convert_episodes_metadata(root, new_root, episodes_metadata, episodes_videos_metadata)
shutil.move(str(root), str(old_root))
@@ -469,10 +475,16 @@ if __name__ == "__main__":
help="Repo branch to push your dataset. Defaults to the main branch.",
)
parser.add_argument(
"--num-workers",
"--data-file-size-in-mb",
type=int,
default=4,
help="Number of workers for parallelizing stats compute. Defaults to 4.",
default=None,
help="File size in MB. Defaults to 100 for data and 500 for videos.",
)
parser.add_argument(
"--video-file-size-in-mb",
type=int,
default=None,
help="File size in MB. Defaults to 100 for data and 500 for videos.",
)
args = parser.parse_args()