Lazy env creation + smart sharding to fix container OOM

This commit is contained in:
Pepijn Kooijmans
2026-03-23 23:15:23 +01:00
parent aae68e3448
commit a9e355bd03
2 changed files with 20 additions and 4 deletions

View File

@@ -596,7 +596,17 @@ class LiberoEnv(gym.Env):
f"Expected action to be 1-D (shape (action_dim,)), "
f"but got shape {action.shape} with ndim={action.ndim}"
)
raw_obs, reward, done, info = self._env.step(action)
try:
raw_obs, reward, done, info = self._env.step(action)
except ValueError as e:
if "terminated episode" not in str(e):
raise
# Robosuite's internal done flag is stale (e.g. from a previous
# termination that wasn't properly cleared by SyncVectorEnv).
# Signal termination so the caller resets us.
obs, reset_info = self.reset()
return obs, 0.0, True, False, {"is_success": False, **reset_info}
is_success = self._env.check_success()
terminated = done or is_success
@@ -616,7 +626,6 @@ class LiberoEnv(gym.Env):
"done": bool(done),
"is_success": bool(is_success),
}
self.reset()
truncated = False
return observation, reward, terminated, truncated, info
@@ -714,7 +723,7 @@ def create_libero_envs(
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
total_tasks += len(selected)
lazy = total_tasks > 50
lazy = total_tasks > 1
if lazy:
print(f"Using lazy env creation for {total_tasks} tasks (envs created on demand)")

View File

@@ -98,7 +98,14 @@ def run_worker(cfg: EvalWorkerConfig) -> dict:
# Shard: this worker handles tasks where index % instance_count == instance_id
if cfg.instance_count > 1:
total = len(tasks)
tasks = [t for idx, t in enumerate(tasks) if idx % cfg.instance_count == cfg.instance_id]
assigned = {i for i in range(total) if i % cfg.instance_count == cfg.instance_id}
for i, (_, _, env) in enumerate(tasks):
if i not in assigned:
try:
env.close()
except Exception:
pass
tasks = [t for i, t in enumerate(tasks) if i in assigned]
logger.info(
"Shard %d/%d: %d/%d tasks assigned.",
cfg.instance_id + 1,