mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 03:11:29 +00:00
RL works at this commit - fixed actor.py and bugs in gym_manipulator
This commit is contained in:
@@ -243,7 +243,7 @@ def act_with_policy(
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env, teleop_device = make_robot_env(cfg=cfg.env)
|
||||
env_processor, action_processor = make_processors(online_env, cfg.env, cfg.policy.device)
|
||||
env_processor, action_processor = make_processors(online_env, cfg.env)
|
||||
|
||||
set_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
@@ -288,18 +288,15 @@ def act_with_policy(
|
||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||
return
|
||||
|
||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
# Extract observation from transition for policy
|
||||
batch_obs = transition[TransitionKey.OBSERVATION]
|
||||
action = policy.select_action(batch=batch_obs)
|
||||
policy_fps = policy_timer.fps_last
|
||||
observation = transition[TransitionKey.OBSERVATION]
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
# Extract observation from transition for policy
|
||||
action = policy.select_action(batch=observation)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
# Use the new step function
|
||||
new_transition, terminate_episode = step_env_and_process_transition(
|
||||
@@ -312,10 +309,11 @@ def act_with_policy(
|
||||
)
|
||||
|
||||
# Extract values from processed transition
|
||||
next_observation = new_transition[TransitionKey.OBSERVATION]
|
||||
executed_action = new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"]
|
||||
reward = new_transition[TransitionKey.REWARD]
|
||||
done = new_transition.get(TransitionKey.DONE, False)
|
||||
truncated = new_transition.get(TransitionKey.TRUNCATED, False)
|
||||
processed_action = new_transition[TransitionKey.ACTION]
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
episode_total_steps += 1
|
||||
@@ -329,13 +327,13 @@ def act_with_policy(
|
||||
# Create transition for learner (convert to old format)
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=transition[TransitionKey.OBSERVATION],
|
||||
action=processed_action,
|
||||
state=observation,
|
||||
action=executed_action,
|
||||
reward=reward,
|
||||
next_state=new_transition[TransitionKey.OBSERVATION],
|
||||
next_state=next_observation,
|
||||
done=done,
|
||||
truncated=truncated,
|
||||
complementary_info=new_transition[TransitionKey.COMPLEMENTARY_DATA],
|
||||
complementary_info={}, # new_transition[TransitionKey.COMPLEMENTARY_DATA],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user