diff --git a/src/lerobot/rl/algorithms/sac/sac_algorithm.py b/src/lerobot/rl/algorithms/sac/sac_algorithm.py index 61886289a..55b228026 100644 --- a/src/lerobot/rl/algorithms/sac/sac_algorithm.py +++ b/src/lerobot/rl/algorithms/sac/sac_algorithm.py @@ -74,6 +74,7 @@ class SACAlgorithm(RLAlgorithm): self._init_critic_encoder() self._init_critics() self._init_temperature() + self._move_to_device() def _init_critic_encoder(self) -> None: """Build or share the encoder used by critics.""" @@ -131,6 +132,14 @@ class SACAlgorithm(RLAlgorithm): dim = action_dim + (1 if self.config.num_discrete_actions is not None else 0) self.target_entropy = -np.prod(dim) / 2 + def _move_to_device(self) -> None: + """Move algorithm-owned modules to the policy device.""" + self.critic_ensemble.to(self._device) + self.critic_target.to(self._device) + self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device)) + if hasattr(self, "discrete_critic_target"): + self.discrete_critic_target.to(self._device) + @property def temperature(self) -> float: return self.log_alpha.exp().item()