mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
optimzize data loading
This commit is contained in:
@@ -14,12 +14,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
# Fix tokenizer parallelism conflicts with multiprocessing
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
from termcolor import colored
|
||||
from torch.amp import GradScaler
|
||||
from torch.optim import Optimizer
|
||||
@@ -240,6 +244,8 @@ def train(cfg: TrainPipelineConfig):
|
||||
sampler=sampler,
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=False,
|
||||
persistent_workers=cfg.num_workers > 0, # Keep workers alive
|
||||
prefetch_factor=2 if cfg.num_workers > 0 else None, # Prefetch batches
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user