fix(profiling): address review feedback

This commit is contained in:
Pepijn
2026-04-23 13:23:09 +02:00
parent bfff81fd4b
commit a23ebf9d35
4 changed files with 37 additions and 21 deletions

View File

@@ -156,7 +156,7 @@ def test_parse_discussion_num_handles_hf_discussion_urls():
@pytest.fixture
def _fake_args(tmp_path):
def fake_args(tmp_path):
"""Shared argparse namespace for main() smoke tests — overridden per-test."""
return argparse.Namespace(
policies=["act"],
@@ -195,14 +195,14 @@ def _stub_train_subprocess(mp_module, *, returncode: int = 0, write_artifacts: b
return _fake_run
def test_main_smoke_writes_row(monkeypatch, _fake_args):
monkeypatch.setattr(mp, "parse_args", lambda: _fake_args)
def test_main_smoke_writes_row(monkeypatch, fake_args):
monkeypatch.setattr(mp, "parse_args", lambda: fake_args)
monkeypatch.setattr(mp.subprocess, "check_output", lambda *a, **k: "deadbeef\n")
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp))
assert mp.main() == 0
row_paths = list(_fake_args.output_dir.rglob("profiling_row.json"))
row_paths = list(fake_args.output_dir.rglob("profiling_row.json"))
assert len(row_paths) == 1
row = json.loads(row_paths[0].read_text())
assert row["policy"] == "act"
@@ -214,10 +214,10 @@ def test_main_smoke_writes_row(monkeypatch, _fake_args):
assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint"
def test_main_records_publish_failure_without_failing(monkeypatch, _fake_args):
_fake_args.publish = True
_fake_args.git_commit = "deadbeef"
monkeypatch.setattr(mp, "parse_args", lambda: _fake_args)
def test_main_records_publish_failure_without_failing(monkeypatch, fake_args):
fake_args.publish = True
fake_args.git_commit = "deadbeef"
monkeypatch.setattr(mp, "parse_args", lambda: fake_args)
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp, write_artifacts=False))
def _fail_upload(**kwargs):
@@ -227,12 +227,24 @@ def test_main_records_publish_failure_without_failing(monkeypatch, _fake_args):
monkeypatch.setattr(mp, "upload_profile_run", _fail_upload)
assert mp.main() == 0
row = json.loads(next(_fake_args.output_dir.rglob("profiling_row.json")).read_text())
row = json.loads(next(fake_args.output_dir.rglob("profiling_row.json")).read_text())
assert row["status"] == "success"
assert row["publish_status"] == "failed"
assert "Authorization error" in row["publish_error"]
def test_main_returns_nonzero_when_training_subprocess_fails(monkeypatch, fake_args):
monkeypatch.setattr(mp, "parse_args", lambda: fake_args)
monkeypatch.setattr(mp.subprocess, "check_output", lambda *a, **k: "deadbeef\n")
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp, returncode=3))
assert mp.main() == 1
row = json.loads(next(fake_args.output_dir.rglob("profiling_row.json")).read_text())
assert row["status"] == "failed"
assert row["return_code"] == 3
# ---------------------------------------------------------------------------
# TrainingProfiler behavior
# ---------------------------------------------------------------------------