stats every minute

This commit is contained in:
Pepijn
2025-08-30 14:38:28 +02:00
parent bde397e891
commit 2507341a32
2 changed files with 117 additions and 21 deletions

View File

@@ -573,15 +573,55 @@ class RLearNPolicy(PreTrainedPolicy):
"timing_total_forward_ms": float(total_forward_time * 1000),
})
# Print detailed timing breakdown during training
# Collect timing statistics for averaged reporting every minute
if self.training:
print(f"RLearN Forward Pass Timing (B={B}, T_eff={T_eff}):")
print(f" Vision encoding: {vision_time*1000:.2f} ms")
print(f" Language encoding: {lang_time*1000:.2f} ms")
print(f" Transformer: {transformer_time*1000:.2f} ms")
print(f" Loss computation: {loss_time*1000:.2f} ms")
print(f" Total forward pass: {total_forward_time*1000:.2f} ms")
print(f" Throughput: {B*T_eff/(total_forward_time):.1f} frames/sec")
# Initialize timing accumulator if not exists
if not hasattr(self, '_timing_stats'):
self._timing_stats = {
'vision_times': [],
'language_times': [],
'transformer_times': [],
'loss_times': [],
'total_forward_times': [],
'throughputs': [],
'batch_sizes': [],
't_effs': [],
'last_print_time': time.perf_counter()
}
# Accumulate current step's timings
stats = self._timing_stats
stats['vision_times'].append(vision_time * 1000)
stats['language_times'].append(lang_time * 1000)
stats['transformer_times'].append(transformer_time * 1000)
stats['loss_times'].append(loss_time * 1000)
stats['total_forward_times'].append(total_forward_time * 1000)
stats['throughputs'].append(B * T_eff / total_forward_time)
stats['batch_sizes'].append(B)
stats['t_effs'].append(T_eff)
# Print averaged stats every minute (60 seconds)
current_time = time.perf_counter()
if current_time - stats['last_print_time'] >= 60.0:
n_samples = len(stats['vision_times'])
if n_samples > 0:
avg_b = sum(stats['batch_sizes']) / n_samples
avg_t_eff = sum(stats['t_effs']) / n_samples
print(f"\nRLearN Average Timing (last {n_samples} steps, avg B={avg_b:.1f}, avg T_eff={avg_t_eff:.1f}):")
print(f" Vision encoding: {sum(stats['vision_times'])/n_samples:.2f} ms")
print(f" Language encoding: {sum(stats['language_times'])/n_samples:.2f} ms")
print(f" Transformer: {sum(stats['transformer_times'])/n_samples:.2f} ms")
print(f" Loss computation: {sum(stats['loss_times'])/n_samples:.2f} ms")
print(f" Total forward pass: {sum(stats['total_forward_times'])/n_samples:.2f} ms")
print(f" Avg throughput: {sum(stats['throughputs'])/n_samples:.1f} frames/sec")
print("-" * 60)
# Reset stats for next minute
for key in stats:
if key != 'last_print_time':
stats[key] = []
stats['last_print_time'] = current_time
return total_loss, loss_dict

View File

@@ -112,15 +112,43 @@ def update_policy(
optim_time = time.perf_counter() - optim_start
total_time = time.perf_counter() - start_time
# Print detailed timing for RLearN policy
# Collect timing statistics for RLearN policy (averaged reporting every minute)
if getattr(policy, "name", None) == "rlearn":
print(f"Training Step Timing:")
print(f" Forward pass: {forward_time*1000:.2f} ms")
print(f" Backward pass: {backward_time*1000:.2f} ms")
print(f" Optimizer step: {optim_time*1000:.2f} ms")
print(f" Total update: {total_time*1000:.2f} ms")
print(f" Steps/sec: {1.0/total_time:.2f}")
print("-" * 40)
# Initialize timing accumulator if not exists
if not hasattr(policy, '_train_timing_stats'):
policy._train_timing_stats = {
'forward_times': [],
'backward_times': [],
'optim_times': [],
'total_times': [],
'last_print_time': time.perf_counter()
}
# Accumulate current step's timings
stats = policy._train_timing_stats
stats['forward_times'].append(forward_time * 1000)
stats['backward_times'].append(backward_time * 1000)
stats['optim_times'].append(optim_time * 1000)
stats['total_times'].append(total_time * 1000)
# Print averaged stats every minute (60 seconds)
current_time = time.perf_counter()
if current_time - stats['last_print_time'] >= 60.0:
n_samples = len(stats['forward_times'])
if n_samples > 0:
print(f"\nTraining Step Average Timing (last {n_samples} steps):")
print(f" Forward pass: {sum(stats['forward_times'])/n_samples:.2f} ms")
print(f" Backward pass: {sum(stats['backward_times'])/n_samples:.2f} ms")
print(f" Optimizer step: {sum(stats['optim_times'])/n_samples:.2f} ms")
print(f" Total update: {sum(stats['total_times'])/n_samples:.2f} ms")
print(f" Avg steps/sec: {1000.0/(sum(stats['total_times'])/n_samples):.2f}")
print("-" * 50)
# Reset stats for next minute
for key in stats:
if key != 'last_print_time':
stats[key] = []
stats['last_print_time'] = current_time
train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item()
@@ -296,12 +324,40 @@ def train(cfg: TrainPipelineConfig):
except Exception:
pass
# Print data loading timing for RLearN
# Collect data pipeline timing for RLearN (averaged reporting every minute)
if getattr(policy, "name", None) == "rlearn":
print(f"Data Pipeline Timing:")
print(f" Data loading: {data_loading_time*1000:.2f} ms")
print(f" Preprocessing: {preprocess_time*1000:.2f} ms")
print(f" Total data pipeline: {(data_loading_time + preprocess_time)*1000:.2f} ms")
# Initialize data timing accumulator if not exists
if not hasattr(policy, '_data_timing_stats'):
policy._data_timing_stats = {
'data_loading_times': [],
'preprocess_times': [],
'last_print_time': time.perf_counter()
}
# Accumulate current step's data timings
data_stats = policy._data_timing_stats
data_stats['data_loading_times'].append(data_loading_time * 1000)
data_stats['preprocess_times'].append(preprocess_time * 1000)
# Print averaged stats every minute (60 seconds)
current_time = time.perf_counter()
if current_time - data_stats['last_print_time'] >= 60.0:
n_samples = len(data_stats['data_loading_times'])
if n_samples > 0:
avg_data_loading = sum(data_stats['data_loading_times']) / n_samples
avg_preprocessing = sum(data_stats['preprocess_times']) / n_samples
print(f"\nData Pipeline Average Timing (last {n_samples} steps):")
print(f" Data loading: {avg_data_loading:.2f} ms")
print(f" Preprocessing: {avg_preprocessing:.2f} ms")
print(f" Total data pipeline: {avg_data_loading + avg_preprocessing:.2f} ms")
print("-" * 50)
# Reset stats for next minute
for key in data_stats:
if key != 'last_print_time':
data_stats[key] = []
data_stats['last_print_time'] = current_time
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.