mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 11:51:25 +00:00
fix(profiling): address review feedback
This commit is contained in:
@@ -156,7 +156,7 @@ def test_parse_discussion_num_handles_hf_discussion_urls():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _fake_args(tmp_path):
|
||||
def fake_args(tmp_path):
|
||||
"""Shared argparse namespace for main() smoke tests — overridden per-test."""
|
||||
return argparse.Namespace(
|
||||
policies=["act"],
|
||||
@@ -195,14 +195,14 @@ def _stub_train_subprocess(mp_module, *, returncode: int = 0, write_artifacts: b
|
||||
return _fake_run
|
||||
|
||||
|
||||
def test_main_smoke_writes_row(monkeypatch, _fake_args):
|
||||
monkeypatch.setattr(mp, "parse_args", lambda: _fake_args)
|
||||
def test_main_smoke_writes_row(monkeypatch, fake_args):
|
||||
monkeypatch.setattr(mp, "parse_args", lambda: fake_args)
|
||||
monkeypatch.setattr(mp.subprocess, "check_output", lambda *a, **k: "deadbeef\n")
|
||||
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp))
|
||||
|
||||
assert mp.main() == 0
|
||||
|
||||
row_paths = list(_fake_args.output_dir.rglob("profiling_row.json"))
|
||||
row_paths = list(fake_args.output_dir.rglob("profiling_row.json"))
|
||||
assert len(row_paths) == 1
|
||||
row = json.loads(row_paths[0].read_text())
|
||||
assert row["policy"] == "act"
|
||||
@@ -214,10 +214,10 @@ def test_main_smoke_writes_row(monkeypatch, _fake_args):
|
||||
assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint"
|
||||
|
||||
|
||||
def test_main_records_publish_failure_without_failing(monkeypatch, _fake_args):
|
||||
_fake_args.publish = True
|
||||
_fake_args.git_commit = "deadbeef"
|
||||
monkeypatch.setattr(mp, "parse_args", lambda: _fake_args)
|
||||
def test_main_records_publish_failure_without_failing(monkeypatch, fake_args):
|
||||
fake_args.publish = True
|
||||
fake_args.git_commit = "deadbeef"
|
||||
monkeypatch.setattr(mp, "parse_args", lambda: fake_args)
|
||||
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp, write_artifacts=False))
|
||||
|
||||
def _fail_upload(**kwargs):
|
||||
@@ -227,12 +227,24 @@ def test_main_records_publish_failure_without_failing(monkeypatch, _fake_args):
|
||||
monkeypatch.setattr(mp, "upload_profile_run", _fail_upload)
|
||||
|
||||
assert mp.main() == 0
|
||||
row = json.loads(next(_fake_args.output_dir.rglob("profiling_row.json")).read_text())
|
||||
row = json.loads(next(fake_args.output_dir.rglob("profiling_row.json")).read_text())
|
||||
assert row["status"] == "success"
|
||||
assert row["publish_status"] == "failed"
|
||||
assert "Authorization error" in row["publish_error"]
|
||||
|
||||
|
||||
def test_main_returns_nonzero_when_training_subprocess_fails(monkeypatch, fake_args):
|
||||
monkeypatch.setattr(mp, "parse_args", lambda: fake_args)
|
||||
monkeypatch.setattr(mp.subprocess, "check_output", lambda *a, **k: "deadbeef\n")
|
||||
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp, returncode=3))
|
||||
|
||||
assert mp.main() == 1
|
||||
|
||||
row = json.loads(next(fake_args.output_dir.rglob("profiling_row.json")).read_text())
|
||||
assert row["status"] == "failed"
|
||||
assert row["return_code"] == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TrainingProfiler behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user