Clean the code

This commit is contained in:
AdilZouitine
2025-04-24 17:22:54 +02:00
parent b8c2b0bb93
commit a8da4a347e
4 changed files with 56 additions and 39 deletions

View File

@@ -256,7 +256,8 @@ def add_actor_information_and_train(
interaction_message_queue (Queue): Queue for receiving interaction messages from the actor.
parameters_queue (Queue): Queue for sending policy parameters to the actor.
"""
# Extract all configuration variables at the beginning
# Extract all configuration variables at the beginning, it improve the speed performance
# of 7%
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
clip_grad_norm_value = cfg.policy.grad_clip_norm
@@ -283,11 +284,11 @@ def add_actor_information_and_train(
policy: SACPolicy = make_policy(
cfg=cfg.policy,
# ds_meta=cfg.dataset,
env_cfg=cfg.env,
)
assert isinstance(policy, nn.Module)
policy.train()
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
@@ -295,6 +296,8 @@ def add_actor_information_and_train(
last_time_policy_pushed = time.time()
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
# If we are resuming, we need to load the training state
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
log_training_info(cfg=cfg, policy=policy)
@@ -330,6 +333,7 @@ def add_actor_information_and_train(
# Initialize iterators
online_iterator = None
offline_iterator = None
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
while True:
# Exit the training loop if shutdown is requested
@@ -337,7 +341,7 @@ def add_actor_information_and_train(
logging.info("[LEARNER] Shutdown signal received. Exiting...")
break
# Process all available transitions
# Process all available transitions to the replay buffer, send by the actor server
logging.debug("[LEARNER] Waiting for transitions")
process_transitions(
transition_queue=transition_queue,
@@ -349,7 +353,7 @@ def add_actor_information_and_train(
)
logging.debug("[LEARNER] Received transitions")
# Process all available interaction messages
# Process all available interaction messages sent by the actor server
logging.debug("[LEARNER] Waiting for interactions")
interaction_message = process_interaction_messages(
interaction_message_queue=interaction_message_queue,
@@ -359,7 +363,7 @@ def add_actor_information_and_train(
)
logging.debug("[LEARNER] Received interactions")
# Wait until the replay buffer has enough samples
# Wait until the replay buffer has enough samples to start training
if len(replay_buffer) < online_step_before_learning:
continue
@@ -410,7 +414,7 @@ def add_actor_information_and_train(
"complementary_info": batch["complementary_info"],
}
# Use the forward method for critic loss (includes both main critic and discrete critic)
# Use the forward method for critic loss
critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
@@ -433,7 +437,7 @@ def add_actor_information_and_train(
)
optimizers["discrete_critic"].step()
# Update target networks
# Update target networks (main and discrete)
policy.update_target_networks()
# Sample for the last update in the UTD ratio
@@ -468,10 +472,8 @@ def add_actor_information_and_train(
"next_observation_feature": next_observation_features,
}
# Use the forward method for critic loss (includes both main critic and discrete critic)
critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad()
loss_critic.backward()
@@ -541,7 +543,7 @@ def add_actor_information_and_train(
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
last_time_policy_pushed = time.time()
# Update target networks
# Update target networks (main and discrete)
policy.update_target_networks()
# Log training metrics at specified intervals
@@ -601,6 +603,8 @@ def start_learner_server(
):
"""
Start the learner server for training.
It will receive transitions and interaction messages from the actor server,
and send policy parameters to the actor server.
Args:
parameters_queue: Queue for sending policy parameters to the actor
@@ -756,7 +760,7 @@ def make_optimizers_and_scheduler(cfg: TrainPipelineConfig, policy: nn.Module):
It also initializes a learning rate scheduler, though currently, it is set to `None`.
**NOTE:**
NOTE:
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.