[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:16:38 +00:00
committed by Michel Aractingi
parent cdcf346061
commit 1c8daf11fd
95 changed files with 1592 additions and 491 deletions

View File

@@ -152,7 +152,8 @@ def rollout(
all_observations.append(deepcopy(observation))
observation = {
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
key: observation[key].to(device, non_blocking=device.type == "cuda")
for key in observation
}
# Infer "task" from attributes of environments.
@@ -515,10 +516,14 @@ def eval_main(cfg: EvalPipelineConfig):
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed)
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info(
colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}"
)
logging.info("Making environment.")
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
env = make_env(
cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs
)
logging.info("Making policy.")
@@ -528,7 +533,12 @@ def eval_main(cfg: EvalPipelineConfig):
)
policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
with (
torch.no_grad(),
torch.autocast(device_type=device.type)
if cfg.policy.use_amp
else nullcontext(),
):
info = eval_policy(
env,
policy,