mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 11:51:25 +00:00
Sub threading for multiprocessing
This commit is contained in:
@@ -36,6 +36,9 @@ from termcolor import colored
|
||||
from torch import nn
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
# For profiling only
|
||||
import datetime
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
@@ -262,15 +265,15 @@ def learner_push_parameters(
|
||||
while True:
|
||||
with policy_lock:
|
||||
params_dict = policy.actor.state_dict()
|
||||
if policy.config.vision_encoder_name is not None:
|
||||
if policy.config.freeze_vision_encoder:
|
||||
params_dict: dict[str, torch.Tensor] = {
|
||||
k: v for k, v in params_dict.items() if not k.startswith("encoder.")
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
||||
)
|
||||
# if policy.config.vision_encoder_name is not None:
|
||||
# if policy.config.freeze_vision_encoder:
|
||||
# params_dict: dict[str, torch.Tensor] = {
|
||||
# k: v for k, v in params_dict.items() if not k.startswith("encoder.")
|
||||
# }
|
||||
# else:
|
||||
# raise NotImplementedError(
|
||||
# "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
||||
# )
|
||||
|
||||
params_dict = move_state_dict_to_device(params_dict, device="cpu")
|
||||
# Serialize
|
||||
@@ -347,6 +350,7 @@ def add_actor_information_and_train(
|
||||
interaction_message, transition = None, None
|
||||
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
|
||||
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
||||
|
||||
while True:
|
||||
while not transition_queue.empty():
|
||||
transition_list = transition_queue.get()
|
||||
@@ -370,6 +374,7 @@ def add_actor_information_and_train(
|
||||
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
|
||||
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
|
||||
|
||||
image_features, next_image_features = None, None
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(cfg.policy.utd_ratio - 1):
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
@@ -385,6 +390,21 @@ def add_actor_information_and_train(
|
||||
done = batch["done"]
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
# Precompute encoder features from the frozen vision encoder if enabled
|
||||
with record_function("encoder_forward"):
|
||||
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
|
||||
with torch.no_grad():
|
||||
image_features = (
|
||||
policy.actor.encoder(observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
next_image_features = (
|
||||
policy.actor.encoder(next_observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
@@ -392,6 +412,8 @@ def add_actor_information_and_train(
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
image_features=image_features,
|
||||
next_image_features=next_image_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
@@ -413,6 +435,19 @@ def add_actor_information_and_train(
|
||||
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
# Precompute encoder features from the frozen vision encoder if enabled
|
||||
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
|
||||
with torch.no_grad():
|
||||
image_features = (
|
||||
policy.actor.encoder(observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
next_image_features = (
|
||||
policy.actor.encoder(next_observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
@@ -420,6 +455,8 @@ def add_actor_information_and_train(
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
image_features=image_features,
|
||||
next_image_features=next_image_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
@@ -431,7 +468,7 @@ def add_actor_information_and_train(
|
||||
if optimization_step % cfg.training.policy_update_freq == 0:
|
||||
for _ in range(cfg.training.policy_update_freq):
|
||||
with policy_lock:
|
||||
loss_actor = policy.compute_loss_actor(observations=observations)
|
||||
loss_actor = policy.compute_loss_actor(observations=observations, image_features=image_features)
|
||||
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
@@ -439,7 +476,7 @@ def add_actor_information_and_train(
|
||||
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
|
||||
loss_temperature = policy.compute_loss_temperature(observations=observations)
|
||||
loss_temperature = policy.compute_loss_temperature(observations=observations, image_features=image_features)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
optimizers["temperature"].step()
|
||||
@@ -503,6 +540,12 @@ def add_actor_information_and_train(
|
||||
|
||||
logging.info("Resume training")
|
||||
|
||||
profiler.step()
|
||||
|
||||
if optimization_step >= 50: # Profile for 500 steps
|
||||
profiler.stop()
|
||||
break
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
"""
|
||||
@@ -583,7 +626,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
)
|
||||
# compile policy
|
||||
policy = torch.compile(policy)
|
||||
# policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
|
||||
Reference in New Issue
Block a user