fix(mps): gradient exploding and nan loss issues with ACT (#1490)

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
Adil Zouitine
2025-07-15 10:28:19 +02:00
committed by GitHub
parent 519b76110e
commit 91b110d806
2 changed files with 8 additions and 11 deletions

View File

@@ -180,7 +180,7 @@ def train(cfg: TrainPipelineConfig):
batch_size=cfg.batch_size,
shuffle=shuffle,
sampler=sampler,
pin_memory=device.type != "cpu",
pin_memory=device.type == "cuda",
drop_last=False,
)
dl_iter = cycle(dataloader)
@@ -207,7 +207,7 @@ def train(cfg: TrainPipelineConfig):
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device, non_blocking=True)
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
train_tracker, output_dict = update_policy(
train_tracker,