mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
cdcf346061
commit
1c8daf11fd
@@ -42,7 +42,11 @@ def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> Non
|
||||
"""
|
||||
Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`.
|
||||
"""
|
||||
py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None)
|
||||
py_state = (
|
||||
rng_state_dict["py_rng_version"].item(),
|
||||
tuple(rng_state_dict["py_rng_state"].tolist()),
|
||||
None,
|
||||
)
|
||||
random.setstate(py_state)
|
||||
|
||||
|
||||
@@ -119,7 +123,9 @@ def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")}
|
||||
np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")}
|
||||
torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")}
|
||||
torch_rng_state_dict = {
|
||||
k: v for k, v in rng_state_dict.items() if k.startswith("torch")
|
||||
}
|
||||
|
||||
deserialize_python_rng_state(py_rng_state_dict)
|
||||
deserialize_numpy_rng_state(np_rng_state_dict)
|
||||
|
||||
Reference in New Issue
Block a user