From ee0814ef60a6f6c75bbc73ec5fd9447e4c5587cb Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Mon, 13 Apr 2026 18:31:17 +0200 Subject: [PATCH] refactor: update SACAlgorithm to pass action_dim to _init_critics and fix encoder reference --- src/lerobot/rl/algorithms/sac/sac_algorithm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/lerobot/rl/algorithms/sac/sac_algorithm.py b/src/lerobot/rl/algorithms/sac/sac_algorithm.py index 10c104bb6..4f92ce859 100644 --- a/src/lerobot/rl/algorithms/sac/sac_algorithm.py +++ b/src/lerobot/rl/algorithms/sac/sac_algorithm.py @@ -58,16 +58,16 @@ class SACAlgorithm(RLAlgorithm): self.optimizers: dict[str, Optimizer] = {} self._optimization_step: int = 0 - self._init_critics() + action_dim = self.policy.config.output_features[ACTION].shape[0] + self._init_critics(action_dim) self._init_temperature() self._device = torch.device(self.policy.config.device) self._move_to_device() - def _init_critics(self) -> None: + def _init_critics(self, action_dim) -> None: """Build critic ensemble, targets.""" encoder = self.policy.encoder_critic - action_dim = self.policy.config.output_features[ACTION].shape[0] heads = [ CriticHead( @@ -76,7 +76,7 @@ class SACAlgorithm(RLAlgorithm): ) for _ in range(self.config.num_critics) ] - self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads) + self.critic_ensemble = CriticEnsemble(encoder=encoder, ensemble=heads) target_heads = [ CriticHead( input_dim=encoder.output_dim + action_dim, @@ -84,7 +84,7 @@ class SACAlgorithm(RLAlgorithm): ) for _ in range(self.config.num_critics) ] - self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads) + self.critic_target = CriticEnsemble(encoder=encoder, ensemble=target_heads) self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) # TODO(Khalil): Investigate and fix torch.compile