small fixes

This commit is contained in:
Pepijn
2025-10-14 10:46:19 +02:00
parent 2bc154e706
commit 6486982ab4

View File

@@ -294,7 +294,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
sampler=sampler,
pin_memory=device.type == "cuda",
drop_last=False,
prefetch_factor=2,
prefetch_factor=2 if cfg.num_workers > 0 else None,
)
if accelerator:
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
@@ -369,6 +369,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
policy=policy if not accelerator else accelerator.unwrap_model(policy),
optimizer=optimizer,
scheduler=lr_scheduler,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
update_last_checkpoint(checkpoint_dir)
if wandb_logger: