feat(eval): add per-episode timing logs to eval worker

Logs avg env step time, avg inference call time, and totals per episode
to identify whether env or policy server is the bottleneck.

Made-with: Cursor
This commit is contained in:
Pepijn Kooijmans
2026-03-25 07:30:49 +01:00
parent ddcda8f1ca
commit 8770c011b0

View File

@@ -35,6 +35,7 @@ from __future__ import annotations
import json
import logging
import pickle # nosec B403 — internal serialisation only
import time
import urllib.request
from dataclasses import dataclass, field
from pathlib import Path
@@ -130,15 +131,24 @@ def run_worker(cfg: EvalWorkerConfig) -> dict:
ep_rewards: list[float] = []
ep_success = False
done = np.zeros(1, dtype=bool)
ep_steps = 0
ep_infer_time = 0.0
ep_env_time = 0.0
ep_infer_calls = 0
while not np.all(done):
if not action_buffer:
t0 = time.monotonic()
chunk_np = _call_server(cfg.server_address, obs_t, cfg.server_timeout)
# chunk_np: (T, action_dim) — split into per-step slices of shape (1, action_dim)
ep_infer_time += time.monotonic() - t0
ep_infer_calls += 1
action_buffer = [chunk_np[i : i + 1] for i in range(chunk_np.shape[0])]
action_np = action_buffer.pop(0) # (1, action_dim)
action_np = action_buffer.pop(0)
t0 = time.monotonic()
obs, reward, terminated, truncated, info = env.step(action_np)
ep_env_time += time.monotonic() - t0
ep_steps += 1
done = terminated | truncated | done
ep_rewards.append(float(np.mean(reward)))
@@ -155,13 +165,22 @@ def run_worker(cfg: EvalWorkerConfig) -> dict:
sum_rewards.append(float(np.sum(ep_rewards)))
max_rewards.append(float(np.max(ep_rewards)) if ep_rewards else 0.0)
successes.append(ep_success)
avg_env_ms = (ep_env_time / ep_steps * 1000) if ep_steps else 0
avg_infer_ms = (ep_infer_time / ep_infer_calls * 1000) if ep_infer_calls else 0
logger.info(
"Task %s[%d] ep %d/%d — success=%s",
"Task %s[%d] ep %d/%d — success=%s | %d steps, %d infer calls | "
"env %.0fms/step, infer %.0fms/call (env %.1fs, infer %.1fs total)",
task_group,
task_id,
ep_idx + 1,
cfg.n_episodes,
ep_success,
ep_steps,
ep_infer_calls,
avg_env_ms,
avg_infer_ms,
ep_env_time,
ep_infer_time,
)
per_task.append(