File size: 2,578 Bytes
1df74c6
d2b7e94
32b2aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1df74c6
 
 
32b2aaa
 
 
 
 
 
 
 
 
 
 
 
 
1df74c6
32b2aaa
 
 
 
 
 
 
 
 
1df74c6
 
 
32b2aaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from typing import Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.utils.parametrizations import weight_norm

from ..hparams import HParams
from .lvcnet import LVCBlock
from .mrstft import MRSTFTLoss


class UnivNet(nn.Module):
    @property
    def d_noise(self):
        return 128

    @property
    def strides(self):
        return [7, 5, 4, 3]

    @property
    def dilations(self):
        return [1, 3, 9, 27]

    @property
    def nc(self):
        return self.hp.univnet_nc

    @property
    def scale_factor(self) -> int:
        return self.hp.hop_size

    def __init__(self, hp: HParams, d_input):
        super().__init__()
        self.d_input = d_input

        self.hp = hp

        self.blocks = nn.ModuleList(
            [
                LVCBlock(
                    self.nc,
                    d_input,
                    stride=stride,
                    dilations=self.dilations,
                    cond_hop_length=hop_length,
                    kpnet_conv_size=3,
                )
                for stride, hop_length in zip(self.strides, np.cumprod(self.strides))
            ]
        )

        self.conv_pre = weight_norm(
            nn.Conv1d(self.d_noise, self.nc, 7, padding=3, padding_mode="reflect")
        )

        self.conv_post = nn.Sequential(
            nn.LeakyReLU(0.2),
            weight_norm(nn.Conv1d(self.nc, 1, 7, padding=3, padding_mode="reflect")),
            nn.Tanh(),
        )

        self.mrstft = MRSTFTLoss(hp)

    @property
    def eps(self):
        return 1e-5

    def forward(self, x: Tensor, y: Union[Tensor, None] = None, npad=10):
        """
        Args:
            x: (b c t), acoustic features
            y: (b t), waveform
        Returns:
            z: (b t), waveform
        """
        assert x.ndim == 3, "x must be 3D tensor"
        assert y is None or y.ndim == 2, "y must be 2D tensor"
        assert (
            x.shape[1] == self.d_input
        ), f"x.shape[1] must be {self.d_input}, but got {x.shape}"
        assert npad >= 0, "npad must be positive or zero"

        x = F.pad(x, (0, npad), "constant", 0)
        z = torch.randn(x.shape[0], self.d_noise, x.shape[2]).to(x)
        z = self.conv_pre(z)  # (b c t)

        for block in self.blocks:
            z = block(z, x)  # (b c t)

        z = self.conv_post(z)  # (b 1 t)
        z = z[..., : -self.scale_factor * npad]
        z = z.squeeze(1)  # (b t)

        if y is not None:
            self.losses = self.mrstft(z, y)

        return z