mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 04:11:24 +00:00
small fixes
This commit is contained in:
@@ -294,7 +294,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
sampler=sampler,
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=False,
|
||||
prefetch_factor=2,
|
||||
prefetch_factor=2 if cfg.num_workers > 0 else None,
|
||||
)
|
||||
if accelerator:
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
@@ -369,6 +369,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
policy=policy if not accelerator else accelerator.unwrap_model(policy),
|
||||
optimizer=optimizer,
|
||||
scheduler=lr_scheduler,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
|
||||
Reference in New Issue
Block a user