diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index 457dad576..842befb74 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -596,7 +596,17 @@ class LiberoEnv(gym.Env): f"Expected action to be 1-D (shape (action_dim,)), " f"but got shape {action.shape} with ndim={action.ndim}" ) - raw_obs, reward, done, info = self._env.step(action) + + try: + raw_obs, reward, done, info = self._env.step(action) + except ValueError as e: + if "terminated episode" not in str(e): + raise + # Robosuite's internal done flag is stale (e.g. from a previous + # termination that wasn't properly cleared by SyncVectorEnv). + # Signal termination so the caller resets us. + obs, reset_info = self.reset() + return obs, 0.0, True, False, {"is_success": False, **reset_info} is_success = self._env.check_success() terminated = done or is_success @@ -616,7 +626,6 @@ class LiberoEnv(gym.Env): "done": bool(done), "is_success": bool(is_success), } - self.reset() truncated = False return observation, reward, terminated, truncated, info @@ -714,7 +723,7 @@ def create_libero_envs( raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).") total_tasks += len(selected) - lazy = total_tasks > 50 + lazy = total_tasks > 1 if lazy: print(f"Using lazy env creation for {total_tasks} tasks (envs created on demand)") diff --git a/src/lerobot/scripts/lerobot_eval_worker.py b/src/lerobot/scripts/lerobot_eval_worker.py index 321795237..e7820d962 100644 --- a/src/lerobot/scripts/lerobot_eval_worker.py +++ b/src/lerobot/scripts/lerobot_eval_worker.py @@ -98,7 +98,14 @@ def run_worker(cfg: EvalWorkerConfig) -> dict: # Shard: this worker handles tasks where index % instance_count == instance_id if cfg.instance_count > 1: total = len(tasks) - tasks = [t for idx, t in enumerate(tasks) if idx % cfg.instance_count == cfg.instance_id] + assigned = {i for i in range(total) if i % cfg.instance_count == cfg.instance_id} + for i, (_, _, env) in enumerate(tasks): + if i not in assigned: + try: + env.close() + except Exception: + pass + tasks = [t for i, t in enumerate(tasks) if i in assigned] logger.info( "Shard %d/%d: %d/%d tasks assigned.", cfg.instance_id + 1,