[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-08-06 04:06:02 +00:00
parent 4bc356b7f3
commit 9124b36b0a
4 changed files with 31 additions and 21 deletions

View File

@@ -253,22 +253,22 @@ def train(cfg: TrainPipelineConfig):
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
):
if cfg.env.multitask_eval:
eval_info = eval_policy_multitask(
eval_env,
policy,
cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
start_seed=cfg.seed,
max_parallel_tasks=cfg.env.max_parallel_tasks,
)
aggregated_results = eval_info["overall"]["aggregated"]
# Print per-suite stats
for task_group, task_group_info in eval_info.items():
if task_group == "overall":
continue # Skip the overall stats since we already printed it
print(f"\nAggregated Metrics for {task_group}:")
print(task_group_info["aggregated"])
eval_info = eval_policy_multitask(
eval_env,
policy,
cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
start_seed=cfg.seed,
max_parallel_tasks=cfg.env.max_parallel_tasks,
)
aggregated_results = eval_info["overall"]["aggregated"]
# Print per-suite stats
for task_group, task_group_info in eval_info.items():
if task_group == "overall":
continue # Skip the overall stats since we already printed it
print(f"\nAggregated Metrics for {task_group}:")
print(task_group_info["aggregated"])
else:
eval_info = eval_policy(
eval_env,