fix: move algorithm-owned modules to the policy device

This commit is contained in:
Khalil Meftah
2026-03-18 15:27:41 +01:00
parent 1f5487eea8
commit d3e6f14d4f

View File

@@ -74,6 +74,7 @@ class SACAlgorithm(RLAlgorithm):
self._init_critic_encoder()
self._init_critics()
self._init_temperature()
self._move_to_device()
def _init_critic_encoder(self) -> None:
"""Build or share the encoder used by critics."""
@@ -131,6 +132,14 @@ class SACAlgorithm(RLAlgorithm):
dim = action_dim + (1 if self.config.num_discrete_actions is not None else 0)
self.target_entropy = -np.prod(dim) / 2
def _move_to_device(self) -> None:
"""Move algorithm-owned modules to the policy device."""
self.critic_ensemble.to(self._device)
self.critic_target.to(self._device)
self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device))
if hasattr(self, "discrete_critic_target"):
self.discrete_critic_target.to(self._device)
@property
def temperature(self) -> float:
return self.log_alpha.exp().item()