optimzize data loading

This commit is contained in:
Pepijn
2025-08-30 15:40:36 +02:00
parent 599218fe9a
commit 0b5da92a58
3 changed files with 15 additions and 6 deletions

View File

@@ -14,12 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import time
from contextlib import nullcontext
from pprint import pformat
from typing import Any
import torch
# Fix tokenizer parallelism conflicts with multiprocessing
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer
@@ -240,6 +244,8 @@ def train(cfg: TrainPipelineConfig):
sampler=sampler,
pin_memory=device.type == "cuda",
drop_last=False,
persistent_workers=cfg.num_workers > 0, # Keep workers alive
prefetch_factor=2 if cfg.num_workers > 0 else None, # Prefetch batches
)
dl_iter = cycle(dataloader)