final refactor/fix

This commit is contained in:
Jade Choghari (jchoghar)
2025-08-25 06:25:02 -04:00
parent afad90ffaa
commit 8d2c66abd2
7 changed files with 47 additions and 75 deletions

View File

@@ -186,7 +186,6 @@ def train(cfg: TrainPipelineConfig):
dl_iter = cycle(dataloader)
policy.train()
train_metrics = {
"loss": AverageMeter("loss", ":.3f"),
"grad_norm": AverageMeter("grdn", ":.3f"),
@@ -263,15 +262,14 @@ def train(cfg: TrainPipelineConfig):
max_parallel_tasks=cfg.env.max_parallel_tasks,
)
aggregated = eval_info["overall"]["aggregated"]
# Print per-suite stats
# Print per-suite stats, log?
for task_group, task_group_info in eval_info.items():
if task_group == "overall":
continue # Skip the overall stats since we already printed it
print(f"\nAggregated Metrics for {task_group}:")
print(task_group_info["aggregated"])
breakpoint()
breakpoint()
else:
print("START EVAL")
eval_info = eval_policy(
eval_env,
policy,
@@ -280,9 +278,8 @@ def train(cfg: TrainPipelineConfig):
max_episodes_rendered=4,
start_seed=cfg.seed,
)
breakpoint()
aggregated = eval_info["aggregated"]
print("END EVAL")
breakpoint()
eval_metrics = {
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),