mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 03:11:29 +00:00
cleanup (#5)
This commit is contained in:
@@ -49,7 +49,7 @@ def create_libero_envs(
|
||||
n_repeat = n_envs // len(tasks_id)
|
||||
print("n_repeat", n_repeat)
|
||||
episode_indices = []
|
||||
for i in range(len(tasks_id)):
|
||||
for _ in range(len(tasks_id)):
|
||||
episode_indices.extend(list(range(n_repeat)))
|
||||
tasks_id = list(chain.from_iterable([[item] * n_repeat for item in tasks_id]))
|
||||
elif n_envs < len(tasks_id):
|
||||
@@ -90,14 +90,17 @@ def create_libero_envs(
|
||||
f"Creating Libero envs with task ids {tasks_id} from suite {_task}, episode_indices: {episode_indices}"
|
||||
)
|
||||
envs_list = [
|
||||
lambda i=i: LiberoEnv(
|
||||
task_suite=task_suite,
|
||||
task_id=tasks_id,
|
||||
task_suite_name=_task,
|
||||
camera_name=camera_name,
|
||||
init_states=init_states,
|
||||
episode_index=episode_indices[i],
|
||||
**gym_kwargs,
|
||||
(lambda i=i, task_suite=task_suite, tasks_id=tasks_id,
|
||||
_task=_task, episode_indices=episode_indices:
|
||||
LiberoEnv(
|
||||
task_suite=task_suite,
|
||||
task_id=tasks_id,
|
||||
task_suite_name=_task,
|
||||
camera_name=camera_name,
|
||||
init_states=init_states,
|
||||
episode_index=episode_indices[i],
|
||||
**gym_kwargs,
|
||||
)
|
||||
)
|
||||
for i in range(n_envs)
|
||||
]
|
||||
|
||||
@@ -301,8 +301,8 @@ def train(cfg: TrainPipelineConfig):
|
||||
if eval_env:
|
||||
# added by jade, close all env in multi eval setup
|
||||
if cfg.env.multitask_eval:
|
||||
for task_group, envs_dict in eval_env.items():
|
||||
for idx, env in envs_dict.items():
|
||||
for _task_group, envs_dict in eval_env.items():
|
||||
for _idx, env in envs_dict.items():
|
||||
env.close()
|
||||
else:
|
||||
eval_env.close()
|
||||
|
||||
Reference in New Issue
Block a user