diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index 3928e6c96..45f7a9c39 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -90,9 +90,9 @@ from lerobot.utils.utils import ( ) from .gym_manipulator import ( - create_transition, make_processors, make_robot_env, + reset_and_build_transition, step_env_and_process_transition, ) @@ -265,13 +265,7 @@ def act_with_policy( dataset_stats=cfg.policy.dataset_stats, ) - obs, info = online_env.reset() - env_processor.reset() - action_processor.reset() - - # Process initial observation - transition = create_transition(observation=obs, info=info) - transition = env_processor(transition) + transition = reset_and_build_transition(online_env, env_processor, action_processor) # NOTE: For the moment we will solely handle the case of a single environment sum_reward_episode = 0 @@ -396,14 +390,7 @@ def act_with_policy( episode_intervention_steps = 0 episode_total_steps = 0 - # Reset environment and processors - obs, info = online_env.reset() - env_processor.reset() - action_processor.reset() - - # Process initial observation - transition = create_transition(observation=obs, info=info) - transition = env_processor(transition) + transition = reset_and_build_transition(online_env, env_processor, action_processor) if cfg.env.fps is not None: dt_time = time.perf_counter() - start_time