feat(metrics): add avg_sum_reward and eval_s to metrics output

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-04-09 10:04:53 +02:00
parent 82034805d6
commit 501b916601

View File

@@ -39,30 +39,30 @@ import sys
from pathlib import Path
def _extract_pc_success(info: dict) -> tuple[float | None, int | None]:
"""Extract (pc_success, n_episodes) from eval_info.json.
def _extract_metrics(info: dict) -> tuple[float | None, int | None, float | None, float | None]:
"""Extract (pc_success, n_episodes, avg_sum_reward, eval_s) from eval_info.json.
Handles two output shapes:
- Single-task: {"aggregated": {"pc_success": 80.0, ...}}
- Multi-task: {"overall": {"pc_success": 80.0, "n_episodes": 5, ...}}
"""
# Single-task path
if "aggregated" in info:
agg = info["aggregated"]
for key in ("aggregated", "overall"):
if key not in info:
continue
agg = info[key]
pc = agg.get("pc_success")
n = agg.get("n_episodes") # may be absent in older format
n = agg.get("n_episodes")
reward = agg.get("avg_sum_reward")
eval_s = agg.get("eval_s")
if pc is not None and not math.isnan(pc):
return float(pc), int(n) if n is not None else None
return (
float(pc),
int(n) if n is not None else None,
float(reward) if reward is not None else None,
float(eval_s) if eval_s is not None else None,
)
# Multi-task path
if "overall" in info:
overall = info["overall"]
pc = overall.get("pc_success")
n = overall.get("n_episodes")
if pc is not None and not math.isnan(pc):
return float(pc), int(n) if n is not None else None
return None, None
return None, None, None, None
def main() -> int:
@@ -80,11 +80,13 @@ def main() -> int:
pc_success: float | None = None
n_episodes: int | None = None
avg_sum_reward: float | None = None
eval_s: float | None = None
if eval_info_path.exists():
try:
info = json.loads(eval_info_path.read_text())
pc_success, n_episodes = _extract_pc_success(info)
pc_success, n_episodes, avg_sum_reward, eval_s = _extract_metrics(info)
except (json.JSONDecodeError, KeyError, TypeError) as exc:
print(f"[parse_eval_metrics] Warning: could not parse eval_info.json: {exc}", file=sys.stderr)
else:
@@ -99,6 +101,8 @@ def main() -> int:
"policy": args.policy,
"pc_success": pc_success,
"n_episodes": n_episodes,
"avg_sum_reward": avg_sum_reward,
"eval_s": eval_s,
}
out_path = artifacts_dir / "metrics.json"