feat: Enable torch.compile for DiffusionPolicy inference (#2486)

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Jash Shah
2026-02-24 08:29:08 -08:00
committed by GitHub
parent 7fd71c83a3
commit dac1efd13d
2 changed files with 9 additions and 0 deletions

View File

@@ -182,6 +182,11 @@ class DiffusionModel(nn.Module):
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
if config.compile_model:
# Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops
# common in diffusion inference.
self.unet = torch.compile(self.unet, mode=config.compile_mode)
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
num_train_timesteps=config.num_train_timesteps,