diff --git a/src/lerobot/policies/dsrl/modeling_dsrl.py b/src/lerobot/policies/dsrl/modeling_dsrl.py index 0faf3ee73..c1c431947 100644 --- a/src/lerobot/policies/dsrl/modeling_dsrl.py +++ b/src/lerobot/policies/dsrl/modeling_dsrl.py @@ -249,7 +249,7 @@ class DSRLPolicy(PreTrainedPolicy): raise ValueError(f"Unknown model type: {model}") def update_target_networks(self): - """Update target networks with exponential moving average""" + """Update target networks of the action critic with exponential moving average""" for target_param, param in zip( self.action_critic_target.parameters(), self.action_critic_ensemble.parameters(), diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index ac76baf9f..76dd0c84b 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -30,6 +30,7 @@ from lerobot.envs.configs import EnvConfig from lerobot.envs.utils import env_to_policy_features from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.policies.dsrl.configuration_dsrl import DSRLConfig from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.policies.pi05.configuration_pi05 import PI05Config @@ -58,7 +59,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: Args: name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", - "vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla". + "vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla", "dsrl". Returns: The policy class corresponding to the given name. @@ -106,6 +107,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy return SmolVLAPolicy + elif name == "dsrl": + from lerobot.policies.dsrl.modeling_dsrl import DSRLPolicy + + return DSRLPolicy else: raise NotImplementedError(f"Policy with name {name} is not implemented.") @@ -120,7 +125,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: Args: policy_type: The type of the policy. Supported types include "tdmpc", "diffusion", "act", "vqbet", "pi0", "pi0fast", "sac", "smolvla", - "reward_classifier". + "reward_classifier", "dsrl". **kwargs: Keyword arguments to be passed to the configuration class constructor. Returns: @@ -149,6 +154,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return SmolVLAConfig(**kwargs) elif policy_type == "reward_classifier": return RewardClassifierConfig(**kwargs) + elif policy_type == "dsrl": + return DSRLConfig(**kwargs) else: raise ValueError(f"Policy type '{policy_type}' is not available.") @@ -307,6 +314,13 @@ def make_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, DSRLConfig): + from lerobot.policies.dsrl.processor_dsrl import make_dsrl_pre_post_processors + + processors = make_dsrl_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) else: raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")