stats every minute

This commit is contained in:
Pepijn
2025-08-30 14:38:28 +02:00
parent bde397e891
commit 2507341a32
2 changed files with 117 additions and 21 deletions

View File

@@ -112,15 +112,43 @@ def update_policy(
optim_time = time.perf_counter() - optim_start
total_time = time.perf_counter() - start_time
# Print detailed timing for RLearN policy
# Collect timing statistics for RLearN policy (averaged reporting every minute)
if getattr(policy, "name", None) == "rlearn":
print(f"Training Step Timing:")
print(f" Forward pass: {forward_time*1000:.2f} ms")
print(f" Backward pass: {backward_time*1000:.2f} ms")
print(f" Optimizer step: {optim_time*1000:.2f} ms")
print(f" Total update: {total_time*1000:.2f} ms")
print(f" Steps/sec: {1.0/total_time:.2f}")
print("-" * 40)
# Initialize timing accumulator if not exists
if not hasattr(policy, '_train_timing_stats'):
policy._train_timing_stats = {
'forward_times': [],
'backward_times': [],
'optim_times': [],
'total_times': [],
'last_print_time': time.perf_counter()
}
# Accumulate current step's timings
stats = policy._train_timing_stats
stats['forward_times'].append(forward_time * 1000)
stats['backward_times'].append(backward_time * 1000)
stats['optim_times'].append(optim_time * 1000)
stats['total_times'].append(total_time * 1000)
# Print averaged stats every minute (60 seconds)
current_time = time.perf_counter()
if current_time - stats['last_print_time'] >= 60.0:
n_samples = len(stats['forward_times'])
if n_samples > 0:
print(f"\nTraining Step Average Timing (last {n_samples} steps):")
print(f" Forward pass: {sum(stats['forward_times'])/n_samples:.2f} ms")
print(f" Backward pass: {sum(stats['backward_times'])/n_samples:.2f} ms")
print(f" Optimizer step: {sum(stats['optim_times'])/n_samples:.2f} ms")
print(f" Total update: {sum(stats['total_times'])/n_samples:.2f} ms")
print(f" Avg steps/sec: {1000.0/(sum(stats['total_times'])/n_samples):.2f}")
print("-" * 50)
# Reset stats for next minute
for key in stats:
if key != 'last_print_time':
stats[key] = []
stats['last_print_time'] = current_time
train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item()
@@ -296,12 +324,40 @@ def train(cfg: TrainPipelineConfig):
except Exception:
pass
# Print data loading timing for RLearN
# Collect data pipeline timing for RLearN (averaged reporting every minute)
if getattr(policy, "name", None) == "rlearn":
print(f"Data Pipeline Timing:")
print(f" Data loading: {data_loading_time*1000:.2f} ms")
print(f" Preprocessing: {preprocess_time*1000:.2f} ms")
print(f" Total data pipeline: {(data_loading_time + preprocess_time)*1000:.2f} ms")
# Initialize data timing accumulator if not exists
if not hasattr(policy, '_data_timing_stats'):
policy._data_timing_stats = {
'data_loading_times': [],
'preprocess_times': [],
'last_print_time': time.perf_counter()
}
# Accumulate current step's data timings
data_stats = policy._data_timing_stats
data_stats['data_loading_times'].append(data_loading_time * 1000)
data_stats['preprocess_times'].append(preprocess_time * 1000)
# Print averaged stats every minute (60 seconds)
current_time = time.perf_counter()
if current_time - data_stats['last_print_time'] >= 60.0:
n_samples = len(data_stats['data_loading_times'])
if n_samples > 0:
avg_data_loading = sum(data_stats['data_loading_times']) / n_samples
avg_preprocessing = sum(data_stats['preprocess_times']) / n_samples
print(f"\nData Pipeline Average Timing (last {n_samples} steps):")
print(f" Data loading: {avg_data_loading:.2f} ms")
print(f" Preprocessing: {avg_preprocessing:.2f} ms")
print(f" Total data pipeline: {avg_data_loading + avg_preprocessing:.2f} ms")
print("-" * 50)
# Reset stats for next minute
for key in data_stats:
if key != 'last_print_time':
data_stats[key] = []
data_stats['last_print_time'] = current_time
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.