From b1e16783de629e7f191ec89191516b8bfbd8186c Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 16 Apr 2026 16:00:49 +0200 Subject: [PATCH] refactor: extract profiling into self-contained TrainingProfiler class MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move all profiling orchestration out of lerobot_train.py and TrainPipelineConfig into a TrainingProfiler class in profiling_utils.py. - lerobot_train.py: ~74 lines of profiling code reduced to ~7 call sites - TrainPipelineConfig: 10 profile_* fields reduced to 2 (mode + output_dir) - update_policy: reverted to clean main-branch signature (no timing_collector) - TrainingProfiler encapsulates torch profiler, timing collection, deterministic forward artifacts, and all output writing - CI script (run_model_profiling.py) unchanged—it only passes the 2 kept fields Made-with: Cursor --- src/lerobot/configs/train.py | 12 --- src/lerobot/scripts/lerobot_train.py | 85 +++------------ src/lerobot/utils/profiling_utils.py | 150 +++++++++++++++++++++----- tests/scripts/test_model_profiling.py | 39 +++---- 4 files changed, 148 insertions(+), 138 deletions(-) diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index e4a5b7fb6..793f495d1 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -57,15 +57,7 @@ class TrainPipelineConfig(HubMixin): num_workers: int = 4 batch_size: int = 8 profile_mode: str = "off" - profile_wait_steps: int = 1 - profile_warmup_steps: int = 2 - profile_active_steps: int = 6 - profile_repeat: int = 1 profile_output_dir: Path | None = None - profile_record_shapes: bool = True - profile_with_memory: bool = True - profile_with_flops: bool = True - profile_with_stack: bool = False steps: int = 100_000 eval_freq: int = 20_000 log_freq: int = 200 @@ -147,10 +139,6 @@ class TrainPipelineConfig(HubMixin): raise ValueError( f"`profile_mode` must be one of 'off', 'summary', or 'trace', got {self.profile_mode}." ) - if self.profile_wait_steps < 0 or self.profile_warmup_steps < 0 or self.profile_active_steps < 0: - raise ValueError("Profiler schedule steps must be non-negative.") - if self.profile_repeat <= 0: - raise ValueError("`profile_repeat` must be strictly positive.") if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None): raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.") diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 7202ddea4..392bfb51f 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -22,7 +22,6 @@ import dataclasses import logging import time from contextlib import nullcontext -from pathlib import Path from pprint import pformat from typing import TYPE_CHECKING, Any @@ -50,13 +49,7 @@ from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.logging_utils import AverageMeter, MetricsTracker -from lerobot.utils.profiling_utils import ( - StepTimingCollector, - ensure_dir, - make_torch_profiler, - write_deterministic_forward_artifacts, - write_torch_profiler_outputs, -) +from lerobot.utils.profiling_utils import TrainingProfiler from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( cycle, @@ -79,7 +72,6 @@ def update_policy( lr_scheduler=None, lock=None, rabc_weights_provider=None, - timing_collector: StepTimingCollector | None = None, ) -> tuple[MetricsTracker, dict]: """ Performs a single training step to update the policy's weights. @@ -113,7 +105,6 @@ def update_policy( rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch) # Let accelerator handle mixed precision - forward_start = time.perf_counter() with accelerator.autocast(): # Use per-sample loss when RA-BC is enabled for proper weighting if rabc_batch_weights is not None: @@ -132,15 +123,11 @@ def update_policy( loss, output_dict = policy.forward(batch) # TODO(rcadene): policy.unnormalize_outputs(out_dict) - forward_s = time.perf_counter() - forward_start # Use accelerator's backward method - backward_start = time.perf_counter() accelerator.backward(loss) - backward_s = time.perf_counter() - backward_start # Clip gradients if specified - optimizer_start = time.perf_counter() if grad_clip_norm > 0: grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm) else: @@ -161,19 +148,11 @@ def update_policy( # Update internal buffers if policy has update method if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update() - optimizer_s = time.perf_counter() - optimizer_start train_metrics.loss = loss.item() train_metrics.grad_norm = grad_norm.item() train_metrics.lr = optimizer.param_groups[0]["lr"] train_metrics.update_s = time.perf_counter() - start_time - if timing_collector is not None: - timing_collector.record( - forward_s=forward_s, - backward_s=backward_s, - optimizer_s=optimizer_s, - total_update_s=train_metrics.update_s.val, - ) return train_metrics, output_dict @@ -228,12 +207,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): if is_main_process: logging.info(pformat(cfg.to_dict())) - profiling_enabled = cfg.profile_mode != "off" - profile_output_dir = None - if profiling_enabled and is_main_process and cfg.profile_output_dir is not None: - profile_output_dir = ensure_dir(Path(cfg.profile_output_dir)) - logging.info("Profiling enabled. Artifacts will be written to %s", profile_output_dir) - # Initialize wandb only on main process if cfg.wandb.enable and cfg.wandb.project and is_main_process: wandb_logger = WandBLogger(cfg) @@ -344,15 +317,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) - if profiling_enabled and is_main_process and profile_output_dir is not None: - logging.info("Recording deterministic forward-pass artifacts") - write_deterministic_forward_artifacts( - policy=policy, - dataset=dataset, - batch_size=cfg.batch_size, - preprocessor=preprocessor, - output_dir=profile_output_dir, - device_type=device.type, + profiler = ( + TrainingProfiler.from_cfg(cfg, device) if cfg.profile_mode != "off" and is_main_process else None + ) + if profiler: + profiler.record_deterministic_forward( + policy=policy, dataset=dataset, batch_size=cfg.batch_size, preprocessor=preprocessor ) # Load precomputed SARM progress for RA-BC if enabled @@ -468,16 +438,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): logging.info( f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}" ) - timing_collector = StepTimingCollector() if profiling_enabled and is_main_process else None - profiler = None - profiler_context = nullcontext() - if profiling_enabled and is_main_process and profile_output_dir is not None: - if device.type == "cuda": - torch.cuda.reset_peak_memory_stats(device) - profiler = make_torch_profiler(cfg, profile_output_dir, device.type) - profiler_context = profiler - - with profiler_context: + with profiler or nullcontext(): for _ in range(step, cfg.steps): start_time = time.perf_counter() batch = next(dl_iter) @@ -493,7 +454,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): accelerator=accelerator, lr_scheduler=lr_scheduler, rabc_weights_provider=rabc_weights, - timing_collector=timing_collector, ) # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we @@ -501,17 +461,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): step += 1 if is_main_process: progbar.update(1) - if timing_collector is not None: - timing_collector.record_dataloading(train_tracker.dataloading_s.val) - if device.type == "cuda": - timing_collector.record_memory( - step=step, - allocated_bytes=torch.cuda.memory_allocated(device), - reserved_bytes=torch.cuda.memory_reserved(device), - ) + if profiler: + profiler.step(step, train_tracker) train_tracker.step() - if profiler is not None: - profiler.step() is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 @@ -606,21 +558,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): if is_main_process: progbar.close() - if timing_collector is not None and profile_output_dir is not None: - extra_profile_metrics = { - "profile_mode": cfg.profile_mode, - "peak_memory_allocated_bytes": ( - torch.cuda.max_memory_allocated(device) if device.type == "cuda" else None - ), - "peak_memory_reserved_bytes": ( - torch.cuda.max_memory_reserved(device) if device.type == "cuda" else None - ), - } - timing_collector.write_json( - profile_output_dir / "step_timing_summary.json", extra=extra_profile_metrics - ) - if profiler is not None and profile_output_dir is not None: - write_torch_profiler_outputs(profiler, profile_output_dir, device_type=device.type) + if profiler: + profiler.finalize() if eval_env: close_envs(eval_env) diff --git a/src/lerobot/utils/profiling_utils.py b/src/lerobot/utils/profiling_utils.py index 69e558356..5cfc66f3b 100644 --- a/src/lerobot/utils/profiling_utils.py +++ b/src/lerobot/utils/profiling_utils.py @@ -18,6 +18,7 @@ from __future__ import annotations import hashlib import json +import logging import statistics from dataclasses import dataclass, field from numbers import Real @@ -47,7 +48,20 @@ def write_profiler_table( output_path.write_text(table) -def make_torch_profiler(cfg: Any, output_dir: Path, device_type: str) -> Any: +def _make_torch_profiler( + *, + mode: str, + output_dir: Path, + device_type: str, + wait_steps: int = 1, + warmup_steps: int = 2, + active_steps: int = 6, + repeat: int = 1, + record_shapes: bool = True, + with_memory: bool = True, + with_flops: bool = True, + with_stack: bool = False, +) -> Any: activities = [torch.profiler.ProfilerActivity.CPU] if device_type == "cuda": activities.append(torch.profiler.ProfilerActivity.CUDA) @@ -55,23 +69,23 @@ def make_torch_profiler(cfg: Any, output_dir: Path, device_type: str) -> Any: trace_dir = ensure_dir(output_dir / "torch_traces") def _trace_ready(profiler: Any) -> None: - if cfg.profile_mode != "trace": + if mode != "trace": return profiler.export_chrome_trace(str(trace_dir / f"trace_step_{profiler.step_num}.json")) return torch.profiler.profile( activities=activities, schedule=torch.profiler.schedule( - wait=cfg.profile_wait_steps, - warmup=cfg.profile_warmup_steps, - active=cfg.profile_active_steps, - repeat=cfg.profile_repeat, + wait=wait_steps, + warmup=warmup_steps, + active=active_steps, + repeat=repeat, ), on_trace_ready=_trace_ready, - record_shapes=cfg.profile_record_shapes, - profile_memory=cfg.profile_with_memory, - with_flops=cfg.profile_with_flops, - with_stack=cfg.profile_with_stack, + record_shapes=record_shapes, + profile_memory=with_memory, + with_flops=with_flops, + with_stack=with_stack, ) @@ -228,25 +242,12 @@ def _as_float(value: Any) -> float: @dataclass -class StepTimingCollector: - forward_s: list[float] = field(default_factory=list) - backward_s: list[float] = field(default_factory=list) - optimizer_s: list[float] = field(default_factory=list) +class _StepTimingCollector: total_update_s: list[float] = field(default_factory=list) dataloading_s: list[float] = field(default_factory=list) memory_timeline: list[dict[str, float | int]] = field(default_factory=list) - def record( - self, - *, - forward_s: float, - backward_s: float, - optimizer_s: float, - total_update_s: float, - ) -> None: - self.forward_s.append(_as_float(forward_s)) - self.backward_s.append(_as_float(backward_s)) - self.optimizer_s.append(_as_float(optimizer_s)) + def record_step(self, total_update_s: float) -> None: self.total_update_s.append(_as_float(total_update_s)) def record_dataloading(self, dataloading_s: float) -> None: @@ -263,9 +264,6 @@ class StepTimingCollector: def to_dict(self) -> dict[str, Any]: return { - "forward_s": _summary(self.forward_s), - "backward_s": _summary(self.backward_s), - "optimizer_s": _summary(self.optimizer_s), "total_update_s": _summary(self.total_update_s), "dataloading_s": _summary(self.dataloading_s), "memory_timeline": self.memory_timeline, @@ -277,3 +275,99 @@ class StepTimingCollector: payload.update(extra) output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text(json.dumps(payload, indent=2, sort_keys=True)) + + +class TrainingProfiler: + """Self-contained profiling orchestrator for the training loop. + + Encapsulates torch profiler setup, step-level timing collection, deterministic + forward-pass artifact recording, and all output writing. The training script + interacts with it through a thin interface (~7 lines). + """ + + def __init__( + self, + mode: str, + output_dir: Path, + device: torch.device, + *, + wait_steps: int = 1, + warmup_steps: int = 2, + active_steps: int = 6, + repeat: int = 1, + record_shapes: bool = True, + with_memory: bool = True, + with_flops: bool = True, + with_stack: bool = False, + ) -> None: + self._mode = mode + self._output_dir = ensure_dir(output_dir) + self._device = device + self._timing = _StepTimingCollector() + self._torch_profiler = _make_torch_profiler( + mode=mode, + output_dir=output_dir, + device_type=device.type, + wait_steps=wait_steps, + warmup_steps=warmup_steps, + active_steps=active_steps, + repeat=repeat, + record_shapes=record_shapes, + with_memory=with_memory, + with_flops=with_flops, + with_stack=with_stack, + ) + logging.info("Profiling enabled. Artifacts will be written to %s", output_dir) + + @classmethod + def from_cfg(cls, cfg: Any, device: torch.device) -> TrainingProfiler: + output_dir = cfg.profile_output_dir + if output_dir is None: + output_dir = Path(cfg.output_dir) / "profiling" + return cls(mode=cfg.profile_mode, output_dir=Path(output_dir), device=device) + + def record_deterministic_forward( + self, + *, + policy: Any, + dataset: Any, + batch_size: int, + preprocessor: Any, + ) -> None: + logging.info("Recording deterministic forward-pass artifacts") + write_deterministic_forward_artifacts( + policy=policy, + dataset=dataset, + batch_size=batch_size, + preprocessor=preprocessor, + output_dir=self._output_dir, + device_type=self._device.type, + ) + + def __enter__(self) -> TrainingProfiler: + if self._device.type == "cuda": + torch.cuda.reset_peak_memory_stats(self._device) + self._torch_profiler.__enter__() + return self + + def __exit__(self, *exc: Any) -> bool: + return self._torch_profiler.__exit__(*exc) + + def step(self, step_num: int, train_tracker: Any) -> None: + self._timing.record_step(_as_float(train_tracker.update_s)) + self._timing.record_dataloading(_as_float(train_tracker.dataloading_s)) + if self._device.type == "cuda": + self._timing.record_memory( + step=step_num, + allocated_bytes=torch.cuda.memory_allocated(self._device), + reserved_bytes=torch.cuda.memory_reserved(self._device), + ) + self._torch_profiler.step() + + def finalize(self) -> None: + extra: dict[str, Any] = {"profile_mode": self._mode} + if self._device.type == "cuda": + extra["peak_memory_allocated_bytes"] = torch.cuda.max_memory_allocated(self._device) + extra["peak_memory_reserved_bytes"] = torch.cuda.max_memory_reserved(self._device) + self._timing.write_json(self._output_dir / "step_timing_summary.json", extra=extra) + write_torch_profiler_outputs(self._torch_profiler, self._output_dir, device_type=self._device.type) diff --git a/tests/scripts/test_model_profiling.py b/tests/scripts/test_model_profiling.py index be11dab49..4a6930f65 100644 --- a/tests/scripts/test_model_profiling.py +++ b/tests/scripts/test_model_profiling.py @@ -65,29 +65,24 @@ def test_pretrained_libero_specs_match_expected_camera_keys_and_normalization(): specs = module.load_specs(spec_path) assert ( - "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", " - "\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}" - in specs["pi0"]["train_args"] + '--rename_map={"observation.images.front": "observation.images.base_0_rgb", ' + '"observation.images.wrist": "observation.images.left_wrist_0_rgb"}' in specs["pi0"].train_args ) assert ( - "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", " - "\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}" - in specs["pi0_fast"]["train_args"] + '--rename_map={"observation.images.front": "observation.images.base_0_rgb", ' + '"observation.images.wrist": "observation.images.left_wrist_0_rgb"}' in specs["pi0_fast"].train_args ) assert ( - "--rename_map={\"observation.images.front\": \"observation.images.base_0_rgb\", " - "\"observation.images.wrist\": \"observation.images.left_wrist_0_rgb\"}" - in specs["pi05"]["train_args"] + '--rename_map={"observation.images.front": "observation.images.base_0_rgb", ' + '"observation.images.wrist": "observation.images.left_wrist_0_rgb"}' in specs["pi05"].train_args ) assert ( - "--policy.normalization_mapping={\"ACTION\": \"MEAN_STD\", " - "\"STATE\": \"MEAN_STD\", \"VISUAL\": \"IDENTITY\"}" - in specs["pi05"]["train_args"] + '--policy.normalization_mapping={"ACTION": "MEAN_STD", ' + '"STATE": "MEAN_STD", "VISUAL": "IDENTITY"}' in specs["pi05"].train_args ) assert ( - "--rename_map={\"observation.images.front\": \"observation.images.camera1\", " - "\"observation.images.wrist\": \"observation.images.camera2\"}" - in specs["smolvla"]["train_args"] + '--rename_map={"observation.images.front": "observation.images.camera1", ' + '"observation.images.wrist": "observation.images.camera2"}' in specs["smolvla"].train_args ) @@ -222,7 +217,6 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path): (profile_dir / "step_timing_summary.json").write_text( json.dumps( { - "forward_s": {"count": 1, "mean": 0.1, "median": 0.1, "min": 0.1, "max": 0.1}, "total_update_s": {"count": 1, "mean": 0.3, "median": 0.3, "min": 0.3, "max": 0.3}, "peak_memory_allocated_bytes": 1024, } @@ -251,7 +245,7 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path): assert row["git_commit"] == "deadbeef" assert row["git_ref"] == "codex/model-profiling" assert row["pr_number"] == 3389 - assert row["step_timing_summary"]["forward_s"]["mean"] == 0.1 + assert row["step_timing_summary"]["total_update_s"]["mean"] == 0.3 assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint" @@ -364,19 +358,14 @@ def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path): def test_step_timing_collector_accepts_metric_like_values(tmp_path): - from lerobot.utils.profiling_utils import StepTimingCollector + from lerobot.utils.profiling_utils import _StepTimingCollector class _MetricLike: def __init__(self, val): self.val = val - collector = StepTimingCollector() - collector.record( - forward_s=0.1, - backward_s=0.2, - optimizer_s=0.3, - total_update_s=_MetricLike(0.6), - ) + collector = _StepTimingCollector() + collector.record_step(_MetricLike(0.6)) collector.record_dataloading(_MetricLike(0.05)) collector.write_json(tmp_path / "step_timing_summary.json")