refactor: update SACAlgorithm to pass action_dim to _init_critics and fix encoder reference

This commit is contained in:
Khalil Meftah
2026-04-13 18:31:17 +02:00
parent 7b0bdf2a98
commit ee0814ef60

View File

@@ -58,16 +58,16 @@ class SACAlgorithm(RLAlgorithm):
self.optimizers: dict[str, Optimizer] = {}
self._optimization_step: int = 0
self._init_critics()
action_dim = self.policy.config.output_features[ACTION].shape[0]
self._init_critics(action_dim)
self._init_temperature()
self._device = torch.device(self.policy.config.device)
self._move_to_device()
def _init_critics(self) -> None:
def _init_critics(self, action_dim) -> None:
"""Build critic ensemble, targets."""
encoder = self.policy.encoder_critic
action_dim = self.policy.config.output_features[ACTION].shape[0]
heads = [
CriticHead(
@@ -76,7 +76,7 @@ class SACAlgorithm(RLAlgorithm):
)
for _ in range(self.config.num_critics)
]
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
self.critic_ensemble = CriticEnsemble(encoder=encoder, ensemble=heads)
target_heads = [
CriticHead(
input_dim=encoder.output_dim + action_dim,
@@ -84,7 +84,7 @@ class SACAlgorithm(RLAlgorithm):
)
for _ in range(self.config.num_critics)
]
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
self.critic_target = CriticEnsemble(encoder=encoder, ensemble=target_heads)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
# TODO(Khalil): Investigate and fix torch.compile