File size: 3,517 Bytes
32b2aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2b7e94
 
 
32b2aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import logging
import math

import torch
import torch.nn as nn

logger = logging.getLogger(__name__)


@torch.jit.script
def _fused_tanh_sigmoid(h):
    a, b = h.chunk(2, dim=1)
    h = a.tanh() * b.sigmoid()
    return h


class WNLayer(nn.Module):
    """
    A DiffWave-like WN
    """

    def __init__(self, hidden_dim, local_dim, global_dim, kernel_size, dilation):
        super().__init__()

        local_output_dim = hidden_dim * 2

        if global_dim is not None:
            self.gconv = nn.Conv1d(global_dim, hidden_dim, 1)

        if local_dim is not None:
            self.lconv = nn.Conv1d(local_dim, local_output_dim, 1)

        self.dconv = nn.Conv1d(
            hidden_dim, local_output_dim, kernel_size, dilation=dilation, padding="same"
        )

        self.out = nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size=1)

    def forward(self, z, l, g):
        identity = z

        if g is not None:
            if g.dim() == 2:
                g = g.unsqueeze(-1)
            z = z + self.gconv(g)

        z = self.dconv(z)

        if l is not None:
            z = z + self.lconv(l)

        z = _fused_tanh_sigmoid(z)

        h = self.out(z)

        z, s = h.chunk(2, dim=1)

        o = (z + identity) / math.sqrt(2)

        return o, s


class WN(nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim,
        local_dim=None,
        global_dim=None,
        n_layers=30,
        kernel_size=3,
        dilation_cycle=5,
        hidden_dim=512,
    ):
        super().__init__()
        assert kernel_size % 2 == 1
        assert hidden_dim % 2 == 0

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.local_dim = local_dim
        self.global_dim = global_dim

        self.start = nn.Conv1d(input_dim, hidden_dim, 1)
        if local_dim is not None:
            self.local_norm = nn.InstanceNorm1d(local_dim)

        self.layers = nn.ModuleList(
            [
                WNLayer(
                    hidden_dim=hidden_dim,
                    local_dim=local_dim,
                    global_dim=global_dim,
                    kernel_size=kernel_size,
                    dilation=2 ** (i % dilation_cycle),
                )
                for i in range(n_layers)
            ]
        )

        self.end = nn.Conv1d(hidden_dim, output_dim, 1)

    def forward(self, z, l=None, g=None):
        """
        Args:
            z: input (b c t)
            l: local condition (b c t)
            g: global condition (b d)
        """
        z = self.start(z)

        if l is not None:
            l = self.local_norm(l)

        # Skips
        s_list = []

        for layer in self.layers:
            z, s = layer(z, l, g)
            s_list.append(s)

        s_list = torch.stack(s_list, dim=0).sum(dim=0)
        s_list = s_list / math.sqrt(len(self.layers))

        o = self.end(s_list)

        return o

    def summarize(self, length=100):
        from ptflops import get_model_complexity_info

        x = torch.randn(1, self.input_dim, length)

        macs, params = get_model_complexity_info(
            self,
            (self.input_dim, length),
            as_strings=True,
            print_per_layer_stat=True,
            verbose=True,
        )

        print(f"Input shape: {x.shape}")
        print(f"Computational complexity: {macs}")
        print(f"Number of parameters: {params}")


if __name__ == "__main__":
    model = WN(input_dim=64, output_dim=64)
    model.summarize()