mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
Lazy env creation + smart sharding to fix container OOM
This commit is contained in:
@@ -596,7 +596,17 @@ class LiberoEnv(gym.Env):
|
||||
f"Expected action to be 1-D (shape (action_dim,)), "
|
||||
f"but got shape {action.shape} with ndim={action.ndim}"
|
||||
)
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
|
||||
try:
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
except ValueError as e:
|
||||
if "terminated episode" not in str(e):
|
||||
raise
|
||||
# Robosuite's internal done flag is stale (e.g. from a previous
|
||||
# termination that wasn't properly cleared by SyncVectorEnv).
|
||||
# Signal termination so the caller resets us.
|
||||
obs, reset_info = self.reset()
|
||||
return obs, 0.0, True, False, {"is_success": False, **reset_info}
|
||||
|
||||
is_success = self._env.check_success()
|
||||
terminated = done or is_success
|
||||
@@ -616,7 +626,6 @@ class LiberoEnv(gym.Env):
|
||||
"done": bool(done),
|
||||
"is_success": bool(is_success),
|
||||
}
|
||||
self.reset()
|
||||
truncated = False
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
@@ -714,7 +723,7 @@ def create_libero_envs(
|
||||
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
|
||||
total_tasks += len(selected)
|
||||
|
||||
lazy = total_tasks > 50
|
||||
lazy = total_tasks > 1
|
||||
if lazy:
|
||||
print(f"Using lazy env creation for {total_tasks} tasks (envs created on demand)")
|
||||
|
||||
|
||||
@@ -98,7 +98,14 @@ def run_worker(cfg: EvalWorkerConfig) -> dict:
|
||||
# Shard: this worker handles tasks where index % instance_count == instance_id
|
||||
if cfg.instance_count > 1:
|
||||
total = len(tasks)
|
||||
tasks = [t for idx, t in enumerate(tasks) if idx % cfg.instance_count == cfg.instance_id]
|
||||
assigned = {i for i in range(total) if i % cfg.instance_count == cfg.instance_id}
|
||||
for i, (_, _, env) in enumerate(tasks):
|
||||
if i not in assigned:
|
||||
try:
|
||||
env.close()
|
||||
except Exception:
|
||||
pass
|
||||
tasks = [t for i, t in enumerate(tasks) if i in assigned]
|
||||
logger.info(
|
||||
"Shard %d/%d: %d/%d tasks assigned.",
|
||||
cfg.instance_id + 1,
|
||||
|
||||
Reference in New Issue
Block a user