mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
refactor(rl): add make_algorithm_config function for RLAlgorithmConfig instantiation
This commit is contained in:
@@ -24,7 +24,10 @@ require_package("grpcio", extra="hilserl", import_name="grpc")
|
||||
|
||||
from .algorithms.base import RLAlgorithm as RLAlgorithm
|
||||
from .algorithms.configs import RLAlgorithmConfig as RLAlgorithmConfig, TrainingStats as TrainingStats
|
||||
from .algorithms.factory import make_algorithm as make_algorithm
|
||||
from .algorithms.factory import (
|
||||
make_algorithm as make_algorithm,
|
||||
make_algorithm_config as make_algorithm_config,
|
||||
)
|
||||
from .algorithms.sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig
|
||||
from .buffer import ReplayBuffer as ReplayBuffer
|
||||
from .data_sources import DataMixer as DataMixer, OnlineOfflineMixer as OnlineOfflineMixer
|
||||
@@ -35,6 +38,7 @@ __all__ = [
|
||||
"RLAlgorithmConfig",
|
||||
"TrainingStats",
|
||||
"make_algorithm",
|
||||
"make_algorithm_config",
|
||||
"SACAlgorithm",
|
||||
"SACAlgorithmConfig",
|
||||
"RLTrainer",
|
||||
|
||||
@@ -20,5 +20,28 @@ from lerobot.rl.algorithms.base import RLAlgorithm
|
||||
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
|
||||
|
||||
|
||||
def make_algorithm_config(algorithm_type: str, **kwargs) -> RLAlgorithmConfig:
|
||||
"""Instantiate an :class:`RLAlgorithmConfig` from its registered type name.
|
||||
|
||||
Args:
|
||||
algorithm_type: Registry key of the algorithm (e.g. ``"sac"``).
|
||||
**kwargs: Keyword arguments forwarded to the config class constructor.
|
||||
|
||||
Returns:
|
||||
An instance of the matching ``RLAlgorithmConfig`` subclass.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``algorithm_type`` is not registered.
|
||||
"""
|
||||
try:
|
||||
cls = RLAlgorithmConfig.get_choice_class(algorithm_type)
|
||||
except KeyError as err:
|
||||
raise ValueError(
|
||||
f"Algorithm type '{algorithm_type}' is not registered. "
|
||||
f"Available: {list(RLAlgorithmConfig.get_known_choices().keys())}"
|
||||
) from err
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
def make_algorithm(cfg: RLAlgorithmConfig, policy: torch.nn.Module) -> RLAlgorithm:
|
||||
return cfg.build_algorithm(policy)
|
||||
|
||||
@@ -21,6 +21,7 @@ from dataclasses import dataclass
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
|
||||
from lerobot.rl.algorithms.factory import make_algorithm_config
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithmConfig # noqa: F401
|
||||
|
||||
|
||||
@@ -45,8 +46,7 @@ class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
super().validate()
|
||||
|
||||
if self.algorithm is None:
|
||||
sac_cls = RLAlgorithmConfig.get_choice_class("sac")
|
||||
self.algorithm = sac_cls()
|
||||
self.algorithm = make_algorithm_config("sac")
|
||||
|
||||
# The pipeline owns the policy config; inject it so the algorithm can
|
||||
# introspect policy architecture (e.g. ``num_discrete_actions``).
|
||||
|
||||
Reference in New Issue
Block a user