mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 12:21:27 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
AdilZouitine
parent
76df8a31b3
commit
38f5fa4523
@@ -170,7 +170,10 @@ def rollout(
|
||||
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
||||
# available of none of the envs finished.
|
||||
if "final_info" in info:
|
||||
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
|
||||
successes = [
|
||||
info["is_success"] if info is not None else False
|
||||
for info in info["final_info"]
|
||||
]
|
||||
else:
|
||||
successes = [False] * env.num_envs
|
||||
|
||||
@@ -184,9 +187,13 @@ def rollout(
|
||||
|
||||
step += 1
|
||||
running_success_rate = (
|
||||
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
|
||||
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any")
|
||||
.numpy()
|
||||
.mean()
|
||||
)
|
||||
progbar.set_postfix(
|
||||
{"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}
|
||||
)
|
||||
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
|
||||
progbar.update()
|
||||
|
||||
# Track the final observation.
|
||||
@@ -204,7 +211,9 @@ def rollout(
|
||||
if return_observations:
|
||||
stacked_observations = {}
|
||||
for key in all_observations[0]:
|
||||
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
|
||||
stacked_observations[key] = torch.stack(
|
||||
[obs[key] for obs in all_observations], dim=1
|
||||
)
|
||||
ret["observation"] = stacked_observations
|
||||
|
||||
if hasattr(policy, "use_original_modules"):
|
||||
@@ -266,7 +275,9 @@ def eval_policy(
|
||||
return
|
||||
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
|
||||
if isinstance(env, gym.vector.SyncVectorEnv):
|
||||
ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
|
||||
ep_frames.append(
|
||||
np.stack([env.envs[i].render() for i in range(n_to_render_now)])
|
||||
) # noqa: B023
|
||||
elif isinstance(env, gym.vector.AsyncVectorEnv):
|
||||
# Here we must render all frames and discard any we don't need.
|
||||
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
|
||||
@@ -278,7 +289,9 @@ def eval_policy(
|
||||
episode_data: dict | None = None
|
||||
|
||||
# we dont want progress bar when we use slurm, since it clutters the logs
|
||||
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
|
||||
progbar = trange(
|
||||
n_batches, desc="Stepping through eval batches", disable=inside_slurm()
|
||||
)
|
||||
for batch_ix in progbar:
|
||||
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
|
||||
# step.
|
||||
@@ -289,7 +302,8 @@ def eval_policy(
|
||||
seeds = None
|
||||
else:
|
||||
seeds = range(
|
||||
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
|
||||
start_seed + (batch_ix * env.num_envs),
|
||||
start_seed + ((batch_ix + 1) * env.num_envs),
|
||||
)
|
||||
rollout_data = rollout(
|
||||
env,
|
||||
@@ -307,13 +321,22 @@ def eval_policy(
|
||||
|
||||
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
|
||||
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
|
||||
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
|
||||
mask = (
|
||||
torch.arange(n_steps)
|
||||
<= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)
|
||||
).int()
|
||||
# Extend metrics.
|
||||
batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum")
|
||||
batch_sum_rewards = einops.reduce(
|
||||
(rollout_data["reward"] * mask), "b n -> b", "sum"
|
||||
)
|
||||
sum_rewards.extend(batch_sum_rewards.tolist())
|
||||
batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max")
|
||||
batch_max_rewards = einops.reduce(
|
||||
(rollout_data["reward"] * mask), "b n -> b", "max"
|
||||
)
|
||||
max_rewards.extend(batch_max_rewards.tolist())
|
||||
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
||||
batch_successes = einops.reduce(
|
||||
(rollout_data["success"] * mask), "b n -> b", "any"
|
||||
)
|
||||
all_successes.extend(batch_successes.tolist())
|
||||
if seeds:
|
||||
all_seeds.extend(seeds)
|
||||
@@ -326,17 +349,27 @@ def eval_policy(
|
||||
rollout_data,
|
||||
done_indices,
|
||||
start_episode_index=batch_ix * env.num_envs,
|
||||
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
|
||||
start_data_index=(
|
||||
0
|
||||
if episode_data is None
|
||||
else (episode_data["index"][-1].item() + 1)
|
||||
),
|
||||
fps=env.unwrapped.metadata["render_fps"],
|
||||
)
|
||||
if episode_data is None:
|
||||
episode_data = this_episode_data
|
||||
else:
|
||||
# Some sanity checks to make sure we are correctly compiling the data.
|
||||
assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0]
|
||||
assert (
|
||||
episode_data["episode_index"][-1] + 1
|
||||
== this_episode_data["episode_index"][0]
|
||||
)
|
||||
assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
|
||||
# Concatenate the episode data.
|
||||
episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data}
|
||||
episode_data = {
|
||||
k: torch.cat([episode_data[k], this_episode_data[k]])
|
||||
for k in episode_data
|
||||
}
|
||||
|
||||
# Maybe render video for visualization.
|
||||
if max_episodes_rendered > 0 and len(ep_frames) > 0:
|
||||
@@ -354,7 +387,9 @@ def eval_policy(
|
||||
target=write_video,
|
||||
args=(
|
||||
str(video_path),
|
||||
stacked_frames[: done_index + 1], # + 1 to capture the last observation
|
||||
stacked_frames[
|
||||
: done_index + 1
|
||||
], # + 1 to capture the last observation
|
||||
env.unwrapped.metadata["render_fps"],
|
||||
),
|
||||
)
|
||||
@@ -363,7 +398,9 @@ def eval_policy(
|
||||
n_episodes_rendered += 1
|
||||
|
||||
progbar.set_postfix(
|
||||
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
|
||||
{
|
||||
"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"
|
||||
}
|
||||
)
|
||||
|
||||
# Wait till all video rendering threads are done.
|
||||
@@ -409,7 +446,11 @@ def eval_policy(
|
||||
|
||||
|
||||
def _compile_episode_data(
|
||||
rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float
|
||||
rollout_data: dict,
|
||||
done_indices: Tensor,
|
||||
start_episode_index: int,
|
||||
start_data_index: int,
|
||||
fps: float,
|
||||
) -> dict:
|
||||
"""Convenience function for `eval_policy(return_episode_data=True)`
|
||||
|
||||
@@ -427,12 +468,16 @@ def _compile_episode_data(
|
||||
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
|
||||
ep_dict = {
|
||||
"action": rollout_data["action"][ep_ix, : num_frames - 1],
|
||||
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
|
||||
"episode_index": torch.tensor(
|
||||
[start_episode_index + ep_ix] * (num_frames - 1)
|
||||
),
|
||||
"frame_index": torch.arange(0, num_frames - 1, 1),
|
||||
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
|
||||
"next.done": rollout_data["done"][ep_ix, : num_frames - 1],
|
||||
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
|
||||
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
|
||||
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(
|
||||
torch.float32
|
||||
),
|
||||
}
|
||||
|
||||
# For the last observation frame, all other keys will just be copy padded.
|
||||
@@ -448,7 +493,9 @@ def _compile_episode_data(
|
||||
for key in ep_dicts[0]:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
|
||||
data_dict["index"] = torch.arange(
|
||||
start_data_index, start_data_index + total_frames, 1
|
||||
)
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
Reference in New Issue
Block a user