diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index 2190070f5..0d5bf1982 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -659,79 +659,86 @@ def control_loop( episode_step = 0 episode_start_time = time.perf_counter() - while episode_idx < cfg.dataset.num_episodes_to_record: - step_start_time = time.perf_counter() + try: + while episode_idx < cfg.dataset.num_episodes_to_record: + step_start_time = time.perf_counter() - # Create a neutral action (no movement) - neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) - if use_gripper: - neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay - - # Use the new step function - transition = step_env_and_process_transition( - env=env, - transition=transition, - action=neutral_action, - env_processor=env_processor, - action_processor=action_processor, - ) - terminated = transition.get(TransitionKey.DONE, False) - truncated = transition.get(TransitionKey.TRUNCATED, False) - - if cfg.mode == "record": - observations = { - k: v.squeeze(0).cpu() - for k, v in transition[TransitionKey.OBSERVATION].items() - if isinstance(v, torch.Tensor) - } - # Use teleop_action if available, otherwise use the action from the transition - action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get( - "teleop_action", transition[TransitionKey.ACTION] - ) - frame = { - **observations, - ACTION: action_to_record.cpu(), - REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32), - DONE: np.array([terminated or truncated], dtype=bool), - } + # Create a neutral action (no movement) + neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) if use_gripper: - discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0) - frame["complementary_info.discrete_penalty"] = np.array([discrete_penalty], dtype=np.float32) + neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay - if dataset is not None: - frame["task"] = cfg.dataset.task - dataset.add_frame(frame) - - episode_step += 1 - - # Handle episode termination - if terminated or truncated: - episode_time = time.perf_counter() - episode_start_time - logging.info( - f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}" + transition = step_env_and_process_transition( + env=env, + transition=transition, + action=neutral_action, + env_processor=env_processor, + action_processor=action_processor, ) - episode_step = 0 - episode_idx += 1 + terminated = transition.get(TransitionKey.DONE, False) + truncated = transition.get(TransitionKey.TRUNCATED, False) - if dataset is not None: - if transition[TransitionKey.INFO].get(TeleopEvents.RERECORD_EPISODE, False): - logging.info(f"Re-recording episode {episode_idx}") - dataset.clear_episode_buffer() - episode_idx -= 1 - else: - logging.info(f"Saving episode {episode_idx}") - dataset.save_episode() + if cfg.mode == "record": + observations = { + k: v.squeeze(0).cpu() + for k, v in transition[TransitionKey.OBSERVATION].items() + if isinstance(v, torch.Tensor) + } + action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get( + "teleop_action", transition[TransitionKey.ACTION] + ) + frame = { + **observations, + ACTION: action_to_record.cpu(), + REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32), + DONE: np.array([terminated or truncated], dtype=bool), + } + if use_gripper: + discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get( + "discrete_penalty", 0.0 + ) + frame["complementary_info.discrete_penalty"] = np.array( + [discrete_penalty], dtype=np.float32 + ) - # Reset for new episode - obs, info = env.reset() - env_processor.reset() - action_processor.reset() + if dataset is not None: + frame["task"] = cfg.dataset.task + dataset.add_frame(frame) - transition = create_transition(observation=obs, info=info) - transition = env_processor(transition) + episode_step += 1 - # Maintain fps timing - precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0)) + # Handle episode termination + if terminated or truncated: + episode_time = time.perf_counter() - episode_start_time + logging.info( + f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}" + ) + episode_step = 0 + episode_idx += 1 + + if dataset is not None: + if transition[TransitionKey.INFO].get(TeleopEvents.RERECORD_EPISODE, False): + logging.info(f"Re-recording episode {episode_idx}") + dataset.clear_episode_buffer() + episode_idx -= 1 + else: + logging.info(f"Saving episode {episode_idx}") + dataset.save_episode() + + # Reset for new episode + obs, info = env.reset() + env_processor.reset() + action_processor.reset() + + transition = create_transition(observation=obs, info=info) + transition = env_processor(transition) + + # Maintain fps timing + precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0)) + finally: + if dataset is not None and dataset.writer is not None and dataset.writer.image_writer is not None: + logging.info("Waiting for image writer to finish...") + dataset.writer.image_writer.stop() if dataset is not None and cfg.dataset.push_to_hub: logging.info("Finalizing dataset before pushing to hub")