mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
remove decode logging
This commit is contained in:
@@ -25,133 +25,6 @@ import torch
|
||||
# Fix tokenizer parallelism conflicts with multiprocessing
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
def _add_video_decoding_timing(dataset):
|
||||
"""Add timing instrumentation to video decoding for debugging."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
|
||||
def instrument_dataset(ds):
|
||||
if not hasattr(ds, '_query_videos'):
|
||||
return
|
||||
|
||||
# Store original method
|
||||
original_query_videos = ds._query_videos
|
||||
|
||||
# Initialize timing stats
|
||||
if not hasattr(ds, '_video_decode_timing'):
|
||||
ds._video_decode_timing = {
|
||||
'decode_times': [],
|
||||
'last_print_time': time.perf_counter()
|
||||
}
|
||||
|
||||
def timed_query_videos(self, query_timestamps, ep_idx):
|
||||
# Debug: print what backend is being used
|
||||
if not hasattr(self, '_backend_logged'):
|
||||
print(f"DEBUG: Video backend in use: {getattr(self, 'video_backend', 'UNKNOWN')}")
|
||||
self._backend_logged = True
|
||||
|
||||
decode_start = time.perf_counter()
|
||||
result = original_query_videos(query_timestamps, ep_idx)
|
||||
decode_time = time.perf_counter() - decode_start
|
||||
|
||||
# Debug problematic 0.5 frames issue
|
||||
actual_frames = 0
|
||||
for key in query_timestamps:
|
||||
actual_frames += len(query_timestamps[key])
|
||||
|
||||
# Accumulate timing
|
||||
timing_stats = self._video_decode_timing
|
||||
timing_stats['decode_times'].append(decode_time * 1000) # Convert to ms
|
||||
timing_stats['actual_frame_counts'] = timing_stats.get('actual_frame_counts', [])
|
||||
timing_stats['actual_frame_counts'].append(actual_frames)
|
||||
|
||||
# Print averaged stats every minute
|
||||
current_time = time.perf_counter()
|
||||
if current_time - timing_stats['last_print_time'] >= 60.0:
|
||||
n_samples = len(timing_stats['decode_times'])
|
||||
if n_samples > 0:
|
||||
avg_decode_time = sum(timing_stats['decode_times']) / n_samples
|
||||
# Use actual frame counts tracked per call
|
||||
actual_counts = timing_stats.get('actual_frame_counts', [])
|
||||
avg_frames_per_call = sum(actual_counts) / len(actual_counts) if actual_counts else 0
|
||||
|
||||
print(f"\nVideo Decoding Timing (last {n_samples} calls):")
|
||||
print(f" Avg decode time: {avg_decode_time:.2f} ms")
|
||||
print(f" Avg frames/call: {avg_frames_per_call:.1f}")
|
||||
print(f" Time per frame: {avg_decode_time/max(avg_frames_per_call, 1):.2f} ms/frame")
|
||||
print("-" * 50)
|
||||
|
||||
# Reset stats
|
||||
timing_stats['decode_times'] = []
|
||||
timing_stats['actual_frame_counts'] = []
|
||||
timing_stats['last_print_time'] = current_time
|
||||
|
||||
return result
|
||||
|
||||
# Bind the method to the instance
|
||||
import types
|
||||
ds._query_videos = types.MethodType(timed_query_videos, ds)
|
||||
|
||||
# Handle both single and multi datasets
|
||||
if isinstance(dataset, MultiLeRobotDataset):
|
||||
for ds in dataset._datasets:
|
||||
instrument_dataset(ds)
|
||||
elif isinstance(dataset, LeRobotDataset):
|
||||
instrument_dataset(dataset)
|
||||
else:
|
||||
print(f"Warning: Unknown dataset type {type(dataset)}, skipping video timing instrumentation")
|
||||
|
||||
|
||||
def _add_video_frame_caching(dataset, cache_size=1000):
|
||||
"""Add LRU caching to video decoding to avoid re-decoding the same frames."""
|
||||
from functools import lru_cache
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
|
||||
def instrument_dataset_caching(ds):
|
||||
if not hasattr(ds, '_query_videos'):
|
||||
return
|
||||
|
||||
# Store original method
|
||||
original_query_videos = ds._query_videos
|
||||
|
||||
# Create cache key from timestamps and episode
|
||||
def make_cache_key(query_timestamps, ep_idx):
|
||||
# Convert to hashable tuple
|
||||
key_parts = [ep_idx]
|
||||
for vid_key in sorted(query_timestamps.keys()):
|
||||
ts_tuple = tuple(round(ts, 6) for ts in query_timestamps[vid_key]) # Round to microsecond precision
|
||||
key_parts.append((vid_key, ts_tuple))
|
||||
return tuple(key_parts)
|
||||
|
||||
# Create LRU cached version
|
||||
@lru_cache(maxsize=cache_size)
|
||||
def cached_decode_frames(cache_key, ep_idx):
|
||||
# Reconstruct query_timestamps from cache_key
|
||||
query_timestamps = {}
|
||||
for item in cache_key[1:]: # Skip ep_idx
|
||||
vid_key, ts_tuple = item
|
||||
query_timestamps[vid_key] = list(ts_tuple)
|
||||
return original_query_videos(query_timestamps, ep_idx)
|
||||
|
||||
def cached_query_videos(self, query_timestamps, ep_idx):
|
||||
cache_key = make_cache_key(query_timestamps, ep_idx)
|
||||
return cached_decode_frames(cache_key, ep_idx)
|
||||
|
||||
# Bind the cached method to the instance
|
||||
import types
|
||||
ds._query_videos = types.MethodType(cached_query_videos, ds)
|
||||
ds._cached_decode_frames = cached_decode_frames # Keep reference for cache info
|
||||
|
||||
print(f"Added video frame caching with size {cache_size}")
|
||||
|
||||
# Handle both single and multi datasets
|
||||
if isinstance(dataset, MultiLeRobotDataset):
|
||||
for ds in dataset._datasets:
|
||||
instrument_dataset_caching(ds)
|
||||
elif isinstance(dataset, LeRobotDataset):
|
||||
instrument_dataset_caching(dataset)
|
||||
else:
|
||||
print(f"Warning: Unknown dataset type {type(dataset)}, skipping video caching")
|
||||
from termcolor import colored
|
||||
from torch.amp import GradScaler
|
||||
from torch.optim import Optimizer
|
||||
@@ -322,14 +195,6 @@ def train(cfg: TrainPipelineConfig):
|
||||
logging.info("RLearN: Setting video_backend to 'pyav' for better performance")
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Add video decoding timing and caching for RLearN debugging
|
||||
if getattr(cfg.policy, "type", None) == "rlearn":
|
||||
_add_video_decoding_timing(dataset)
|
||||
# Add frame caching for small datasets
|
||||
if hasattr(dataset, 'num_frames') and dataset.num_frames < 1000:
|
||||
_add_video_frame_caching(dataset, cache_size=500)
|
||||
logging.info(f"RLearN: Added frame caching for {dataset.num_frames} frame dataset")
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
@@ -505,24 +370,6 @@ def train(cfg: TrainPipelineConfig):
|
||||
print(f" Data loading: {avg_data_loading:.2f} ms")
|
||||
print(f" Preprocessing: {avg_preprocessing:.2f} ms")
|
||||
print(f" Total data pipeline: {avg_data_loading + avg_preprocessing:.2f} ms")
|
||||
|
||||
# Show video decoding breakdown if available
|
||||
try:
|
||||
ds = dataset._datasets[0] if hasattr(dataset, '_datasets') else dataset
|
||||
if hasattr(ds, '_video_decode_timing'):
|
||||
recent_decodes = ds._video_decode_timing.get('decode_times', [])
|
||||
if recent_decodes:
|
||||
avg_video_decode = sum(recent_decodes) / len(recent_decodes)
|
||||
print(f" └─ Video decoding: ~{avg_video_decode:.2f} ms/call (included in data loading)")
|
||||
|
||||
# Show cache hit rate if available
|
||||
if hasattr(ds, '_cached_decode_frames'):
|
||||
cache_info = ds._cached_decode_frames.cache_info()
|
||||
hit_rate = cache_info.hits / max(cache_info.hits + cache_info.misses, 1) * 100
|
||||
print(f" └─ Cache hit rate: {hit_rate:.1f}% ({cache_info.hits}H/{cache_info.misses}M, size={cache_info.currsize})")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print("-" * 50)
|
||||
|
||||
# Reset stats for next minute
|
||||
|
||||
Reference in New Issue
Block a user