mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 03:41:25 +00:00
feat(train): add accelerate for multi gpu training (#2154)
* Enhance training and logging functionality with accelerator support - Added support for multi-GPU training by introducing an `accelerator` parameter in training functions. - Updated `update_policy` to handle gradient updates based on the presence of an accelerator. - Modified logging to prevent duplicate messages in non-main processes. - Enhanced `set_seed` and `get_safe_torch_device` functions to accommodate accelerator usage. - Updated `MetricsTracker` to account for the number of processes when calculating metrics. - Introduced a new feature in `pyproject.toml` for the `accelerate` library dependency. * Initialize logging in training script for both main and non-main processes - Added `init_logging` calls to ensure proper logging setup when using the accelerator and in standard training mode. - This change enhances the clarity and consistency of logging during training sessions. * add docs and only push model once * Place logging under accelerate and update docs * fix pre commit * only log in main process * main logging * try with local rank * add tests * change runner * fix test * dont push to hub in multi gpu tests * pre download dataset in tests * small fixes * fix path optimizer state * update docs, and small improvements in train * simplify accelerate main process detection * small improvements in train * fix OOM bug * change accelerate detection * add some debugging * always use accelerate * cleanup update method * cleanup * fix bug * scale lr decay if we reduce steps * cleanup logging * fix formatting * encorperate feedback pr * add min memory to cpu tests * use accelerate to determin logging * fix precommit and fix tests * chore: minor details --------- Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com> Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -164,14 +164,20 @@ def set_rng_state(random_state_dict: dict[str, Any]):
|
||||
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
||||
|
||||
|
||||
def set_seed(seed) -> None:
|
||||
def set_seed(seed, accelerator: Callable | None = None) -> None:
|
||||
"""Set seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
if accelerator:
|
||||
from accelerate.utils import set_seed as _accelerate_set_seed
|
||||
|
||||
_accelerate_set_seed(seed)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||
|
||||
Reference in New Issue
Block a user