From fd5a788120ebd19f0196678da2707219e8ecafce Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Mon, 27 Apr 2026 15:55:16 +0200 Subject: [PATCH] refactor(rl): add make_algorithm_config function for RLAlgorithmConfig instantiation --- src/lerobot/rl/__init__.py | 6 +++++- src/lerobot/rl/algorithms/factory.py | 23 +++++++++++++++++++++++ src/lerobot/rl/train_rl.py | 4 ++-- 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/lerobot/rl/__init__.py b/src/lerobot/rl/__init__.py index 314781c9f..77efe1784 100644 --- a/src/lerobot/rl/__init__.py +++ b/src/lerobot/rl/__init__.py @@ -24,7 +24,10 @@ require_package("grpcio", extra="hilserl", import_name="grpc") from .algorithms.base import RLAlgorithm as RLAlgorithm from .algorithms.configs import RLAlgorithmConfig as RLAlgorithmConfig, TrainingStats as TrainingStats -from .algorithms.factory import make_algorithm as make_algorithm +from .algorithms.factory import ( + make_algorithm as make_algorithm, + make_algorithm_config as make_algorithm_config, +) from .algorithms.sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig from .buffer import ReplayBuffer as ReplayBuffer from .data_sources import DataMixer as DataMixer, OnlineOfflineMixer as OnlineOfflineMixer @@ -35,6 +38,7 @@ __all__ = [ "RLAlgorithmConfig", "TrainingStats", "make_algorithm", + "make_algorithm_config", "SACAlgorithm", "SACAlgorithmConfig", "RLTrainer", diff --git a/src/lerobot/rl/algorithms/factory.py b/src/lerobot/rl/algorithms/factory.py index 3704fe1e7..9e71473ec 100644 --- a/src/lerobot/rl/algorithms/factory.py +++ b/src/lerobot/rl/algorithms/factory.py @@ -20,5 +20,28 @@ from lerobot.rl.algorithms.base import RLAlgorithm from lerobot.rl.algorithms.configs import RLAlgorithmConfig +def make_algorithm_config(algorithm_type: str, **kwargs) -> RLAlgorithmConfig: + """Instantiate an :class:`RLAlgorithmConfig` from its registered type name. + + Args: + algorithm_type: Registry key of the algorithm (e.g. ``"sac"``). + **kwargs: Keyword arguments forwarded to the config class constructor. + + Returns: + An instance of the matching ``RLAlgorithmConfig`` subclass. + + Raises: + ValueError: If ``algorithm_type`` is not registered. + """ + try: + cls = RLAlgorithmConfig.get_choice_class(algorithm_type) + except KeyError as err: + raise ValueError( + f"Algorithm type '{algorithm_type}' is not registered. " + f"Available: {list(RLAlgorithmConfig.get_known_choices().keys())}" + ) from err + return cls(**kwargs) + + def make_algorithm(cfg: RLAlgorithmConfig, policy: torch.nn.Module) -> RLAlgorithm: return cfg.build_algorithm(policy) diff --git a/src/lerobot/rl/train_rl.py b/src/lerobot/rl/train_rl.py index 442856bf5..21ee3afb9 100644 --- a/src/lerobot/rl/train_rl.py +++ b/src/lerobot/rl/train_rl.py @@ -21,6 +21,7 @@ from dataclasses import dataclass from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.rl.algorithms.configs import RLAlgorithmConfig +from lerobot.rl.algorithms.factory import make_algorithm_config from lerobot.rl.algorithms.sac import SACAlgorithmConfig # noqa: F401 @@ -45,8 +46,7 @@ class TrainRLServerPipelineConfig(TrainPipelineConfig): super().validate() if self.algorithm is None: - sac_cls = RLAlgorithmConfig.get_choice_class("sac") - self.algorithm = sac_cls() + self.algorithm = make_algorithm_config("sac") # The pipeline owns the policy config; inject it so the algorithm can # introspect policy architecture (e.g. ``num_discrete_actions``).