Enhance training and logging functionality with accelerator support

- Added support for multi-GPU training by introducing an `accelerator` parameter in training functions.
- Updated `update_policy` to handle gradient updates based on the presence of an accelerator.
- Modified logging to prevent duplicate messages in non-main processes.
- Enhanced `set_seed` and `get_safe_torch_device` functions to accommodate accelerator usage.
- Updated `MetricsTracker` to account for the number of processes when calculating metrics.
- Introduced a new feature in `pyproject.toml` for the `accelerate` library dependency.
This commit is contained in:
AdilZouitine
2025-10-02 18:11:27 +02:00
parent abde7be3b3
commit f30da2dec1
5 changed files with 137 additions and 45 deletions

View File

@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable
from typing import Any
from lerobot.utils.utils import format_big_number
@@ -84,6 +85,7 @@ class MetricsTracker:
"samples",
"episodes",
"epochs",
"accelerator",
]
def __init__(
@@ -93,6 +95,7 @@ class MetricsTracker:
num_episodes: int,
metrics: dict[str, AverageMeter],
initial_step: int = 0,
accelerator: Callable | None = None,
):
self.__dict__.update(dict.fromkeys(self.__keys__))
self._batch_size = batch_size
@@ -106,6 +109,7 @@ class MetricsTracker:
self.samples = self.steps * self._batch_size
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames
self.accelerator = accelerator
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
if name in self.__dict__:
@@ -128,7 +132,7 @@ class MetricsTracker:
Updates metrics that depend on 'step' for one step.
"""
self.steps += 1
self.samples += self._batch_size
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames

View File

@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from collections.abc import Generator
from collections.abc import Callable, Generator
from contextlib import contextmanager
from pathlib import Path
from typing import Any
@@ -164,14 +164,20 @@ def set_rng_state(random_state_dict: dict[str, Any]):
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
def set_seed(seed) -> None:
def set_seed(seed, accelerator: Callable | None = None) -> None:
"""Set seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
if accelerator:
from accelerate.utils import set_seed
set_seed(seed)
@contextmanager
def seeded_context(seed: int) -> Generator[None, None, None]:

View File

@@ -20,6 +20,7 @@ import select
import subprocess
import sys
import time
from collections.abc import Callable
from copy import copy, deepcopy
from datetime import datetime
from pathlib import Path
@@ -49,13 +50,15 @@ def auto_select_torch_device() -> torch.device:
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
def get_safe_torch_device(
try_device: str, log: bool = False, accelerator: Callable | None = None
) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available."""
try_device = str(try_device)
match try_device:
case "cuda":
assert torch.cuda.is_available()
device = torch.device("cuda")
device = accelerator.device if accelerator else torch.device("cuda")
case "mps":
assert torch.backends.mps.is_available()
device = torch.device("mps")
@@ -109,6 +112,7 @@ def init_logging(
display_pid: bool = False,
console_level: str = "INFO",
file_level: str = "DEBUG",
accelerator: Callable | None = None,
):
def custom_format(record: logging.LogRecord) -> str:
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@@ -137,6 +141,10 @@ def init_logging(
console_handler.setFormatter(formatter)
console_handler.setLevel(console_level.upper())
logger.addHandler(console_handler)
if accelerator is not None and not accelerator.is_main_process:
# Disable duplicate logging on non-main processes
logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
logging.getLogger().setLevel(logging.WARNING)
# Additionally write logs to file
if log_file is not None:
@@ -158,6 +166,10 @@ def format_big_number(num, precision=0):
return num
def is_launched_with_accelerate() -> bool:
return "ACCELERATE_MIXED_PRECISION" in os.environ
def say(text: str, blocking: bool = False):
system = platform.system()