|
""" |
|
utility helpers for distributed checks |
|
""" |
|
import os |
|
from contextlib import contextmanager |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from accelerate import Accelerator |
|
|
|
accelerate = None |
|
|
|
|
|
def load_accelerate(): |
|
global accelerate |
|
accelerate = Accelerator() |
|
|
|
|
|
def is_distributed(): |
|
""" |
|
Check if distributed training is initialized. |
|
""" |
|
global accelerate |
|
if not accelerate: |
|
accelerate = Accelerator() |
|
return dist.is_available() and dist.is_initialized() |
|
|
|
|
|
def barrier(): |
|
""" |
|
Acts as a barrier to wait for all processes. This ensures that all processes |
|
reach the barrier before proceeding further. |
|
""" |
|
if is_distributed(): |
|
dist.barrier() |
|
|
|
|
|
def is_main_process(): |
|
""" |
|
Check if the current process is the main process. |
|
If not in distributed mode, always return True. |
|
""" |
|
if not is_distributed(): |
|
return True |
|
return dist.get_rank() == 0 |
|
|
|
|
|
def get_world_size(): |
|
return int(os.getenv("WORLD_SIZE", "1")) |
|
|
|
|
|
@contextmanager |
|
def zero_first(is_main): |
|
""" |
|
runs the wrapped context so that rank 0 runs first before other ranks |
|
""" |
|
if not is_main: |
|
barrier() |
|
yield |
|
if is_main: |
|
barrier() |
|
|
|
|
|
def gather_scalar_from_all_ranks(fn, world_size=1): |
|
""" |
|
Run a callable 'fn' on all ranks and gather the results on the specified rank. |
|
|
|
Args: |
|
- fn (callable): A function that computes the value. This should not have any side effects. |
|
- rank (int, optional): The rank that gathers the values. Default is 0. |
|
- world_size (int, optional): Total number of processes in the current distributed setup. |
|
|
|
Returns: |
|
- A list of computed values from all ranks if on the gathering rank, otherwise None. |
|
""" |
|
value_scalar = fn() |
|
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float() |
|
|
|
if not is_main_process(): |
|
dist.gather(value_tensor, dst=0) |
|
else: |
|
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)] |
|
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0) |
|
|
|
|
|
gathered_values = [] |
|
for tensor in gathered_tensors: |
|
if tensor == tensor.int(): |
|
gathered_values.append(int(tensor.item())) |
|
else: |
|
gathered_values.append(float(tensor.item())) |
|
return gathered_values |
|
return None |
|
|