mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 03:11:29 +00:00
fix(hil-serl): discrete critic send through network (#1468)
Co-authored-by: Khalil Meftah <kmeftah.khalil@gmail.com> Co-authored-by: jpizarrom <jpizarrom@gmail.com>
This commit is contained in:
@@ -1109,8 +1109,18 @@ def check_nan_in_transition(
|
||||
|
||||
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||
logging.debug("[LEARNER] Pushing actor policy to the queue")
|
||||
state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu")
|
||||
state_bytes = state_to_bytes(state_dict)
|
||||
|
||||
# Create a dictionary to hold all the state dicts
|
||||
state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")}
|
||||
|
||||
# Add discrete critic if it exists
|
||||
if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None:
|
||||
state_dicts["discrete_critic"] = move_state_dict_to_device(
|
||||
policy.discrete_critic.state_dict(), device="cpu"
|
||||
)
|
||||
logging.debug("[LEARNER] Including discrete critic in state dict push")
|
||||
|
||||
state_bytes = state_to_bytes(state_dicts)
|
||||
parameters_queue.put(state_bytes)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user