update factory with dsrl

This commit is contained in:
Michel Aractingi
2025-10-13 16:12:39 +02:00
parent 5c9bfd57ec
commit 7cd710857d
2 changed files with 17 additions and 3 deletions

View File

@@ -249,7 +249,7 @@ class DSRLPolicy(PreTrainedPolicy):
raise ValueError(f"Unknown model type: {model}")
def update_target_networks(self):
"""Update target networks with exponential moving average"""
"""Update target networks of the action critic with exponential moving average"""
for target_param, param in zip(
self.action_critic_target.parameters(),
self.action_critic_ensemble.parameters(),

View File

@@ -30,6 +30,7 @@ from lerobot.envs.configs import EnvConfig
from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.dsrl.configuration_dsrl import DSRLConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pi05.configuration_pi05 import PI05Config
@@ -58,7 +59,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla".
"vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla", "dsrl".
Returns:
The policy class corresponding to the given name.
@@ -106,6 +107,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "dsrl":
from lerobot.policies.dsrl.modeling_dsrl import DSRLPolicy
return DSRLPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -120,7 +125,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"diffusion", "act", "vqbet", "pi0", "pi0fast", "sac", "smolvla",
"reward_classifier".
"reward_classifier", "dsrl".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -149,6 +154,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return SmolVLAConfig(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
elif policy_type == "dsrl":
return DSRLConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
@@ -307,6 +314,13 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, DSRLConfig):
from lerobot.policies.dsrl.processor_dsrl import make_dsrl_pre_post_processors
processors = make_dsrl_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")