File size: 3,556 Bytes
2bb0b78
 
 
7657632
e30f1e3
fc2d6be
 
7657632
2bb0b78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc2d6be
 
7657632
 
 
 
fc2d6be
 
 
 
 
 
 
 
 
 
7657632
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09f1543
 
7657632
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e30f1e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
utility helpers for distributed checks
"""
import os
import pickle  # nosec
from contextlib import contextmanager

import torch
import torch.distributed as dist
from accelerate import Accelerator

accelerate = None  # pylint: disable=invalid-name


def load_accelerate():
    global accelerate  # pylint: disable=global-statement
    accelerate = Accelerator()


def is_distributed():
    """
    Check if distributed training is initialized.
    """
    global accelerate  # pylint: disable=global-statement
    if not accelerate:
        accelerate = Accelerator()
    return dist.is_available() and dist.is_initialized()


def barrier():
    """
    Acts as a barrier to wait for all processes. This ensures that all processes
    reach the barrier before proceeding further.
    """
    if is_distributed():
        dist.barrier()


def is_main_process():
    """
    Check if the current process is the main process.
    If not in distributed mode, always return True.
    """
    if not is_distributed():
        return True
    return dist.get_rank() == 0


def get_world_size():
    return int(os.getenv("WORLD_SIZE", "1"))


@contextmanager
def zero_first(is_main):
    """
    runs the wrapped context so that rank 0 runs first before other ranks
    """
    if not is_main:  # other ranks wait first
        barrier()
    yield
    if is_main:  # then rank 0 waits after it has run the context
        barrier()


def gather_scalar_from_all_ranks(fn, world_size=1):  # pylint: disable=invalid-name
    """
    Run a callable 'fn' on all ranks and gather the results on the specified rank.

    Args:
    - fn (callable): A function that computes the value. This should not have any side effects.
    - rank (int, optional): The rank that gathers the values. Default is 0.
    - world_size (int, optional): Total number of processes in the current distributed setup.

    Returns:
    - A list of computed values from all ranks if on the gathering rank, otherwise None.
    """
    value_scalar = fn()
    if not is_distributed():
        return [value_scalar]
    value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()

    if not is_main_process():
        dist.gather(value_tensor, dst=0)
    else:
        gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
        dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)

        # Convert tensors back to their original type (int or float)
        gathered_values = []
        for tensor in gathered_tensors:
            if tensor == tensor.int():
                gathered_values.append(int(tensor.item()))
            else:
                gathered_values.append(float(tensor.item()))
        return gathered_values
    return None


def broadcast_dict(vals: dict):
    if not is_distributed():
        return vals

    if is_main_process():
        data_byte = pickle.dumps(vals)
        data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
        data_size = torch.IntTensor([len(data_byte)]).to("cuda")
    else:
        data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
        data_size = torch.IntTensor([0]).to("cuda")

    dist.broadcast(data_size, 0)
    if not is_main_process():
        # resize
        data_tensor = data_tensor.new_empty([data_size.item()])

    dist.broadcast(data_tensor, 0)

    if not is_main_process():
        data_list = data_tensor.cpu().tolist()
        data_byte = bytes(data_list[: data_size.item()])
        vals = pickle.loads(data_byte)  # nosec

    return vals