refactor: decouple policy from algorithm

This commit is contained in:
Khalil Meftah
2026-03-11 16:49:14 +01:00
parent 8d50be9faa
commit 1f5487eea8
12 changed files with 769 additions and 908 deletions

View File

@@ -61,8 +61,8 @@ from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import TransitionKey
from lerobot.rl.algorithms import RLAlgorithm, make_algorithm
from lerobot.rl.process import ProcessSignalHandler
from lerobot.rl.queue import get_last_item_from_queue
from lerobot.robots import so_follower # noqa: F401
@@ -81,6 +81,7 @@ from lerobot.utils.random_utils import set_seed
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.transition import (
Transition,
move_state_dict_to_device,
move_transition_to_device,
)
from lerobot.utils.utils import (
@@ -247,9 +248,6 @@ def act_with_policy(
logging.info("make_policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy instance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
policy = make_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
@@ -257,8 +255,6 @@ def act_with_policy(
policy = policy.eval()
assert isinstance(policy, nn.Module)
algorithm = make_algorithm(policy=policy, policy_cfg=cfg.policy, algorithm_name=cfg.algorithm)
# Build policy pre/post processors for observation normalization and action unnormalization
processor_kwargs = {}
postprocessor_kwargs = {}
@@ -324,7 +320,7 @@ def act_with_policy(
# Time policy inference and check if it meets FPS requirement
with policy_timer:
action = algorithm.select_action(observation_for_inference)
action = policy.select_action(observation_for_inference)
policy_fps = policy_timer.fps_last
# Postprocess action (unnormalization, move to cpu).
@@ -397,7 +393,7 @@ def act_with_policy(
if done or truncated:
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
update_policy_parameters(algorithm=algorithm, parameters_queue=parameters_queue, device=device)
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0:
push_transitions_to_transport_queue(
@@ -695,8 +691,8 @@ def interactions_stream(
# Policy functions
def update_policy_parameters(algorithm: RLAlgorithm, parameters_queue: Queue, device):
"""Load the latest weights from the learner via the algorithm's ``load_weights`` API."""
def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue, device):
"""Load the latest policy weights from the learner."""
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
if bytes_state_dict is not None:
logging.info("[ACTOR] Load new parameters from Learner.")
@@ -711,7 +707,8 @@ def update_policy_parameters(algorithm: RLAlgorithm, parameters_queue: Queue, de
# - Skip encoder params entirely when freeze_vision_encoder=True
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
# Load actor state dict
algorithm.load_weights(state_dicts, device=device)
state_dicts = move_state_dict_to_device(state_dicts, device=device)
policy.load_state_dict(state_dicts)
# Utilities functions