[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-04 13:38:47 +00:00
committed by AdilZouitine
parent 76df8a31b3
commit 38f5fa4523
79 changed files with 2782 additions and 788 deletions

View File

@@ -170,7 +170,10 @@ def rollout(
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available of none of the envs finished.
if "final_info" in info:
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
successes = [
info["is_success"] if info is not None else False
for info in info["final_info"]
]
else:
successes = [False] * env.num_envs
@@ -184,9 +187,13 @@ def rollout(
step += 1
running_success_rate = (
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any")
.numpy()
.mean()
)
progbar.set_postfix(
{"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}
)
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
progbar.update()
# Track the final observation.
@@ -204,7 +211,9 @@ def rollout(
if return_observations:
stacked_observations = {}
for key in all_observations[0]:
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
stacked_observations[key] = torch.stack(
[obs[key] for obs in all_observations], dim=1
)
ret["observation"] = stacked_observations
if hasattr(policy, "use_original_modules"):
@@ -266,7 +275,9 @@ def eval_policy(
return
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
if isinstance(env, gym.vector.SyncVectorEnv):
ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
ep_frames.append(
np.stack([env.envs[i].render() for i in range(n_to_render_now)])
) # noqa: B023
elif isinstance(env, gym.vector.AsyncVectorEnv):
# Here we must render all frames and discard any we don't need.
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
@@ -278,7 +289,9 @@ def eval_policy(
episode_data: dict | None = None
# we dont want progress bar when we use slurm, since it clutters the logs
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
progbar = trange(
n_batches, desc="Stepping through eval batches", disable=inside_slurm()
)
for batch_ix in progbar:
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
# step.
@@ -289,7 +302,8 @@ def eval_policy(
seeds = None
else:
seeds = range(
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
start_seed + (batch_ix * env.num_envs),
start_seed + ((batch_ix + 1) * env.num_envs),
)
rollout_data = rollout(
env,
@@ -307,13 +321,22 @@ def eval_policy(
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
mask = (
torch.arange(n_steps)
<= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)
).int()
# Extend metrics.
batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum")
batch_sum_rewards = einops.reduce(
(rollout_data["reward"] * mask), "b n -> b", "sum"
)
sum_rewards.extend(batch_sum_rewards.tolist())
batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max")
batch_max_rewards = einops.reduce(
(rollout_data["reward"] * mask), "b n -> b", "max"
)
max_rewards.extend(batch_max_rewards.tolist())
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
batch_successes = einops.reduce(
(rollout_data["success"] * mask), "b n -> b", "any"
)
all_successes.extend(batch_successes.tolist())
if seeds:
all_seeds.extend(seeds)
@@ -326,17 +349,27 @@ def eval_policy(
rollout_data,
done_indices,
start_episode_index=batch_ix * env.num_envs,
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
start_data_index=(
0
if episode_data is None
else (episode_data["index"][-1].item() + 1)
),
fps=env.unwrapped.metadata["render_fps"],
)
if episode_data is None:
episode_data = this_episode_data
else:
# Some sanity checks to make sure we are correctly compiling the data.
assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0]
assert (
episode_data["episode_index"][-1] + 1
== this_episode_data["episode_index"][0]
)
assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
# Concatenate the episode data.
episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data}
episode_data = {
k: torch.cat([episode_data[k], this_episode_data[k]])
for k in episode_data
}
# Maybe render video for visualization.
if max_episodes_rendered > 0 and len(ep_frames) > 0:
@@ -354,7 +387,9 @@ def eval_policy(
target=write_video,
args=(
str(video_path),
stacked_frames[: done_index + 1], # + 1 to capture the last observation
stacked_frames[
: done_index + 1
], # + 1 to capture the last observation
env.unwrapped.metadata["render_fps"],
),
)
@@ -363,7 +398,9 @@ def eval_policy(
n_episodes_rendered += 1
progbar.set_postfix(
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
{
"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"
}
)
# Wait till all video rendering threads are done.
@@ -409,7 +446,11 @@ def eval_policy(
def _compile_episode_data(
rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float
rollout_data: dict,
done_indices: Tensor,
start_episode_index: int,
start_data_index: int,
fps: float,
) -> dict:
"""Convenience function for `eval_policy(return_episode_data=True)`
@@ -427,12 +468,16 @@ def _compile_episode_data(
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
ep_dict = {
"action": rollout_data["action"][ep_ix, : num_frames - 1],
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
"episode_index": torch.tensor(
[start_episode_index + ep_ix] * (num_frames - 1)
),
"frame_index": torch.arange(0, num_frames - 1, 1),
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
"next.done": rollout_data["done"][ep_ix, : num_frames - 1],
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(
torch.float32
),
}
# For the last observation frame, all other keys will just be copy padded.
@@ -448,7 +493,9 @@ def _compile_episode_data(
for key in ep_dicts[0]:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
data_dict["index"] = torch.arange(
start_data_index, start_data_index + total_frames, 1
)
return data_dict