mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 03:11:29 +00:00
feat: Enable torch.compile for DiffusionPolicy inference (#2486)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -182,6 +182,11 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
|
||||
if config.compile_model:
|
||||
# Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops
|
||||
# common in diffusion inference.
|
||||
self.unet = torch.compile(self.unet, mode=config.compile_mode)
|
||||
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
num_train_timesteps=config.num_train_timesteps,
|
||||
|
||||
Reference in New Issue
Block a user