mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 12:21:27 +00:00
refactor: update SACAlgorithm to pass action_dim to _init_critics and fix encoder reference
This commit is contained in:
@@ -58,16 +58,16 @@ class SACAlgorithm(RLAlgorithm):
|
||||
self.optimizers: dict[str, Optimizer] = {}
|
||||
self._optimization_step: int = 0
|
||||
|
||||
self._init_critics()
|
||||
action_dim = self.policy.config.output_features[ACTION].shape[0]
|
||||
self._init_critics(action_dim)
|
||||
self._init_temperature()
|
||||
|
||||
self._device = torch.device(self.policy.config.device)
|
||||
self._move_to_device()
|
||||
|
||||
def _init_critics(self) -> None:
|
||||
def _init_critics(self, action_dim) -> None:
|
||||
"""Build critic ensemble, targets."""
|
||||
encoder = self.policy.encoder_critic
|
||||
action_dim = self.policy.config.output_features[ACTION].shape[0]
|
||||
|
||||
heads = [
|
||||
CriticHead(
|
||||
@@ -76,7 +76,7 @@ class SACAlgorithm(RLAlgorithm):
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
|
||||
self.critic_ensemble = CriticEnsemble(encoder=encoder, ensemble=heads)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder.output_dim + action_dim,
|
||||
@@ -84,7 +84,7 @@ class SACAlgorithm(RLAlgorithm):
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
|
||||
self.critic_target = CriticEnsemble(encoder=encoder, ensemble=target_heads)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
# TODO(Khalil): Investigate and fix torch.compile
|
||||
|
||||
Reference in New Issue
Block a user