From bf5c037959faee7c7f49d6c02b28ad59ebe32a14 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sat, 30 Aug 2025 16:28:29 +0200 Subject: [PATCH] remove decode logging --- src/lerobot/scripts/train.py | 153 ----------------------------------- 1 file changed, 153 deletions(-) diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index b7b82cf0f..91514e711 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -25,133 +25,6 @@ import torch # Fix tokenizer parallelism conflicts with multiprocessing os.environ["TOKENIZERS_PARALLELISM"] = "false" - -def _add_video_decoding_timing(dataset): - """Add timing instrumentation to video decoding for debugging.""" - from lerobot.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset - - def instrument_dataset(ds): - if not hasattr(ds, '_query_videos'): - return - - # Store original method - original_query_videos = ds._query_videos - - # Initialize timing stats - if not hasattr(ds, '_video_decode_timing'): - ds._video_decode_timing = { - 'decode_times': [], - 'last_print_time': time.perf_counter() - } - - def timed_query_videos(self, query_timestamps, ep_idx): - # Debug: print what backend is being used - if not hasattr(self, '_backend_logged'): - print(f"DEBUG: Video backend in use: {getattr(self, 'video_backend', 'UNKNOWN')}") - self._backend_logged = True - - decode_start = time.perf_counter() - result = original_query_videos(query_timestamps, ep_idx) - decode_time = time.perf_counter() - decode_start - - # Debug problematic 0.5 frames issue - actual_frames = 0 - for key in query_timestamps: - actual_frames += len(query_timestamps[key]) - - # Accumulate timing - timing_stats = self._video_decode_timing - timing_stats['decode_times'].append(decode_time * 1000) # Convert to ms - timing_stats['actual_frame_counts'] = timing_stats.get('actual_frame_counts', []) - timing_stats['actual_frame_counts'].append(actual_frames) - - # Print averaged stats every minute - current_time = time.perf_counter() - if current_time - timing_stats['last_print_time'] >= 60.0: - n_samples = len(timing_stats['decode_times']) - if n_samples > 0: - avg_decode_time = sum(timing_stats['decode_times']) / n_samples - # Use actual frame counts tracked per call - actual_counts = timing_stats.get('actual_frame_counts', []) - avg_frames_per_call = sum(actual_counts) / len(actual_counts) if actual_counts else 0 - - print(f"\nVideo Decoding Timing (last {n_samples} calls):") - print(f" Avg decode time: {avg_decode_time:.2f} ms") - print(f" Avg frames/call: {avg_frames_per_call:.1f}") - print(f" Time per frame: {avg_decode_time/max(avg_frames_per_call, 1):.2f} ms/frame") - print("-" * 50) - - # Reset stats - timing_stats['decode_times'] = [] - timing_stats['actual_frame_counts'] = [] - timing_stats['last_print_time'] = current_time - - return result - - # Bind the method to the instance - import types - ds._query_videos = types.MethodType(timed_query_videos, ds) - - # Handle both single and multi datasets - if isinstance(dataset, MultiLeRobotDataset): - for ds in dataset._datasets: - instrument_dataset(ds) - elif isinstance(dataset, LeRobotDataset): - instrument_dataset(dataset) - else: - print(f"Warning: Unknown dataset type {type(dataset)}, skipping video timing instrumentation") - - -def _add_video_frame_caching(dataset, cache_size=1000): - """Add LRU caching to video decoding to avoid re-decoding the same frames.""" - from functools import lru_cache - from lerobot.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset - - def instrument_dataset_caching(ds): - if not hasattr(ds, '_query_videos'): - return - - # Store original method - original_query_videos = ds._query_videos - - # Create cache key from timestamps and episode - def make_cache_key(query_timestamps, ep_idx): - # Convert to hashable tuple - key_parts = [ep_idx] - for vid_key in sorted(query_timestamps.keys()): - ts_tuple = tuple(round(ts, 6) for ts in query_timestamps[vid_key]) # Round to microsecond precision - key_parts.append((vid_key, ts_tuple)) - return tuple(key_parts) - - # Create LRU cached version - @lru_cache(maxsize=cache_size) - def cached_decode_frames(cache_key, ep_idx): - # Reconstruct query_timestamps from cache_key - query_timestamps = {} - for item in cache_key[1:]: # Skip ep_idx - vid_key, ts_tuple = item - query_timestamps[vid_key] = list(ts_tuple) - return original_query_videos(query_timestamps, ep_idx) - - def cached_query_videos(self, query_timestamps, ep_idx): - cache_key = make_cache_key(query_timestamps, ep_idx) - return cached_decode_frames(cache_key, ep_idx) - - # Bind the cached method to the instance - import types - ds._query_videos = types.MethodType(cached_query_videos, ds) - ds._cached_decode_frames = cached_decode_frames # Keep reference for cache info - - print(f"Added video frame caching with size {cache_size}") - - # Handle both single and multi datasets - if isinstance(dataset, MultiLeRobotDataset): - for ds in dataset._datasets: - instrument_dataset_caching(ds) - elif isinstance(dataset, LeRobotDataset): - instrument_dataset_caching(dataset) - else: - print(f"Warning: Unknown dataset type {type(dataset)}, skipping video caching") from termcolor import colored from torch.amp import GradScaler from torch.optim import Optimizer @@ -322,14 +195,6 @@ def train(cfg: TrainPipelineConfig): logging.info("RLearN: Setting video_backend to 'pyav' for better performance") dataset = make_dataset(cfg) - - # Add video decoding timing and caching for RLearN debugging - if getattr(cfg.policy, "type", None) == "rlearn": - _add_video_decoding_timing(dataset) - # Add frame caching for small datasets - if hasattr(dataset, 'num_frames') and dataset.num_frames < 1000: - _add_video_frame_caching(dataset, cache_size=500) - logging.info(f"RLearN: Added frame caching for {dataset.num_frames} frame dataset") # Create environment used for evaluating checkpoints during training on simulation data. # On real-world data, no need to create an environment as evaluations are done outside train.py, @@ -505,24 +370,6 @@ def train(cfg: TrainPipelineConfig): print(f" Data loading: {avg_data_loading:.2f} ms") print(f" Preprocessing: {avg_preprocessing:.2f} ms") print(f" Total data pipeline: {avg_data_loading + avg_preprocessing:.2f} ms") - - # Show video decoding breakdown if available - try: - ds = dataset._datasets[0] if hasattr(dataset, '_datasets') else dataset - if hasattr(ds, '_video_decode_timing'): - recent_decodes = ds._video_decode_timing.get('decode_times', []) - if recent_decodes: - avg_video_decode = sum(recent_decodes) / len(recent_decodes) - print(f" └─ Video decoding: ~{avg_video_decode:.2f} ms/call (included in data loading)") - - # Show cache hit rate if available - if hasattr(ds, '_cached_decode_frames'): - cache_info = ds._cached_decode_frames.cache_info() - hit_rate = cache_info.hits / max(cache_info.hits + cache_info.misses, 1) * 100 - print(f" └─ Cache hit rate: {hit_rate:.1f}% ({cache_info.hits}H/{cache_info.misses}M, size={cache_info.currsize})") - except Exception: - pass - print("-" * 50) # Reset stats for next minute