Sub threading for multiprocessing

This commit is contained in:
Michel Aractingi
2025-02-20 17:21:55 +00:00
parent ff47c0b0d3
commit a9e912a05c
5 changed files with 870 additions and 36 deletions

View File

@@ -36,6 +36,9 @@ from termcolor import colored
from torch import nn
from torch.optim.optimizer import Optimizer
# For profiling only
import datetime
from lerobot.common.datasets.factory import make_dataset
# TODO: Remove the import of maniskill
@@ -262,15 +265,15 @@ def learner_push_parameters(
while True:
with policy_lock:
params_dict = policy.actor.state_dict()
if policy.config.vision_encoder_name is not None:
if policy.config.freeze_vision_encoder:
params_dict: dict[str, torch.Tensor] = {
k: v for k, v in params_dict.items() if not k.startswith("encoder.")
}
else:
raise NotImplementedError(
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
)
# if policy.config.vision_encoder_name is not None:
# if policy.config.freeze_vision_encoder:
# params_dict: dict[str, torch.Tensor] = {
# k: v for k, v in params_dict.items() if not k.startswith("encoder.")
# }
# else:
# raise NotImplementedError(
# "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
# )
params_dict = move_state_dict_to_device(params_dict, device="cpu")
# Serialize
@@ -347,6 +350,7 @@ def add_actor_information_and_train(
interaction_message, transition = None, None
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
while True:
while not transition_queue.empty():
transition_list = transition_queue.get()
@@ -370,6 +374,7 @@ def add_actor_information_and_train(
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
image_features, next_image_features = None, None
time_for_one_optimization_step = time.time()
for _ in range(cfg.policy.utd_ratio - 1):
batch = replay_buffer.sample(batch_size)
@@ -385,6 +390,21 @@ def add_actor_information_and_train(
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
# Precompute encoder features from the frozen vision encoder if enabled
with record_function("encoder_forward"):
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
with torch.no_grad():
image_features = (
policy.actor.encoder(observations)
if policy.actor.encoder is not None
else None
)
next_image_features = (
policy.actor.encoder(next_observations)
if policy.actor.encoder is not None
else None
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@@ -392,6 +412,8 @@ def add_actor_information_and_train(
rewards=rewards,
next_observations=next_observations,
done=done,
image_features=image_features,
next_image_features=next_image_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
@@ -413,6 +435,19 @@ def add_actor_information_and_train(
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
# Precompute encoder features from the frozen vision encoder if enabled
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
with torch.no_grad():
image_features = (
policy.actor.encoder(observations)
if policy.actor.encoder is not None
else None
)
next_image_features = (
policy.actor.encoder(next_observations)
if policy.actor.encoder is not None
else None
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@@ -420,6 +455,8 @@ def add_actor_information_and_train(
rewards=rewards,
next_observations=next_observations,
done=done,
image_features=image_features,
next_image_features=next_image_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
@@ -431,7 +468,7 @@ def add_actor_information_and_train(
if optimization_step % cfg.training.policy_update_freq == 0:
for _ in range(cfg.training.policy_update_freq):
with policy_lock:
loss_actor = policy.compute_loss_actor(observations=observations)
loss_actor = policy.compute_loss_actor(observations=observations, image_features=image_features)
optimizers["actor"].zero_grad()
loss_actor.backward()
@@ -439,7 +476,7 @@ def add_actor_information_and_train(
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(observations=observations)
loss_temperature = policy.compute_loss_temperature(observations=observations, image_features=image_features)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
@@ -503,6 +540,12 @@ def add_actor_information_and_train(
logging.info("Resume training")
profiler.step()
if optimization_step >= 50: # Profile for 500 steps
profiler.stop()
break
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
"""
@@ -583,7 +626,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
# compile policy
policy = torch.compile(policy)
# policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)