File size: 1,988 Bytes
910e2ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn as nn
import math
import torch.distributed as dist


def _all_to_all(

    input_: torch.Tensor,

    world_size: int,

    group: dist.ProcessGroup,

    scatter_dim: int,

    gather_dim: int,

    concat_output: bool,

):
    if world_size == 1:
        return input_
    input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
    output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
    dist.all_to_all(output_list, input_list, group=group)
    if concat_output:
        return torch.cat(output_list, dim=gather_dim).contiguous()
    else:
        # For multi-gpus inference, the latent on each gpu are same, only remain the first one
        return output_list[0]


class _AllToAll(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input_, process_group, world_size, scatter_dim, gather_dim, concat_output):
        ctx.process_group = process_group
        ctx.scatter_dim = scatter_dim
        ctx.gather_dim = gather_dim
        ctx.world_size = world_size
        ctx.concat_output = concat_output
        output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim, concat_output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = _all_to_all(
            grad_output,
            ctx.world_size,
            ctx.process_group,
            ctx.gather_dim,
            ctx.scatter_dim,
            ctx.concat_output,
        )
        return (
            grad_output,
            None,
            None,
            None,
            None,
        )


def all_to_all(

    input_: torch.Tensor,

    process_group: dist.ProcessGroup,

    world_size: int = 1,

    scatter_dim: int = 2,

    gather_dim: int = 1,

    concat_output: bool = True,

):
    return _AllToAll.apply(input_, process_group, world_size, scatter_dim, gather_dim, concat_output)