""" | |
utility helpers for distributed checks | |
""" | |
from contextlib import contextmanager | |
import torch.distributed as dist | |
from accelerate import Accelerator | |
accelerate = None # pylint: disable=invalid-name | |
def load_accelerate(): | |
global accelerate # pylint: disable=global-statement | |
accelerate = Accelerator() | |
def is_distributed(): | |
""" | |
Check if distributed training is initialized. | |
""" | |
global accelerate # pylint: disable=global-statement | |
if not accelerate: | |
accelerate = Accelerator() | |
return dist.is_available() and dist.is_initialized() | |
def barrier(): | |
""" | |
Acts as a barrier to wait for all processes. This ensures that all processes | |
reach the barrier before proceeding further. | |
""" | |
if is_distributed(): | |
dist.barrier() | |
def is_main_process(): | |
""" | |
Check if the current process is the main process. | |
If not in distributed mode, always return True. | |
""" | |
if not is_distributed(): | |
return True | |
return dist.get_rank() == 0 | |
def zero_first(is_main): | |
""" | |
runs the wrapped context so that rank 0 runs first before other ranks | |
""" | |
if not is_main: # other ranks wait first | |
barrier() | |
yield | |
if is_main: # then rank 0 waits after it has run the context | |
barrier() | |