File size: 3,028 Bytes
910e2ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import torch
import torch.distributed as dist
from .utils import is_dist_avail_and_initialized, get_rank


SEQ_PARALLEL_GROUP = None
SEQ_PARALLEL_SIZE = None
SEQ_PARALLEL_PROC_NUM = None    # using how many process for sequence parallel

SYNC_INPUT_GROUP = None
SYNC_INPUT_SIZE = None

def is_sequence_parallel_initialized():
    if SEQ_PARALLEL_GROUP is None:
        return False
    else:
        return True


def init_sequence_parallel_group(args):
    global SEQ_PARALLEL_GROUP
    global SEQ_PARALLEL_SIZE
    global SEQ_PARALLEL_PROC_NUM

    assert SEQ_PARALLEL_GROUP is None, "sequence parallel group is already initialized"
    assert is_dist_avail_and_initialized(), "The pytorch distributed should be initialized"
    SEQ_PARALLEL_SIZE = args.sp_group_size
    
    print(f"Setting the Sequence Parallel Size {SEQ_PARALLEL_SIZE}")

    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()

    if args.sp_proc_num == -1:
        SEQ_PARALLEL_PROC_NUM = world_size
    else:
        SEQ_PARALLEL_PROC_NUM = args.sp_proc_num

    assert SEQ_PARALLEL_PROC_NUM % SEQ_PARALLEL_SIZE == 0, "The process needs to be evenly divided"

    for i in range(0, SEQ_PARALLEL_PROC_NUM, SEQ_PARALLEL_SIZE):
        ranks = list(range(i, i + SEQ_PARALLEL_SIZE))
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
            SEQ_PARALLEL_GROUP = group
            break


def init_sync_input_group(args):
    global SYNC_INPUT_GROUP
    global SYNC_INPUT_SIZE

    assert SYNC_INPUT_GROUP is None, "parallel group is already initialized"
    assert is_dist_avail_and_initialized(), "The pytorch distributed should be initialized"
    SYNC_INPUT_SIZE = args.max_frames

    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()

    for i in range(0, world_size, SYNC_INPUT_SIZE):
        ranks = list(range(i, i + SYNC_INPUT_SIZE))
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
            SYNC_INPUT_GROUP = group
            break


def get_sequence_parallel_group():
    assert SEQ_PARALLEL_GROUP is not None, "sequence parallel group is not initialized"
    return SEQ_PARALLEL_GROUP


def get_sync_input_group():
    return SYNC_INPUT_GROUP


def get_sequence_parallel_world_size():
    assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized"
    return SEQ_PARALLEL_SIZE


def get_sequence_parallel_rank():
    assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized"
    rank = get_rank()
    cp_rank = rank % SEQ_PARALLEL_SIZE
    return cp_rank


def get_sequence_parallel_group_rank():
    assert SEQ_PARALLEL_SIZE is not None, "sequence parallel size is not initialized"
    rank = get_rank()
    cp_group_rank = rank // SEQ_PARALLEL_SIZE
    return cp_group_rank


def get_sequence_parallel_proc_num():
    return SEQ_PARALLEL_PROC_NUM