mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
Clean the code
This commit is contained in:
@@ -256,7 +256,8 @@ def add_actor_information_and_train(
|
||||
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
|
||||
parameters_queue (Queue): Queue for sending policy parameters to the actor.
|
||||
"""
|
||||
# Extract all configuration variables at the beginning
|
||||
# Extract all configuration variables at the beginning, it improve the speed performance
|
||||
# of 7%
|
||||
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
|
||||
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
|
||||
clip_grad_norm_value = cfg.policy.grad_clip_norm
|
||||
@@ -283,11 +284,11 @@ def add_actor_information_and_train(
|
||||
|
||||
policy: SACPolicy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
# ds_meta=cfg.dataset,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
policy.train()
|
||||
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
@@ -295,6 +296,8 @@ def add_actor_information_and_train(
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
|
||||
|
||||
# If we are resuming, we need to load the training state
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
|
||||
|
||||
log_training_info(cfg=cfg, policy=policy)
|
||||
@@ -330,6 +333,7 @@ def add_actor_information_and_train(
|
||||
# Initialize iterators
|
||||
online_iterator = None
|
||||
offline_iterator = None
|
||||
|
||||
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
|
||||
while True:
|
||||
# Exit the training loop if shutdown is requested
|
||||
@@ -337,7 +341,7 @@ def add_actor_information_and_train(
|
||||
logging.info("[LEARNER] Shutdown signal received. Exiting...")
|
||||
break
|
||||
|
||||
# Process all available transitions
|
||||
# Process all available transitions to the replay buffer, send by the actor server
|
||||
logging.debug("[LEARNER] Waiting for transitions")
|
||||
process_transitions(
|
||||
transition_queue=transition_queue,
|
||||
@@ -349,7 +353,7 @@ def add_actor_information_and_train(
|
||||
)
|
||||
logging.debug("[LEARNER] Received transitions")
|
||||
|
||||
# Process all available interaction messages
|
||||
# Process all available interaction messages sent by the actor server
|
||||
logging.debug("[LEARNER] Waiting for interactions")
|
||||
interaction_message = process_interaction_messages(
|
||||
interaction_message_queue=interaction_message_queue,
|
||||
@@ -359,7 +363,7 @@ def add_actor_information_and_train(
|
||||
)
|
||||
logging.debug("[LEARNER] Received interactions")
|
||||
|
||||
# Wait until the replay buffer has enough samples
|
||||
# Wait until the replay buffer has enough samples to start training
|
||||
if len(replay_buffer) < online_step_before_learning:
|
||||
continue
|
||||
|
||||
@@ -410,7 +414,7 @@ def add_actor_information_and_train(
|
||||
"complementary_info": batch["complementary_info"],
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss (includes both main critic and discrete critic)
|
||||
# Use the forward method for critic loss
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
# Main critic optimization
|
||||
@@ -433,7 +437,7 @@ def add_actor_information_and_train(
|
||||
)
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
# Update target networks
|
||||
# Update target networks (main and discrete)
|
||||
policy.update_target_networks()
|
||||
|
||||
# Sample for the last update in the UTD ratio
|
||||
@@ -468,10 +472,8 @@ def add_actor_information_and_train(
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss (includes both main critic and discrete critic)
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
# Main critic optimization
|
||||
loss_critic = critic_output["loss_critic"]
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
@@ -541,7 +543,7 @@ def add_actor_information_and_train(
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
# Update target networks
|
||||
# Update target networks (main and discrete)
|
||||
policy.update_target_networks()
|
||||
|
||||
# Log training metrics at specified intervals
|
||||
@@ -601,6 +603,8 @@ def start_learner_server(
|
||||
):
|
||||
"""
|
||||
Start the learner server for training.
|
||||
It will receive transitions and interaction messages from the actor server,
|
||||
and send policy parameters to the actor server.
|
||||
|
||||
Args:
|
||||
parameters_queue: Queue for sending policy parameters to the actor
|
||||
@@ -756,7 +760,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
|
||||
|
||||
It also initializes a learning rate scheduler, though currently, it is set to `None`.
|
||||
|
||||
**NOTE:**
|
||||
NOTE:
|
||||
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
|
||||
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user