mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
- Fixed big issue in the loading of the policy parameters sent by the learner to the actor -- pass only the actor to the update_policy_parameters and remove strict=False
- Fixed big issue in the normalization of the actions in the `forward` function of the critic -- remove the `torch.no_grad` decorator in `normalize.py` in the normalization function - Fixed performance issue to boost the optimization frequency by setting the storage device to be the same as the device of learning. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
@@ -166,7 +166,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, d
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
state_dict = parameters_queue.get()
|
||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||
policy.load_state_dict(state_dict, strict=False)
|
||||
policy.load_state_dict(state_dict)
|
||||
|
||||
|
||||
def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module):
|
||||
@@ -182,7 +182,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg.env)
|
||||
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
@@ -283,7 +283,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||
# TODO: Handle logging for episode information
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
|
||||
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
|
||||
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
send_transitions_in_chunks(
|
||||
|
||||
Reference in New Issue
Block a user