From 9124b36b0a714e7202809ca66635555ddf4bc19b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 Aug 2025 04:06:02 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lerobot/envs/configs.py | 1 + src/lerobot/envs/libero.py | 3 ++- src/lerobot/scripts/eval.py | 16 ++++++++++++---- src/lerobot/scripts/train.py | 32 ++++++++++++++++---------------- 4 files changed, 31 insertions(+), 21 deletions(-) diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index c9db0979f..eab53085e 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -272,6 +272,7 @@ class HILEnvConfig(EnvConfig): "gripper_penalty": self.gripper_penalty, } + @EnvConfig.register_subclass("libero") @dataclass class LiberoEnv(EnvConfig): diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index 75dfd6ada..218464070 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -2,7 +2,8 @@ import math import os from collections import defaultdict from itertools import chain -from typing import Any, Callable +from typing import Any +from collections.abc import Callable import gymnasium as gym import numpy as np diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index a56c4c3b5..7f9b8f362 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -55,13 +55,12 @@ from copy import deepcopy from dataclasses import asdict from pathlib import Path from pprint import pformat -from typing import Callable +from collections.abc import Callable import einops import gymnasium as gym import numpy as np import torch -from termcolor import colored from torch import Tensor, nn from tqdm import trange @@ -73,7 +72,6 @@ from lerobot.policies.factory import make_policy from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters from lerobot.utils.io_utils import write_video -from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( get_safe_torch_device, init_logging, @@ -456,6 +454,7 @@ def _compile_episode_data( return data_dict + @parser.wrap() def eval(cfg: EvalPipelineConfig): logging.info(pformat(asdict(cfg))) @@ -523,6 +522,7 @@ def eval(cfg: EvalPipelineConfig): logging.info("End of eval") + def eval_policy_multitask( envs: dict[str, dict[str, gym.vector.VectorEnv]], policy, @@ -545,7 +545,14 @@ def eval_policy_multitask( """Evaluates a single task in parallel.""" print(f"Evaluating: task_group: {task_group}, task_id: {task_id} ...") task_result = eval_policy( - env, policy, n_episodes, max_episodes_rendered, videos_dir, return_episode_data, start_seed, verbose=verbose + env, + policy, + n_episodes, + max_episodes_rendered, + videos_dir, + return_episode_data, + start_seed, + verbose=verbose, ) per_episode = task_result["per_episode"] @@ -629,6 +636,7 @@ def eval_policy_multitask( return results + if __name__ == "__main__": init_logging() eval_main() diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 9b287d957..6cb476afb 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -253,22 +253,22 @@ def train(cfg: TrainPipelineConfig): torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), ): if cfg.env.multitask_eval: - eval_info = eval_policy_multitask( - eval_env, - policy, - cfg.eval.n_episodes, - videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", - max_episodes_rendered=4, - start_seed=cfg.seed, - max_parallel_tasks=cfg.env.max_parallel_tasks, - ) - aggregated_results = eval_info["overall"]["aggregated"] - # Print per-suite stats - for task_group, task_group_info in eval_info.items(): - if task_group == "overall": - continue # Skip the overall stats since we already printed it - print(f"\nAggregated Metrics for {task_group}:") - print(task_group_info["aggregated"]) + eval_info = eval_policy_multitask( + eval_env, + policy, + cfg.eval.n_episodes, + videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", + max_episodes_rendered=4, + start_seed=cfg.seed, + max_parallel_tasks=cfg.env.max_parallel_tasks, + ) + aggregated_results = eval_info["overall"]["aggregated"] + # Print per-suite stats + for task_group, task_group_info in eval_info.items(): + if task_group == "overall": + continue # Skip the overall stats since we already printed it + print(f"\nAggregated Metrics for {task_group}:") + print(task_group_info["aggregated"]) else: eval_info = eval_policy( eval_env,