refactor(eval): integrate preprocessor and postprocessor into rollout and eval_policy functions (#1900)

* refactor(eval): integrate preprocessor and postprocessor into rollout and eval_policy functions

- Updated the `rollout` and `eval_policy` functions to accept preprocessor and postprocessor parameters, enhancing the flexibility of the evaluation pipeline.
- Adjusted the implementation to apply preprocessing and postprocessing steps during policy evaluation, improving the overall data handling and processing flow.

* refactor(eval): remove redundant observation device conversion in rollout function

- Eliminated unnecessary device conversion for the observation dictionary within the `rollout` function, streamlining the code and enhancing readability.
- This change simplifies the observation handling process, aligning with the preference for clearer solutions.

* debug

* refactor(utils): enhance task handling in add_envs_task function

- Improved the `add_envs_task` function to validate the output of `task_description` and `task` calls, ensuring they return lists of strings.
- Removed the use of `else` statement for environments without language instructions, simplifying the logic and enhancing readability.
- Streamlined the observation dictionary handling by ensuring consistent data types for task attributes.
This commit is contained in:
Adil Zouitine
2025-09-09 17:00:34 +02:00
committed by GitHub
parent 846677f9cc
commit a74b90edd1
3 changed files with 51 additions and 19 deletions

View File

@@ -127,9 +127,29 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
"""Adds task feature to the observation dict with respect to the first environment attribute."""
if hasattr(env.envs[0], "task_description"):
observation["task"] = env.call("task_description")
task_result = env.call("task_description")
if isinstance(task_result, tuple):
task_result = list(task_result)
if not isinstance(task_result, list):
raise TypeError(f"Expected task_description to return a list, got {type(task_result)}")
if not all(isinstance(item, str) for item in task_result):
raise TypeError("All items in task_description result must be strings")
observation["task"] = task_result
elif hasattr(env.envs[0], "task"):
observation["task"] = env.call("task")
task_result = env.call("task")
if isinstance(task_result, tuple):
task_result = list(task_result)
if not isinstance(task_result, list):
raise TypeError(f"Expected task to return a list, got {type(task_result)}")
if not all(isinstance(item, str) for item in task_result):
raise TypeError("All items in task result must be strings")
observation["task"] = task_result
else: # For envs without language instructions, e.g. aloha transfer cube and etc.
num_envs = observation[list(observation.keys())[0]].shape[0]
observation["task"] = ["" for _ in range(num_envs)]

View File

@@ -56,6 +56,7 @@ from copy import deepcopy
from dataclasses import asdict
from pathlib import Path
from pprint import pformat
from typing import Any
import einops
import gymnasium as gym
@@ -69,9 +70,10 @@ from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
from lerobot.policies.factory import make_policy
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.processor.core import TransitionKey
from lerobot.processor.pipeline import PolicyProcessorPipeline
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
@@ -84,6 +86,8 @@ from lerobot.utils.utils import (
def rollout(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline[dict[str, Any]],
postprocessor: PolicyProcessorPipeline[dict[str, Any]],
seeds: list[int] | None = None,
return_observations: bool = False,
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
@@ -120,7 +124,6 @@ def rollout(
The dictionary described above.
"""
assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
device = get_device_from_parameters(policy)
# Reset the policy and environments.
policy.reset()
@@ -151,19 +154,16 @@ def rollout(
if return_observations:
all_observations.append(deepcopy(observation))
observation = {
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
}
# Infer "task" from attributes of environments.
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
observation = add_envs_task(env, observation)
observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
# Convert to CPU / numpy.
action = action.to("cpu").numpy()
action: np.ndarray = action.to("cpu").numpy()
assert action.ndim == 2, "Action dimensions should be (batch, action_dim)"
# Apply the next action.
@@ -220,6 +220,8 @@ def rollout(
def eval_policy(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
preprocessor: PolicyProcessorPipeline,
postprocessor: PolicyProcessorPipeline,
n_episodes: int,
max_episodes_rendered: int = 0,
videos_dir: Path | None = None,
@@ -296,8 +298,10 @@ def eval_policy(
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
)
rollout_data = rollout(
env,
policy,
env=env,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
seeds=list(seeds) if seeds else None,
return_observations=return_episode_data,
render_callback=render_frame if max_episodes_rendered > 0 else None,
@@ -479,13 +483,19 @@ def eval_main(cfg: EvalPipelineConfig):
cfg=cfg.policy,
env_cfg=cfg.env,
)
policy.eval()
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path
)
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy(
env,
policy,
cfg.eval.n_episodes,
env=env,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
max_episodes_rendered=10,
videos_dir=Path(cfg.output_dir) / "videos",
start_seed=cfg.seed,

View File

@@ -298,9 +298,11 @@ def train(cfg: TrainPipelineConfig):
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
):
eval_info = eval_policy(
eval_env,
policy,
cfg.eval.n_episodes,
env=eval_env,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
start_seed=cfg.seed,