File size: 3,605 Bytes
46a75d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch

from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.utils.helpers import sequence_mask

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def test_encoder():
    input_dummy = torch.rand(8, 14, 37).to(device)
    input_lengths = torch.randint(31, 37, (8,)).long().to(device)
    input_lengths[-1] = 37
    input_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
    # relative positional transformer encoder
    layer = Encoder(
        out_channels=11,
        in_hidden_channels=14,
        encoder_type="relative_position_transformer",
        encoder_params={
            "hidden_channels_ffn": 768,
            "num_heads": 2,
            "kernel_size": 3,
            "dropout_p": 0.1,
            "num_layers": 6,
            "rel_attn_window_size": 4,
            "input_length": None,
        },
    ).to(device)
    output = layer(input_dummy, input_mask)
    assert list(output.shape) == [8, 11, 37]
    # residual conv bn encoder
    layer = Encoder(
        out_channels=11,
        in_hidden_channels=14,
        encoder_type="residual_conv_bn",
        encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13},
    ).to(device)
    output = layer(input_dummy, input_mask)
    assert list(output.shape) == [8, 11, 37]
    # FFTransformer encoder
    layer = Encoder(
        out_channels=14,
        in_hidden_channels=14,
        encoder_type="fftransformer",
        encoder_params={"hidden_channels_ffn": 31, "num_heads": 2, "num_layers": 2, "dropout_p": 0.1},
    ).to(device)
    output = layer(input_dummy, input_mask)
    assert list(output.shape) == [8, 14, 37]


def test_decoder():
    input_dummy = torch.rand(8, 128, 37).to(device)
    input_lengths = torch.randint(31, 37, (8,)).long().to(device)
    input_lengths[-1] = 37

    input_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device)
    # residual bn conv decoder
    layer = Decoder(out_channels=11, in_hidden_channels=128).to(device)
    output = layer(input_dummy, input_mask)
    assert list(output.shape) == [8, 11, 37]
    # transformer decoder
    layer = Decoder(
        out_channels=11,
        in_hidden_channels=128,
        decoder_type="relative_position_transformer",
        decoder_params={
            "hidden_channels_ffn": 128,
            "num_heads": 2,
            "kernel_size": 3,
            "dropout_p": 0.1,
            "num_layers": 8,
            "rel_attn_window_size": 4,
            "input_length": None,
        },
    ).to(device)
    output = layer(input_dummy, input_mask)
    assert list(output.shape) == [8, 11, 37]
    # wavenet decoder
    layer = Decoder(
        out_channels=11,
        in_hidden_channels=128,
        decoder_type="wavenet",
        decoder_params={
            "num_blocks": 12,
            "hidden_channels": 192,
            "kernel_size": 5,
            "dilation_rate": 1,
            "num_layers": 4,
            "dropout_p": 0.05,
        },
    ).to(device)
    output = layer(input_dummy, input_mask)
    # FFTransformer decoder
    layer = Decoder(
        out_channels=11,
        in_hidden_channels=128,
        decoder_type="fftransformer",
        decoder_params={
            "hidden_channels_ffn": 31,
            "num_heads": 2,
            "dropout_p": 0.1,
            "num_layers": 2,
        },
    ).to(device)
    output = layer(input_dummy, input_mask)
    assert list(output.shape) == [8, 11, 37]