diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 1c29797e6..8715cb9b6 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -573,15 +573,55 @@ class RLearNPolicy(PreTrainedPolicy): "timing_total_forward_ms": float(total_forward_time * 1000), }) - # Print detailed timing breakdown during training + # Collect timing statistics for averaged reporting every minute if self.training: - print(f"RLearN Forward Pass Timing (B={B}, T_eff={T_eff}):") - print(f" Vision encoding: {vision_time*1000:.2f} ms") - print(f" Language encoding: {lang_time*1000:.2f} ms") - print(f" Transformer: {transformer_time*1000:.2f} ms") - print(f" Loss computation: {loss_time*1000:.2f} ms") - print(f" Total forward pass: {total_forward_time*1000:.2f} ms") - print(f" Throughput: {B*T_eff/(total_forward_time):.1f} frames/sec") + # Initialize timing accumulator if not exists + if not hasattr(self, '_timing_stats'): + self._timing_stats = { + 'vision_times': [], + 'language_times': [], + 'transformer_times': [], + 'loss_times': [], + 'total_forward_times': [], + 'throughputs': [], + 'batch_sizes': [], + 't_effs': [], + 'last_print_time': time.perf_counter() + } + + # Accumulate current step's timings + stats = self._timing_stats + stats['vision_times'].append(vision_time * 1000) + stats['language_times'].append(lang_time * 1000) + stats['transformer_times'].append(transformer_time * 1000) + stats['loss_times'].append(loss_time * 1000) + stats['total_forward_times'].append(total_forward_time * 1000) + stats['throughputs'].append(B * T_eff / total_forward_time) + stats['batch_sizes'].append(B) + stats['t_effs'].append(T_eff) + + # Print averaged stats every minute (60 seconds) + current_time = time.perf_counter() + if current_time - stats['last_print_time'] >= 60.0: + n_samples = len(stats['vision_times']) + if n_samples > 0: + avg_b = sum(stats['batch_sizes']) / n_samples + avg_t_eff = sum(stats['t_effs']) / n_samples + + print(f"\nRLearN Average Timing (last {n_samples} steps, avg B={avg_b:.1f}, avg T_eff={avg_t_eff:.1f}):") + print(f" Vision encoding: {sum(stats['vision_times'])/n_samples:.2f} ms") + print(f" Language encoding: {sum(stats['language_times'])/n_samples:.2f} ms") + print(f" Transformer: {sum(stats['transformer_times'])/n_samples:.2f} ms") + print(f" Loss computation: {sum(stats['loss_times'])/n_samples:.2f} ms") + print(f" Total forward pass: {sum(stats['total_forward_times'])/n_samples:.2f} ms") + print(f" Avg throughput: {sum(stats['throughputs'])/n_samples:.1f} frames/sec") + print("-" * 60) + + # Reset stats for next minute + for key in stats: + if key != 'last_print_time': + stats[key] = [] + stats['last_print_time'] = current_time return total_loss, loss_dict diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 26e531ba5..292990c0c 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -112,15 +112,43 @@ def update_policy( optim_time = time.perf_counter() - optim_start total_time = time.perf_counter() - start_time - # Print detailed timing for RLearN policy + # Collect timing statistics for RLearN policy (averaged reporting every minute) if getattr(policy, "name", None) == "rlearn": - print(f"Training Step Timing:") - print(f" Forward pass: {forward_time*1000:.2f} ms") - print(f" Backward pass: {backward_time*1000:.2f} ms") - print(f" Optimizer step: {optim_time*1000:.2f} ms") - print(f" Total update: {total_time*1000:.2f} ms") - print(f" Steps/sec: {1.0/total_time:.2f}") - print("-" * 40) + # Initialize timing accumulator if not exists + if not hasattr(policy, '_train_timing_stats'): + policy._train_timing_stats = { + 'forward_times': [], + 'backward_times': [], + 'optim_times': [], + 'total_times': [], + 'last_print_time': time.perf_counter() + } + + # Accumulate current step's timings + stats = policy._train_timing_stats + stats['forward_times'].append(forward_time * 1000) + stats['backward_times'].append(backward_time * 1000) + stats['optim_times'].append(optim_time * 1000) + stats['total_times'].append(total_time * 1000) + + # Print averaged stats every minute (60 seconds) + current_time = time.perf_counter() + if current_time - stats['last_print_time'] >= 60.0: + n_samples = len(stats['forward_times']) + if n_samples > 0: + print(f"\nTraining Step Average Timing (last {n_samples} steps):") + print(f" Forward pass: {sum(stats['forward_times'])/n_samples:.2f} ms") + print(f" Backward pass: {sum(stats['backward_times'])/n_samples:.2f} ms") + print(f" Optimizer step: {sum(stats['optim_times'])/n_samples:.2f} ms") + print(f" Total update: {sum(stats['total_times'])/n_samples:.2f} ms") + print(f" Avg steps/sec: {1000.0/(sum(stats['total_times'])/n_samples):.2f}") + print("-" * 50) + + # Reset stats for next minute + for key in stats: + if key != 'last_print_time': + stats[key] = [] + stats['last_print_time'] = current_time train_metrics.loss = loss.item() train_metrics.grad_norm = grad_norm.item() @@ -296,12 +324,40 @@ def train(cfg: TrainPipelineConfig): except Exception: pass - # Print data loading timing for RLearN + # Collect data pipeline timing for RLearN (averaged reporting every minute) if getattr(policy, "name", None) == "rlearn": - print(f"Data Pipeline Timing:") - print(f" Data loading: {data_loading_time*1000:.2f} ms") - print(f" Preprocessing: {preprocess_time*1000:.2f} ms") - print(f" Total data pipeline: {(data_loading_time + preprocess_time)*1000:.2f} ms") + # Initialize data timing accumulator if not exists + if not hasattr(policy, '_data_timing_stats'): + policy._data_timing_stats = { + 'data_loading_times': [], + 'preprocess_times': [], + 'last_print_time': time.perf_counter() + } + + # Accumulate current step's data timings + data_stats = policy._data_timing_stats + data_stats['data_loading_times'].append(data_loading_time * 1000) + data_stats['preprocess_times'].append(preprocess_time * 1000) + + # Print averaged stats every minute (60 seconds) + current_time = time.perf_counter() + if current_time - data_stats['last_print_time'] >= 60.0: + n_samples = len(data_stats['data_loading_times']) + if n_samples > 0: + avg_data_loading = sum(data_stats['data_loading_times']) / n_samples + avg_preprocessing = sum(data_stats['preprocess_times']) / n_samples + + print(f"\nData Pipeline Average Timing (last {n_samples} steps):") + print(f" Data loading: {avg_data_loading:.2f} ms") + print(f" Preprocessing: {avg_preprocessing:.2f} ms") + print(f" Total data pipeline: {avg_data_loading + avg_preprocessing:.2f} ms") + print("-" * 50) + + # Reset stats for next minute + for key in data_stats: + if key != 'last_print_time': + data_stats[key] = [] + data_stats['last_print_time'] = current_time # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # increment `step` here.