refactor(policies): rename policies/sac → policies/gaussian_actor

This commit is contained in:
Khalil Meftah
2026-04-23 19:13:18 +02:00
parent 8065bf15c7
commit 06255996ea
24 changed files with 185 additions and 168 deletions

View File

@@ -17,8 +17,8 @@
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
from lerobot.policies.gaussian_actor.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import ClassifierOutput
from lerobot.utils.constants import OBS_IMAGE, REWARD
from tests.utils import skip_if_package_missing
@@ -38,7 +38,7 @@ def test_classifier_output():
@skip_if_package_missing("transformers")
def test_binary_classifier_with_default_params():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
config = RewardClassifierConfig()
config.input_features = {
@@ -79,7 +79,7 @@ def test_binary_classifier_with_default_params():
@skip_if_package_missing("transformers")
def test_multiclass_classifier():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
num_classes = 5
config = RewardClassifierConfig()
@@ -118,7 +118,7 @@ def test_multiclass_classifier():
@skip_if_package_missing("transformers")
def test_default_device():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
config = RewardClassifierConfig()
assert config.device == "cpu"
@@ -130,7 +130,7 @@ def test_default_device():
@skip_if_package_missing("transformers")
def test_explicit_device_setup():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
config = RewardClassifierConfig(device="cpu")
assert config.device == "cpu"

View File

@@ -17,19 +17,19 @@
import pytest
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.configuration_sac import (
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
ActorLearnerConfig,
ActorNetworkConfig,
ConcurrencyConfig,
CriticNetworkConfig,
GaussianActorConfig,
PolicyConfig,
SACConfig,
)
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
def test_sac_config_default_initialization():
config = SACConfig()
def test_gaussian_actor_config_default_initialization():
config = GaussianActorConfig()
assert config.normalization_mapping == {
"VISUAL": NormalizationMode.MEAN_STD,
@@ -175,8 +175,8 @@ def test_concurrency_config():
assert config.learner == "threads"
def test_sac_config_custom_initialization():
config = SACConfig(
def test_gaussian_actor_config_custom_initialization():
config = GaussianActorConfig(
device="cpu",
discount=0.95,
temperature_init=0.5,
@@ -190,7 +190,7 @@ def test_sac_config_custom_initialization():
def test_validate_features():
config = SACConfig(
config = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
@@ -198,7 +198,7 @@ def test_validate_features():
def test_validate_features_missing_observation():
config = SACConfig(
config = GaussianActorConfig(
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
@@ -209,7 +209,7 @@ def test_validate_features_missing_observation():
def test_validate_features_missing_action():
config = SACConfig(
config = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)

View File

@@ -22,8 +22,8 @@ import torch
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import MLP, GaussianActorPolicy
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.utils.random_utils import seeded_context, set_seed
@@ -81,9 +81,9 @@ def test_mlp_with_custom_final_activation():
assert (y >= -1).all() and (y <= 1).all()
def test_sac_policy_with_default_args():
def test_gaussian_actor_policy_with_default_args():
with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"):
SACPolicy()
GaussianActorPolicy()
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
@@ -142,12 +142,12 @@ def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: i
def create_default_config(
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
) -> SACConfig:
) -> GaussianActorConfig:
action_dim = continuous_action_dim
if has_discrete_action:
action_dim += 1
config = SACConfig(
config = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
dataset_stats={
@@ -167,7 +167,7 @@ def create_default_config(
def create_config_with_visual_input(
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
) -> SACConfig:
) -> GaussianActorConfig:
config = create_default_config(
state_dim=state_dim,
continuous_action_dim=continuous_action_dim,
@@ -186,9 +186,9 @@ def create_config_with_visual_input(
return config
def _make_algorithm(config: SACConfig) -> tuple[SACAlgorithm, SACPolicy]:
def _make_algorithm(config: GaussianActorConfig) -> tuple[SACAlgorithm, GaussianActorPolicy]:
"""Helper to create policy + algorithm pair for tests that need critics."""
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
policy.train()
algo_config = SACAlgorithmConfig.from_policy_config(config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
@@ -197,9 +197,9 @@ def _make_algorithm(config: SACConfig) -> tuple[SACAlgorithm, SACPolicy]:
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
def test_gaussian_actor_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
policy.eval()
with torch.no_grad():
@@ -209,11 +209,11 @@ def test_sac_policy_select_action(batch_size: int, state_dim: int, action_dim: i
assert selected_action.shape[-1] == action_dim
def test_sac_policy_select_action_with_discrete():
def test_gaussian_actor_policy_select_action_with_discrete():
"""select_action should return continuous + discrete actions."""
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.num_discrete_actions = 3
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
policy.eval()
with torch.no_grad():
@@ -225,9 +225,9 @@ def test_sac_policy_select_action_with_discrete():
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_policy_forward(batch_size: int, state_dim: int, action_dim: int):
def test_gaussian_actor_policy_forward(batch_size: int, state_dim: int, action_dim: int):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
policy.eval()
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
@@ -307,7 +307,7 @@ def test_sac_training_with_visual_input(batch_size: int, state_dim: int, action_
[(1, 6, 6, "lerobot/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
)
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
def test_sac_policy_with_pretrained_encoder(
def test_gaussian_actor_policy_with_pretrained_encoder(
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
@@ -415,7 +415,7 @@ def test_sac_algorithm_target_entropy_with_discrete_action():
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
config.num_discrete_actions = 5
algo_config = SACAlgorithmConfig.from_policy_config(config)
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
assert algorithm.target_entropy == -3.5
@@ -425,7 +425,7 @@ def test_sac_algorithm_temperature():
config = create_default_config(continuous_action_dim=10, state_dim=10)
algo_config = SACAlgorithmConfig.from_policy_config(config)
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
assert algorithm.temperature == pytest.approx(1.0)
@@ -437,7 +437,7 @@ def test_sac_algorithm_update_target_network():
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.critic_target_update_weight = 1.0
algo_config = SACAlgorithmConfig.from_policy_config(config)
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
for p in algorithm.critic_ensemble.parameters():
@@ -472,7 +472,7 @@ def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
algorithm.optimizers["critic"].step()
def test_sac_policy_save_and_load(tmp_path):
def test_gaussian_actor_policy_save_and_load(tmp_path):
"""Test that the policy can be saved and loaded from pretrained."""
root = tmp_path / "test_sac_save_and_load"
@@ -481,10 +481,10 @@ def test_sac_policy_save_and_load(tmp_path):
batch_size = 2
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
policy.eval()
policy.save_pretrained(root)
loaded_policy = SACPolicy.from_pretrained(root, config=config)
loaded_policy = GaussianActorPolicy.from_pretrained(root, config=config)
loaded_policy.eval()
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
@@ -503,7 +503,7 @@ def test_sac_policy_save_and_load(tmp_path):
assert torch.allclose(actions, loaded_actions)
def test_sac_policy_save_and_load_with_discrete_critic(tmp_path):
def test_gaussian_actor_policy_save_and_load_with_discrete_critic(tmp_path):
"""Discrete critic should be saved/loaded as part of the policy."""
root = tmp_path / "test_sac_save_and_load_discrete"
@@ -512,11 +512,11 @@ def test_sac_policy_save_and_load_with_discrete_critic(tmp_path):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
config.num_discrete_actions = 3
policy = SACPolicy(config=config)
policy = GaussianActorPolicy(config=config)
policy.eval()
policy.save_pretrained(root)
loaded_policy = SACPolicy.from_pretrained(root, config=config)
loaded_policy = GaussianActorPolicy.from_pretrained(root, config=config)
loaded_policy.eval()
assert loaded_policy.discrete_critic is not None

View File

@@ -21,8 +21,8 @@ import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
from lerobot.policies.gaussian_actor.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.gaussian_actor.reward_model.processor_classifier import make_classifier_processor
from lerobot.processor import (
DataProcessorPipeline,
DeviceProcessorStep,

View File

@@ -21,8 +21,8 @@ import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.policies.gaussian_actor.processor_gaussian_actor import make_gaussian_actor_pre_post_processors
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DataProcessorPipeline,
@@ -38,7 +38,7 @@ from lerobot.utils.constants import ACTION, OBS_STATE
def create_default_config():
"""Create a default SAC configuration for testing."""
config = SACConfig()
config = GaussianActorConfig()
config.input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
}
@@ -66,7 +66,7 @@ def test_make_sac_processor_basic():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -88,12 +88,12 @@ def test_make_sac_processor_basic():
assert isinstance(postprocessor.steps[1], DeviceProcessorStep)
def test_sac_processor_normalization_modes():
def test_gaussian_actor_processor_normalization_modes():
"""Test that SAC processor correctly handles different normalization modes."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -121,13 +121,13 @@ def test_sac_processor_normalization_modes():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_sac_processor_cuda():
def test_gaussian_actor_processor_cuda():
"""Test SAC processor with CUDA device."""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -153,13 +153,13 @@ def test_sac_processor_cuda():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_sac_processor_accelerate_scenario():
def test_gaussian_actor_processor_accelerate_scenario():
"""Test SAC processor in simulated Accelerate scenario."""
config = create_default_config()
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -180,13 +180,13 @@ def test_sac_processor_accelerate_scenario():
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
def test_sac_processor_multi_gpu():
def test_gaussian_actor_processor_multi_gpu():
"""Test SAC processor with multi-GPU setup."""
config = create_default_config()
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -206,11 +206,11 @@ def test_sac_processor_multi_gpu():
assert processed[TransitionKey.ACTION.value].device == device
def test_sac_processor_without_stats():
def test_gaussian_actor_processor_without_stats():
"""Test SAC processor creation without dataset statistics."""
config = create_default_config()
preprocessor, postprocessor = make_sac_pre_post_processors(config, dataset_stats=None)
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(config, dataset_stats=None)
# Should still create processors
assert preprocessor is not None
@@ -226,12 +226,12 @@ def test_sac_processor_without_stats():
assert processed is not None
def test_sac_processor_save_and_load():
def test_gaussian_actor_processor_save_and_load():
"""Test saving and loading SAC processor."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -257,14 +257,14 @@ def test_sac_processor_save_and_load():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_sac_processor_mixed_precision():
def test_gaussian_actor_processor_mixed_precision():
"""Test SAC processor with mixed precision."""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
# Create processor
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -304,12 +304,12 @@ def test_sac_processor_mixed_precision():
assert processed[TransitionKey.ACTION.value].dtype == torch.float16
def test_sac_processor_batch_data():
def test_gaussian_actor_processor_batch_data():
"""Test SAC processor with batched data."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -329,12 +329,12 @@ def test_sac_processor_batch_data():
assert processed[TransitionKey.ACTION.value].shape == (batch_size, 5)
def test_sac_processor_edge_cases():
def test_gaussian_actor_processor_edge_cases():
"""Test SAC processor with edge cases."""
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_sac_pre_post_processors(
preprocessor, postprocessor = make_gaussian_actor_pre_post_processors(
config,
stats,
)
@@ -358,13 +358,13 @@ def test_sac_processor_edge_cases():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_sac_processor_bfloat16_device_float32_normalizer():
def test_gaussian_actor_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
preprocessor, _ = make_sac_pre_post_processors(
preprocessor, _ = make_gaussian_actor_pre_post_processors(
config,
stats,
)

View File

@@ -28,7 +28,7 @@ from torch.multiprocessing import Event, Queue
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
from lerobot.utils.transition import Transition
from tests.utils import skip_if_package_missing
@@ -81,7 +81,7 @@ def cfg():
port = find_free_port()
policy_cfg = SACConfig()
policy_cfg = GaussianActorConfig()
policy_cfg.actor_learner_config.learner_host = "127.0.0.1"
policy_cfg.actor_learner_config.learner_port = port
policy_cfg.concurrency.actor = "threads"
@@ -312,7 +312,7 @@ def test_learner_algorithm_wiring():
"""Verify that make_algorithm constructs an SACAlgorithm from config,
make_optimizers_and_scheduler() creates the right optimizers, update() works, and
get_weights() output is serializable."""
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm
from lerobot.transport.utils import state_to_bytes
@@ -320,7 +320,7 @@ def test_learner_algorithm_wiring():
state_dim = 10
action_dim = 6
sac_cfg = SACConfig(
sac_cfg = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
@@ -331,7 +331,7 @@ def test_learner_algorithm_wiring():
)
sac_cfg.validate_features()
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
@@ -399,13 +399,13 @@ def test_learner_algorithm_wiring():
def test_initial_and_periodic_weight_push_consistency():
"""Both initial and periodic weight pushes should use algorithm.get_weights()
and produce identical structures."""
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
state_dim = 10
action_dim = 6
sac_cfg = SACConfig(
sac_cfg = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
@@ -416,7 +416,7 @@ def test_initial_and_periodic_weight_push_consistency():
)
sac_cfg.validate_features()
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
algorithm.make_optimizers_and_scheduler()
@@ -437,13 +437,13 @@ def test_initial_and_periodic_weight_push_consistency():
def test_actor_side_algorithm_select_action_and_load_weights():
"""Simulate actor: create algorithm without optimizers, select_action, load_weights."""
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm
state_dim = 10
action_dim = 6
sac_cfg = SACConfig(
sac_cfg = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
@@ -455,7 +455,7 @@ def test_actor_side_algorithm_select_action_and_load_weights():
sac_cfg.validate_features()
# Actor side: no optimizers
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
policy.eval()
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert isinstance(algorithm, SACAlgorithm)

View File

@@ -22,8 +22,8 @@ pytest.importorskip("grpc")
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats
from lerobot.rl.algorithms.factory import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
@@ -47,8 +47,8 @@ def _make_sac_config(
utd_ratio: int = 1,
policy_update_freq: int = 1,
with_images: bool = False,
) -> SACConfig:
config = SACConfig(
) -> GaussianActorConfig:
config = GaussianActorConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
@@ -79,7 +79,7 @@ def _make_algorithm(
policy_update_freq: int = 1,
num_discrete_actions: int | None = None,
with_images: bool = False,
) -> tuple[SACAlgorithm, SACPolicy]:
) -> tuple[SACAlgorithm, GaussianActorPolicy]:
sac_cfg = _make_sac_config(
state_dim=state_dim,
action_dim=action_dim,
@@ -88,7 +88,7 @@ def _make_algorithm(
num_discrete_actions=num_discrete_actions,
with_images=with_images,
)
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
policy.train()
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
@@ -349,7 +349,7 @@ def test_optimization_step_can_be_set_for_resume():
def test_make_algorithm_returns_sac_for_sac_policy():
sac_cfg = _make_sac_config()
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {}
@@ -358,7 +358,7 @@ def test_make_algorithm_returns_sac_for_sac_policy():
def test_make_optimizers_creates_expected_keys():
"""make_optimizers_and_scheduler() should populate the algorithm with Adam optimizers."""
sac_cfg = _make_sac_config()
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
optimizers = algorithm.make_optimizers_and_scheduler()
assert "actor" in optimizers
@@ -371,7 +371,7 @@ def test_make_optimizers_creates_expected_keys():
def test_actor_side_no_optimizers():
"""Actor-side usage: no optimizers needed, make_optimizers_and_scheduler is not called."""
sac_cfg = _make_sac_config()
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {}
@@ -379,7 +379,7 @@ def test_actor_side_no_optimizers():
def test_make_algorithm_copies_config_fields():
sac_cfg = _make_sac_config(utd_ratio=5, policy_update_freq=3)
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert algorithm.config.utd_ratio == 5
assert algorithm.config.policy_update_freq == 3
@@ -404,7 +404,7 @@ def test_load_weights_round_trip():
algo_src.update(_batch_iterator())
sac_cfg = _make_sac_config(state_dim=10, action_dim=6)
policy_dst = SACPolicy(config=sac_cfg)
policy_dst = GaussianActorPolicy(config=sac_cfg)
algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config)
weights = algo_src.get_weights()
@@ -423,7 +423,7 @@ def test_load_weights_round_trip_with_discrete_critic():
algo_src.update(_batch_iterator(action_dim=7))
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
policy_dst = SACPolicy(config=sac_cfg)
policy_dst = GaussianActorPolicy(config=sac_cfg)
algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config)
weights = algo_src.get_weights()
@@ -470,7 +470,7 @@ def test_build_algorithm_via_config():
"""SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm."""
sac_cfg = _make_sac_config(utd_ratio=2)
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = algo_config.build_algorithm(policy)
assert isinstance(algorithm, SACAlgorithm)
@@ -480,6 +480,6 @@ def test_build_algorithm_via_config():
def test_make_algorithm_uses_build_algorithm():
"""make_algorithm should delegate to config.build_algorithm (no hardcoded if/else)."""
sac_cfg = _make_sac_config()
policy = SACPolicy(config=sac_cfg)
policy = GaussianActorPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert isinstance(algorithm, SACAlgorithm)