mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
Added possibility to cache the embedding of the images when the encoder choice is pretrained and frozen
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
@@ -384,6 +384,21 @@ def add_actor_information_and_train(
|
||||
done = batch["done"]
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
# Precompute encoder features from the frozen vision encoder if enabled
|
||||
obs_features, next_obs_features = None, None
|
||||
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
|
||||
with torch.no_grad():
|
||||
obs_features = (
|
||||
policy.actor.encoder(observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
next_obs_features = (
|
||||
policy.actor.encoder(next_observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
@@ -391,6 +406,8 @@ def add_actor_information_and_train(
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
obs_features=obs_features, # pass precomputed features
|
||||
next_obs_features=next_obs_features, # for target computation
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
@@ -412,6 +429,21 @@ def add_actor_information_and_train(
|
||||
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
# Precompute encoder features from the frozen vision encoder if enabled
|
||||
obs_features, next_obs_features = None, None
|
||||
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
|
||||
with torch.no_grad():
|
||||
obs_features = (
|
||||
policy.actor.encoder(observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
next_obs_features = (
|
||||
policy.actor.encoder(next_observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
@@ -419,6 +451,8 @@ def add_actor_information_and_train(
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
obs_features=obs_features, # pass precomputed features
|
||||
next_obs_features=next_obs_features, # for target computation
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
@@ -430,7 +464,10 @@ def add_actor_information_and_train(
|
||||
if optimization_step % cfg.training.policy_update_freq == 0:
|
||||
for _ in range(cfg.training.policy_update_freq):
|
||||
with policy_lock:
|
||||
loss_actor = policy.compute_loss_actor(observations=observations)
|
||||
loss_actor = policy.compute_loss_actor(
|
||||
observations=observations,
|
||||
obs_features=obs_features, # reuse precomputed features here
|
||||
)
|
||||
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
@@ -438,7 +475,10 @@ def add_actor_information_and_train(
|
||||
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
|
||||
loss_temperature = policy.compute_loss_temperature(observations=observations)
|
||||
loss_temperature = policy.compute_loss_temperature(
|
||||
observations=observations,
|
||||
obs_features=obs_features, # and for temperature loss as well
|
||||
)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
optimizers["temperature"].step()
|
||||
@@ -582,7 +622,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
)
|
||||
# compile policy
|
||||
policy = torch.compile(policy)
|
||||
# policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
|
||||
Reference in New Issue
Block a user