feat(train): add accelerate for multi gpu training (#2154)

* Enhance training and logging functionality with accelerator support

- Added support for multi-GPU training by introducing an `accelerator` parameter in training functions.
- Updated `update_policy` to handle gradient updates based on the presence of an accelerator.
- Modified logging to prevent duplicate messages in non-main processes.
- Enhanced `set_seed` and `get_safe_torch_device` functions to accommodate accelerator usage.
- Updated `MetricsTracker` to account for the number of processes when calculating metrics.
- Introduced a new feature in `pyproject.toml` for the `accelerate` library dependency.

* Initialize logging in training script for both main and non-main processes

- Added `init_logging` calls to ensure proper logging setup when using the accelerator and in standard training mode.
- This change enhances the clarity and consistency of logging during training sessions.

* add docs and only push model once

* Place  logging under accelerate and update docs

* fix pre commit

* only log in main process

* main logging

* try with local rank

* add tests

* change runner

* fix test

* dont push to hub in multi gpu tests

* pre download dataset in tests

* small fixes

* fix path optimizer state

* update docs, and small improvements in train

* simplify accelerate main process detection

* small improvements in train

* fix OOM bug

* change accelerate detection

* add some debugging

* always use accelerate

* cleanup update method

* cleanup

* fix bug

* scale lr decay if we reduce steps

* cleanup logging

* fix formatting

* encorperate feedback pr

* add min memory to cpu tests

* use accelerate to determin logging

* fix precommit and fix tests

* chore: minor details

---------

Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
Pepijn
2025-10-16 17:41:55 +02:00
committed by GitHub
parent 845b359d39
commit e82e7a02e9
13 changed files with 625 additions and 134 deletions

View File

@@ -27,6 +27,7 @@ from statistics import mean
import numpy as np
import torch
from accelerate import Accelerator
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
@@ -110,36 +111,50 @@ def init_logging(
display_pid: bool = False,
console_level: str = "INFO",
file_level: str = "DEBUG",
accelerator: Accelerator | None = None,
):
"""Initialize logging configuration for LeRobot.
In multi-GPU training, only the main process logs to console to avoid duplicate output.
Non-main processes have console logging suppressed but can still log to file.
Args:
log_file: Optional file path to write logs to
display_pid: Include process ID in log messages (useful for debugging multi-process)
console_level: Logging level for console output
file_level: Logging level for file output
accelerator: Optional Accelerator instance (for multi-GPU detection)
"""
def custom_format(record: logging.LogRecord) -> str:
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}"
# NOTE: Display PID is useful for multi-process logging.
if display_pid:
pid_str = f"[PID: {os.getpid()}]"
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
else:
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.getMessage()}"
return message
pid_str = f"[PID: {os.getpid()}] " if display_pid else ""
return f"{record.levelname} {pid_str}{dt} {fnameline[-15:]:>15} {record.getMessage()}"
formatter = logging.Formatter()
formatter.format = custom_format
logger = logging.getLogger()
logger.setLevel(logging.NOTSET) # Set the logger to the lowest level to capture all messages
logger.setLevel(logging.NOTSET)
# Remove unused default handlers
for handler in logger.handlers[:]:
logger.removeHandler(handler)
# Clear any existing handlers
logger.handlers.clear()
# Write logs to console
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
console_handler.setLevel(console_level.upper())
logger.addHandler(console_handler)
# Determine if this is a non-main process in distributed training
is_main_process = accelerator.is_main_process if accelerator is not None else True
# Console logging (main process only)
if is_main_process:
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
console_handler.setLevel(console_level.upper())
logger.addHandler(console_handler)
else:
# Suppress console output for non-main processes
logger.addHandler(logging.NullHandler())
logger.setLevel(logging.ERROR)
# Additionally write logs to file
if log_file is not None:
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)