Refactored actor.py to use the pipeline

This commit is contained in:
Michel Aractingi
2025-08-02 19:06:56 +02:00
parent e6e1edfd74
commit cfa672129e
2 changed files with 186 additions and 93 deletions

View File

@@ -83,6 +83,13 @@ from lerobot.utils.transition import (
move_state_dict_to_device,
move_transition_to_device,
)
from lerobot.processor.pipeline import EnvTransition, TransitionKey
from lerobot.scripts.rl.gym_manipulator import (
create_transition,
make_processors,
step_env_and_process_transition,
)
from lerobot.utils.utils import (
TimerManager,
get_safe_torch_device,
@@ -236,7 +243,8 @@ def act_with_policy(
logging.info("make_env online")
online_env = make_robot_env(cfg=cfg.env)
online_env, teleop_device = make_robot_env(cfg=cfg.env)
env_processor, action_processor = make_processors(online_env, cfg.env, cfg.policy.device)
set_seed(cfg.seed)
device = get_safe_torch_device(cfg.policy.device, log=True)
@@ -257,6 +265,13 @@ def act_with_policy(
assert isinstance(policy, nn.Module)
obs, info = online_env.reset()
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
transition = env_processor(transition)
# NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0
@@ -277,7 +292,9 @@ def act_with_policy(
if interaction_step >= cfg.policy.online_step_before_learning:
# Time policy inference and check if it meets FPS requirement
with policy_timer:
action = policy.select_action(batch=obs)
# Extract observation from transition for policy
batch_obs = transition[TransitionKey.OBSERVATION]
action = policy.select_action(batch=batch_obs)
policy_fps = policy_timer.fps_last
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
@@ -285,34 +302,46 @@ def act_with_policy(
else:
action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action)
# Use the new step function
new_transition, terminate_episode = step_env_and_process_transition(
env=online_env,
transition=transition,
action=action,
teleop_device=teleop_device,
env_processor=env_processor,
action_processor=action_processor,
)
# Extract values from processed transition
reward = new_transition[TransitionKey.REWARD]
done = new_transition.get(TransitionKey.DONE, False)
truncated = new_transition.get(TransitionKey.TRUNCATED, False)
processed_action = new_transition[TransitionKey.ACTION]
sum_reward_episode += float(reward)
# Increment total steps counter for intervention rate
episode_total_steps += 1
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
if "is_intervention" in info and info["is_intervention"]:
# NOTE: The action space for demonstration before hand is with the full action space
# but sometimes for example we want to deactivate the gripper
action = info["action_intervention"]
# Check for intervention from transition info
intervention_info = new_transition[TransitionKey.INFO]
if intervention_info.get("is_intervention", False):
episode_intervention = True
# Increment intervention steps counter
episode_intervention_steps += 1
# Create transition for learner (convert to old format)
list_transition_to_send_to_learner.append(
Transition(
state=obs,
action=action,
state=transition[TransitionKey.OBSERVATION],
action=processed_action,
reward=reward,
next_state=next_obs,
next_state=new_transition[TransitionKey.OBSERVATION],
done=done,
truncated=truncated, # TODO: (azouitine) Handle truncation properly
complementary_info=info,
truncated=truncated,
complementary_info=new_transition[TransitionKey.COMPLEMENTARY_DATA],
)
)
# assign obs to the next obs and continue the rollout
obs = next_obs
# Update transition for next iteration
transition = new_transition
if done or truncated:
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
@@ -347,12 +376,21 @@ def act_with_policy(
)
)
# Reset intervention counters
# Reset intervention counters and environment
sum_reward_episode = 0.0
episode_intervention = False
episode_intervention_steps = 0
episode_total_steps = 0
# Reset environment and processors
obs, info = online_env.reset()
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
transition = env_processor(transition)
if cfg.env.fps is not None:
dt_time = time.perf_counter() - start_time