diff --git a/examples/dataset/create_progress_videos.py b/examples/dataset/create_progress_videos.py new file mode 100644 index 000000000..5f98d2cea --- /dev/null +++ b/examples/dataset/create_progress_videos.py @@ -0,0 +1,680 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes. + +Downloads datasets from HuggingFace, seeks directly into the episode segment +of the source video, draws a progress line on each frame, and writes the result. + +Usage: + python examples/dataset/create_progress_videos.py \ + --repo-id lerobot-data-collection/level2_final_quality3 \ + --episode 1100 + + python examples/dataset/create_progress_videos.py \ + --repo-id lerobot-data-collection/level2_final_quality3 \ + --episode 1100 \ + --camera-key observation.images.top \ + --output-dir ./my_videos \ + --gif +""" + +from __future__ import annotations + +import argparse +import json +import logging +import subprocess +from pathlib import Path + +import cv2 +import numpy as np +import pandas as pd +from huggingface_hub import snapshot_download + +GRAPH_Y_TOP_FRAC = 0.01 +GRAPH_Y_BOT_FRAC = 0.99 +LINE_THICKNESS = 3 +SHADOW_THICKNESS = 6 +REF_ALPHA = 0.45 +FILL_ALPHA = 0.55 +SCORE_FONT_SCALE = 0.8 +TASK_FONT_SCALE = 0.55 + + +def download_episode_metadata(repo_id: str, episode: int) -> Path: + """Download only the metadata and sarm_progress files for a dataset. + + Args: + repo_id: HuggingFace dataset repository ID. + episode: Episode index (used for logging only; all meta is fetched). + + Returns: + Local cache path for the downloaded snapshot. + """ + logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode) + local_path = Path( + snapshot_download( + repo_id=repo_id, + repo_type="dataset", + allow_patterns=["meta/**", "sarm_progress.parquet"], + ignore_patterns=["*.mp4"], + ) + ) + return local_path + + +def load_episode_meta(local_path: Path, episode: int, camera_key: str | None) -> dict: + """Read info.json and episode parquet to resolve fps, video path, and timestamps. + + Args: + local_path: Local cache directory containing meta/. + episode: Episode index to look up. + camera_key: Camera observation key (e.g. "observation.images.base"). + If None, the first available video key is used. + + Returns: + Dict with keys: fps, camera, video_rel, chunk_index, file_index, + from_ts, to_ts, task_name. + """ + info = json.loads((local_path / "meta" / "info.json").read_text()) + fps = info["fps"] + features = info["features"] + + video_keys = [k for k, v in features.items() if v.get("dtype") == "video"] + if not video_keys: + raise RuntimeError("No video keys found in dataset features") + + if camera_key is not None: + if camera_key not in video_keys: + raise RuntimeError(f"camera_key='{camera_key}' not found. Available: {video_keys}") + selected_camera = camera_key + else: + selected_camera = video_keys[0] + logging.info(" fps=%d camera='%s' all_cams=%s", fps, selected_camera, video_keys) + + episode_rows = [] + for parquet_file in sorted((local_path / "meta" / "episodes").glob("**/*.parquet")): + episode_rows.append(pd.read_parquet(parquet_file)) + episode_df = pd.concat(episode_rows, ignore_index=True) + row = episode_df[episode_df["episode_index"] == episode] + if row.empty: + raise RuntimeError(f"Episode {episode} not found in episode metadata") + row = row.iloc[0] + + chunk_col = f"videos/{selected_camera}/chunk_index" + file_col = f"videos/{selected_camera}/file_index" + ts_from_col = f"videos/{selected_camera}/from_timestamp" + ts_to_col = f"videos/{selected_camera}/to_timestamp" + + if chunk_col not in row.index: + chunk_col = f"{selected_camera}/chunk_index" + file_col = f"{selected_camera}/file_index" + ts_from_col = f"{selected_camera}/from_timestamp" + ts_to_col = f"{selected_camera}/to_timestamp" + if chunk_col not in row.index: + raise RuntimeError( + f"Cannot find video metadata columns for {selected_camera}.\nAvailable: {list(row.index)}" + ) + + chunk_index = int(row[chunk_col]) + file_index = int(row[file_col]) + from_timestamp = float(row[ts_from_col]) + to_timestamp = float(row[ts_to_col]) + + video_template = info.get( + "video_path", "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4" + ) + video_rel = video_template.format( + video_key=selected_camera, + chunk_index=chunk_index, + file_index=file_index, + ) + + task_name = _resolve_task_name(row, local_path) + + return { + "fps": fps, + "camera": selected_camera, + "video_rel": video_rel, + "chunk_index": chunk_index, + "file_index": file_index, + "from_ts": from_timestamp, + "to_ts": to_timestamp, + "task_name": task_name, + } + + +def _resolve_task_name(row: pd.Series, local_path: Path) -> str: + """Best-effort extraction of the task name for an episode row. + + Args: + row: Single-episode row from the episodes parquet. + local_path: Dataset cache root. + + Returns: + Task name string, or empty string if unavailable. + """ + try: + if "tasks" in row.index and row["tasks"] is not None: + tasks_val = row["tasks"] + if isinstance(tasks_val, (list, tuple, np.ndarray)) and len(tasks_val) > 0: + return str(tasks_val[0]) + return str(tasks_val).strip("[]'") + + tasks_parquet = local_path / "meta" / "tasks.parquet" + if tasks_parquet.exists(): + tasks_df = pd.read_parquet(tasks_parquet) + task_idx = int(row.get("task_index", 0)) if "task_index" in row.index else 0 + match = tasks_df[tasks_df["task_index"] == task_idx] + if not match.empty: + return str(match.index[0]) + except Exception as exc: + logging.warning("Could not load task name: %s", exc) + return "" + + +def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path: + """Download the specific video file if not already cached. + + Args: + repo_id: HuggingFace dataset repository ID. + local_path: Local cache directory. + video_rel: Relative path to the video file within the dataset. + + Returns: + Absolute path to the downloaded video file. + """ + video_path = local_path / video_rel + if video_path.exists(): + logging.info(" Video already cached: %s", video_path) + return video_path + logging.info("[2/4] Downloading video file %s ...", video_rel) + snapshot_download( + repo_id=repo_id, + repo_type="dataset", + local_dir=str(local_path), + allow_patterns=[video_rel], + ) + if not video_path.exists(): + raise RuntimeError(f"Video not found after download: {video_path}") + return video_path + + +def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None: + """Load sarm_progress values for an episode. + + Args: + local_path: Dataset cache root. + episode: Episode index. + + Returns: + Sorted (N, 2) array of (frame_index, progress), or None if unavailable. + """ + parquet_path = local_path / "sarm_progress.parquet" + if not parquet_path.exists(): + logging.warning("sarm_progress.parquet not found") + return None + df = pd.read_parquet(parquet_path) + logging.info(" sarm_progress.parquet columns: %s", list(df.columns)) + episode_df = df[df["episode_index"] == episode].copy() + if episode_df.empty: + logging.warning("No sarm_progress rows for episode %d", episode) + return None + episode_df = episode_df.sort_values("frame_index") + + if "progress_dense" in episode_df.columns and episode_df["progress_dense"].notna().any(): + progress_column = "progress_dense" + elif "progress_sparse" in episode_df.columns: + progress_column = "progress_sparse" + else: + progress_columns = [c for c in episode_df.columns if "progress" in c.lower()] + if not progress_columns: + return None + progress_column = progress_columns[0] + + logging.info(" Using progress column: '%s'", progress_column) + return episode_df[["frame_index", progress_column]].rename(columns={progress_column: "progress"}).values + + +def _precompute_pixel_coords( + progress_data: np.ndarray, + num_frames: int, + frame_width: int, + frame_height: int, +) -> np.ndarray: + """Map progress samples to pixel coordinates for overlay drawing. + + Args: + progress_data: (N, 2) array of (frame_index, progress). + num_frames: Total number of video frames. + frame_width: Video width in pixels. + frame_height: Video height in pixels. + + Returns: + (N, 2) array of (x, y) pixel coordinates. + """ + frame_indices = progress_data[:, 0].astype(float) + progress_values = np.clip(progress_data[:, 1].astype(float), 0.0, 1.0) + + y_top = int(frame_height * GRAPH_Y_TOP_FRAC) + y_bot = int(frame_height * GRAPH_Y_BOT_FRAC) + graph_height = y_bot - y_top + + x_coords = (frame_indices / (num_frames - 1) * (frame_width - 1)).astype(int) + y_coords = (y_bot - progress_values * graph_height).astype(int) + + return np.stack([x_coords, y_coords], axis=1) + + +def _progress_color(normalized_position: float) -> tuple[int, int, int]: + """Interpolate BGR color from red to green based on position in [0, 1]. + + Args: + normalized_position: Value in [0, 1] indicating how far along the episode. + + Returns: + BGR color tuple. + """ + red = int(255 * (1.0 - normalized_position)) + green = int(255 * normalized_position) + return (0, green, red) + + +def _prerender_fill_polygon( + pixel_coords: np.ndarray, + frame_width: int, + frame_height: int, +) -> np.ndarray: + """Pre-render the grey fill polygon under the progress curve as a BGRA image. + + Args: + pixel_coords: (N, 2) array of (x, y) pixel coordinates. + frame_width: Video width in pixels. + frame_height: Video height in pixels. + + Returns: + BGRA image array of shape (frame_height, frame_width, 4). + """ + y_bot = int(frame_height * GRAPH_Y_BOT_FRAC) + fill_image = np.zeros((frame_height, frame_width, 4), dtype=np.uint8) + polygon = np.concatenate( + [ + pixel_coords, + [[pixel_coords[-1][0], y_bot], [pixel_coords[0][0], y_bot]], + ], + axis=0, + ).astype(np.int32) + cv2.fillPoly(fill_image, [polygon], color=(128, 128, 128, int(255 * FILL_ALPHA))) + return fill_image + + +def _alpha_composite_region(base: np.ndarray, overlay_bgra: np.ndarray, x_limit: int) -> None: + """Blend BGRA overlay onto BGR base in-place, up to x_limit columns. + + Args: + base: BGR frame to draw on (modified in-place). + overlay_bgra: BGRA overlay image. + x_limit: Only blend columns [0, x_limit). + """ + if x_limit <= 0: + return + region_base = base[:, :x_limit] + region_overlay = overlay_bgra[:, :x_limit] + alpha = region_overlay[:, :, 3:4].astype(np.float32) / 255.0 + region_base[:] = np.clip( + region_overlay[:, :, :3].astype(np.float32) * alpha + region_base.astype(np.float32) * (1.0 - alpha), + 0, + 255, + ).astype(np.uint8) + + +def _draw_text_outlined( + frame: np.ndarray, + text: str, + position: tuple[int, int], + font_scale: float, + thickness: int = 1, +) -> None: + """Draw white text with a dark outline for readability on any background. + + Args: + frame: BGR image to draw on (modified in-place). + text: String to render. + position: (x, y) bottom-left corner of the text. + font_scale: OpenCV font scale. + thickness: Text stroke thickness. + """ + font = cv2.FONT_HERSHEY_SIMPLEX + cv2.putText(frame, text, position, font, font_scale, (0, 0, 0), thickness + 2, cv2.LINE_AA) + cv2.putText(frame, text, position, font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA) + + +def composite_progress_video( + video_path: Path, + from_timestamp: float, + to_timestamp: float, + progress_data: np.ndarray, + output_path: Path, + fps: float, + task_name: str = "", +) -> Path: + """Read episode frames by seeking into the source video, draw progress overlay, write output. + + Uses cv2.CAP_PROP_POS_MSEC to seek directly into the source video, + eliminating the need for an intermediate clip file. + + Args: + video_path: Path to the full source video file. + from_timestamp: Start timestamp of the episode in seconds. + to_timestamp: End timestamp of the episode in seconds. + progress_data: (N, 2) array of (frame_index, progress). + output_path: Path to write the output MP4. + fps: Frames per second for the output video. + task_name: Optional task name to display at the top of the video. + + Returns: + Path to the written output file (MP4). + """ + capture = cv2.VideoCapture(str(video_path)) + try: + capture.set(cv2.CAP_PROP_POS_MSEC, from_timestamp * 1000) + + frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) + frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) + duration_seconds = to_timestamp - from_timestamp + num_frames = int(round(duration_seconds * fps)) + + logging.info( + " Video: %dx%d, %d frames @ %.1f fps (%.2fs)", + frame_width, + frame_height, + num_frames, + fps, + duration_seconds, + ) + + pixel_coords = _precompute_pixel_coords(progress_data, num_frames, frame_width, frame_height) + y_ref = int(frame_height * GRAPH_Y_TOP_FRAC) + + fill_image = _prerender_fill_polygon(pixel_coords, frame_width, frame_height) + + ref_line_image = np.zeros((frame_height, frame_width, 4), dtype=np.uint8) + cv2.line( + ref_line_image, + (0, y_ref), + (frame_width - 1, y_ref), + (200, 200, 200, int(255 * REF_ALPHA)), + 1, + cv2.LINE_AA, + ) + + frame_indices = progress_data[:, 0].astype(int) + progress_values = progress_data[:, 1].astype(float) + + logging.info("[3/4] Compositing %d frames ...", num_frames) + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + writer = cv2.VideoWriter(str(output_path), fourcc, fps, (frame_width, frame_height)) + + for frame_idx in range(num_frames): + ret, frame = capture.read() + if not ret: + break + + drawn_count = int(np.searchsorted(frame_indices, frame_idx, side="right")) + x_current = ( + int(pixel_coords[min(drawn_count, len(pixel_coords)) - 1][0]) + 1 if drawn_count > 0 else 0 + ) + + _alpha_composite_region(frame, ref_line_image, frame_width) + _alpha_composite_region(frame, fill_image, x_current) + + if drawn_count >= 2: + time_position = (drawn_count - 1) / max(len(progress_values) - 1, 1) + line_color = _progress_color(time_position) + points = pixel_coords[:drawn_count].reshape(-1, 1, 2).astype(np.int32) + cv2.polylines( + frame, + [points], + isClosed=False, + color=(255, 255, 255), + thickness=SHADOW_THICKNESS, + lineType=cv2.LINE_AA, + ) + cv2.polylines( + frame, + [points], + isClosed=False, + color=line_color, + thickness=LINE_THICKNESS, + lineType=cv2.LINE_AA, + ) + + if drawn_count > 0: + score = float(progress_values[min(drawn_count, len(progress_values)) - 1]) + score_text = f"{score:.2f}" + (text_width, _), _ = cv2.getTextSize( + score_text, cv2.FONT_HERSHEY_SIMPLEX, SCORE_FONT_SCALE, 2 + ) + score_x = frame_width - text_width - 12 + score_y = frame_height - 12 + time_position = (drawn_count - 1) / max(len(progress_values) - 1, 1) + score_color = _progress_color(time_position) + cv2.putText( + frame, + score_text, + (score_x, score_y), + cv2.FONT_HERSHEY_SIMPLEX, + SCORE_FONT_SCALE, + (0, 0, 0), + 4, + cv2.LINE_AA, + ) + cv2.putText( + frame, + score_text, + (score_x, score_y), + cv2.FONT_HERSHEY_SIMPLEX, + SCORE_FONT_SCALE, + score_color, + 2, + cv2.LINE_AA, + ) + + if task_name: + (text_width, _), _ = cv2.getTextSize(task_name, cv2.FONT_HERSHEY_SIMPLEX, TASK_FONT_SCALE, 1) + task_x = max((frame_width - text_width) // 2, 4) + _draw_text_outlined(frame, task_name, (task_x, 22), TASK_FONT_SCALE) + + writer.write(frame) + if frame_idx % 100 == 0: + logging.info(" Frame %d/%d ...", frame_idx, num_frames) + + writer.release() + finally: + capture.release() + + logging.info(" MP4 written: %s", output_path) + return output_path + + +def convert_mp4_to_gif(mp4_path: Path) -> Path: + """Convert an MP4 to an optimized GIF using ffmpeg palette generation. + + Args: + mp4_path: Path to the source MP4 file. + + Returns: + Path to the generated GIF file. + """ + capture = cv2.VideoCapture(str(mp4_path)) + frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) + capture.release() + + gif_path = mp4_path.with_suffix(".gif") + palette_path = mp4_path.parent / "_palette.png" + + logging.info("[4/4] Converting to GIF ...") + result_palette = subprocess.run( # nosec B607 + [ + "ffmpeg", + "-y", + "-i", + str(mp4_path), + "-vf", + f"fps=10,scale={frame_width}:-1:flags=lanczos,palettegen=max_colors=128:stats_mode=diff", + "-update", + "1", + str(palette_path), + ], + capture_output=True, + text=True, + ) + if result_palette.returncode != 0: + logging.warning("palettegen failed:\n%s", result_palette.stderr[-500:]) + + result_gif = subprocess.run( # nosec B607 + [ + "ffmpeg", + "-y", + "-i", + str(mp4_path), + "-i", + str(palette_path), + "-filter_complex", + f"fps=10,scale={frame_width}:-1:flags=lanczos[v];[v][1:v]paletteuse=dither=bayer:bayer_scale=3", + str(gif_path), + ], + capture_output=True, + text=True, + ) + if result_gif.returncode != 0: + logging.warning("GIF encode failed:\n%s", result_gif.stderr[-500:]) + + palette_path.unlink(missing_ok=True) + logging.info(" GIF written: %s", gif_path) + return gif_path + + +def process_dataset( + repo_id: str, + episode: int, + camera_key: str | None, + output_dir: Path, + create_gif: bool = False, +) -> Path | None: + """Full pipeline: download, extract metadata, composite progress, write output. + + Args: + repo_id: HuggingFace dataset repository ID. + episode: Episode index. + camera_key: Camera key to use, or None for auto-selection. + output_dir: Directory to write output files. + create_gif: If True, also generate a GIF from the MP4. + + Returns: + Path to the final output file, or None on failure. + """ + safe_name = repo_id.replace("/", "_") + logging.info("Processing: %s | episode %d", repo_id, episode) + + local_path = download_episode_metadata(repo_id, episode) + logging.info(" Local cache: %s", local_path) + + episode_meta = load_episode_meta(local_path, episode, camera_key) + logging.info(" Episode meta: %s", episode_meta) + + video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"]) + + progress_data = load_progress_data(local_path, episode) + if progress_data is None: + logging.error("Could not load sarm_progress data. Skipping overlay.") + return None + + logging.info(" Progress frames: %d", len(progress_data)) + + output_path = output_dir / f"{safe_name}_ep{episode}_progress.mp4" + final_path = composite_progress_video( + video_path=video_path, + from_timestamp=episode_meta["from_ts"], + to_timestamp=episode_meta["to_ts"], + progress_data=progress_data, + output_path=output_path, + fps=episode_meta["fps"], + task_name=episode_meta.get("task_name", ""), + ) + + if create_gif: + final_path = convert_mp4_to_gif(final_path) + + logging.info("Done: %s", final_path) + return final_path + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes." + ) + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="HuggingFace dataset repository ID (e.g. 'lerobot-data-collection/level2_final_quality3').", + ) + parser.add_argument( + "--episode", + type=int, + required=True, + help="Episode index to visualize.", + ) + parser.add_argument( + "--camera-key", + type=str, + default=None, + help="Camera observation key (e.g. 'observation.images.base'). Auto-selects first camera if omitted.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("progress_videos"), + help="Directory to write output files (default: ./progress_videos).", + ) + parser.add_argument( + "--gif", + action="store_true", + help="Also generate a GIF from the MP4 output.", + ) + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + args.output_dir.mkdir(parents=True, exist_ok=True) + + result = process_dataset( + repo_id=args.repo_id, + episode=args.episode, + camera_key=args.camera_key, + output_dir=args.output_dir, + create_gif=args.gif, + ) + + if result: + logging.info("Output: %s", result) + + +if __name__ == "__main__": + main()