[Port HIl-Serl] Refactor gym-manipulator (#1034)

This commit is contained in:
Michel Aractingi
2025-04-25 16:34:54 +02:00
committed by GitHub
parent a8da4a347e
commit bd4db8d747
13 changed files with 624 additions and 946 deletions

View File

@@ -23,8 +23,6 @@ from pathlib import Path
from pprint import pformat
import grpc
# Import generated stubs
import hilserl_pb2_grpc # type: ignore
import torch
from termcolor import colored
@@ -39,8 +37,6 @@ from lerobot.common.constants import (
TRAINING_STATE_DIR,
)
from lerobot.common.datasets.factory import make_dataset
# TODO: Remove the import of maniskill
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
@@ -62,16 +58,17 @@ from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.server import learner_service
from lerobot.scripts.server.buffer import (
ReplayBuffer,
from lerobot.scripts.server.buffer import ReplayBuffer, concatenate_batch_transitions
from lerobot.scripts.server.network_utils import (
bytes_to_python_object,
bytes_to_transitions,
concatenate_batch_transitions,
move_state_dict_to_device,
move_transition_to_device,
state_to_bytes,
)
from lerobot.scripts.server.utils import setup_process_handlers
from lerobot.scripts.server.utils import (
move_state_dict_to_device,
move_transition_to_device,
setup_process_handlers,
)
LOG_PREFIX = "[LEARNER]"
@@ -307,17 +304,10 @@ def add_actor_information_and_train(
offline_replay_buffer = None
if cfg.dataset is not None:
active_action_dims = None
# TODO: FIX THIS
if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [
i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask
]
offline_replay_buffer = initialize_offline_replay_buffer(
cfg=cfg,
device=device,
storage_device=storage_device,
active_action_dims=active_action_dims,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
@@ -342,7 +332,6 @@ def add_actor_information_and_train(
break
# Process all available transitions to the replay buffer, send by the actor server
logging.debug("[LEARNER] Waiting for transitions")
process_transitions(
transition_queue=transition_queue,
replay_buffer=replay_buffer,
@@ -351,35 +340,29 @@ def add_actor_information_and_train(
dataset_repo_id=dataset_repo_id,
shutdown_event=shutdown_event,
)
logging.debug("[LEARNER] Received transitions")
# Process all available interaction messages sent by the actor server
logging.debug("[LEARNER] Waiting for interactions")
interaction_message = process_interaction_messages(
interaction_message_queue=interaction_message_queue,
interaction_step_shift=interaction_step_shift,
wandb_logger=wandb_logger,
shutdown_event=shutdown_event,
)
logging.debug("[LEARNER] Received interactions")
# Wait until the replay buffer has enough samples to start training
if len(replay_buffer) < online_step_before_learning:
continue
if online_iterator is None:
logging.debug("[LEARNER] Initializing online replay buffer iterator")
online_iterator = replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
if offline_replay_buffer is not None and offline_iterator is None:
logging.debug("[LEARNER] Initializing offline replay buffer iterator")
offline_iterator = offline_replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
logging.debug("[LEARNER] Starting optimization loop")
time_for_one_optimization_step = time.time()
for _ in range(utd_ratio - 1):
# Sample from the iterators
@@ -967,7 +950,6 @@ def initialize_offline_replay_buffer(
cfg: TrainPipelineConfig,
device: str,
storage_device: str,
active_action_dims: list[int] | None = None,
) -> ReplayBuffer:
"""
Initialize an offline replay buffer from a dataset.
@@ -976,7 +958,6 @@ def initialize_offline_replay_buffer(
cfg (TrainPipelineConfig): Training configuration
device (str): Device to store tensors on
storage_device (str): Device for storage optimization
active_action_dims (list[int] | None): Active action dimensions for masking
Returns:
ReplayBuffer: Initialized offline replay buffer
@@ -997,7 +978,6 @@ def initialize_offline_replay_buffer(
offline_dataset,
device=device,
state_keys=cfg.policy.input_features.keys(),
action_mask=active_action_dims,
storage_device=storage_device,
optimize_memory=True,
capacity=cfg.policy.offline_buffer_capacity,
@@ -1096,44 +1076,6 @@ def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
parameters_queue.put(state_bytes)
def check_weight_gradients(module: nn.Module) -> dict[str, bool]:
"""
Checks whether each parameter in the module has a gradient.
Args:
module (nn.Module): A PyTorch module whose parameters will be inspected.
Returns:
dict[str, bool]: A dictionary where each key is the parameter name and the value is
True if the parameter has an associated gradient (i.e. .grad is not None),
otherwise False.
"""
grad_status = {}
for name, param in module.named_parameters():
grad_status[name] = param.grad is not None
return grad_status
def get_overlapping_parameters(model: nn.Module, grad_status: dict[str, bool]) -> dict[str, bool]:
"""
Returns a dictionary of parameters (from actor) that also exist in the grad_status dictionary.
Args:
actor (nn.Module): The actor model.
grad_status (dict[str, bool]): A dictionary where keys are parameter names and values indicate
whether each parameter has a gradient.
Returns:
dict[str, bool]: A dictionary containing only the overlapping parameter names and their gradient status.
"""
# Get actor parameter names as a set.
model_param_names = {name for name, _ in model.named_parameters()}
# Intersect parameter names between actor and grad_status.
overlapping = {name: grad_status[name] for name in grad_status if name in model_param_names}
return overlapping
def process_interaction_message(
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
):