This commit is contained in:
Jade Choghari
2025-08-28 22:49:32 +03:00
committed by GitHub
parent 440e22c184
commit cb18fc07ef
2 changed files with 14 additions and 11 deletions

View File

@@ -49,7 +49,7 @@ def create_libero_envs(
n_repeat = n_envs // len(tasks_id)
print("n_repeat", n_repeat)
episode_indices = []
for i in range(len(tasks_id)):
for _ in range(len(tasks_id)):
episode_indices.extend(list(range(n_repeat)))
tasks_id = list(chain.from_iterable([[item] * n_repeat for item in tasks_id]))
elif n_envs < len(tasks_id):
@@ -90,14 +90,17 @@ def create_libero_envs(
f"Creating Libero envs with task ids {tasks_id} from suite {_task}, episode_indices: {episode_indices}"
)
envs_list = [
lambda i=i: LiberoEnv(
task_suite=task_suite,
task_id=tasks_id,
task_suite_name=_task,
camera_name=camera_name,
init_states=init_states,
episode_index=episode_indices[i],
**gym_kwargs,
(lambda i=i, task_suite=task_suite, tasks_id=tasks_id,
_task=_task, episode_indices=episode_indices:
LiberoEnv(
task_suite=task_suite,
task_id=tasks_id,
task_suite_name=_task,
camera_name=camera_name,
init_states=init_states,
episode_index=episode_indices[i],
**gym_kwargs,
)
)
for i in range(n_envs)
]

View File

@@ -301,8 +301,8 @@ def train(cfg: TrainPipelineConfig):
if eval_env:
# added by jade, close all env in multi eval setup
if cfg.env.multitask_eval:
for task_group, envs_dict in eval_env.items():
for idx, env in envs_dict.items():
for _task_group, envs_dict in eval_env.items():
for _idx, env in envs_dict.items():
env.close()
else:
eval_env.close()