mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
stats every minute
This commit is contained in:
@@ -112,15 +112,43 @@ def update_policy(
|
||||
optim_time = time.perf_counter() - optim_start
|
||||
total_time = time.perf_counter() - start_time
|
||||
|
||||
# Print detailed timing for RLearN policy
|
||||
# Collect timing statistics for RLearN policy (averaged reporting every minute)
|
||||
if getattr(policy, "name", None) == "rlearn":
|
||||
print(f"Training Step Timing:")
|
||||
print(f" Forward pass: {forward_time*1000:.2f} ms")
|
||||
print(f" Backward pass: {backward_time*1000:.2f} ms")
|
||||
print(f" Optimizer step: {optim_time*1000:.2f} ms")
|
||||
print(f" Total update: {total_time*1000:.2f} ms")
|
||||
print(f" Steps/sec: {1.0/total_time:.2f}")
|
||||
print("-" * 40)
|
||||
# Initialize timing accumulator if not exists
|
||||
if not hasattr(policy, '_train_timing_stats'):
|
||||
policy._train_timing_stats = {
|
||||
'forward_times': [],
|
||||
'backward_times': [],
|
||||
'optim_times': [],
|
||||
'total_times': [],
|
||||
'last_print_time': time.perf_counter()
|
||||
}
|
||||
|
||||
# Accumulate current step's timings
|
||||
stats = policy._train_timing_stats
|
||||
stats['forward_times'].append(forward_time * 1000)
|
||||
stats['backward_times'].append(backward_time * 1000)
|
||||
stats['optim_times'].append(optim_time * 1000)
|
||||
stats['total_times'].append(total_time * 1000)
|
||||
|
||||
# Print averaged stats every minute (60 seconds)
|
||||
current_time = time.perf_counter()
|
||||
if current_time - stats['last_print_time'] >= 60.0:
|
||||
n_samples = len(stats['forward_times'])
|
||||
if n_samples > 0:
|
||||
print(f"\nTraining Step Average Timing (last {n_samples} steps):")
|
||||
print(f" Forward pass: {sum(stats['forward_times'])/n_samples:.2f} ms")
|
||||
print(f" Backward pass: {sum(stats['backward_times'])/n_samples:.2f} ms")
|
||||
print(f" Optimizer step: {sum(stats['optim_times'])/n_samples:.2f} ms")
|
||||
print(f" Total update: {sum(stats['total_times'])/n_samples:.2f} ms")
|
||||
print(f" Avg steps/sec: {1000.0/(sum(stats['total_times'])/n_samples):.2f}")
|
||||
print("-" * 50)
|
||||
|
||||
# Reset stats for next minute
|
||||
for key in stats:
|
||||
if key != 'last_print_time':
|
||||
stats[key] = []
|
||||
stats['last_print_time'] = current_time
|
||||
|
||||
train_metrics.loss = loss.item()
|
||||
train_metrics.grad_norm = grad_norm.item()
|
||||
@@ -296,12 +324,40 @@ def train(cfg: TrainPipelineConfig):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Print data loading timing for RLearN
|
||||
# Collect data pipeline timing for RLearN (averaged reporting every minute)
|
||||
if getattr(policy, "name", None) == "rlearn":
|
||||
print(f"Data Pipeline Timing:")
|
||||
print(f" Data loading: {data_loading_time*1000:.2f} ms")
|
||||
print(f" Preprocessing: {preprocess_time*1000:.2f} ms")
|
||||
print(f" Total data pipeline: {(data_loading_time + preprocess_time)*1000:.2f} ms")
|
||||
# Initialize data timing accumulator if not exists
|
||||
if not hasattr(policy, '_data_timing_stats'):
|
||||
policy._data_timing_stats = {
|
||||
'data_loading_times': [],
|
||||
'preprocess_times': [],
|
||||
'last_print_time': time.perf_counter()
|
||||
}
|
||||
|
||||
# Accumulate current step's data timings
|
||||
data_stats = policy._data_timing_stats
|
||||
data_stats['data_loading_times'].append(data_loading_time * 1000)
|
||||
data_stats['preprocess_times'].append(preprocess_time * 1000)
|
||||
|
||||
# Print averaged stats every minute (60 seconds)
|
||||
current_time = time.perf_counter()
|
||||
if current_time - data_stats['last_print_time'] >= 60.0:
|
||||
n_samples = len(data_stats['data_loading_times'])
|
||||
if n_samples > 0:
|
||||
avg_data_loading = sum(data_stats['data_loading_times']) / n_samples
|
||||
avg_preprocessing = sum(data_stats['preprocess_times']) / n_samples
|
||||
|
||||
print(f"\nData Pipeline Average Timing (last {n_samples} steps):")
|
||||
print(f" Data loading: {avg_data_loading:.2f} ms")
|
||||
print(f" Preprocessing: {avg_preprocessing:.2f} ms")
|
||||
print(f" Total data pipeline: {avg_data_loading + avg_preprocessing:.2f} ms")
|
||||
print("-" * 50)
|
||||
|
||||
# Reset stats for next minute
|
||||
for key in data_stats:
|
||||
if key != 'last_print_time':
|
||||
data_stats[key] = []
|
||||
data_stats['last_print_time'] = current_time
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
|
||||
Reference in New Issue
Block a user