mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 11:21:27 +00:00
[Port HIl-Serl] Refactor gym-manipulator (#1034)
This commit is contained in:
@@ -23,8 +23,6 @@ from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import grpc
|
||||
|
||||
# Import generated stubs
|
||||
import hilserl_pb2_grpc # type: ignore
|
||||
import torch
|
||||
from termcolor import colored
|
||||
@@ -39,8 +37,6 @@ from lerobot.common.constants import (
|
||||
TRAINING_STATE_DIR,
|
||||
)
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
@@ -62,16 +58,17 @@ from lerobot.common.utils.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.server import learner_service
|
||||
from lerobot.scripts.server.buffer import (
|
||||
ReplayBuffer,
|
||||
from lerobot.scripts.server.buffer import ReplayBuffer, concatenate_batch_transitions
|
||||
from lerobot.scripts.server.network_utils import (
|
||||
bytes_to_python_object,
|
||||
bytes_to_transitions,
|
||||
concatenate_batch_transitions,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
state_to_bytes,
|
||||
)
|
||||
from lerobot.scripts.server.utils import setup_process_handlers
|
||||
from lerobot.scripts.server.utils import (
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
setup_process_handlers,
|
||||
)
|
||||
|
||||
LOG_PREFIX = "[LEARNER]"
|
||||
|
||||
@@ -307,17 +304,10 @@ def add_actor_information_and_train(
|
||||
offline_replay_buffer = None
|
||||
|
||||
if cfg.dataset is not None:
|
||||
active_action_dims = None
|
||||
# TODO: FIX THIS
|
||||
if cfg.env.wrapper.joint_masking_action_space is not None:
|
||||
active_action_dims = [
|
||||
i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask
|
||||
]
|
||||
offline_replay_buffer = initialize_offline_replay_buffer(
|
||||
cfg=cfg,
|
||||
device=device,
|
||||
storage_device=storage_device,
|
||||
active_action_dims=active_action_dims,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
@@ -342,7 +332,6 @@ def add_actor_information_and_train(
|
||||
break
|
||||
|
||||
# Process all available transitions to the replay buffer, send by the actor server
|
||||
logging.debug("[LEARNER] Waiting for transitions")
|
||||
process_transitions(
|
||||
transition_queue=transition_queue,
|
||||
replay_buffer=replay_buffer,
|
||||
@@ -351,35 +340,29 @@ def add_actor_information_and_train(
|
||||
dataset_repo_id=dataset_repo_id,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
logging.debug("[LEARNER] Received transitions")
|
||||
|
||||
# Process all available interaction messages sent by the actor server
|
||||
logging.debug("[LEARNER] Waiting for interactions")
|
||||
interaction_message = process_interaction_messages(
|
||||
interaction_message_queue=interaction_message_queue,
|
||||
interaction_step_shift=interaction_step_shift,
|
||||
wandb_logger=wandb_logger,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
logging.debug("[LEARNER] Received interactions")
|
||||
|
||||
# Wait until the replay buffer has enough samples to start training
|
||||
if len(replay_buffer) < online_step_before_learning:
|
||||
continue
|
||||
|
||||
if online_iterator is None:
|
||||
logging.debug("[LEARNER] Initializing online replay buffer iterator")
|
||||
online_iterator = replay_buffer.get_iterator(
|
||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||
)
|
||||
|
||||
if offline_replay_buffer is not None and offline_iterator is None:
|
||||
logging.debug("[LEARNER] Initializing offline replay buffer iterator")
|
||||
offline_iterator = offline_replay_buffer.get_iterator(
|
||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Starting optimization loop")
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(utd_ratio - 1):
|
||||
# Sample from the iterators
|
||||
@@ -967,7 +950,6 @@ def initialize_offline_replay_buffer(
|
||||
cfg: TrainPipelineConfig,
|
||||
device: str,
|
||||
storage_device: str,
|
||||
active_action_dims: list[int] | None = None,
|
||||
) -> ReplayBuffer:
|
||||
"""
|
||||
Initialize an offline replay buffer from a dataset.
|
||||
@@ -976,7 +958,6 @@ def initialize_offline_replay_buffer(
|
||||
cfg (TrainPipelineConfig): Training configuration
|
||||
device (str): Device to store tensors on
|
||||
storage_device (str): Device for storage optimization
|
||||
active_action_dims (list[int] | None): Active action dimensions for masking
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: Initialized offline replay buffer
|
||||
@@ -997,7 +978,6 @@ def initialize_offline_replay_buffer(
|
||||
offline_dataset,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_features.keys(),
|
||||
action_mask=active_action_dims,
|
||||
storage_device=storage_device,
|
||||
optimize_memory=True,
|
||||
capacity=cfg.policy.offline_buffer_capacity,
|
||||
@@ -1096,44 +1076,6 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||
parameters_queue.put(state_bytes)
|
||||
|
||||
|
||||
def check_weight_gradients(module: nn.Module) -> dict[str, bool]:
|
||||
"""
|
||||
Checks whether each parameter in the module has a gradient.
|
||||
|
||||
Args:
|
||||
module (nn.Module): A PyTorch module whose parameters will be inspected.
|
||||
|
||||
Returns:
|
||||
dict[str, bool]: A dictionary where each key is the parameter name and the value is
|
||||
True if the parameter has an associated gradient (i.e. .grad is not None),
|
||||
otherwise False.
|
||||
"""
|
||||
grad_status = {}
|
||||
for name, param in module.named_parameters():
|
||||
grad_status[name] = param.grad is not None
|
||||
return grad_status
|
||||
|
||||
|
||||
def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]:
|
||||
"""
|
||||
Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary.
|
||||
|
||||
Args:
|
||||
actor (nn.Module): The actor model.
|
||||
grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate
|
||||
whether each parameter has a gradient.
|
||||
|
||||
Returns:
|
||||
dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status.
|
||||
"""
|
||||
# Get actor parameter names as a set.
|
||||
model_param_names = {name for name, _ in model.named_parameters()}
|
||||
|
||||
# Intersect parameter names between actor and grad_status.
|
||||
overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names}
|
||||
return overlapping
|
||||
|
||||
|
||||
def process_interaction_message(
|
||||
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user