add multitask

This commit is contained in:
Jade Choghari (jchoghar)
2025-08-17 14:27:53 -04:00
parent c20bf75ba0
commit ac0993c2e3
15 changed files with 91 additions and 32 deletions

View File

@@ -269,7 +269,10 @@ def train(cfg: TrainPipelineConfig):
continue # Skip the overall stats since we already printed it
print(f"\nAggregated Metrics for {task_group}:")
print(task_group_info["aggregated"])
breakpoint()
else:
print("START EVAL")
breakpoint()
eval_info = eval_policy(
eval_env,
policy,
@@ -278,6 +281,8 @@ def train(cfg: TrainPipelineConfig):
max_episodes_rendered=4,
start_seed=cfg.seed,
)
aggregated = eval_info["aggregated"]
print("END EVAL")
eval_metrics = {
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
@@ -287,9 +292,9 @@ def train(cfg: TrainPipelineConfig):
eval_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
)
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success")
eval_tracker.eval_s = aggregated.pop("eval_s")
eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
eval_tracker.pc_success = aggregated.pop("pc_success")
logging.info(eval_tracker)
if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}