File size: 3,466 Bytes
7694c84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from typing import Optional, List


def extract_chunks(A: Tensor, 
                   ofx: Tensor, 
                   mel_ids: Optional[Tensor] = None, 
                   chunk_len: int = 128):
    """
    Args:
        A (Tensor): spectrograms [B, F, T]
        ofx (Tensor): offsets [num_chunks,]
        mel_ids (Tensor): [num_chunks,]
    Returns:
        chunks (Tensor): [num_chunks, F, chunk_len]
    """
    ids = torch.arange(0, chunk_len, device=A.device)[None,:].repeat(len(mel_ids), 1) + ofx[:,None]

    if mel_ids is None:
        mel_ids = torch.arange(0, A.size(0), device=A.device)[:,None] * A.size(2)
    ids = ids + mel_ids[:,None] * A.size(2)

    chunks = A.transpose(0, 1).flatten(1)[:, ids.long()].transpose(0, 1)
    return chunks


def calc_feature_match_loss(fmaps_gen: List[Tensor],
                            fmaps_org: List[Tensor]
                            ):
    
    loss_fmatch = 0.
    for (fmap_gen, fmap_org) in zip(fmaps_gen, fmaps_org):
        fmap_org.detach_()
        loss_fmatch += (fmap_gen - fmap_org).abs().mean()

    loss_fmatch = loss_fmatch / len(fmaps_gen)
    return loss_fmatch


class Conv2DSpectralNorm(nn.Conv2d):
    """Convolution layer that applies Spectral Normalization before every call."""

    def __init__(self, cnum_in: int, cnum_out: int, 
                 kernel_size: int, stride: int, padding: int = 0, 
                 n_iter: int = 1, eps: float = 1e-12, 
                 bias: bool = True):
        super().__init__(cnum_in,
                         cnum_out, kernel_size=kernel_size,
                         stride=stride, padding=padding, bias=bias)
        self.register_buffer("weight_u", torch.empty(self.weight.size(0), 1))
        nn.init.trunc_normal_(self.weight_u)
        self.n_iter = n_iter
        self.eps = eps

    def l2_norm(self, x):
        return F.normalize(x, p=2, dim=0, eps=self.eps)

    def forward(self, x):

        weight_orig = self.weight.flatten(1).detach()

        for _ in range(self.n_iter):
            v = self.l2_norm(weight_orig.t() @ self.weight_u)
            self.weight_u = self.l2_norm(weight_orig @ v)

        sigma = self.weight_u.t() @ weight_orig @ v
        self.weight.data.div_(sigma)

        x = super().forward(x)

        return x


class DConv(nn.Module):
    def __init__(self, cnum_in,
                 cnum_out, ksize=5, stride=2, padding='auto'):
        super().__init__()
        padding = (ksize-1)//2 if padding == 'auto' else padding
        self.conv_sn = Conv2DSpectralNorm(
            cnum_in, cnum_out, ksize, stride, padding)
        #self.conv_sn = spectral_norm(nn.Conv2d(cnum_in, cnum_out, ksize, stride, padding))
        self.leaky = nn.LeakyReLU(negative_slope=0.2)

    def forward(self, x):
        x = self.conv_sn(x)
        x = self.leaky(x)
        return x


class PatchDiscriminator(nn.Module):
    def __init__(self, cnum_in, cnum):
        super().__init__()
        self.conv1 = DConv(cnum_in, cnum)
        self.conv2 = DConv(cnum, 2*cnum)
        self.conv3 = DConv(2*cnum, 4*cnum)
        self.conv4 = DConv(4*cnum, 4*cnum)
        self.conv5 = DConv(4*cnum, 4*cnum)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        x4 = self.conv4(x3)
        x5 = self.conv5(x4)
        x = nn.Flatten()(x5)

        return x, [x1, x2, x3, x4]