From 1ed32210c7285d3586d06c8c361cb369315ebce1 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Fri, 24 Apr 2026 13:18:33 +0200 Subject: [PATCH] refactor(rl/sac): consolidate hyperparameter ownership and clean up discrete critic --- docs/source/hilserl.mdx | 2 +- .../configuration_gaussian_actor.py | 77 +++++------------ .../gaussian_actor/modeling_gaussian_actor.py | 41 +++++----- .../rl/algorithms/sac/configuration_sac.py | 27 +++--- .../rl/algorithms/sac/sac_algorithm.py | 64 +++++++-------- tests/policies/test_gaussian_actor_config.py | 37 ++------- tests/policies/test_gaussian_actor_policy.py | 19 +++-- tests/rl/test_actor_learner.py | 3 - tests/rl/test_sac_algorithm.py | 82 ++++++++++++++----- 9 files changed, 162 insertions(+), 190 deletions(-) diff --git a/docs/source/hilserl.mdx b/docs/source/hilserl.mdx index e596ff00b..8dc98c49f 100644 --- a/docs/source/hilserl.mdx +++ b/docs/source/hilserl.mdx @@ -926,7 +926,7 @@ The ideal behaviour is that your intervention rate should drop gradually during Some configuration values have a disproportionate impact on training stability and speed: -- **`temperature_init`** (`policy.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning. +- **`temperature_init`** (`algorithm.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning. - **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) – interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency. - **`storage_device`** (`policy.storage_device`) – device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second. diff --git a/src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py b/src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py index fb1b12c39..f7fb6b03b 100644 --- a/src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py +++ b/src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py @@ -136,80 +136,41 @@ class GaussianActorConfig(PreTrainedConfig): # Dimension of the image embedding pooling image_embedding_pooling_dim: int = 8 - # Training parameter - # Number of steps for online training - online_steps: int = 1000000 - # Capacity of the online replay buffer - online_buffer_capacity: int = 100000 - # Capacity of the offline replay buffer - offline_buffer_capacity: int = 100000 - # Whether to use asynchronous prefetching for the buffers - async_prefetch: bool = False - # Number of steps before learning starts - online_step_before_learning: int = 100 - # Frequency of policy updates - policy_update_freq: int = 1 - - # SAC algorithm parameters - # Discount factor for the SAC algorithm - discount: float = 0.99 - # Initial temperature value - temperature_init: float = 1.0 - # Number of critics in the ensemble - num_critics: int = 2 - # Number of subsampled critics for training - num_subsample_critics: int | None = None - # Learning rate for the critic network - critic_lr: float = 3e-4 - # Learning rate for the actor network - actor_lr: float = 3e-4 - # Learning rate for the temperature parameter - temperature_lr: float = 3e-4 - # Weight for the critic target update - critic_target_update_weight: float = 0.005 - # Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1) - utd_ratio: int = 1 + # Encoder architecture # Hidden dimension size for the state encoder state_encoder_hidden_dim: int = 256 # Dimension of the latent space latent_dim: int = 256 - # Target entropy for the SAC algorithm - target_entropy: float | None = None - # Whether to use backup entropy for the SAC algorithm - use_backup_entropy: bool = True - # Gradient clipping norm for the SAC algorithm - grad_clip_norm: float = 40.0 - # Network configuration - # Configuration for the critic network architecture - critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) - # Configuration for the actor network architecture - actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig) - # Configuration for the policy parameters - policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig) - # Configuration for the discrete critic network - discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) - # Configuration for actor-learner architecture + # Online training (TODO(Khalil): relocate to TrainRLServerPipelineConfig) + online_steps: int = 1000000 + online_buffer_capacity: int = 100000 + offline_buffer_capacity: int = 100000 + async_prefetch: bool = False + online_step_before_learning: int = 100 + + # Actor-learner transport (TODO(Khalil): relocate to TrainRLServerPipelineConfig). actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig) - # Configuration for concurrency settings (you can use threads or processes for the actor and learner) concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig) - # Optimizations - # torch.compile is currently disabled by default due to known issues with the SAC - # critic ensemble and shared encoder. - use_torch_compile: bool = False + # Network architecture + # Actor network + actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig) + # Gaussian head parameters + policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig) + # Discrete critic + discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) def __post_init__(self): super().__post_init__() - # Any validation specific to SAC configuration def get_optimizer_preset(self) -> MultiAdamConfig: return MultiAdamConfig( weight_decay=0.0, optimizer_groups={ - "actor": {"lr": self.actor_lr}, - "critic": {"lr": self.critic_lr}, - "temperature": {"lr": self.temperature_lr}, + "actor": {"lr": 3e-4}, + "critic": {"lr": 3e-4}, + "temperature": {"lr": 3e-4}, }, ) diff --git a/src/lerobot/policies/gaussian_actor/modeling_gaussian_actor.py b/src/lerobot/policies/gaussian_actor/modeling_gaussian_actor.py index 5242954c1..cabeae3f4 100644 --- a/src/lerobot/policies/gaussian_actor/modeling_gaussian_actor.py +++ b/src/lerobot/policies/gaussian_actor/modeling_gaussian_actor.py @@ -19,7 +19,6 @@ from collections.abc import Callable from dataclasses import asdict from typing import Any -import numpy as np import torch import torch.nn as nn from torch import Tensor @@ -61,7 +60,7 @@ class GaussianActorPolicy( continuous_action_dim = config.output_features[ACTION].shape[0] self._init_encoders() self._init_actor(continuous_action_dim) - self.discrete_critic = None + self._init_discrete_critic() def get_optim_params(self) -> dict: optim_params = { @@ -125,19 +124,14 @@ class GaussianActorPolicy( def load_actor_weights(self, state_dicts: dict[str, Any], device: str | torch.device = "cpu") -> None: from lerobot.utils.transition import move_state_dict_to_device - actor_sd = move_state_dict_to_device(state_dicts["policy"], device=device) - self.actor.load_state_dict(actor_sd) + actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device) + self.actor.load_state_dict(actor_state_dict) - if "discrete_critic" in state_dicts: - dc_sd = move_state_dict_to_device(state_dicts["discrete_critic"], device=device) - if self.discrete_critic is None: - self.discrete_critic = DiscreteCritic( - encoder=self.encoder_critic, - input_dim=self.encoder_critic.output_dim, - output_dim=self.config.num_discrete_actions, - **asdict(self.config.discrete_critic_network_kwargs), - ).to(device) - self.discrete_critic.load_state_dict(dc_sd) + if "discrete_critic" in state_dicts and self.discrete_critic is not None: + discrete_critic_state_dict = move_state_dict_to_device( + state_dicts["discrete_critic"], device=device + ) + self.discrete_critic.load_state_dict(discrete_critic_state_dict) def _init_encoders(self): """Initialize shared or separate encoders for actor and critic.""" @@ -148,7 +142,7 @@ class GaussianActorPolicy( ) def _init_actor(self, continuous_action_dim): - """Initialize policy actor network and default target entropy.""" + """Initialize policy actor network.""" # NOTE: The actor select only the continuous action part self.actor = Policy( encoder=self.encoder_actor, @@ -158,10 +152,19 @@ class GaussianActorPolicy( **asdict(self.config.policy_kwargs), ) - self.target_entropy = self.config.target_entropy - if self.target_entropy is None: - dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0) - self.target_entropy = -np.prod(dim) / 2 + def _init_discrete_critic(self) -> None: + """Initialize discrete critic network.""" + if self.config.num_discrete_actions is None: + self.discrete_critic = None + return + + # TODO(Khalil): Compile the discrete critic + self.discrete_critic = DiscreteCritic( + encoder=self.encoder_critic, + input_dim=self.encoder_critic.output_dim, + output_dim=self.config.num_discrete_actions, + **asdict(self.config.discrete_critic_network_kwargs), + ) class GaussianActorObservationEncoder(nn.Module): diff --git a/src/lerobot/rl/algorithms/sac/configuration_sac.py b/src/lerobot/rl/algorithms/sac/configuration_sac.py index bd5e5ed9e..cbce441d4 100644 --- a/src/lerobot/rl/algorithms/sac/configuration_sac.py +++ b/src/lerobot/rl/algorithms/sac/configuration_sac.py @@ -35,7 +35,7 @@ class SACAlgorithmConfig(RLAlgorithmConfig): """SAC algorithm hyperparameters.""" # Policy config - sac_config: GaussianActorConfig + policy_config: GaussianActorConfig # Optimizer learning rates actor_lr: float = 3e-4 @@ -55,31 +55,26 @@ class SACAlgorithmConfig(RLAlgorithmConfig): # Temperature / entropy temperature_init: float = 1.0 + # Target entropy for automatic temperature tuning. If ``None``, defaults to + # ``-|A|/2`` where ``|A|`` is the total action dimension (continuous + 1 if + # there is a discrete action head). + target_entropy: float | None = None # Update loop utd_ratio: int = 1 policy_update_freq: int = 1 grad_clip_norm: float = 40.0 + # Optimizations + # torch.compile is currently disabled by default + use_torch_compile: bool = False + @classmethod def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig: - """Build an algorithm config by copying hyperparameters from the policy config.""" + """Build an algorithm config with default hyperparameters for a given policy.""" return cls( - actor_lr=policy_cfg.actor_lr, - critic_lr=policy_cfg.critic_lr, - temperature_lr=policy_cfg.temperature_lr, - discount=policy_cfg.discount, - use_backup_entropy=policy_cfg.use_backup_entropy, - critic_target_update_weight=policy_cfg.critic_target_update_weight, - num_critics=policy_cfg.num_critics, - num_subsample_critics=policy_cfg.num_subsample_critics, - critic_network_kwargs=policy_cfg.critic_network_kwargs, + policy_config=policy_cfg, discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs, - temperature_init=policy_cfg.temperature_init, - utd_ratio=policy_cfg.utd_ratio, - policy_update_freq=policy_cfg.policy_update_freq, - grad_clip_norm=policy_cfg.grad_clip_norm, - sac_config=policy_cfg, ) def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm: diff --git a/src/lerobot/rl/algorithms/sac/sac_algorithm.py b/src/lerobot/rl/algorithms/sac/sac_algorithm.py index da0e56a6c..78cb40536 100644 --- a/src/lerobot/rl/algorithms/sac/sac_algorithm.py +++ b/src/lerobot/rl/algorithms/sac/sac_algorithm.py @@ -54,14 +54,14 @@ class SACAlgorithm(RLAlgorithm): config: SACAlgorithmConfig, ): self.config = config - self.policy_config = config.sac_config + self.policy_config = config.policy_config self.policy = policy self.optimizers: dict[str, Optimizer] = {} self._optimization_step: int = 0 action_dim = self.policy.config.output_features[ACTION].shape[0] self._init_critics(action_dim) - self._init_temperature() + self._init_temperature(action_dim) self._device = torch.device(self.policy.config.device) self._move_to_device() @@ -90,49 +90,44 @@ class SACAlgorithm(RLAlgorithm): # TODO(Khalil): Investigate and fix torch.compile # NOTE: torch.compile is disabled, policy does not converge when enabled. - if self.policy_config.use_torch_compile: + if self.config.use_torch_compile: self.critic_ensemble = torch.compile(self.critic_ensemble) self.critic_target = torch.compile(self.critic_target) - self.discrete_critic = None self.discrete_critic_target = None if self.policy_config.num_discrete_actions is not None: - self.discrete_critic, self.discrete_critic_target = self._init_discrete_critics(encoder) - self.policy.discrete_critic = self.discrete_critic + self.discrete_critic_target = self._init_discrete_critic_target(encoder) - def _init_discrete_critics( - self, encoder: GaussianActorObservationEncoder - ) -> tuple[DiscreteCritic, DiscreteCritic]: - """Build discrete critic ensemble and target networks.""" - discrete_critic = DiscreteCritic( - encoder=encoder, - input_dim=encoder.output_dim, - output_dim=self.policy_config.num_discrete_actions, - **asdict(self.config.discrete_critic_network_kwargs), - ) + def _init_discrete_critic_target(self, encoder: GaussianActorObservationEncoder) -> DiscreteCritic: + """Build target discrete critic (main network is owned by the policy).""" discrete_critic_target = DiscreteCritic( encoder=encoder, input_dim=encoder.output_dim, output_dim=self.policy_config.num_discrete_actions, **asdict(self.config.discrete_critic_network_kwargs), ) - # TODO(Khalil): Compile the discrete critic - discrete_critic_target.load_state_dict(discrete_critic.state_dict()) - return discrete_critic, discrete_critic_target + discrete_critic_target.load_state_dict(self.policy.discrete_critic.state_dict()) + return discrete_critic_target - def _init_temperature(self) -> None: - """Set up temperature parameter (log_alpha).""" + def _init_temperature(self, continuous_action_dim: int) -> None: + """Set up temperature parameter (log_alpha) and target entropy.""" temp_init = self.config.temperature_init self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)])) + self.target_entropy = self.config.target_entropy + if self.target_entropy is None: + total_action_dim = continuous_action_dim + ( + 1 if self.policy_config.num_discrete_actions is not None else 0 + ) + self.target_entropy = -total_action_dim / 2 + def _move_to_device(self) -> None: self.policy.to(self._device) self.critic_ensemble.to(self._device) self.critic_target.to(self._device) self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device)) - if self.discrete_critic is not None: - self.discrete_critic.to(self._device) + if self.discrete_critic_target is not None: self.discrete_critic_target.to(self._device) @property @@ -175,7 +170,7 @@ class SACAlgorithm(RLAlgorithm): Returns: Tensor of Q-values from the discrete critic network """ - discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic + discrete_critic = self.discrete_critic_target if use_target else self.policy.discrete_critic q_values = discrete_critic(observations, observation_features) return q_values @@ -196,7 +191,7 @@ class SACAlgorithm(RLAlgorithm): loss_dc = self._compute_loss_discrete_critic(fb) self.optimizers["discrete_critic"].zero_grad() loss_dc.backward() - torch.nn.utils.clip_grad_norm_(self.discrete_critic.parameters(), max_norm=clip) + torch.nn.utils.clip_grad_norm_(self.policy.discrete_critic.parameters(), max_norm=clip) self.optimizers["discrete_critic"].step() self._update_target_networks() @@ -219,7 +214,9 @@ class SACAlgorithm(RLAlgorithm): loss_dc = self._compute_loss_discrete_critic(fb) self.optimizers["discrete_critic"].zero_grad() loss_dc.backward() - dc_grad = torch.nn.utils.clip_grad_norm_(self.discrete_critic.parameters(), max_norm=clip).item() + dc_grad = torch.nn.utils.clip_grad_norm_( + self.policy.discrete_critic.parameters(), max_norm=clip + ).item() self.optimizers["discrete_critic"].step() stats.losses["loss_discrete_critic"] = loss_dc.item() stats.grad_norms["discrete_critic"] = dc_grad @@ -396,7 +393,7 @@ class SACAlgorithm(RLAlgorithm): with torch.no_grad(): _, log_probs, _ = self.policy.actor(observations, observation_features) - temperature_loss = (-self.log_alpha.exp() * (log_probs + self.policy.target_entropy)).mean() + temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean() return temperature_loss def _update_target_networks(self) -> None: @@ -411,7 +408,7 @@ class SACAlgorithm(RLAlgorithm): if self.policy_config.num_discrete_actions is not None: for target_p, p in zip( self.discrete_critic_target.parameters(), - self.discrete_critic.parameters(), + self.policy.discrete_critic.parameters(), strict=True, ): target_p.data.copy_( @@ -471,7 +468,7 @@ class SACAlgorithm(RLAlgorithm): } if self.policy_config.num_discrete_actions is not None: self.optimizers["discrete_critic"] = torch.optim.Adam( - self.discrete_critic.parameters(), lr=self.config.critic_lr + self.policy.discrete_critic.parameters(), lr=self.config.critic_lr ) return self.optimizers @@ -485,16 +482,13 @@ class SACAlgorithm(RLAlgorithm): } if self.policy_config.num_discrete_actions is not None: state_dicts["discrete_critic"] = move_state_dict_to_device( - self.discrete_critic.state_dict(), device="cpu" + self.policy.discrete_critic.state_dict(), device="cpu" ) return state_dicts def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: - actor_sd = move_state_dict_to_device(weights["policy"], device=device) - self.policy.actor.load_state_dict(actor_sd) - if "discrete_critic" in weights and self.policy_config.num_discrete_actions is not None: - dc_sd = move_state_dict_to_device(weights["discrete_critic"], device=device) - self.discrete_critic.load_state_dict(dc_sd) + """Load actor + discrete-critic weights into the policy.""" + self.policy.load_actor_weights(weights, device=device) def get_observation_features( self, observations: Tensor, next_observations: Tensor diff --git a/tests/policies/test_gaussian_actor_config.py b/tests/policies/test_gaussian_actor_config.py index 0a77a421e..004612374 100644 --- a/tests/policies/test_gaussian_actor_config.py +++ b/tests/policies/test_gaussian_actor_config.py @@ -55,9 +55,6 @@ def test_gaussian_actor_config_default_initialization(): # Basic parameters assert config.device == "cpu" assert config.storage_device == "cpu" - assert config.discount == 0.99 - assert config.temperature_init == 1.0 - assert config.num_critics == 2 # Architecture specifics assert config.vision_encoder_name is None @@ -66,6 +63,8 @@ def test_gaussian_actor_config_default_initialization(): assert config.shared_encoder is True assert config.num_discrete_actions is None assert config.image_embedding_pooling_dim == 8 + assert config.state_encoder_hidden_dim == 256 + assert config.latent_dim == 256 # Training parameters assert config.online_steps == 1000000 @@ -73,20 +72,6 @@ def test_gaussian_actor_config_default_initialization(): assert config.offline_buffer_capacity == 100000 assert config.async_prefetch is False assert config.online_step_before_learning == 100 - assert config.policy_update_freq == 1 - - # SAC algorithm parameters - assert config.num_subsample_critics is None - assert config.critic_lr == 3e-4 - assert config.actor_lr == 3e-4 - assert config.temperature_lr == 3e-4 - assert config.critic_target_update_weight == 0.005 - assert config.utd_ratio == 1 - assert config.state_encoder_hidden_dim == 256 - assert config.latent_dim == 256 - assert config.target_entropy is None - assert config.use_backup_entropy is True - assert config.grad_clip_norm == 40.0 # Dataset stats defaults expected_dataset_stats = { @@ -105,11 +90,6 @@ def test_gaussian_actor_config_default_initialization(): } assert config.dataset_stats == expected_dataset_stats - # Critic network configuration - assert config.critic_network_kwargs.hidden_dims == [256, 256] - assert config.critic_network_kwargs.activate_final is True - assert config.critic_network_kwargs.final_activation is None - # Actor network configuration assert config.actor_network_kwargs.hidden_dims == [256, 256] assert config.actor_network_kwargs.activate_final is True @@ -135,7 +115,6 @@ def test_gaussian_actor_config_default_initialization(): assert config.concurrency.learner == "threads" assert isinstance(config.actor_network_kwargs, ActorNetworkConfig) - assert isinstance(config.critic_network_kwargs, CriticNetworkConfig) assert isinstance(config.policy_kwargs, PolicyConfig) assert isinstance(config.actor_learner_config, ActorLearnerConfig) assert isinstance(config.concurrency, ConcurrencyConfig) @@ -178,15 +157,15 @@ def test_concurrency_config(): def test_gaussian_actor_config_custom_initialization(): config = GaussianActorConfig( device="cpu", - discount=0.95, - temperature_init=0.5, - num_critics=3, + latent_dim=128, + state_encoder_hidden_dim=128, + num_discrete_actions=3, ) assert config.device == "cpu" - assert config.discount == 0.95 - assert config.temperature_init == 0.5 - assert config.num_critics == 3 + assert config.latent_dim == 128 + assert config.state_encoder_hidden_dim == 128 + assert config.num_discrete_actions == 3 def test_validate_features(): diff --git a/tests/policies/test_gaussian_actor_policy.py b/tests/policies/test_gaussian_actor_policy.py index 7e32bc5fa..ea807c907 100644 --- a/tests/policies/test_gaussian_actor_policy.py +++ b/tests/policies/test_gaussian_actor_policy.py @@ -404,19 +404,16 @@ def test_sac_training_with_discrete_critic(): def test_sac_algorithm_target_entropy(): + """Target entropy is an SAC hyperparameter and lives on the algorithm.""" config = create_default_config(continuous_action_dim=10, state_dim=10) - _, policy = _make_algorithm(config) - algo_config = SACAlgorithmConfig.from_policy_config(config) - algorithm = SACAlgorithm(policy=policy, config=algo_config) + algorithm, _ = _make_algorithm(config) assert algorithm.target_entropy == -5.0 def test_sac_algorithm_target_entropy_with_discrete_action(): config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True) config.num_discrete_actions = 5 - algo_config = SACAlgorithmConfig.from_policy_config(config) - policy = GaussianActorPolicy(config=config) - algorithm = SACAlgorithm(policy=policy, config=algo_config) + algorithm, _ = _make_algorithm(config) assert algorithm.target_entropy == -3.5 @@ -435,8 +432,8 @@ def test_sac_algorithm_temperature(): def test_sac_algorithm_update_target_network(): config = create_default_config(state_dim=10, continuous_action_dim=6) - config.critic_target_update_weight = 1.0 algo_config = SACAlgorithmConfig.from_policy_config(config) + algo_config.critic_target_update_weight = 1.0 policy = GaussianActorPolicy(config=config) algorithm = SACAlgorithm(policy=policy, config=algo_config) @@ -454,9 +451,13 @@ def test_sac_algorithm_with_critics_number_of_heads(num_critics: int): action_dim = 10 state_dim = 10 config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) - config.num_critics = num_critics - algorithm, policy = _make_algorithm(config) + policy = GaussianActorPolicy(config=config) + policy.train() + algo_config = SACAlgorithmConfig.from_policy_config(config) + algo_config.num_critics = num_critics + algorithm = SACAlgorithm(policy=policy, config=algo_config) + algorithm.make_optimizers_and_scheduler() assert len(algorithm.critic_ensemble.critics) == num_critics diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py index 02231849a..3dc65118f 100644 --- a/tests/rl/test_actor_learner.py +++ b/tests/rl/test_actor_learner.py @@ -327,7 +327,6 @@ def test_learner_algorithm_wiring(): OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, }, - use_torch_compile=False, ) sac_cfg.validate_features() @@ -412,7 +411,6 @@ def test_initial_and_periodic_weight_push_consistency(): OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, }, - use_torch_compile=False, ) sac_cfg.validate_features() @@ -450,7 +448,6 @@ def test_actor_side_algorithm_select_action_and_load_weights(): OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, }, - use_torch_compile=False, ) sac_cfg.validate_features() diff --git a/tests/rl/test_sac_algorithm.py b/tests/rl/test_sac_algorithm.py index 7ca2b3228..a2653b6cb 100644 --- a/tests/rl/test_sac_algorithm.py +++ b/tests/rl/test_sac_algorithm.py @@ -44,8 +44,6 @@ def _make_sac_config( state_dim: int = 10, action_dim: int = 6, num_discrete_actions: int | None = None, - utd_ratio: int = 1, - policy_update_freq: int = 1, with_images: bool = False, ) -> GaussianActorConfig: config = GaussianActorConfig( @@ -55,10 +53,7 @@ def _make_sac_config( OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim}, ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim}, }, - utd_ratio=utd_ratio, - policy_update_freq=policy_update_freq, num_discrete_actions=num_discrete_actions, - use_torch_compile=False, ) if with_images: config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84)) @@ -83,14 +78,14 @@ def _make_algorithm( sac_cfg = _make_sac_config( state_dim=state_dim, action_dim=action_dim, - utd_ratio=utd_ratio, - policy_update_freq=policy_update_freq, num_discrete_actions=num_discrete_actions, with_images=with_images, ) policy = GaussianActorPolicy(config=sac_cfg) policy.train() algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg) + algo_config.utd_ratio = utd_ratio + algo_config.policy_update_freq = policy_update_freq algorithm = SACAlgorithm(policy=policy, config=algo_config) algorithm.make_optimizers_and_scheduler() return algorithm, policy @@ -136,13 +131,16 @@ def test_sac_algorithm_config_registered(): def test_sac_algorithm_config_from_policy_config(): - """from_policy_config should copy algorithm hyperparameters from the policy config.""" - sac_cfg = _make_sac_config(utd_ratio=4, policy_update_freq=2) + """from_policy_config embeds the policy config and uses SAC defaults.""" + sac_cfg = _make_sac_config() algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg) - assert algo_cfg.sac_config is sac_cfg - assert algo_cfg.utd_ratio == 4 - assert algo_cfg.policy_update_freq == 2 - assert algo_cfg.grad_clip_norm == sac_cfg.grad_clip_norm + assert algo_cfg.policy_config is sac_cfg + assert algo_cfg.discrete_critic_network_kwargs is sac_cfg.discrete_critic_network_kwargs + # Defaults come from SACAlgorithmConfig, not from the policy config. + assert algo_cfg.utd_ratio == 1 + assert algo_cfg.policy_update_freq == 1 + assert algo_cfg.grad_clip_norm == 40.0 + assert algo_cfg.actor_lr == 3e-4 # =========================================================================== @@ -377,12 +375,14 @@ def test_actor_side_no_optimizers(): assert algorithm.optimizers == {} -def test_make_algorithm_copies_config_fields(): - sac_cfg = _make_sac_config(utd_ratio=5, policy_update_freq=3) +def test_make_algorithm_uses_sac_algorithm_defaults(): + """make_algorithm populates SACAlgorithmConfig with its own defaults.""" + sac_cfg = _make_sac_config() policy = GaussianActorPolicy(config=sac_cfg) algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac") - assert algorithm.config.utd_ratio == 5 - assert algorithm.config.policy_update_freq == 3 + assert algorithm.config.utd_ratio == 1 + assert algorithm.config.policy_update_freq == 1 + assert algorithm.config.grad_clip_norm == 40.0 def test_make_algorithm_raises_for_unknown_type(): @@ -431,10 +431,10 @@ def test_load_weights_round_trip_with_discrete_critic(): assert "discrete_critic" in weights assert len(weights["discrete_critic"]) > 0 - dst_dc_state_dict = algo_dst.discrete_critic.state_dict() + dst_discrete_critic_state_dict = algo_dst.policy.discrete_critic.state_dict() for key, tensor in weights["discrete_critic"].items(): assert torch.equal( - dst_dc_state_dict[key].cpu(), + dst_discrete_critic_state_dict[key].cpu(), tensor.cpu(), ), f"Discrete critic param '{key}' mismatch after load_weights" @@ -446,6 +446,47 @@ def test_load_weights_ignores_missing_discrete_critic(): algorithm.load_weights(weights, device="cpu") +def test_actor_side_weight_sync_with_discrete_critic(): + """End-to-end: learner ``algorithm.get_weights()`` -> actor ``policy.load_actor_weights()``.""" + # Learner side: train the algorithm so its weights diverge from init. + algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6) + algo_src.update(_batch_iterator(action_dim=7)) + weights = algo_src.get_weights() + + # Actor side: fresh policy, no algorithm/optimizer. + sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6) + policy_actor = GaussianActorPolicy(config=sac_cfg) + + # Snapshot initial actor state for the "did it change?" assertion below. + initial_discrete_critic_state_dict = { + k: v.clone() for k, v in policy_actor.discrete_critic.state_dict().items() + } + + policy_actor.load_actor_weights(weights, device="cpu") + + # Actor weights match the learner's exported actor state dict. + actor_state_dict = policy_actor.actor.state_dict() + for key, tensor in weights["policy"].items(): + assert torch.equal(actor_state_dict[key].cpu(), tensor.cpu()), ( + f"Actor param '{key}' not synced by load_actor_weights" + ) + + # Discrete critic weights match the learner's exported discrete critic. + discrete_critic_state_dict = policy_actor.discrete_critic.state_dict() + for key, tensor in weights["discrete_critic"].items(): + assert torch.equal(discrete_critic_state_dict[key].cpu(), tensor.cpu()), ( + f"Discrete critic param '{key}' not synced by load_actor_weights" + ) + + # Sanity: the discrete critic actually changed (otherwise the sync is trivial). + changed = any( + not torch.equal(initial_discrete_critic_state_dict[key], discrete_critic_state_dict[key]) + for key in initial_discrete_critic_state_dict + if key in discrete_critic_state_dict + ) + assert changed, "Discrete critic weights did not change between init and after sync" + + # =========================================================================== # TrainingStats generic losses dict # =========================================================================== @@ -468,8 +509,9 @@ def test_training_stats_generic_losses(): def test_build_algorithm_via_config(): """SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm.""" - sac_cfg = _make_sac_config(utd_ratio=2) + sac_cfg = _make_sac_config() algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg) + algo_config.utd_ratio = 2 policy = GaussianActorPolicy(config=sac_cfg) algorithm = algo_config.build_algorithm(policy)