fix accel

This commit is contained in:
Pepijn
2025-09-23 22:32:22 +02:00
parent b794fc3c70
commit fc7998a3d5

View File

@@ -21,6 +21,7 @@ from typing import Any
import torch
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer
@@ -176,9 +177,12 @@ def train(cfg: TrainPipelineConfig):
# Initialize Accelerate if requested
accelerator = None
if cfg.use_accelerate:
# Configure DDP to handle unused parameters
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
mixed_precision=cfg.mixed_precision,
kwargs_handlers=[ddp_kwargs],
)
device = accelerator.device
if accelerator.is_main_process: