Files
lerobot-clone/src/lerobot/utils/random_utils.py

199 lines
7.1 KiB
Python
Raw Normal View History

#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
feat(train): add accelerate for multi gpu training (#2154) * Enhance training and logging functionality with accelerator support - Added support for multi-GPU training by introducing an `accelerator` parameter in training functions. - Updated `update_policy` to handle gradient updates based on the presence of an accelerator. - Modified logging to prevent duplicate messages in non-main processes. - Enhanced `set_seed` and `get_safe_torch_device` functions to accommodate accelerator usage. - Updated `MetricsTracker` to account for the number of processes when calculating metrics. - Introduced a new feature in `pyproject.toml` for the `accelerate` library dependency. * Initialize logging in training script for both main and non-main processes - Added `init_logging` calls to ensure proper logging setup when using the accelerator and in standard training mode. - This change enhances the clarity and consistency of logging during training sessions. * add docs and only push model once * Place logging under accelerate and update docs * fix pre commit * only log in main process * main logging * try with local rank * add tests * change runner * fix test * dont push to hub in multi gpu tests * pre download dataset in tests * small fixes * fix path optimizer state * update docs, and small improvements in train * simplify accelerate main process detection * small improvements in train * fix OOM bug * change accelerate detection * add some debugging * always use accelerate * cleanup update method * cleanup * fix bug * scale lr decay if we reduce steps * cleanup logging * fix formatting * encorperate feedback pr * add min memory to cpu tests * use accelerate to determin logging * fix precommit and fix tests * chore: minor details --------- Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com> Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2025-10-16 17:41:55 +02:00
from collections.abc import Callable, Generator
from contextlib import contextmanager
from pathlib import Path
from typing import Any
import numpy as np
import torch
from safetensors.torch import load_file, save_file
from lerobot.datasets.utils import flatten_dict, unflatten_dict
2025-09-24 11:11:53 +02:00
from lerobot.utils.constants import RNG_STATE
def serialize_python_rng_state() -> dict[str, torch.Tensor]:
"""
Returns the rng state for `random` in the form of a flat dict[str, torch.Tensor] to be saved using
`safetensors.save_file()` or `torch.save()`.
"""
py_state = random.getstate()
return {
"py_rng_version": torch.tensor([py_state[0]], dtype=torch.int64),
"py_rng_state": torch.tensor(py_state[1], dtype=torch.int64),
}
def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
"""
Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`.
"""
py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None)
random.setstate(py_state)
def serialize_numpy_rng_state() -> dict[str, torch.Tensor]:
"""
Returns the rng state for `numpy` in the form of a flat dict[str, torch.Tensor] to be saved using
`safetensors.save_file()` or `torch.save()`.
"""
np_state = np.random.get_state()
# Ensure no breaking changes from numpy
assert np_state[0] == "MT19937"
return {
"np_rng_state_values": torch.tensor(np_state[1], dtype=torch.int64),
"np_rng_state_index": torch.tensor([np_state[2]], dtype=torch.int64),
"np_rng_has_gauss": torch.tensor([np_state[3]], dtype=torch.int64),
"np_rng_cached_gaussian": torch.tensor([np_state[4]], dtype=torch.float32),
}
def deserialize_numpy_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
"""
Restores the rng state for `numpy` from a dictionary produced by `serialize_numpy_rng_state()`.
"""
np_state = (
"MT19937",
rng_state_dict["np_rng_state_values"].numpy(),
rng_state_dict["np_rng_state_index"].item(),
rng_state_dict["np_rng_has_gauss"].item(),
rng_state_dict["np_rng_cached_gaussian"].item(),
)
np.random.set_state(np_state)
def serialize_torch_rng_state() -> dict[str, torch.Tensor]:
"""
Returns the rng state for `torch` in the form of a flat dict[str, torch.Tensor] to be saved using
`safetensors.save_file()` or `torch.save()`.
"""
torch_rng_state_dict = {"torch_rng_state": torch.get_rng_state()}
if torch.cuda.is_available():
torch_rng_state_dict["torch_cuda_rng_state"] = torch.cuda.get_rng_state()
return torch_rng_state_dict
def deserialize_torch_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
"""
Restores the rng state for `torch` from a dictionary produced by `serialize_torch_rng_state()`.
"""
torch.set_rng_state(rng_state_dict["torch_rng_state"])
if torch.cuda.is_available() and "torch_cuda_rng_state" in rng_state_dict:
torch.cuda.set_rng_state(rng_state_dict["torch_cuda_rng_state"])
def serialize_rng_state() -> dict[str, torch.Tensor]:
"""
Returns the rng state for `random`, `numpy`, and `torch`, in the form of a flat
dict[str, torch.Tensor] to be saved using `safetensors.save_file()` `torch.save()`.
"""
py_rng_state_dict = serialize_python_rng_state()
np_rng_state_dict = serialize_numpy_rng_state()
torch_rng_state_dict = serialize_torch_rng_state()
return {
**py_rng_state_dict,
**np_rng_state_dict,
**torch_rng_state_dict,
}
def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
"""
Restores the rng state for `random`, `numpy`, and `torch` from a dictionary produced by
`serialize_rng_state()`.
"""
py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")}
np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")}
torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")}
deserialize_python_rng_state(py_rng_state_dict)
deserialize_numpy_rng_state(np_rng_state_dict)
deserialize_torch_rng_state(torch_rng_state_dict)
def save_rng_state(save_dir: Path) -> None:
rng_state_dict = serialize_rng_state()
flat_rng_state_dict = flatten_dict(rng_state_dict)
save_file(flat_rng_state_dict, save_dir / RNG_STATE)
def load_rng_state(save_dir: Path) -> None:
flat_rng_state_dict = load_file(save_dir / RNG_STATE)
rng_state_dict = unflatten_dict(flat_rng_state_dict)
deserialize_rng_state(rng_state_dict)
def get_rng_state() -> dict[str, Any]:
"""Get the random state for `random`, `numpy`, and `torch`."""
random_state_dict = {
"random_state": random.getstate(),
"numpy_random_state": np.random.get_state(),
"torch_random_state": torch.random.get_rng_state(),
}
if torch.cuda.is_available():
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
return random_state_dict
def set_rng_state(random_state_dict: dict[str, Any]):
"""Set the random state for `random`, `numpy`, and `torch`.
Args:
random_state_dict: A dictionary of the form returned by `get_rng_state`.
"""
random.setstate(random_state_dict["random_state"])
np.random.set_state(random_state_dict["numpy_random_state"])
torch.random.set_rng_state(random_state_dict["torch_random_state"])
if torch.cuda.is_available():
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
feat(train): add accelerate for multi gpu training (#2154) * Enhance training and logging functionality with accelerator support - Added support for multi-GPU training by introducing an `accelerator` parameter in training functions. - Updated `update_policy` to handle gradient updates based on the presence of an accelerator. - Modified logging to prevent duplicate messages in non-main processes. - Enhanced `set_seed` and `get_safe_torch_device` functions to accommodate accelerator usage. - Updated `MetricsTracker` to account for the number of processes when calculating metrics. - Introduced a new feature in `pyproject.toml` for the `accelerate` library dependency. * Initialize logging in training script for both main and non-main processes - Added `init_logging` calls to ensure proper logging setup when using the accelerator and in standard training mode. - This change enhances the clarity and consistency of logging during training sessions. * add docs and only push model once * Place logging under accelerate and update docs * fix pre commit * only log in main process * main logging * try with local rank * add tests * change runner * fix test * dont push to hub in multi gpu tests * pre download dataset in tests * small fixes * fix path optimizer state * update docs, and small improvements in train * simplify accelerate main process detection * small improvements in train * fix OOM bug * change accelerate detection * add some debugging * always use accelerate * cleanup update method * cleanup * fix bug * scale lr decay if we reduce steps * cleanup logging * fix formatting * encorperate feedback pr * add min memory to cpu tests * use accelerate to determin logging * fix precommit and fix tests * chore: minor details --------- Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com> Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2025-10-16 17:41:55 +02:00
def set_seed(seed, accelerator: Callable | None = None) -> None:
"""Set seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
feat(train): add accelerate for multi gpu training (#2154) * Enhance training and logging functionality with accelerator support - Added support for multi-GPU training by introducing an `accelerator` parameter in training functions. - Updated `update_policy` to handle gradient updates based on the presence of an accelerator. - Modified logging to prevent duplicate messages in non-main processes. - Enhanced `set_seed` and `get_safe_torch_device` functions to accommodate accelerator usage. - Updated `MetricsTracker` to account for the number of processes when calculating metrics. - Introduced a new feature in `pyproject.toml` for the `accelerate` library dependency. * Initialize logging in training script for both main and non-main processes - Added `init_logging` calls to ensure proper logging setup when using the accelerator and in standard training mode. - This change enhances the clarity and consistency of logging during training sessions. * add docs and only push model once * Place logging under accelerate and update docs * fix pre commit * only log in main process * main logging * try with local rank * add tests * change runner * fix test * dont push to hub in multi gpu tests * pre download dataset in tests * small fixes * fix path optimizer state * update docs, and small improvements in train * simplify accelerate main process detection * small improvements in train * fix OOM bug * change accelerate detection * add some debugging * always use accelerate * cleanup update method * cleanup * fix bug * scale lr decay if we reduce steps * cleanup logging * fix formatting * encorperate feedback pr * add min memory to cpu tests * use accelerate to determin logging * fix precommit and fix tests * chore: minor details --------- Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com> Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2025-10-16 17:41:55 +02:00
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
feat(train): add accelerate for multi gpu training (#2154) * Enhance training and logging functionality with accelerator support - Added support for multi-GPU training by introducing an `accelerator` parameter in training functions. - Updated `update_policy` to handle gradient updates based on the presence of an accelerator. - Modified logging to prevent duplicate messages in non-main processes. - Enhanced `set_seed` and `get_safe_torch_device` functions to accommodate accelerator usage. - Updated `MetricsTracker` to account for the number of processes when calculating metrics. - Introduced a new feature in `pyproject.toml` for the `accelerate` library dependency. * Initialize logging in training script for both main and non-main processes - Added `init_logging` calls to ensure proper logging setup when using the accelerator and in standard training mode. - This change enhances the clarity and consistency of logging during training sessions. * add docs and only push model once * Place logging under accelerate and update docs * fix pre commit * only log in main process * main logging * try with local rank * add tests * change runner * fix test * dont push to hub in multi gpu tests * pre download dataset in tests * small fixes * fix path optimizer state * update docs, and small improvements in train * simplify accelerate main process detection * small improvements in train * fix OOM bug * change accelerate detection * add some debugging * always use accelerate * cleanup update method * cleanup * fix bug * scale lr decay if we reduce steps * cleanup logging * fix formatting * encorperate feedback pr * add min memory to cpu tests * use accelerate to determin logging * fix precommit and fix tests * chore: minor details --------- Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com> Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2025-10-16 17:41:55 +02:00
if accelerator:
from accelerate.utils import set_seed as _accelerate_set_seed
_accelerate_set_seed(seed)
@contextmanager
def seeded_context(seed: int) -> Generator[None, None, None]:
"""Set the seed when entering a context, and restore the prior random state at exit.
Example usage:
```
a = random.random() # produces some random number
with seeded_context(1337):
b = random.random() # produces some other random number
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
```
"""
random_state_dict = get_rng_state()
set_seed(seed)
yield None
set_rng_state(random_state_dict)