perf(eval): shared memory, observation passthrough, task prefetch

- AsyncVectorEnv now uses shared_memory=True for zero-copy observation transfer
- LiberoEnvConfig.gym_kwargs passes observation_height/width to the env
- eval_policy_all prefetches next task's workers while current task runs

Made-with: Cursor
This commit is contained in:
Pepijn Kooijmans
2026-04-07 16:21:52 +02:00
parent 634aa89558
commit 5ce727f20f
3 changed files with 22 additions and 3 deletions

View File

@@ -402,7 +402,12 @@ class LiberoEnv(EnvConfig):
@property
def gym_kwargs(self) -> dict:
kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode}
kwargs: dict[str, Any] = {
"obs_type": self.obs_type,
"render_mode": self.render_mode,
"observation_height": self.observation_height,
"observation_width": self.observation_width,
}
if self.task_ids is not None:
kwargs["task_ids"] = self.task_ids
return kwargs

View File

@@ -435,7 +435,9 @@ class _LazyAsyncVectorEnv:
def _ensure(self):
if self._env is None:
self._env = gym.vector.AsyncVectorEnv(self._env_fns, context="forkserver")
self._env = gym.vector.AsyncVectorEnv(
self._env_fns, context="forkserver", shared_memory=True
)
def reset(self, **kwargs):
self._ensure()

View File

@@ -760,7 +760,19 @@ def eval_policy_all(
)
if max_parallel_tasks <= 1:
for task_group, task_id, env in tasks:
prefetch_thread: threading.Thread | None = None
for i, (task_group, task_id, env) in enumerate(tasks):
if prefetch_thread is not None:
prefetch_thread.join()
prefetch_thread = None
# Prefetch next task's AsyncVectorEnv workers while this task runs.
if i + 1 < len(tasks):
next_env = tasks[i + 1][2]
if hasattr(next_env, "_ensure"):
prefetch_thread = threading.Thread(target=next_env._ensure, daemon=True)
prefetch_thread.start()
try:
tg, tid, metrics = task_runner(task_group, task_id, env)
_accumulate_to(tg, metrics)