mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
refactor: decouple policy from algorithm
This commit is contained in:
@@ -61,8 +61,8 @@ from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.rl.algorithms import RLAlgorithm, make_algorithm
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.rl.queue import get_last_item_from_queue
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
@@ -81,6 +81,7 @@ from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.transition import (
|
||||
Transition,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
)
|
||||
from lerobot.utils.utils import (
|
||||
@@ -247,9 +248,6 @@ def act_with_policy(
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy instance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
@@ -257,8 +255,6 @@ def act_with_policy(
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=cfg.policy, algorithm_name=cfg.algorithm)
|
||||
|
||||
# Build policy pre/post processors for observation normalization and action unnormalization
|
||||
processor_kwargs = {}
|
||||
postprocessor_kwargs = {}
|
||||
@@ -324,7 +320,7 @@ def act_with_policy(
|
||||
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
action = algorithm.select_action(observation_for_inference)
|
||||
action = policy.select_action(observation_for_inference)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
# Postprocess action (unnormalization, move to cpu).
|
||||
@@ -397,7 +393,7 @@ def act_with_policy(
|
||||
if done or truncated:
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
|
||||
update_policy_parameters(algorithm=algorithm, parameters_queue=parameters_queue, device=device)
|
||||
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
push_transitions_to_transport_queue(
|
||||
@@ -695,8 +691,8 @@ def interactions_stream(
|
||||
# Policy functions
|
||||
|
||||
|
||||
def update_policy_parameters(algorithm: RLAlgorithm, parameters_queue: Queue, device):
|
||||
"""Load the latest weights from the learner via the algorithm's ``load_weights`` API."""
|
||||
def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue, device):
|
||||
"""Load the latest policy weights from the learner."""
|
||||
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
|
||||
if bytes_state_dict is not None:
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
@@ -711,7 +707,8 @@ def update_policy_parameters(algorithm: RLAlgorithm, parameters_queue: Queue, de
|
||||
# - Skip encoder params entirely when freeze_vision_encoder=True
|
||||
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
|
||||
# Load actor state dict
|
||||
algorithm.load_weights(state_dicts, device=device)
|
||||
state_dicts = move_state_dict_to_device(state_dicts, device=device)
|
||||
policy.load_state_dict(state_dicts)
|
||||
|
||||
|
||||
# Utilities functions
|
||||
|
||||
Reference in New Issue
Block a user