mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 11:51:25 +00:00
feat(metrics): add avg_sum_reward and eval_s to metrics output
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -39,30 +39,30 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _extract_pc_success(info: dict) -> tuple[float | None, int | None]:
|
||||
"""Extract (pc_success, n_episodes) from eval_info.json.
|
||||
def _extract_metrics(info: dict) -> tuple[float | None, int | None, float | None, float | None]:
|
||||
"""Extract (pc_success, n_episodes, avg_sum_reward, eval_s) from eval_info.json.
|
||||
|
||||
Handles two output shapes:
|
||||
- Single-task: {"aggregated": {"pc_success": 80.0, ...}}
|
||||
- Multi-task: {"overall": {"pc_success": 80.0, "n_episodes": 5, ...}}
|
||||
"""
|
||||
# Single-task path
|
||||
if "aggregated" in info:
|
||||
agg = info["aggregated"]
|
||||
for key in ("aggregated", "overall"):
|
||||
if key not in info:
|
||||
continue
|
||||
agg = info[key]
|
||||
pc = agg.get("pc_success")
|
||||
n = agg.get("n_episodes") # may be absent in older format
|
||||
n = agg.get("n_episodes")
|
||||
reward = agg.get("avg_sum_reward")
|
||||
eval_s = agg.get("eval_s")
|
||||
if pc is not None and not math.isnan(pc):
|
||||
return float(pc), int(n) if n is not None else None
|
||||
return (
|
||||
float(pc),
|
||||
int(n) if n is not None else None,
|
||||
float(reward) if reward is not None else None,
|
||||
float(eval_s) if eval_s is not None else None,
|
||||
)
|
||||
|
||||
# Multi-task path
|
||||
if "overall" in info:
|
||||
overall = info["overall"]
|
||||
pc = overall.get("pc_success")
|
||||
n = overall.get("n_episodes")
|
||||
if pc is not None and not math.isnan(pc):
|
||||
return float(pc), int(n) if n is not None else None
|
||||
|
||||
return None, None
|
||||
return None, None, None, None
|
||||
|
||||
|
||||
def main() -> int:
|
||||
@@ -80,11 +80,13 @@ def main() -> int:
|
||||
|
||||
pc_success: float | None = None
|
||||
n_episodes: int | None = None
|
||||
avg_sum_reward: float | None = None
|
||||
eval_s: float | None = None
|
||||
|
||||
if eval_info_path.exists():
|
||||
try:
|
||||
info = json.loads(eval_info_path.read_text())
|
||||
pc_success, n_episodes = _extract_pc_success(info)
|
||||
pc_success, n_episodes, avg_sum_reward, eval_s = _extract_metrics(info)
|
||||
except (json.JSONDecodeError, KeyError, TypeError) as exc:
|
||||
print(f"[parse_eval_metrics] Warning: could not parse eval_info.json: {exc}", file=sys.stderr)
|
||||
else:
|
||||
@@ -99,6 +101,8 @@ def main() -> int:
|
||||
"policy": args.policy,
|
||||
"pc_success": pc_success,
|
||||
"n_episodes": n_episodes,
|
||||
"avg_sum_reward": avg_sum_reward,
|
||||
"eval_s": eval_s,
|
||||
}
|
||||
|
||||
out_path = artifacts_dir / "metrics.json"
|
||||
|
||||
Reference in New Issue
Block a user