Added possibility to cache the embedding of the images when the encoder choice is pretrained and frozen

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-02-18 08:28:13 +00:00
parent befa1fe9af
commit 8469d13681
3 changed files with 74 additions and 32 deletions

View File

@@ -384,6 +384,21 @@ def add_actor_information_and_train(
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
# Precompute encoder features from the frozen vision encoder if enabled
obs_features, next_obs_features = None, None
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
with torch.no_grad():
obs_features = (
policy.actor.encoder(observations)
if policy.actor.encoder is not None
else None
)
next_obs_features = (
policy.actor.encoder(next_observations)
if policy.actor.encoder is not None
else None
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@@ -391,6 +406,8 @@ def add_actor_information_and_train(
rewards=rewards,
next_observations=next_observations,
done=done,
obs_features=obs_features, # pass precomputed features
next_obs_features=next_obs_features, # for target computation
)
optimizers["critic"].zero_grad()
loss_critic.backward()
@@ -412,6 +429,21 @@ def add_actor_information_and_train(
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
# Precompute encoder features from the frozen vision encoder if enabled
obs_features, next_obs_features = None, None
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
with torch.no_grad():
obs_features = (
policy.actor.encoder(observations)
if policy.actor.encoder is not None
else None
)
next_obs_features = (
policy.actor.encoder(next_observations)
if policy.actor.encoder is not None
else None
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
observations=observations,
@@ -419,6 +451,8 @@ def add_actor_information_and_train(
rewards=rewards,
next_observations=next_observations,
done=done,
obs_features=obs_features, # pass precomputed features
next_obs_features=next_obs_features, # for target computation
)
optimizers["critic"].zero_grad()
loss_critic.backward()
@@ -430,7 +464,10 @@ def add_actor_information_and_train(
if optimization_step % cfg.training.policy_update_freq == 0:
for _ in range(cfg.training.policy_update_freq):
with policy_lock:
loss_actor = policy.compute_loss_actor(observations=observations)
loss_actor = policy.compute_loss_actor(
observations=observations,
obs_features=obs_features, # reuse precomputed features here
)
optimizers["actor"].zero_grad()
loss_actor.backward()
@@ -438,7 +475,10 @@ def add_actor_information_and_train(
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(observations=observations)
loss_temperature = policy.compute_loss_temperature(
observations=observations,
obs_features=obs_features, # and for temperature loss as well
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
@@ -582,7 +622,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
# compile policy
policy = torch.compile(policy)
# policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)