mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 12:21:27 +00:00
Refactored actor.py to use the pipeline
This commit is contained in:
@@ -83,6 +83,13 @@ from lerobot.utils.transition import (
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
)
|
||||
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||
from lerobot.scripts.rl.gym_manipulator import (
|
||||
create_transition,
|
||||
make_processors,
|
||||
step_env_and_process_transition,
|
||||
)
|
||||
from lerobot.utils.utils import (
|
||||
TimerManager,
|
||||
get_safe_torch_device,
|
||||
@@ -236,7 +243,8 @@ def act_with_policy(
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(cfg=cfg.env)
|
||||
online_env, teleop_device = make_robot_env(cfg=cfg.env)
|
||||
env_processor, action_processor = make_processors(online_env, cfg.env, cfg.policy.device)
|
||||
|
||||
set_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
@@ -257,6 +265,13 @@ def act_with_policy(
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||
transition = env_processor(transition)
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
@@ -277,7 +292,9 @@ def act_with_policy(
|
||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
action = policy.select_action(batch=obs)
|
||||
# Extract observation from transition for policy
|
||||
batch_obs = transition[TransitionKey.OBSERVATION]
|
||||
action = policy.select_action(batch=batch_obs)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
@@ -285,34 +302,46 @@ def act_with_policy(
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
|
||||
# Use the new step function
|
||||
new_transition, terminate_episode = step_env_and_process_transition(
|
||||
env=online_env,
|
||||
transition=transition,
|
||||
action=action,
|
||||
teleop_device=teleop_device,
|
||||
env_processor=env_processor,
|
||||
action_processor=action_processor,
|
||||
)
|
||||
|
||||
# Extract values from processed transition
|
||||
reward = new_transition[TransitionKey.REWARD]
|
||||
done = new_transition.get(TransitionKey.DONE, False)
|
||||
truncated = new_transition.get(TransitionKey.TRUNCATED, False)
|
||||
processed_action = new_transition[TransitionKey.ACTION]
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
# Increment total steps counter for intervention rate
|
||||
episode_total_steps += 1
|
||||
|
||||
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
|
||||
if "is_intervention" in info and info["is_intervention"]:
|
||||
# NOTE: The action space for demonstration before hand is with the full action space
|
||||
# but sometimes for example we want to deactivate the gripper
|
||||
action = info["action_intervention"]
|
||||
# Check for intervention from transition info
|
||||
intervention_info = new_transition[TransitionKey.INFO]
|
||||
if intervention_info.get("is_intervention", False):
|
||||
episode_intervention = True
|
||||
# Increment intervention steps counter
|
||||
episode_intervention_steps += 1
|
||||
|
||||
# Create transition for learner (convert to old format)
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=obs,
|
||||
action=action,
|
||||
state=transition[TransitionKey.OBSERVATION],
|
||||
action=processed_action,
|
||||
reward=reward,
|
||||
next_state=next_obs,
|
||||
next_state=new_transition[TransitionKey.OBSERVATION],
|
||||
done=done,
|
||||
truncated=truncated, # TODO: (azouitine) Handle truncation properly
|
||||
complementary_info=info,
|
||||
truncated=truncated,
|
||||
complementary_info=new_transition[TransitionKey.COMPLEMENTARY_DATA],
|
||||
)
|
||||
)
|
||||
# assign obs to the next obs and continue the rollout
|
||||
obs = next_obs
|
||||
|
||||
# Update transition for next iteration
|
||||
transition = new_transition
|
||||
|
||||
if done or truncated:
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
@@ -347,12 +376,21 @@ def act_with_policy(
|
||||
)
|
||||
)
|
||||
|
||||
# Reset intervention counters
|
||||
# Reset intervention counters and environment
|
||||
sum_reward_episode = 0.0
|
||||
episode_intervention = False
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
# Reset environment and processors
|
||||
obs, info = online_env.reset()
|
||||
complementary_data = {"raw_joint_positions": info.pop("raw_joint_positions")}
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||
transition = env_processor(transition)
|
||||
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
|
||||
Reference in New Issue
Block a user