mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 04:11:24 +00:00
Enhance training and logging functionality with accelerator support
- Added support for multi-GPU training by introducing an `accelerator` parameter in training functions. - Updated `update_policy` to handle gradient updates based on the presence of an accelerator. - Modified logging to prevent duplicate messages in non-main processes. - Enhanced `set_seed` and `get_safe_torch_device` functions to accommodate accelerator usage. - Updated `MetricsTracker` to account for the number of processes when calculating metrics. - Introduced a new feature in `pyproject.toml` for the `accelerate` library dependency.
This commit is contained in:
@@ -20,6 +20,7 @@ import select
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from copy import copy, deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -49,13 +50,15 @@ def auto_select_torch_device() -> torch.device:
|
||||
|
||||
|
||||
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
|
||||
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
||||
def get_safe_torch_device(
|
||||
try_device: str, log: bool = False, accelerator: Callable | None = None
|
||||
) -> torch.device:
|
||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||
try_device = str(try_device)
|
||||
match try_device:
|
||||
case "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
device = torch.device("cuda")
|
||||
device = accelerator.device if accelerator else torch.device("cuda")
|
||||
case "mps":
|
||||
assert torch.backends.mps.is_available()
|
||||
device = torch.device("mps")
|
||||
@@ -109,6 +112,7 @@ def init_logging(
|
||||
display_pid: bool = False,
|
||||
console_level: str = "INFO",
|
||||
file_level: str = "DEBUG",
|
||||
accelerator: Callable | None = None,
|
||||
):
|
||||
def custom_format(record: logging.LogRecord) -> str:
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -137,6 +141,10 @@ def init_logging(
|
||||
console_handler.setFormatter(formatter)
|
||||
console_handler.setLevel(console_level.upper())
|
||||
logger.addHandler(console_handler)
|
||||
if accelerator is not None and not accelerator.is_main_process:
|
||||
# Disable duplicate logging on non-main processes
|
||||
logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
|
||||
# Additionally write logs to file
|
||||
if log_file is not None:
|
||||
@@ -158,6 +166,10 @@ def format_big_number(num, precision=0):
|
||||
return num
|
||||
|
||||
|
||||
def is_launched_with_accelerate() -> bool:
|
||||
return "ACCELERATE_MIXED_PRECISION" in os.environ
|
||||
|
||||
|
||||
def say(text: str, blocking: bool = False):
|
||||
system = platform.system()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user