mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -272,6 +272,7 @@ class HILEnvConfig(EnvConfig):
|
||||
"gripper_penalty": self.gripper_penalty,
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("libero")
|
||||
@dataclass
|
||||
class LiberoEnv(EnvConfig):
|
||||
|
||||
@@ -2,7 +2,8 @@ import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
from collections.abc import Callable
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
@@ -55,13 +55,12 @@ from copy import deepcopy
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from termcolor import colored
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
@@ -73,7 +72,6 @@ from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.utils.io_utils import write_video
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.utils import (
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
@@ -456,6 +454,7 @@ def _compile_episode_data(
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def eval(cfg: EvalPipelineConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
@@ -523,6 +522,7 @@ def eval(cfg: EvalPipelineConfig):
|
||||
|
||||
logging.info("End of eval")
|
||||
|
||||
|
||||
def eval_policy_multitask(
|
||||
envs: dict[str, dict[str, gym.vector.VectorEnv]],
|
||||
policy,
|
||||
@@ -545,7 +545,14 @@ def eval_policy_multitask(
|
||||
"""Evaluates a single task in parallel."""
|
||||
print(f"Evaluating: task_group: {task_group}, task_id: {task_id} ...")
|
||||
task_result = eval_policy(
|
||||
env, policy, n_episodes, max_episodes_rendered, videos_dir, return_episode_data, start_seed, verbose=verbose
|
||||
env,
|
||||
policy,
|
||||
n_episodes,
|
||||
max_episodes_rendered,
|
||||
videos_dir,
|
||||
return_episode_data,
|
||||
start_seed,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
per_episode = task_result["per_episode"]
|
||||
@@ -629,6 +636,7 @@ def eval_policy_multitask(
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
eval_main()
|
||||
|
||||
@@ -253,22 +253,22 @@ def train(cfg: TrainPipelineConfig):
|
||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||
):
|
||||
if cfg.env.multitask_eval:
|
||||
eval_info = eval_policy_multitask(
|
||||
eval_env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||
)
|
||||
aggregated_results = eval_info["overall"]["aggregated"]
|
||||
# Print per-suite stats
|
||||
for task_group, task_group_info in eval_info.items():
|
||||
if task_group == "overall":
|
||||
continue # Skip the overall stats since we already printed it
|
||||
print(f"\nAggregated Metrics for {task_group}:")
|
||||
print(task_group_info["aggregated"])
|
||||
eval_info = eval_policy_multitask(
|
||||
eval_env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||
)
|
||||
aggregated_results = eval_info["overall"]["aggregated"]
|
||||
# Print per-suite stats
|
||||
for task_group, task_group_info in eval_info.items():
|
||||
if task_group == "overall":
|
||||
continue # Skip the overall stats since we already printed it
|
||||
print(f"\nAggregated Metrics for {task_group}:")
|
||||
print(task_group_info["aggregated"])
|
||||
else:
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
|
||||
Reference in New Issue
Block a user