[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-08-06 04:06:02 +00:00
parent 4bc356b7f3
commit 9124b36b0a
4 changed files with 31 additions and 21 deletions

View File

@@ -272,6 +272,7 @@ class HILEnvConfig(EnvConfig):
"gripper_penalty": self.gripper_penalty,
}
@EnvConfig.register_subclass("libero")
@dataclass
class LiberoEnv(EnvConfig):

View File

@@ -2,7 +2,8 @@ import math
import os
from collections import defaultdict
from itertools import chain
from typing import Any, Callable
from typing import Any
from collections.abc import Callable
import gymnasium as gym
import numpy as np

View File

@@ -55,13 +55,12 @@ from copy import deepcopy
from dataclasses import asdict
from pathlib import Path
from pprint import pformat
from typing import Callable
from collections.abc import Callable
import einops
import gymnasium as gym
import numpy as np
import torch
from termcolor import colored
from torch import Tensor, nn
from tqdm import trange
@@ -73,7 +72,6 @@ from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
get_safe_torch_device,
init_logging,
@@ -456,6 +454,7 @@ def _compile_episode_data(
return data_dict
@parser.wrap()
def eval(cfg: EvalPipelineConfig):
logging.info(pformat(asdict(cfg)))
@@ -523,6 +522,7 @@ def eval(cfg: EvalPipelineConfig):
logging.info("End of eval")
def eval_policy_multitask(
envs: dict[str, dict[str, gym.vector.VectorEnv]],
policy,
@@ -545,7 +545,14 @@ def eval_policy_multitask(
"""Evaluates a single task in parallel."""
print(f"Evaluating: task_group: {task_group}, task_id: {task_id} ...")
task_result = eval_policy(
env, policy, n_episodes, max_episodes_rendered, videos_dir, return_episode_data, start_seed, verbose=verbose
env,
policy,
n_episodes,
max_episodes_rendered,
videos_dir,
return_episode_data,
start_seed,
verbose=verbose,
)
per_episode = task_result["per_episode"]
@@ -629,6 +636,7 @@ def eval_policy_multitask(
return results
if __name__ == "__main__":
init_logging()
eval_main()

View File

@@ -253,22 +253,22 @@ def train(cfg: TrainPipelineConfig):
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
):
if cfg.env.multitask_eval:
eval_info = eval_policy_multitask(
eval_env,
policy,
cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
start_seed=cfg.seed,
max_parallel_tasks=cfg.env.max_parallel_tasks,
)
aggregated_results = eval_info["overall"]["aggregated"]
# Print per-suite stats
for task_group, task_group_info in eval_info.items():
if task_group == "overall":
continue # Skip the overall stats since we already printed it
print(f"\nAggregated Metrics for {task_group}:")
print(task_group_info["aggregated"])
eval_info = eval_policy_multitask(
eval_env,
policy,
cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
start_seed=cfg.seed,
max_parallel_tasks=cfg.env.max_parallel_tasks,
)
aggregated_results = eval_info["overall"]["aggregated"]
# Print per-suite stats
for task_group, task_group_info in eval_info.items():
if task_group == "overall":
continue # Skip the overall stats since we already printed it
print(f"\nAggregated Metrics for {task_group}:")
print(task_group_info["aggregated"])
else:
eval_info = eval_policy(
eval_env,