diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index 872597340..4cbf89b27 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -46,6 +46,7 @@ Note that in both examples, the repo/folder should contain at least `config.json You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py """ +import concurrent import json import logging import threading @@ -56,7 +57,7 @@ from copy import deepcopy from dataclasses import asdict from pathlib import Path from pprint import pformat -import concurrent + import einops import gymnasium as gym import numpy as np @@ -158,7 +159,7 @@ def rollout( observation = add_envs_task(env, observation) with torch.inference_mode(): action = policy.select_action(observation) - observation['observation.images.image'] + observation["observation.images.image"] # Convert to CPU / numpy. action = action.to("cpu").numpy() assert action.ndim == 2, "Action dimensions should be (batch, action_dim)" @@ -177,12 +178,11 @@ def rollout( # Keep track of which environments are done so far. # done = terminated | truncated | done - #TODO: jadechoghari changed, this is cleaner + # TODO: jadechoghari changed, this is cleaner done = terminated | truncated | done if step + 1 == max_steps: done = np.ones_like(done, dtype=bool) - all_actions.append(torch.from_numpy(action)) all_rewards.append(torch.from_numpy(reward)) all_dones.append(torch.from_numpy(done)) @@ -378,7 +378,7 @@ def eval_policy( # Wait till all video rendering threads are done. for thread in threads: thread.join() - + # Compile eval info. info = { "per_episode": [ @@ -460,16 +460,22 @@ def _compile_episode_data( return data_dict + def set_global_seed(seed): """Set seed for reproducibility.""" import random + random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + + def log_output_dir(out_dir): - logging.info("Output dir:"+ f" {out_dir}") + logging.info("Output dir:" + f" {out_dir}") + + @parser.wrap() def eval(cfg: EvalPipelineConfig): logging.info(pformat(asdict(cfg)))