Added caching function in the learner_server and modeling sac in order to limit the number of forward passes through the pretrained encoder when its frozen.

Added tensordict dependencies
Updated the version of torch and torchvision

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-02-21 10:13:43 +00:00
parent e1d55c7a44
commit d3b84ecd6f
8 changed files with 66 additions and 42 deletions

View File

@@ -169,6 +169,25 @@ def initialize_replay_buffer(
)
def get_observation_features(policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor) -> tuple[torch.Tensor | None, torch.Tensor | None]:
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
return None, None
with torch.no_grad():
observation_features = (
policy.actor.encoder(observations)
if policy.actor.encoder is not None
else None
)
next_observation_features = (
policy.actor.encoder(next_observations)
if policy.actor.encoder is not None
else None
)
return observation_features, next_observation_features
def start_learner_threads(
cfg: DictConfig,
device: str,
@@ -345,9 +364,6 @@ def add_actor_information_and_train(
if len(replay_buffer) < cfg.training.online_step_before_learning:
continue
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
time_for_one_optimization_step = time.time()
for _ in range(cfg.policy.utd_ratio - 1):
batch = replay_buffer.sample(batch_size)
@@ -356,6 +372,7 @@ def add_actor_information_and_train(
batch_offline = offline_replay_buffer.sample(batch_size)
batch = concatenate_batch_transitions(batch, batch_offline)
actions = batch["action"]
rewards = batch["reward"]
observations = batch["state"]
@@ -365,6 +382,7 @@ def add_actor_information_and_train(
observations=observations, actions=actions, next_state=next_observations
)
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@@ -372,6 +390,8 @@ def add_actor_information_and_train(
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
@@ -395,6 +415,7 @@ def add_actor_information_and_train(
observations=observations, actions=actions, next_state=next_observations
)
observation_features, next_observation_features = get_observation_features(policy, observations, next_observations)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@@ -402,6 +423,8 @@ def add_actor_information_and_train(
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
optimizers["critic"].zero_grad()
loss_critic.backward()
@@ -413,7 +436,8 @@ def add_actor_information_and_train(
if optimization_step % cfg.training.policy_update_freq == 0:
for _ in range(cfg.training.policy_update_freq):
with policy_lock:
loss_actor = policy.compute_loss_actor(observations=observations)
loss_actor = policy.compute_loss_actor(observations=observations,
observation_features=observation_features)
optimizers["actor"].zero_grad()
loss_actor.backward()
@@ -422,7 +446,8 @@ def add_actor_information_and_train(
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(
observations=observations
observations=observations,
observation_features=observation_features
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()