refactor(policies): rename policies/sac → policies/gaussian_actor

This commit is contained in:
Khalil Meftah
2026-04-23 19:13:18 +02:00
parent 8065bf15c7
commit 06255996ea
24 changed files with 185 additions and 168 deletions

View File

@@ -820,10 +820,10 @@ The LeRobot system uses a distributed actor-learner architecture for training. T
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`.
1. Configure the policy settings (`type="sac"`, `device`, etc.)
1. Configure the policy settings (`type="gaussian_actor"`, `device`, etc.)
2. Set `dataset` to your cropped dataset
3. Configure environment settings with crop parameters
4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/sac/configuration_sac.py#L79).
4. Check the other parameters related to the Gaussian Actor in [configuration_gaussian_actor.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py#L79).
5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
**Starting the Learner**

View File

@@ -7,9 +7,9 @@ import torch
from lerobot.datasets import LeRobotDataset
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
from lerobot.policies import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.policies import GaussianActorConfig
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.gym_manipulator import make_robot_env
@@ -28,7 +28,7 @@ def run_learner(
transitions_queue: mp.Queue,
parameters_queue: mp.Queue,
shutdown_event: mp.Event,
policy_learner: SACPolicy,
policy_learner: GaussianActorPolicy,
online_buffer: ReplayBuffer,
offline_buffer: ReplayBuffer,
lr: float = 3e-4,
@@ -116,7 +116,7 @@ def run_actor(
transitions_queue: mp.Queue,
parameters_queue: mp.Queue,
shutdown_event: mp.Event,
policy_actor: SACPolicy,
policy_actor: GaussianActorPolicy,
reward_classifier: Classifier,
env_cfg: HILSerlRobotEnvConfig,
device: torch.device = "mps",
@@ -264,14 +264,14 @@ def main():
action_features = hw_to_dataset_features(env.robot.action_features, "action")
# Create SAC policy for action selection
policy_cfg = SACConfig(
policy_cfg = GaussianActorConfig(
device=device,
input_features=obs_features,
output_features=action_features,
)
policy_actor = SACPolicy(policy_cfg)
policy_learner = SACPolicy(policy_cfg)
policy_actor = GaussianActorPolicy(policy_cfg)
policy_learner = GaussianActorPolicy(policy_cfg)
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)

View File

@@ -15,6 +15,10 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
from .gaussian_actor.reward_model.configuration_classifier import (
RewardClassifierConfig as RewardClassifierConfig,
)
from .groot.configuration_groot import GrootConfig as GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
@@ -22,8 +26,6 @@ from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
from .rtc import ActionInterpolator as ActionInterpolator
from .sac.configuration_sac import SACConfig as SACConfig
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
from .sarm.configuration_sarm import SARMConfig as SARMConfig
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
@@ -32,21 +34,21 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
from .wall_x.configuration_wall_x import WallXConfig as WallXConfig
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
# NOTE: Policy modeling classes (e.g., SACPolicy) are intentionally NOT re-exported here.
# NOTE: Policy modeling classes (e.g., GaussianActorPolicy) are intentionally NOT re-exported here.
# They have heavy optional dependencies and are loaded lazily via get_policy_class().
# Import directly: ``from lerobot.policies.sac.modeling_sac import SACPolicy``
# Import directly: ``from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy``
__all__ = [
# Configuration classes
"ACTConfig",
"DiffusionConfig",
"GaussianActorConfig",
"GrootConfig",
"MultiTaskDiTConfig",
"PI0Config",
"PI0FastConfig",
"PI05Config",
"RewardClassifierConfig",
"SACConfig",
"SARMConfig",
"SmolVLAConfig",
"TDMPCConfig",

View File

@@ -46,13 +46,13 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
from .act.configuration_act import ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from .gaussian_actor.reward_model.configuration_classifier import RewardClassifierConfig
from .groot.configuration_groot import GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config
from .pi05.configuration_pi05 import PI05Config
from .pretrained import PreTrainedPolicy
from .sac.configuration_sac import SACConfig
from .sac.reward_model.configuration_classifier import RewardClassifierConfig
from .sarm.configuration_sarm import SARMConfig
from .smolvla.configuration_smolvla import SmolVLAConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig
@@ -89,7 +89,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "reward_classifier", "smolvla", "wall_x".
Returns:
The policy class corresponding to the given name.
@@ -128,12 +128,12 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from .pi05.modeling_pi05 import PI05Policy
return PI05Policy
elif name == "sac":
from .sac.modeling_sac import SACPolicy
elif name == "gaussian_actor":
from .gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
return SACPolicy
return GaussianActorPolicy
elif name == "reward_classifier":
from .sac.reward_model.modeling_classifier import Classifier
from .gaussian_actor.reward_model.modeling_classifier import Classifier
return Classifier
elif name == "smolvla":
@@ -172,7 +172,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
"smolvla", "reward_classifier", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
@@ -196,8 +196,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return PI0Config(**kwargs)
elif policy_type == "pi05":
return PI05Config(**kwargs)
elif policy_type == "sac":
return SACConfig(**kwargs)
elif policy_type == "gaussian_actor":
return GaussianActorConfig(**kwargs)
elif policy_type == "smolvla":
return SmolVLAConfig(**kwargs)
elif policy_type == "reward_classifier":
@@ -370,16 +370,16 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SACConfig):
from .sac.processor_sac import make_sac_pre_post_processors
elif isinstance(policy_cfg, GaussianActorConfig):
from .gaussian_actor.processor_gaussian_actor import make_gaussian_actor_pre_post_processors
processors = make_sac_pre_post_processors(
processors = make_gaussian_actor_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, RewardClassifierConfig):
from .sac.reward_model.processor_classifier import make_classifier_processor
from .gaussian_actor.reward_model.processor_classifier import make_classifier_processor
processors = make_classifier_processor(
config=policy_cfg,

View File

@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_sac import SACConfig
from .modeling_sac import SACPolicy
from .processor_sac import make_sac_pre_post_processors
from .configuration_gaussian_actor import GaussianActorConfig
from .modeling_gaussian_actor import GaussianActorPolicy
from .processor_gaussian_actor import make_gaussian_actor_pre_post_processors
__all__ = ["SACConfig", "SACPolicy", "make_sac_pre_post_processors"]
__all__ = ["GaussianActorConfig", "GaussianActorPolicy", "make_gaussian_actor_pre_post_processors"]

View File

@@ -75,18 +75,19 @@ class PolicyConfig:
init_final: float = 0.05
@PreTrainedConfig.register_subclass("sac")
@PreTrainedConfig.register_subclass("gaussian_actor")
@dataclass
class SACConfig(PreTrainedConfig):
"""Soft Actor-Critic (SAC) configuration.
class GaussianActorConfig(PreTrainedConfig):
"""Gaussian actor configuration.
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
reinforcement learning framework. It learns a policy and a Q-function simultaneously
using experience collected from the environment.
This configures the policy-side (actor + observation encoder) of a Gaussian
policy, as used by SAC and related maximum-entropy continuous-control algorithms.
By default the actor output is a tanh-squashed diagonal Gaussian
(``TanhMultivariateNormalDiag``); the tanh squashing can be disabled via
``policy_kwargs.use_tanh_squash``. The critics, temperature, and Bellman-update
logic live on the algorithm side (see ``lerobot.rl.algorithms.sac``).
This configuration class contains all the parameters needed to define a SAC agent,
including network architectures, optimization settings, and algorithm-specific
hyperparameters.
CLI: ``--policy.type=gaussian_actor``.
"""
# Mapping of feature types to normalization modes

View File

@@ -29,22 +29,29 @@ from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
from ..pretrained import PreTrainedPolicy
from ..utils import get_device_from_parameters
from .configuration_sac import SACConfig, is_image_feature
from .configuration_gaussian_actor import GaussianActorConfig, is_image_feature
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
class SACPolicy(
class GaussianActorPolicy(
PreTrainedPolicy,
):
"""SAC policy."""
"""Gaussian actor + observation encoder.
config_class = SACConfig
name = "sac"
Policy-side ``nn.Module`` used by SAC and related maximum-entropy continuous
control algorithms. It owns the actor network (``Policy``) and the observation
encoder (``GaussianActorObservationEncoder``); the critics, temperature, and
Bellman-update logic live on the algorithm side
(see ``lerobot.rl.algorithms.sac``).
"""
config_class = GaussianActorConfig
name = "gaussian_actor"
def __init__(
self,
config: SACConfig | None = None,
config: GaussianActorConfig | None = None,
):
super().__init__(config)
config.validate_features()
@@ -73,7 +80,9 @@ class SACPolicy(
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
raise NotImplementedError(
"GaussianActorPolicy does not support action chunking. It returns single actions!"
)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@@ -133,9 +142,9 @@ class SACPolicy(
def _init_encoders(self):
"""Initialize shared or separate encoders for actor and critic."""
self.shared_encoder = self.config.shared_encoder
self.encoder_critic = SACObservationEncoder(self.config)
self.encoder_critic = GaussianActorObservationEncoder(self.config)
self.encoder_actor = (
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
self.encoder_critic if self.shared_encoder else GaussianActorObservationEncoder(self.config)
)
def _init_actor(self, continuous_action_dim):
@@ -155,10 +164,10 @@ class SACPolicy(
self.target_entropy = -np.prod(dim) / 2
class SACObservationEncoder(nn.Module):
class GaussianActorObservationEncoder(nn.Module):
"""Encode image and/or state vector observations."""
def __init__(self, config: SACConfig) -> None:
def __init__(self, config: GaussianActorConfig) -> None:
super().__init__()
self.config = config
self._init_image_layers()
@@ -411,7 +420,7 @@ class DiscreteCritic(nn.Module):
class Policy(nn.Module):
def __init__(
self,
encoder: SACObservationEncoder,
encoder: GaussianActorObservationEncoder,
network: nn.Module,
action_dim: int,
std_min: float = -5,
@@ -422,7 +431,7 @@ class Policy(nn.Module):
encoder_is_shared: bool = False,
):
super().__init__()
self.encoder: SACObservationEncoder = encoder
self.encoder: GaussianActorObservationEncoder = encoder
self.network = network
self.action_dim = action_dim
self.std_min = std_min
@@ -496,7 +505,7 @@ class Policy(nn.Module):
class DefaultImageEncoder(nn.Module):
def __init__(self, config: SACConfig):
def __init__(self, config: GaussianActorConfig):
super().__init__()
image_key = next(key for key in config.input_features if is_image_feature(key))
self.image_enc_layers = nn.Sequential(
@@ -542,12 +551,12 @@ def freeze_image_encoder(image_encoder: nn.Module):
class PretrainedImageEncoder(nn.Module):
def __init__(self, config: SACConfig):
def __init__(self, config: GaussianActorConfig):
super().__init__()
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
def _load_pretrained_vision_encoder(self, config: SACConfig):
def _load_pretrained_vision_encoder(self, config: GaussianActorConfig):
"""Set up CNN encoder"""
from transformers import AutoModel

View File

@@ -32,18 +32,18 @@ from lerobot.processor import (
)
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from .configuration_sac import SACConfig
from .configuration_gaussian_actor import GaussianActorConfig
def make_sac_pre_post_processors(
config: SACConfig,
def make_gaussian_actor_pre_post_processors(
config: GaussianActorConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for the SAC policy.
Constructs pre-processor and post-processor pipelines for the Gaussian actor policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
@@ -56,7 +56,7 @@ def make_sac_pre_post_processors(
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the SAC policy.
config: The configuration object for the tanh-Gaussian policy.
dataset_stats: A dictionary of statistics for normalization.
Returns:

View File

@@ -557,7 +557,7 @@ class RewardClassifierProcessorStep(ProcessorStep):
def __post_init__(self):
"""Initializes the reward classifier model after the dataclass is created."""
if self.pretrained_path is not None:
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
self.reward_classifier = Classifier.from_pretrained(self.pretrained_path)
self.reward_classifier.to(self.device)

View File

@@ -251,7 +251,7 @@ def act_with_policy(
logging.info("make_policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy instance
### To avoid sending a policy object through the port, we create a policy instance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
policy = make_policy(
cfg=cfg.policy,

View File

@@ -19,7 +19,10 @@ from typing import TYPE_CHECKING
import torch
from lerobot.policies.sac.configuration_sac import CriticNetworkConfig, SACConfig
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
CriticNetworkConfig,
GaussianActorConfig,
)
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
if TYPE_CHECKING:
@@ -32,7 +35,7 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
"""SAC algorithm hyperparameters."""
# Policy config
sac_config: SACConfig
sac_config: GaussianActorConfig
# Optimizer learning rates
actor_lr: float = 3e-4
@@ -59,7 +62,7 @@ class SACAlgorithmConfig(RLAlgorithmConfig):
grad_clip_norm: float = 40.0
@classmethod
def from_policy_config(cls, policy_cfg: SACConfig) -> SACAlgorithmConfig:
def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig:
"""Build an algorithm config by copying hyperparameters from the policy config."""
return cls(
actor_lr=policy_cfg.actor_lr,

View File

@@ -26,12 +26,12 @@ import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.optim import Optimizer
from lerobot.policies.sac.modeling_sac import (
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import (
DISCRETE_DIMENSION_INDEX,
MLP,
DiscreteCritic,
SACObservationEncoder,
SACPolicy,
GaussianActorObservationEncoder,
GaussianActorPolicy,
orthogonal_init,
)
from lerobot.policies.utils import get_device_from_parameters
@@ -50,7 +50,7 @@ class SACAlgorithm(RLAlgorithm):
def __init__(
self,
policy: SACPolicy,
policy: GaussianActorPolicy,
config: SACAlgorithmConfig,
):
self.config = config
@@ -100,7 +100,9 @@ class SACAlgorithm(RLAlgorithm):
self.discrete_critic, self.discrete_critic_target = self._init_discrete_critics(encoder)
self.policy.discrete_critic = self.discrete_critic
def _init_discrete_critics(self, encoder: SACObservationEncoder) -> tuple[DiscreteCritic, DiscreteCritic]:
def _init_discrete_critics(
self, encoder: GaussianActorObservationEncoder
) -> tuple[DiscreteCritic, DiscreteCritic]:
"""Build discrete critic ensemble and target networks."""
discrete_critic = DiscreteCritic(
encoder=encoder,
@@ -557,7 +559,7 @@ class CriticEnsemble(nn.Module):
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
Args:
encoder (SACObservationEncoder): encoder for observations.
encoder (GaussianActorObservationEncoder): encoder for observations.
ensemble (List[CriticHead]): list of critic heads.
init_final (float | None): optional initializer scale for final layers.
@@ -566,7 +568,7 @@ class CriticEnsemble(nn.Module):
def __init__(
self,
encoder: SACObservationEncoder,
encoder: GaussianActorObservationEncoder,
ensemble: list[CriticHead],
init_final: float | None = None,
):

View File

@@ -39,8 +39,8 @@ For more details, see the [Physical Intelligence π₀ blog post](https://www.ph
π₀.₅ represents a significant evolution from π₀, developed by Physical Intelligence to address a big challenge in robotics: open-world generalization. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training.
For more details, see the [Physical Intelligence π₀.₅ blog post](https://www.physicalintelligence.company/blog/pi05).
{% elif model_name == "sac" %}
[Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) is an entropy-regularised actor-critic algorithm offering stable, sample-efficient learning in continuous-control environments.
{% elif model_name == "gaussian_actor" %}
This is a Gaussian Actor policy (Gaussian policy with a tanh squash) — the policy-side component used by [Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) and related maximum-entropy continuous-control algorithms.
{% elif model_name == "reward_classifier" %}
A reward classifier is a lightweight neural network that scores observations or trajectories for task success, providing a learned reward signal or offline evaluation when explicit rewards are unavailable.
{% else %}

View File

@@ -17,8 +17,8 @@
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
from lerobot.policies.gaussian_actor.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import ClassifierOutput
from lerobot.utils.constants import OBS_IMAGE, REWARD
from tests.utils import skip_if_package_missing
@@ -38,7 +38,7 @@ def test_classifier_output():
@skip_if_package_missing("transformers")
def test_binary_classifier_with_default_params():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
config = RewardClassifierConfig()
config.input_features = {
@@ -79,7 +79,7 @@ def test_binary_classifier_with_default_params():
@skip_if_package_missing("transformers")
def test_multiclass_classifier():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
num_classes = 5
config = RewardClassifierConfig()
@@ -118,7 +118,7 @@ def test_multiclass_classifier():
@skip_if_package_missing("transformers")
def test_default_device():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
config = RewardClassifierConfig()
assert config.device == "cpu"
@@ -130,7 +130,7 @@ def test_default_device():
@skip_if_package_missing("transformers")
def test_explicit_device_setup():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
config = RewardClassifierConfig(device="cpu")
assert config.device == "cpu"

View File

@@ -17,19 +17,19 @@
import pytest
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.configuration_sac import (
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
ActorLearnerConfig,
ActorNetworkConfig,
ConcurrencyConfig,
CriticNetworkConfig,
GaussianActorConfig,
PolicyConfig,
SACConfig,
)
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
def test_sac_config_default_initialization():
config = SACConfig()
def test_gaussian_actor_config_default_initialization():
config = GaussianActorConfig()
assert config.normalization_mapping == {
"VISUAL": NormalizationMode.MEAN_STD,
@@ -175,8 +175,8 @@ def test_concurrency_config():
assert config.learner == "threads"
def test_sac_config_custom_initialization():
config = SACConfig(
def test_gaussian_actor_config_custom_initialization():
config = GaussianActorConfig(
device="cpu",
discount=0.95,
temperature_init=0.5,
@@ -190,7 +190,7 @@ def test_sac_config_custom_initialization():
def test_validate_features():
config = SACConfig(
config = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
@@ -198,7 +198,7 @@ def test_validate_features():
def test_validate_features_missing_observation():
config = SACConfig(
config = GaussianActorConfig(
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
@@ -209,7 +209,7 @@ def test_validate_features_missing_observation():
def test_validate_features_missing_action():
config = SACConfig(
config = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)

View File

@@ -22,8 +22,8 @@ import torch
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import MLP, GaussianActorPolicy
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.utils.random_utils import seeded_context, set_seed
@@ -81,9 +81,9 @@ def test_mlp_with_custom_final_activation():
assert (y >= -1).all() and (y <= 1).all()
def test_sac_policy_with_default_args():
def test_gaussian_actor_policy_with_default_args():
with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"):
SACPolicy()
GaussianActorPolicy()
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
@@ -142,12 +142,12 @@ def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: i
def create_default_config(
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
) -> SACConfig:
) -> GaussianActorConfig:
action_dim = continuous_action_dim
if has_discrete_action:
action_dim += 1
config = SACConfig(
config = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
dataset_stats={
@@ -167,7 +167,7 @@ def create_default_config(
def create_config_with_visual_input(
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
) -> SACConfig:
) -> GaussianActorConfig:
config = create_default_config(
state_dim=state_dim,
continuous_action_dim=continuous_action_dim,
@@ -186,9 +186,9 @@ def create_config_with_visual_input(
return config
def _make_algorithm(config: SACConfig) -> tuple[SACAlgorithm, SACPolicy]:
def _make_algorithm(config: GaussianActorConfig) -> tuple[SACAlgorithm, GaussianActorPolicy]:
"""Helper to create policy + algorithm pair for tests that need critics."""
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
policy.train()
algo_config = SACAlgorithmConfig.from_policy_config(config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
@@ -197,9 +197,9 @@ def _make_algorithm(config: SACConfig) -> tuple[SACAlgorithm, SACPolicy]:
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
def test_gaussian_actor_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
policy.eval()
with torch.no_grad():
@@ -209,11 +209,11 @@ def test_sac_policy_select_action(batch_size: int, state_dim: int, action_dim: i
assert selected_action.shape[-1] == action_dim
def test_sac_policy_select_action_with_discrete():
def test_gaussian_actor_policy_select_action_with_discrete():
"""select_action should return continuous + discrete actions."""
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.num_discrete_actions = 3
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
policy.eval()
with torch.no_grad():
@@ -225,9 +225,9 @@ def test_sac_policy_select_action_with_discrete():
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_policy_forward(batch_size: int, state_dim: int, action_dim: int):
def test_gaussian_actor_policy_forward(batch_size: int, state_dim: int, action_dim: int):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
policy.eval()
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
@@ -307,7 +307,7 @@ def test_sac_training_with_visual_input(batch_size: int, state_dim: int, action_
[(1, 6, 6, "lerobot/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
)
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
def test_sac_policy_with_pretrained_encoder(
def test_gaussian_actor_policy_with_pretrained_encoder(
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
@@ -415,7 +415,7 @@ def test_sac_algorithm_target_entropy_with_discrete_action():
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
config.num_discrete_actions = 5
algo_config = SACAlgorithmConfig.from_policy_config(config)
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
assert algorithm.target_entropy == -3.5
@@ -425,7 +425,7 @@ def test_sac_algorithm_temperature():
config = create_default_config(continuous_action_dim=10, state_dim=10)
algo_config = SACAlgorithmConfig.from_policy_config(config)
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
assert algorithm.temperature == pytest.approx(1.0)
@@ -437,7 +437,7 @@ def test_sac_algorithm_update_target_network():
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.critic_target_update_weight = 1.0
algo_config = SACAlgorithmConfig.from_policy_config(config)
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
for p in algorithm.critic_ensemble.parameters():
@@ -472,7 +472,7 @@ def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
algorithm.optimizers["critic"].step()
def test_sac_policy_save_and_load(tmp_path):
def test_gaussian_actor_policy_save_and_load(tmp_path):
"""Test that the policy can be saved and loaded from pretrained."""
root = tmp_path / "test_sac_save_and_load"
@@ -481,10 +481,10 @@ def test_sac_policy_save_and_load(tmp_path):
batch_size = 2
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
policy.eval()
policy.save_pretrained(root)
loaded_policy = SACPolicy.from_pretrained(root, config=config)
loaded_policy = GaussianActorPolicy.from_pretrained(root, config=config)
loaded_policy.eval()
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
@@ -503,7 +503,7 @@ def test_sac_policy_save_and_load(tmp_path):
assert torch.allclose(actions, loaded_actions)
def test_sac_policy_save_and_load_with_discrete_critic(tmp_path):
def test_gaussian_actor_policy_save_and_load_with_discrete_critic(tmp_path):
"""Discrete critic should be saved/loaded as part of the policy."""
root = tmp_path / "test_sac_save_and_load_discrete"
@@ -512,11 +512,11 @@ def test_sac_policy_save_and_load_with_discrete_critic(tmp_path):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
config.num_discrete_actions = 3
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
policy.eval()
policy.save_pretrained(root)
loaded_policy = SACPolicy.from_pretrained(root, config=config)
loaded_policy = GaussianActorPolicy.from_pretrained(root, config=config)
loaded_policy.eval()
assert loaded_policy.discrete_critic is not None

View File

@@ -21,8 +21,8 @@ import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
from lerobot.policies.gaussian_actor.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.gaussian_actor.reward_model.processor_classifier import make_classifier_processor
from lerobot.processor import (
DataProcessorPipeline,
DeviceProcessorStep,

View File

@@ -21,8 +21,8 @@ import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.policies.gaussian_actor.processor_gaussian_actor import make_gaussian_actor_pre_post_processors
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DataProcessorPipeline,
@@ -38,7 +38,7 @@ from lerobot.utils.constants import ACTION, OBS_STATE
def create_default_config():
"""Create a default SAC configuration for testing."""
config = SACConfig()
config = GaussianActorConfig()
config.input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
}
@@ -66,7 +66,7 @@ def test_make_sac_processor_basic():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -88,12 +88,12 @@ def test_make_sac_processor_basic():
assert isinstance(postprocessor.steps[1], DeviceProcessorStep)
def test_sac_processor_normalization_modes():
def test_gaussian_actor_processor_normalization_modes():
"""Test that SAC processor correctly handles different normalization modes."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -121,13 +121,13 @@ def test_sac_processor_normalization_modes():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_sac_processor_cuda():
def test_gaussian_actor_processor_cuda():
"""Test SAC processor with CUDA device."""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -153,13 +153,13 @@ def test_sac_processor_cuda():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_sac_processor_accelerate_scenario():
def test_gaussian_actor_processor_accelerate_scenario():
"""Test SAC processor in simulated Accelerate scenario."""
config = create_default_config()
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -180,13 +180,13 @@ def test_sac_processor_accelerate_scenario():
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
def test_sac_processor_multi_gpu():
def test_gaussian_actor_processor_multi_gpu():
"""Test SAC processor with multi-GPU setup."""
config = create_default_config()
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -206,11 +206,11 @@ def test_sac_processor_multi_gpu():
assert processed[TransitionKey.ACTION.value].device == device
def test_sac_processor_without_stats():
def test_gaussian_actor_processor_without_stats():
"""Test SAC processor creation without dataset statistics."""
config = create_default_config()
preprocessor, postprocessor = make_sac_pre_post_processors(config, dataset_stats=None)
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(config, dataset_stats=None)
# Should still create processors
assert preprocessor is not None
@@ -226,12 +226,12 @@ def test_sac_processor_without_stats():
assert processed is not None
def test_sac_processor_save_and_load():
def test_gaussian_actor_processor_save_and_load():
"""Test saving and loading SAC processor."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -257,14 +257,14 @@ def test_sac_processor_save_and_load():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_sac_processor_mixed_precision():
def test_gaussian_actor_processor_mixed_precision():
"""Test SAC processor with mixed precision."""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
# Create processor
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -304,12 +304,12 @@ def test_sac_processor_mixed_precision():
assert processed[TransitionKey.ACTION.value].dtype == torch.float16
def test_sac_processor_batch_data():
def test_gaussian_actor_processor_batch_data():
"""Test SAC processor with batched data."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -329,12 +329,12 @@ def test_sac_processor_batch_data():
assert processed[TransitionKey.ACTION.value].shape == (batch_size, 5)
def test_sac_processor_edge_cases():
def test_gaussian_actor_processor_edge_cases():
"""Test SAC processor with edge cases."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -358,13 +358,13 @@ def test_sac_processor_edge_cases():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_sac_processor_bfloat16_device_float32_normalizer():
def test_gaussian_actor_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
preprocessor, _ = make_sac_pre_post_processors(
preprocessor, _ = make_gaussian_actor_pre_post_processors(
config,
stats,
)

View File

@@ -28,7 +28,7 @@ from torch.multiprocessing import Event, Queue
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
from lerobot.utils.transition import Transition
from tests.utils import skip_if_package_missing
@@ -81,7 +81,7 @@ def cfg():
port = find_free_port()
policy_cfg = SACConfig()
policy_cfg = GaussianActorConfig()
policy_cfg.actor_learner_config.learner_host = "127.0.0.1"
policy_cfg.actor_learner_config.learner_port = port
policy_cfg.concurrency.actor = "threads"
@@ -312,7 +312,7 @@ def test_learner_algorithm_wiring():
"""Verify that make_algorithm constructs an SACAlgorithm from config,
make_optimizers_and_scheduler() creates the right optimizers, update() works, and
get_weights() output is serializable."""
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm
from lerobot.transport.utils import state_to_bytes
@@ -320,7 +320,7 @@ def test_learner_algorithm_wiring():
state_dim = 10
action_dim = 6
sac_cfg = SACConfig(
sac_cfg = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
@@ -331,7 +331,7 @@ def test_learner_algorithm_wiring():
)
sac_cfg.validate_features()
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
@@ -399,13 +399,13 @@ def test_learner_algorithm_wiring():
def test_initial_and_periodic_weight_push_consistency():
"""Both initial and periodic weight pushes should use algorithm.get_weights()
and produce identical structures."""
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
state_dim = 10
action_dim = 6
sac_cfg = SACConfig(
sac_cfg = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
@@ -416,7 +416,7 @@ def test_initial_and_periodic_weight_push_consistency():
)
sac_cfg.validate_features()
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
algorithm.make_optimizers_and_scheduler()
@@ -437,13 +437,13 @@ def test_initial_and_periodic_weight_push_consistency():
def test_actor_side_algorithm_select_action_and_load_weights():
"""Simulate actor: create algorithm without optimizers, select_action, load_weights."""
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm
state_dim = 10
action_dim = 6
sac_cfg = SACConfig(
sac_cfg = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
@@ -455,7 +455,7 @@ def test_actor_side_algorithm_select_action_and_load_weights():
sac_cfg.validate_features()
# Actor side: no optimizers
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
policy.eval()
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert isinstance(algorithm, SACAlgorithm)

View File

@@ -22,8 +22,8 @@ pytest.importorskip("grpc")
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
@@ -47,8 +47,8 @@ def _make_sac_config(
utd_ratio: int = 1,
policy_update_freq: int = 1,
with_images: bool = False,
) -> SACConfig:
config = SACConfig(
) -> GaussianActorConfig:
config = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
@@ -79,7 +79,7 @@ def _make_algorithm(
policy_update_freq: int = 1,
num_discrete_actions: int | None = None,
with_images: bool = False,
) -> tuple[SACAlgorithm, SACPolicy]:
) -> tuple[SACAlgorithm, GaussianActorPolicy]:
sac_cfg = _make_sac_config(
state_dim=state_dim,
action_dim=action_dim,
@@ -88,7 +88,7 @@ def _make_algorithm(
num_discrete_actions=num_discrete_actions,
with_images=with_images,
)
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
@@ -349,7 +349,7 @@ def test_optimization_step_can_be_set_for_resume():
def test_make_algorithm_returns_sac_for_sac_policy():
sac_cfg = _make_sac_config()
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {}
@@ -358,7 +358,7 @@ def test_make_algorithm_returns_sac_for_sac_policy():
def test_make_optimizers_creates_expected_keys():
"""make_optimizers_and_scheduler() should populate the algorithm with Adam optimizers."""
sac_cfg = _make_sac_config()
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
optimizers = algorithm.make_optimizers_and_scheduler()
assert "actor" in optimizers
@@ -371,7 +371,7 @@ def test_make_optimizers_creates_expected_keys():
def test_actor_side_no_optimizers():
"""Actor-side usage: no optimizers needed, make_optimizers_and_scheduler is not called."""
sac_cfg = _make_sac_config()
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {}
@@ -379,7 +379,7 @@ def test_actor_side_no_optimizers():
def test_make_algorithm_copies_config_fields():
sac_cfg = _make_sac_config(utd_ratio=5, policy_update_freq=3)
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert algorithm.config.utd_ratio == 5
assert algorithm.config.policy_update_freq == 3
@@ -404,7 +404,7 @@ def test_load_weights_round_trip():
algo_src.update(_batch_iterator())
sac_cfg = _make_sac_config(state_dim=10, action_dim=6)
policy_dst = SACPolicy(config=sac_cfg)
policy_dst = GaussianActorPolicy(config=sac_cfg)
algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config)
weights = algo_src.get_weights()
@@ -423,7 +423,7 @@ def test_load_weights_round_trip_with_discrete_critic():
algo_src.update(_batch_iterator(action_dim=7))
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
policy_dst = SACPolicy(config=sac_cfg)
policy_dst = GaussianActorPolicy(config=sac_cfg)
algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config)
weights = algo_src.get_weights()
@@ -470,7 +470,7 @@ def test_build_algorithm_via_config():
"""SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm."""
sac_cfg = _make_sac_config(utd_ratio=2)
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = algo_config.build_algorithm(policy)
assert isinstance(algorithm, SACAlgorithm)
@@ -480,6 +480,6 @@ def test_build_algorithm_via_config():
def test_make_algorithm_uses_build_algorithm():
"""make_algorithm should delegate to config.build_algorithm (no hardcoded if/else)."""
sac_cfg = _make_sac_config()
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert isinstance(algorithm, SACAlgorithm)