#!/usr/bin/env python # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Visualize SARM Subtask Annotations This script creates visualizations of the subtask annotations generated by subtask_annotation.py. For each episode, it shows: - A timeline with dashed vertical lines at subtask boundaries - Sample frames from the episode at key points (start, middle, end of each subtask) - Color-coded subtask segments Usage: python visualize_subtask_annotations.py --repo-id pepijn223/mydataset --video-key observation.images.top --num-episodes 5 """ import argparse import random from pathlib import Path import cv2 import matplotlib.pyplot as plt import matplotlib.patches as mpatches import numpy as np import pandas as pd from matplotlib.lines import Line2D from rich.console import Console from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import load_episodes from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp def timestamp_to_seconds(timestamp: str) -> float: """Convert MM:SS or SS timestamp to seconds""" parts = timestamp.split(":") if len(parts) == 2: return int(parts[0]) * 60 + int(parts[1]) else: return int(parts[0]) def load_annotations_from_dataset(dataset_path: Path) -> dict[int, SubtaskAnnotation]: """ Load annotations from LeRobot dataset parquet files. Reads subtask annotations from the episodes metadata parquet files. """ episodes_dataset = load_episodes(dataset_path) if episodes_dataset is None or len(episodes_dataset) == 0: return {} # Check if subtask columns exist if "subtask_names" not in episodes_dataset.column_names: return {} # Convert to pandas DataFrame for easier access episodes_df = episodes_dataset.to_pandas() annotations = {} for ep_idx in episodes_df.index: subtask_names = episodes_df.loc[ep_idx, "subtask_names"] # Skip episodes without annotations if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)): continue start_times = episodes_df.loc[ep_idx, "subtask_start_times"] end_times = episodes_df.loc[ep_idx, "subtask_end_times"] # Reconstruct SubtaskAnnotation from stored data subtasks = [] for i, name in enumerate(subtask_names): # Convert seconds back to MM:SS format start_sec = int(start_times[i]) end_sec = int(end_times[i]) start_str = f"{start_sec // 60:02d}:{start_sec % 60:02d}" end_str = f"{end_sec // 60:02d}:{end_sec % 60:02d}" subtasks.append( Subtask( name=name, timestamps=Timestamp(start=start_str, end=end_str) ) ) annotations[int(ep_idx)] = SubtaskAnnotation(subtasks=subtasks) return annotations # Color palette for subtasks (colorblind-friendly) SUBTASK_COLORS = [ "#E69F00", # Orange "#56B4E9", # Sky blue "#009E73", # Bluish green "#F0E442", # Yellow "#0072B2", # Blue "#D55E00", # Vermillion "#CC79A7", # Reddish purple "#999999", # Gray ] def extract_frame_from_video(video_path: Path, timestamp: float) -> np.ndarray | None: """Extract a single frame from video at given timestamp.""" cap = cv2.VideoCapture(str(video_path)) if not cap.isOpened(): return None # Set position to timestamp cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000) ret, frame = cap.read() cap.release() if ret: # Convert BGR to RGB return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) return None def visualize_episode( episode_idx: int, annotation, video_path: Path, video_start_timestamp: float, video_end_timestamp: float, fps: int, output_path: Path, video_key: str, ): """ Create visualization for a single episode. Shows: - Top row: Sample frames from the episode (one per subtask) - Bottom: Timeline with subtask segments and boundary lines """ subtasks = annotation.subtasks num_subtasks = len(subtasks) if num_subtasks == 0: print(f"No subtasks found for episode {episode_idx}") return # Calculate episode duration episode_duration = video_end_timestamp - video_start_timestamp # Extract sample frames - get frame from middle of each subtask sample_frames = [] frame_timestamps = [] for subtask in subtasks: start_sec = timestamp_to_seconds(subtask.timestamps.start) end_sec = timestamp_to_seconds(subtask.timestamps.end) mid_sec = (start_sec + end_sec) / 2 # Convert to video timestamp (add video_start_timestamp offset) video_timestamp = video_start_timestamp + mid_sec frame_timestamps.append(mid_sec) frame = extract_frame_from_video(video_path, video_timestamp) sample_frames.append(frame) # Create figure fig = plt.figure(figsize=(16, 10)) # Use a dark background for better contrast fig.patch.set_facecolor('#1a1a2e') # Calculate grid layout # Top section: frames (variable number of columns based on subtasks) # Bottom section: timeline # Create gridspec gs = fig.add_gridspec( 2, max(num_subtasks, 1), height_ratios=[2, 1], hspace=0.3, wspace=0.1, left=0.05, right=0.95, top=0.88, bottom=0.1 ) # Add title fig.suptitle( f"Episode {episode_idx} - Subtask Annotations", fontsize=18, fontweight='bold', color='white', y=0.96 ) # Add subtitle with video info fig.text( 0.5, 0.91, f"Camera: {video_key} | Duration: {episode_duration:.1f}s | {num_subtasks} subtasks", ha='center', fontsize=11, color='#888888' ) # Plot sample frames for i, (frame, subtask) in enumerate(zip(sample_frames, subtasks)): ax = fig.add_subplot(gs[0, i]) ax.set_facecolor('#16213e') if frame is not None: ax.imshow(frame) else: ax.text(0.5, 0.5, "Frame\nN/A", ha='center', va='center', fontsize=12, color='white', transform=ax.transAxes) ax.set_title( f"{subtask.name}", fontsize=10, fontweight='bold', color=SUBTASK_COLORS[i % len(SUBTASK_COLORS)], pad=8 ) ax.axis('off') # Add frame timestamp below ax.text( 0.5, -0.08, f"t={frame_timestamps[i]:.1f}s", ha='center', fontsize=9, color='#888888', transform=ax.transAxes ) # Create timeline subplot spanning all columns ax_timeline = fig.add_subplot(gs[1, :]) ax_timeline.set_facecolor('#16213e') # Get total duration from last subtask end time total_duration = timestamp_to_seconds(subtasks[-1].timestamps.end) # Draw subtask segments as colored bars bar_height = 0.6 bar_y = 0.5 for i, subtask in enumerate(subtasks): start_sec = timestamp_to_seconds(subtask.timestamps.start) end_sec = timestamp_to_seconds(subtask.timestamps.end) color = SUBTASK_COLORS[i % len(SUBTASK_COLORS)] # Draw segment bar rect = mpatches.FancyBboxPatch( (start_sec, bar_y - bar_height/2), end_sec - start_sec, bar_height, boxstyle="round,pad=0.02,rounding_size=0.1", facecolor=color, edgecolor='white', linewidth=1.5, alpha=0.85 ) ax_timeline.add_patch(rect) # Add subtask label inside bar mid_x = (start_sec + end_sec) / 2 duration = end_sec - start_sec # Only add text if segment is wide enough if duration > total_duration * 0.08: ax_timeline.text( mid_x, bar_y, subtask.name, ha='center', va='center', fontsize=9, fontweight='bold', color='black' if i in [3] else 'white', # Yellow needs dark text rotation=0 if duration > total_duration * 0.15 else 45 ) # Draw boundary lines (dashed vertical lines between subtasks) boundary_times = [] for i, subtask in enumerate(subtasks): start_sec = timestamp_to_seconds(subtask.timestamps.start) end_sec = timestamp_to_seconds(subtask.timestamps.end) # Add start boundary (except for first subtask at t=0) if i == 0 and start_sec > 0: boundary_times.append(start_sec) elif i > 0: boundary_times.append(start_sec) # Add end boundary for last subtask if i == len(subtasks) - 1: boundary_times.append(end_sec) # Draw dashed lines at boundaries for t in boundary_times: ax_timeline.axvline( x=t, ymin=0.1, ymax=0.9, color='white', linestyle='--', linewidth=2, alpha=0.9 ) # Add time label below line ax_timeline.text( t, 0.0, f"{int(t//60):02d}:{int(t%60):02d}", ha='center', va='top', fontsize=8, color='#cccccc' ) # Add start line at t=0 ax_timeline.axvline(x=0, ymin=0.1, ymax=0.9, color='#00ff00', linestyle='-', linewidth=2.5, alpha=0.9) ax_timeline.text(0, 0.0, "00:00", ha='center', va='top', fontsize=8, color='#00ff00', fontweight='bold') # Configure timeline axes ax_timeline.set_xlim(-total_duration * 0.02, total_duration * 1.02) ax_timeline.set_ylim(-0.3, 1.2) ax_timeline.set_xlabel("Time (seconds)", fontsize=11, color='white', labelpad=10) ax_timeline.set_ylabel("") # Style the axes ax_timeline.spines['top'].set_visible(False) ax_timeline.spines['right'].set_visible(False) ax_timeline.spines['left'].set_visible(False) ax_timeline.spines['bottom'].set_color('#444444') ax_timeline.tick_params(axis='x', colors='#888888', labelsize=9) ax_timeline.tick_params(axis='y', left=False, labelleft=False) # Add x-axis ticks at regular intervals tick_interval = max(1, int(total_duration / 10)) ax_timeline.set_xticks(np.arange(0, total_duration + tick_interval, tick_interval)) # Add legend explaining line styles legend_elements = [ Line2D([0], [0], color='#00ff00', linewidth=2.5, linestyle='-', label='Start'), Line2D([0], [0], color='white', linewidth=2, linestyle='--', label='Subtask boundary'), ] ax_timeline.legend( handles=legend_elements, loc='upper right', framealpha=0.3, facecolor='#16213e', edgecolor='#444444', fontsize=9, labelcolor='white' ) # Save figure plt.savefig(output_path, dpi=150, facecolor=fig.get_facecolor(), edgecolor='none', bbox_inches='tight') plt.close() return output_path def main(): parser = argparse.ArgumentParser( description="Visualize SARM subtask annotations", formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument( "--repo-id", type=str, required=True, help="HuggingFace dataset repository ID", ) parser.add_argument( "--num-episodes", type=int, default=5, help="Number of random episodes to visualize (default: 5)", ) parser.add_argument( "--episodes", type=int, nargs="+", default=None, help="Specific episode indices to visualize (overrides --num-episodes)", ) parser.add_argument( "--video-key", type=str, default=None, help="Camera/video key to use. If not specified, uses first available.", ) parser.add_argument( "--output-dir", type=str, default="./subtask_viz", help="Output directory for visualizations (default: ./subtask_viz)", ) parser.add_argument( "--seed", type=int, default=None, help="Random seed for reproducibility", ) args = parser.parse_args() console = Console() # Set random seed if specified if args.seed is not None: random.seed(args.seed) console.print(f"\n[cyan]Loading dataset: {args.repo_id}[/cyan]") dataset = LeRobotDataset(args.repo_id, download_videos=True) fps = dataset.fps # Get video key if args.video_key: if args.video_key not in dataset.meta.video_keys: console.print(f"[red]Error: Video key '{args.video_key}' not found[/red]") console.print(f"[yellow]Available: {', '.join(dataset.meta.video_keys)}[/yellow]") return video_key = args.video_key else: video_key = dataset.meta.video_keys[0] console.print(f"[cyan]Using camera: {video_key}[/cyan]") console.print(f"[cyan]FPS: {fps}[/cyan]") # Load annotations console.print(f"\n[cyan]Loading annotations...[/cyan]") annotations = load_annotations_from_dataset(dataset.root) if not annotations: console.print("[red]Error: No annotations found in dataset[/red]") console.print("[yellow]Run subtask_annotation.py first to generate annotations[/yellow]") return console.print(f"[green]Found {len(annotations)} annotated episodes[/green]") # Determine which episodes to visualize if args.episodes: episode_indices = args.episodes # Validate episodes exist for ep in episode_indices: if ep not in annotations: console.print(f"[yellow]Warning: Episode {ep} has no annotation, skipping[/yellow]") episode_indices = [ep for ep in episode_indices if ep in annotations] else: # Random selection available_episodes = list(annotations.keys()) num_to_select = min(args.num_episodes, len(available_episodes)) episode_indices = random.sample(available_episodes, num_to_select) episode_indices.sort() if not episode_indices: console.print("[red]Error: No valid episodes to visualize[/red]") return console.print(f"[cyan]Visualizing episodes: {episode_indices}[/cyan]") # Create output directory output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Generate visualizations for ep_idx in episode_indices: console.print(f"\n[cyan]Processing episode {ep_idx}...[/cyan]") annotation = annotations[ep_idx] # Get video path and timestamps video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key) if not video_path.exists(): console.print(f"[red]Video not found: {video_path}[/red]") continue # Get episode-specific timestamps within the video file video_path_key = f"videos/{video_key}/from_timestamp" video_path_key_to = f"videos/{video_key}/to_timestamp" video_start_timestamp = float(dataset.meta.episodes[video_path_key][ep_idx]) video_end_timestamp = float(dataset.meta.episodes[video_path_key_to][ep_idx]) # Create visualization output_path = output_dir / f"episode_{ep_idx:04d}_subtasks.png" try: visualize_episode( episode_idx=ep_idx, annotation=annotation, video_path=video_path, video_start_timestamp=video_start_timestamp, video_end_timestamp=video_end_timestamp, fps=fps, output_path=output_path, video_key=video_key, ) console.print(f"[green]✓ Saved: {output_path}[/green]") except Exception as e: console.print(f"[red]✗ Failed to visualize episode {ep_idx}: {e}[/red]") # Print summary console.print(f"\n[bold green]{'=' * 50}[/bold green]") console.print(f"[bold green]Visualization Complete![/bold green]") console.print(f"[bold green]{'=' * 50}[/bold green]") console.print(f"Output directory: {output_dir.absolute()}") console.print(f"Episodes visualized: {len(episode_indices)}") if __name__ == "__main__": main()