mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
fix accel
This commit is contained in:
@@ -21,6 +21,7 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
from termcolor import colored
|
||||
from torch.amp import GradScaler
|
||||
from torch.optim import Optimizer
|
||||
@@ -176,9 +177,12 @@ def train(cfg: TrainPipelineConfig):
|
||||
# Initialize Accelerate if requested
|
||||
accelerator = None
|
||||
if cfg.use_accelerate:
|
||||
# Configure DDP to handle unused parameters
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
mixed_precision=cfg.mixed_precision,
|
||||
kwargs_handlers=[ddp_kwargs],
|
||||
)
|
||||
device = accelerator.device
|
||||
if accelerator.is_main_process:
|
||||
|
||||
Reference in New Issue
Block a user