refactor(profiling): consolidate into single module

Unify the profiling subsystem into one file per reviewer request.

Before (4 files):
  src/lerobot/utils/profiling_utils.py        399 LOC
  scripts/ci/run_model_profiling.py           337 LOC
  profiling/model_profiling_specs.json        181 LOC
  tests/scripts/test_model_profiling.py       423 LOC

After (2 files):
  src/lerobot/utils/model_profiling.py        758 LOC — TrainingProfiler +
                                                       CI orchestrator +
                                                       POLICY_SPECS (inline)
  tests/test_model_profiling.py               315 LOC

Net: -267 LOC and 4 files → 2. All functionality preserved: per-step
forward/backward/optimizer timings, torch profiler tables + chrome
traces, deterministic-forward fingerprint, HF Hub result upload, and
the same CLI surface.

Changes:
- Collapse `_StepTimingCollector` into inline attributes on
  `TrainingProfiler` (no separate class).
- Drop `ProfilingSpec` dataclass; specs are plain dicts.
- Inline the JSON matrix as a module-level `POLICY_SPECS` dict —
  one less file to keep in sync with the training args.
- CI workflow invokes `python -m lerobot.utils.model_profiling` in
  place of the standalone script.
- Tests import `lerobot.utils.model_profiling` directly instead of
  loading a script-by-path. Removed JSON schema tests that no
  longer apply.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-04-20 21:31:17 +02:00
parent 8d982614a6
commit a515eadc96
8 changed files with 1079 additions and 1346 deletions

View File

@@ -1,337 +0,0 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import argparse
import json
import re
import shutil
import subprocess
import time
from dataclasses import dataclass
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from huggingface_hub import CommitOperationAdd, HfApi
from huggingface_hub.errors import HfHubHTTPError
@dataclass(frozen=True)
class ProfilingSpec:
name: str
steps: int
train_args: list[str]
@dataclass(frozen=True)
class UploadTarget:
local_path: Path
path_in_repo: str
@dataclass(frozen=True)
class UploadResult:
uploaded_paths: dict[str, str]
pr_url: str | None = None
def utc_timestamp_slug(now: datetime | None = None) -> str:
current = now or datetime.now(UTC)
return current.strftime("%Y%m%dT%H%M%SZ")
def make_hub_file_url(
repo_id: str,
path_in_repo: str,
repo_type: str = "dataset",
revision: str = "main",
) -> str:
prefix = "datasets/" if repo_type == "dataset" else ""
return f"https://huggingface.co/{prefix}{repo_id}/resolve/{revision}/{path_in_repo}"
def parse_discussion_num(pr_url: str | None) -> int | None:
if not pr_url:
return None
match = re.search(r"/discussions/(\d+)$", pr_url)
return int(match.group(1)) if match else None
def upload_targets(
repo_id: str,
targets: list[UploadTarget],
*,
repo_type: str = "dataset",
token: str | None = None,
commit_message: str | None = None,
create_pr: bool = False,
) -> UploadResult:
api = HfApi(token=token)
operations = [
CommitOperationAdd(path_in_repo=target.path_in_repo, path_or_fileobj=str(target.local_path))
for target in targets
]
commit = api.create_commit(
repo_id=repo_id,
repo_type=repo_type,
operations=operations,
commit_message=commit_message or f"Upload {len(targets)} profiling artifacts",
revision="main",
create_pr=create_pr,
)
revision = "main"
pr_num = parse_discussion_num(commit.pr_url)
if create_pr and pr_num is not None:
revision = f"refs/pr/{pr_num}"
uploaded = {
target.path_in_repo: make_hub_file_url(
repo_id, target.path_in_repo, repo_type=repo_type, revision=revision
)
for target in targets
}
return UploadResult(uploaded_paths=uploaded, pr_url=commit.pr_url)
def normalize_repo_id(repo: str, hub_org: str) -> str:
return repo if "/" in repo else f"{hub_org}/{repo}"
def load_specs(path: Path) -> dict[str, ProfilingSpec]:
payload = json.loads(path.read_text())
return {
name: ProfilingSpec(name=name, steps=spec["steps"], train_args=spec["train_args"])
for name, spec in payload.items()
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--spec-file", type=Path, default=Path("profiling/model_profiling_specs.json"))
parser.add_argument("--policies", nargs="*", default=None)
parser.add_argument("--output_dir", type=Path, required=True)
parser.add_argument("--hub_org", default="lerobot")
parser.add_argument("--results_repo", default="model-profiling-history")
parser.add_argument("--publish", action="store_true")
parser.add_argument("--profile_mode", choices=["summary", "trace"], default="trace")
parser.add_argument("--git_commit", default="")
parser.add_argument("--git_ref", default="")
parser.add_argument("--pr_number", default="")
return parser.parse_args()
def get_selected_names(requested: list[str] | None, specs: dict[str, ProfilingSpec]) -> list[str]:
if not requested:
return list(specs)
unknown = sorted(set(requested) - set(specs))
if unknown:
raise ValueError(f"Unknown profiling policies: {', '.join(unknown)}")
return requested
def build_train_command(spec: ProfilingSpec, run_dir: Path, profile_mode: str) -> list[str]:
train_output_dir = run_dir / "train"
profile_output_dir = run_dir / "profiling"
return [
"uv",
"run",
"lerobot-train",
*spec.train_args,
f"--output_dir={train_output_dir}",
f"--steps={spec.steps}",
"--eval_freq=0",
"--save_checkpoint=false",
f"--save_freq={spec.steps}",
"--wandb.enable=false",
"--policy.push_to_hub=false",
"--num_workers=0",
"--log_freq=1",
f"--profile_mode={profile_mode}",
f"--profile_output_dir={profile_output_dir}",
]
def load_json_if_exists(path: Path) -> dict[str, Any] | None:
if not path.exists():
return None
return json.loads(path.read_text())
def build_artifact_index(
*,
repo_id: str,
run_dir: Path,
policy_name: str,
run_id: str,
) -> tuple[dict[str, Any], dict[str, Any], list[UploadTarget], str]:
row_path_in_repo = f"rows/{policy_name}/{run_id}.json"
artifact_root = f"artifacts/{policy_name}/{run_id}"
artifact_paths: dict[str, Any] = {
"row": row_path_in_repo,
"profiling_files": {},
"torch_tables": {},
"trace_files": {},
}
artifact_urls: dict[str, Any] = {
"row": make_hub_file_url(repo_id, row_path_in_repo),
"profiling_files": {},
"torch_tables": {},
"trace_files": {},
}
targets: list[UploadTarget] = []
for name in ("stdout.txt", "stderr.txt"):
path = run_dir / name
if not path.exists():
continue
repo_path = f"{artifact_root}/{name}"
artifact_paths[name.removesuffix(".txt")] = repo_path
artifact_urls[name.removesuffix(".txt")] = make_hub_file_url(repo_id, repo_path)
targets.append(UploadTarget(local_path=path, path_in_repo=repo_path))
profiling_dir = run_dir / "profiling"
for path in sorted(profiling_dir.rglob("*")) if profiling_dir.exists() else []:
if not path.is_file():
continue
relative_path = str(path.relative_to(run_dir))
repo_path = f"{artifact_root}/{relative_path}"
artifact_paths["profiling_files"][relative_path] = repo_path
artifact_urls["profiling_files"][relative_path] = make_hub_file_url(repo_id, repo_path)
targets.append(UploadTarget(local_path=path, path_in_repo=repo_path))
if path.name == "step_timing_summary.json":
artifact_paths["step_timing_summary"] = repo_path
artifact_urls["step_timing_summary"] = make_hub_file_url(repo_id, repo_path)
elif "torch_tables" in path.parts:
artifact_paths["torch_tables"][path.name] = repo_path
artifact_urls["torch_tables"][path.name] = make_hub_file_url(repo_id, repo_path)
elif "torch_traces" in path.parts:
artifact_paths["trace_files"][path.name] = repo_path
artifact_urls["trace_files"][path.name] = make_hub_file_url(repo_id, repo_path)
return artifact_paths, artifact_urls, targets, row_path_in_repo
def upload_profile_run(
*,
repo_id: str,
row_path: Path,
row_path_in_repo: str,
artifact_targets: list[UploadTarget],
create_pr: bool = False,
) -> UploadResult:
return upload_targets(
repo_id=repo_id,
targets=[*artifact_targets, UploadTarget(local_path=row_path, path_in_repo=row_path_in_repo)],
repo_type="dataset",
commit_message=f"Add model profiling row {row_path_in_repo}",
create_pr=create_pr,
)
def main() -> int:
args = parse_args()
specs = load_specs(args.spec_file)
selected = get_selected_names(args.policies, specs)
args.output_dir.mkdir(parents=True, exist_ok=True)
repo_id = normalize_repo_id(args.results_repo, args.hub_org)
git_executable = shutil.which("git")
if not git_executable:
raise RuntimeError("git executable not found in PATH")
git_commit = (
args.git_commit or subprocess.check_output([git_executable, "rev-parse", "HEAD"], text=True).strip()
)
pr_number = int(args.pr_number) if str(args.pr_number).strip() else None
for policy_name in selected:
spec = specs[policy_name]
run_id = f"{utc_timestamp_slug()}__{policy_name}"
run_dir = args.output_dir / policy_name / run_id
run_dir.mkdir(parents=True, exist_ok=True)
cmd = build_train_command(spec, run_dir, args.profile_mode)
start = time.perf_counter()
result = subprocess.run(cmd, capture_output=True, text=True)
duration_s = time.perf_counter() - start
stdout_path = run_dir / "stdout.txt"
stderr_path = run_dir / "stderr.txt"
stdout_path.write_text(result.stdout)
stderr_path.write_text(result.stderr)
profile_summary = load_json_if_exists(run_dir / "profiling" / "step_timing_summary.json") or {}
deterministic_forward = (
load_json_if_exists(run_dir / "profiling" / "deterministic_forward.json") or {}
)
artifact_paths, artifact_urls, artifact_targets, row_path_in_repo = build_artifact_index(
repo_id=repo_id,
run_dir=run_dir,
policy_name=policy_name,
run_id=run_id,
)
row = {
"schema_version": 1,
"created_at": datetime.now(UTC).isoformat(),
"run_id": run_id,
"policy": policy_name,
"git_commit": git_commit,
"git_ref": args.git_ref or None,
"pr_number": pr_number,
"status": "success" if result.returncode == 0 else "failed",
"return_code": result.returncode,
"profile_mode": args.profile_mode,
"wall_time_s": duration_s,
"spec": {
"steps": spec.steps,
"train_args": spec.train_args,
},
"step_timing_summary": profile_summary,
"deterministic_forward": deterministic_forward,
"artifact_paths": artifact_paths,
"artifact_urls": artifact_urls,
"stderr_tail": result.stderr.splitlines()[-20:],
}
row_path = run_dir / "profiling_row.json"
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
if args.publish:
try:
upload_result = upload_profile_run(
repo_id=repo_id,
row_path=row_path,
row_path_in_repo=row_path_in_repo,
artifact_targets=artifact_targets,
create_pr=pr_number is not None,
)
except HfHubHTTPError as exc:
row["publish_status"] = "failed"
row["publish_error"] = str(exc)
else:
row["publish_status"] = "success"
row["uploaded_paths"] = upload_result.uploaded_paths
row["publish_pr_url"] = upload_result.pr_url
row["publish_pr_number"] = parse_discussion_num(upload_result.pr_url)
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
print(json.dumps(row, indent=2, sort_keys=True))
return 0
if __name__ == "__main__":
raise SystemExit(main())