Spaces:
Runtime error
Runtime error
# copy/pasted from pytorch lightning | |
# https://github.com/Lightning-AI/lightning/blob/0d52f4577310b5a1624bed4d23d49e37fb05af9e/src/lightning_fabric/utilities/seed.py | |
# and | |
# https://github.com/Lightning-AI/lightning/blob/98f7696d1681974d34fad59c03b4b58d9524ed13/src/pytorch_lightning/utilities/seed.py | |
# Copyright The Lightning team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from contextlib import contextmanager | |
from typing import Generator, Dict, Any | |
import torch | |
import numpy as np | |
from random import getstate as python_get_rng_state | |
from random import setstate as python_set_rng_state | |
def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]: | |
"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python.""" | |
states = { | |
"torch": torch.get_rng_state(), | |
"numpy": np.random.get_state(), | |
"python": python_get_rng_state(), | |
} | |
if include_cuda: | |
states["torch.cuda"] = torch.cuda.get_rng_state_all() | |
return states | |
def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None: | |
"""Set the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python in the current | |
process.""" | |
torch.set_rng_state(rng_state_dict["torch"]) | |
# torch.cuda rng_state is only included since v1.8. | |
if "torch.cuda" in rng_state_dict: | |
torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"]) | |
np.random.set_state(rng_state_dict["numpy"]) | |
version, state, gauss = rng_state_dict["python"] | |
python_set_rng_state((version, tuple(state), gauss)) | |
def isolate_rng(include_cuda: bool = True) -> Generator[None, None, None]: | |
"""A context manager that resets the global random state on exit to what it was before entering. | |
It supports isolating the states for PyTorch, Numpy, and Python built-in random number generators. | |
Args: | |
include_cuda: Whether to allow this function to also control the `torch.cuda` random number generator. | |
Set this to ``False`` when using the function in a forked process where CUDA re-initialization is | |
prohibited. | |
Example: | |
>>> import torch | |
>>> torch.manual_seed(1) # doctest: +ELLIPSIS | |
<torch._C.Generator object at ...> | |
>>> with isolate_rng(): | |
... [torch.rand(1) for _ in range(3)] | |
[tensor([0.7576]), tensor([0.2793]), tensor([0.4031])] | |
>>> torch.rand(1) | |
tensor([0.7576]) | |
""" | |
states = _collect_rng_states(include_cuda) | |
yield | |
_set_rng_states(states) |