mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
perf(eval): shared memory, observation passthrough, task prefetch
- AsyncVectorEnv now uses shared_memory=True for zero-copy observation transfer - LiberoEnvConfig.gym_kwargs passes observation_height/width to the env - eval_policy_all prefetches next task's workers while current task runs Made-with: Cursor
This commit is contained in:
@@ -402,7 +402,12 @@ class LiberoEnv(EnvConfig):
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode}
|
||||
kwargs: dict[str, Any] = {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"observation_height": self.observation_height,
|
||||
"observation_width": self.observation_width,
|
||||
}
|
||||
if self.task_ids is not None:
|
||||
kwargs["task_ids"] = self.task_ids
|
||||
return kwargs
|
||||
|
||||
@@ -435,7 +435,9 @@ class _LazyAsyncVectorEnv:
|
||||
|
||||
def _ensure(self):
|
||||
if self._env is None:
|
||||
self._env = gym.vector.AsyncVectorEnv(self._env_fns, context="forkserver")
|
||||
self._env = gym.vector.AsyncVectorEnv(
|
||||
self._env_fns, context="forkserver", shared_memory=True
|
||||
)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
self._ensure()
|
||||
|
||||
@@ -760,7 +760,19 @@ def eval_policy_all(
|
||||
)
|
||||
|
||||
if max_parallel_tasks <= 1:
|
||||
for task_group, task_id, env in tasks:
|
||||
prefetch_thread: threading.Thread | None = None
|
||||
for i, (task_group, task_id, env) in enumerate(tasks):
|
||||
if prefetch_thread is not None:
|
||||
prefetch_thread.join()
|
||||
prefetch_thread = None
|
||||
|
||||
# Prefetch next task's AsyncVectorEnv workers while this task runs.
|
||||
if i + 1 < len(tasks):
|
||||
next_env = tasks[i + 1][2]
|
||||
if hasattr(next_env, "_ensure"):
|
||||
prefetch_thread = threading.Thread(target=next_env._ensure, daemon=True)
|
||||
prefetch_thread.start()
|
||||
|
||||
try:
|
||||
tg, tid, metrics = task_runner(task_group, task_id, env)
|
||||
_accumulate_to(tg, metrics)
|
||||
|
||||
Reference in New Issue
Block a user