refactor(rl): add make_algorithm_config function for RLAlgorithmConfig instantiation

This commit is contained in:
Khalil Meftah
2026-04-27 15:55:16 +02:00
parent 9ce9e01469
commit fd5a788120
3 changed files with 30 additions and 3 deletions

View File

@@ -24,7 +24,10 @@ require_package("grpcio", extra="hilserl", import_name="grpc")
from .algorithms.base import RLAlgorithm as RLAlgorithm
from .algorithms.configs import RLAlgorithmConfig as RLAlgorithmConfig, TrainingStats as TrainingStats
from .algorithms.factory import make_algorithm as make_algorithm
from .algorithms.factory import (
make_algorithm as make_algorithm,
make_algorithm_config as make_algorithm_config,
)
from .algorithms.sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig
from .buffer import ReplayBuffer as ReplayBuffer
from .data_sources import DataMixer as DataMixer, OnlineOfflineMixer as OnlineOfflineMixer
@@ -35,6 +38,7 @@ __all__ = [
"RLAlgorithmConfig",
"TrainingStats",
"make_algorithm",
"make_algorithm_config",
"SACAlgorithm",
"SACAlgorithmConfig",
"RLTrainer",

View File

@@ -20,5 +20,28 @@ from lerobot.rl.algorithms.base import RLAlgorithm
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
def make_algorithm_config(algorithm_type: str, **kwargs) -> RLAlgorithmConfig:
"""Instantiate an :class:`RLAlgorithmConfig` from its registered type name.
Args:
algorithm_type: Registry key of the algorithm (e.g. ``"sac"``).
**kwargs: Keyword arguments forwarded to the config class constructor.
Returns:
An instance of the matching ``RLAlgorithmConfig`` subclass.
Raises:
ValueError: If ``algorithm_type`` is not registered.
"""
try:
cls = RLAlgorithmConfig.get_choice_class(algorithm_type)
except KeyError as err:
raise ValueError(
f"Algorithm type '{algorithm_type}' is not registered. "
f"Available: {list(RLAlgorithmConfig.get_known_choices().keys())}"
) from err
return cls(**kwargs)
def make_algorithm(cfg: RLAlgorithmConfig, policy: torch.nn.Module) -> RLAlgorithm:
return cfg.build_algorithm(policy)

View File

@@ -21,6 +21,7 @@ from dataclasses import dataclass
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
from lerobot.rl.algorithms.factory import make_algorithm_config
from lerobot.rl.algorithms.sac import SACAlgorithmConfig # noqa: F401
@@ -45,8 +46,7 @@ class TrainRLServerPipelineConfig(TrainPipelineConfig):
super().validate()
if self.algorithm is None:
sac_cls = RLAlgorithmConfig.get_choice_class("sac")
self.algorithm = sac_cls()
self.algorithm = make_algorithm_config("sac")
# The pipeline owns the policy config; inject it so the algorithm can
# introspect policy architecture (e.g. ``num_discrete_actions``).