mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
Compare commits
1 Commits
codex/fix-
...
fix/force_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
74b7cd246e |
@@ -263,7 +263,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||||
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
||||||
# Force the device to be CPU when policy.device is set to CPU.
|
# Force the device to be CPU when policy.device is set to CPU.
|
||||||
force_cpu = cfg.policy.device == "cpu"
|
# Note (maractin): cfg.policy may be None before validate() fully loads from pretrained_path
|
||||||
|
force_cpu = cfg.policy is not None and cfg.policy.device == "cpu"
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(
|
||||||
step_scheduler_with_optimizer=False,
|
step_scheduler_with_optimizer=False,
|
||||||
kwargs_handlers=[ddp_kwargs],
|
kwargs_handlers=[ddp_kwargs],
|
||||||
|
|||||||
Reference in New Issue
Block a user