mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
add multitask
This commit is contained in:
@@ -269,7 +269,10 @@ def train(cfg: TrainPipelineConfig):
|
||||
continue # Skip the overall stats since we already printed it
|
||||
print(f"\nAggregated Metrics for {task_group}:")
|
||||
print(task_group_info["aggregated"])
|
||||
breakpoint()
|
||||
else:
|
||||
print("START EVAL")
|
||||
breakpoint()
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
@@ -278,6 +281,8 @@ def train(cfg: TrainPipelineConfig):
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
aggregated = eval_info["aggregated"]
|
||||
print("END EVAL")
|
||||
|
||||
eval_metrics = {
|
||||
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
||||
@@ -287,9 +292,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
eval_tracker = MetricsTracker(
|
||||
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
|
||||
)
|
||||
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
|
||||
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
|
||||
eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success")
|
||||
eval_tracker.eval_s = aggregated.pop("eval_s")
|
||||
eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
|
||||
eval_tracker.pc_success = aggregated.pop("pc_success")
|
||||
logging.info(eval_tracker)
|
||||
if wandb_logger:
|
||||
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
||||
|
||||
Reference in New Issue
Block a user