mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
add changes
This commit is contained in:
@@ -143,7 +143,7 @@ def rollout(
|
||||
leave=False,
|
||||
)
|
||||
check_env_attributes_and_types(env)
|
||||
while not np.all(done):
|
||||
while not np.all(done) and step < max_steps:
|
||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||
observation = preprocess_observation(observation)
|
||||
if return_observations:
|
||||
@@ -185,8 +185,11 @@ def rollout(
|
||||
all_successes.append(torch.tensor(successes))
|
||||
|
||||
step += 1
|
||||
print(step)
|
||||
running_success_rate = (
|
||||
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
|
||||
# einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean() #TODO: changed by jade
|
||||
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "max")
|
||||
|
||||
)
|
||||
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
|
||||
progbar.update()
|
||||
@@ -315,7 +318,8 @@ def eval_policy(
|
||||
sum_rewards.extend(batch_sum_rewards.tolist())
|
||||
batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max")
|
||||
max_rewards.extend(batch_max_rewards.tolist())
|
||||
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
||||
# batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
||||
batch_successes = einops.reduce((rollout_data["success"] * mask).float(), "b n -> b", "max")
|
||||
all_successes.extend(batch_successes.tolist())
|
||||
if seeds:
|
||||
all_seeds.extend(seeds)
|
||||
|
||||
Reference in New Issue
Block a user