add libero

This commit is contained in:
Jade Choghari
2025-08-05 23:55:08 -04:00
parent 06bebd97b3
commit 21a961ecbb
6 changed files with 626 additions and 59 deletions

View File

@@ -50,12 +50,12 @@ import json
import logging
import threading
import time
from collections.abc import Callable
from contextlib import nullcontext
from copy import deepcopy
from dataclasses import asdict
from pathlib import Path
from pprint import pformat
from typing import Callable
import einops
import gymnasium as gym
@@ -456,55 +456,179 @@ def _compile_episode_data(
return data_dict
@parser.wrap()
def eval_main(cfg: EvalPipelineConfig):
def eval(cfg: EvalPipelineConfig):
logging.info(pformat(asdict(cfg)))
# Check device is available
device = get_safe_torch_device(cfg.policy.device, log=True)
device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed)
set_global_seed(cfg.seed)
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
log_output_dir(cfg.output_dir)
logging.info("Making environment.")
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Making policy.")
policy = make_policy(
cfg=cfg.policy,
device=device,
env_cfg=cfg.env,
)
policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy(
env,
policy,
cfg.eval.n_episodes,
max_episodes_rendered=10,
videos_dir=Path(cfg.output_dir) / "videos",
start_seed=cfg.seed,
)
print(info["aggregated"])
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
if cfg.env.multitask_eval:
info = eval_policy_multitask(
env,
policy,
cfg.eval.n_episodes,
max_episodes_rendered=10,
videos_dir=Path(cfg.output_dir) / "videos",
start_seed=cfg.seed,
max_parallel_tasks=cfg.env.max_parallel_tasks,
verbose=False,
)
# Print overall stats
print("Overall Aggregated Metrics:")
print(info["overall"]["aggregated"])
# Print per-suite stats
for task_group, task_group_info in info.items():
if task_group == "overall":
continue # Skip the overall stats since we already printed it
print(f"\nAggregated Metrics for {task_group}:")
print(task_group_info["aggregated"])
for _task_group, v in env.items():
for _env in v.values():
_env.close()
else:
info = eval_policy(
env,
policy,
cfg.eval.n_episodes,
max_episodes_rendered=10,
videos_dir=Path(cfg.output_dir) / "videos",
start_seed=cfg.seed,
)
print(info["aggregated"])
env.close()
# Save info
with open(Path(cfg.output_dir) / "eval_info.json", "w") as f:
json.dump(info, f, indent=2)
env.close()
logging.info("End of eval")
def eval_policy_multitask(
envs: dict[str, dict[str, gym.vector.VectorEnv]],
policy,
n_episodes: int,
max_episodes_rendered: int = 0,
videos_dir: Path | None = None,
return_episode_data: bool = False,
start_seed: int | None = None,
max_parallel_tasks: int = 5,
verbose: bool = True,
) -> dict:
global_start = time.time()
results = {}
def main():
init_logging()
eval_main()
overall_rewards, overall_max_rewards, overall_successes = [], [], []
overall_video_paths = []
overall_episode_data = None
def eval_task(task_group, task_id, env):
"""Evaluates a single task in parallel."""
print(f"Evaluating: task_group: {task_group}, task_id: {task_id} ...")
task_result = eval_policy(
env, policy, n_episodes, max_episodes_rendered, videos_dir, return_episode_data, start_seed, verbose=verbose
)
per_episode = task_result["per_episode"]
return {
"task_group": task_group,
"task_id": task_id,
"sum_rewards": [ep["sum_reward"] for ep in per_episode],
"max_rewards": [ep["max_reward"] for ep in per_episode],
"successes": [ep["success"] for ep in per_episode],
"video_paths": task_result.get("video_paths", []),
}
with concurrent.futures.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
future_to_task = {
executor.submit(eval_task, task_group, task_id, env): (task_group, task_id)
for task_group, tasks in envs.items()
for task_id, env in tasks.items()
}
task_group_results = {}
for future in concurrent.futures.as_completed(future_to_task):
task_result = future.result()
task_group = task_result["task_group"]
if task_group not in task_group_results:
task_group_results[task_group] = {
"sum_rewards": [],
"max_rewards": [],
"successes": [],
"video_paths": [],
}
task_group_results[task_group]["sum_rewards"].extend(task_result["sum_rewards"])
task_group_results[task_group]["max_rewards"].extend(task_result["max_rewards"])
task_group_results[task_group]["successes"].extend(task_result["successes"])
task_group_results[task_group]["video_paths"].extend(task_result["video_paths"])
# Process results per task group
for task_group, data in task_group_results.items():
suite_rewards = data["sum_rewards"]
suite_max_rewards = data["max_rewards"]
suite_successes = data["successes"]
suite_video_paths = data["video_paths"]
suite_eval_s = time.time() - global_start
suite_eval_ep_s = suite_eval_s / max(1, len(suite_rewards))
results[task_group] = {
"aggregated": {
"avg_sum_reward": float(np.nanmean(suite_rewards)),
"avg_max_reward": float(np.nanmean(suite_max_rewards)),
"pc_success": float(np.nanmean(suite_successes) * 100),
"eval_s": suite_eval_s,
"eval_ep_s": suite_eval_ep_s,
},
"video_paths": suite_video_paths,
"episodes": None, # Modify if episode data is needed
}
overall_rewards.extend(suite_rewards)
overall_max_rewards.extend(suite_max_rewards)
overall_successes.extend(suite_successes)
overall_video_paths.extend(suite_video_paths)
# Global metrics
global_eval_s = time.time() - global_start
global_eval_ep_s = global_eval_s / max(1, len(overall_rewards))
results["overall"] = {
"aggregated": {
"avg_sum_reward": float(np.nanmean(overall_rewards)),
"avg_max_reward": float(np.nanmean(overall_max_rewards)),
"pc_success": float(np.nanmean(overall_successes) * 100),
"eval_s": global_eval_s,
"eval_ep_s": global_eval_ep_s,
},
"video_paths": overall_video_paths,
"episodes": overall_episode_data,
}
return results
if __name__ == "__main__":
main()
init_logging()
eval_main()