add decode logging

This commit is contained in:
Pepijn
2025-08-30 16:00:55 +02:00
parent aed90c8042
commit 0be53ef3e1
4 changed files with 213 additions and 1 deletions

View File

@@ -86,6 +86,58 @@ def _add_video_decoding_timing(dataset):
instrument_dataset(dataset)
else:
print(f"Warning: Unknown dataset type {type(dataset)}, skipping video timing instrumentation")
def _add_video_frame_caching(dataset, cache_size=1000):
"""Add LRU caching to video decoding to avoid re-decoding the same frames."""
from functools import lru_cache
from lerobot.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
def instrument_dataset_caching(ds):
if not hasattr(ds, '_query_videos'):
return
# Store original method
original_query_videos = ds._query_videos
# Create cache key from timestamps and episode
def make_cache_key(query_timestamps, ep_idx):
# Convert to hashable tuple
key_parts = [ep_idx]
for vid_key in sorted(query_timestamps.keys()):
ts_tuple = tuple(round(ts, 6) for ts in query_timestamps[vid_key]) # Round to microsecond precision
key_parts.append((vid_key, ts_tuple))
return tuple(key_parts)
# Create LRU cached version
@lru_cache(maxsize=cache_size)
def cached_decode_frames(cache_key, ep_idx):
# Reconstruct query_timestamps from cache_key
query_timestamps = {}
for item in cache_key[1:]: # Skip ep_idx
vid_key, ts_tuple = item
query_timestamps[vid_key] = list(ts_tuple)
return original_query_videos(query_timestamps, ep_idx)
def cached_query_videos(self, query_timestamps, ep_idx):
cache_key = make_cache_key(query_timestamps, ep_idx)
return cached_decode_frames(cache_key, ep_idx)
# Bind the cached method to the instance
import types
ds._query_videos = types.MethodType(cached_query_videos, ds)
ds._cached_decode_frames = cached_decode_frames # Keep reference for cache info
print(f"Added video frame caching with size {cache_size}")
# Handle both single and multi datasets
if isinstance(dataset, MultiLeRobotDataset):
for ds in dataset._datasets:
instrument_dataset_caching(ds)
elif isinstance(dataset, LeRobotDataset):
instrument_dataset_caching(dataset)
else:
print(f"Warning: Unknown dataset type {type(dataset)}, skipping video caching")
from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer
@@ -243,7 +295,13 @@ def train(cfg: TrainPipelineConfig):
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("Creating dataset")
dataset = make_dataset(cfg)
# Pass video backend to dataset for RLearN optimization
dataset_kwargs = {}
if getattr(cfg.policy, "type", None) == "rlearn" and hasattr(cfg.policy, "video_backend"):
dataset_kwargs["video_backend"] = cfg.policy.video_backend
logging.info(f"Using video backend: {cfg.policy.video_backend}")
dataset = make_dataset(cfg, **dataset_kwargs)
# Add video decoding timing for RLearN debugging
if getattr(cfg.policy, "type", None) == "rlearn":
@@ -432,6 +490,12 @@ def train(cfg: TrainPipelineConfig):
if recent_decodes:
avg_video_decode = sum(recent_decodes) / len(recent_decodes)
print(f" └─ Video decoding: ~{avg_video_decode:.2f} ms/call (included in data loading)")
# Show cache hit rate if available
if hasattr(ds, '_cached_decode_frames'):
cache_info = ds._cached_decode_frames.cache_info()
hit_rate = cache_info.hits / max(cache_info.hits + cache_info.misses, 1) * 100
print(f" └─ Cache hit rate: {hit_rate:.1f}% ({cache_info.hits}H/{cache_info.misses}M, size={cache_info.currsize})")
except Exception:
pass