File size: 2,937 Bytes
46a75d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch

from TTS.tts.configs.delightful_tts_config import DelightfulTTSConfig
from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel
from TTS.tts.models.delightful_tts import DelightfulTtsArgs, VocoderConfig
from TTS.tts.utils.helpers import rand_segments
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.vocoder.models.hifigan_generator import HifiganGenerator

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

args = DelightfulTtsArgs()
v_args = VocoderConfig()


config = DelightfulTTSConfig(
    model_args=args,
    # compute_f0=True,
    # f0_cache_path=os.path.join(output_path, "f0_cache"),
    text_cleaner="english_cleaners",
    use_phonemes=True,
    phoneme_language="en-us",
    # phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
)

tokenizer, config = TTSTokenizer.init_from_config(config)


def test_acoustic_model():
    dummy_tokens = torch.rand((1, 41)).long().to(device)
    dummy_text_lens = torch.tensor([41]).long().to(device)
    dummy_spec = torch.rand((1, 100, 207)).to(device)
    dummy_spec_lens = torch.tensor([207]).to(device)
    dummy_pitch = torch.rand((1, 1, 207)).long().to(device)
    dummy_energy = torch.rand((1, 1, 207)).long().to(device)

    args.out_channels = 100
    args.num_mels = 100

    acoustic_model = AcousticModel(args=args, tokenizer=tokenizer, speaker_manager=None).to(device)
    acoustic_model = acoustic_model.train()

    output = acoustic_model(
        tokens=dummy_tokens,
        src_lens=dummy_text_lens,
        mel_lens=dummy_spec_lens,
        mels=dummy_spec,
        pitches=dummy_pitch,
        energies=dummy_energy,
        attn_priors=None,
        d_vectors=None,
        speaker_idx=None,
    )
    assert list(output["model_outputs"].shape) == [1, 207, 100]
    # output["model_outputs"].sum().backward()


def test_hifi_decoder():
    dummy_input = torch.rand((1, 207, 100)).to(device)
    dummy_spec_lens = torch.tensor([207]).to(device)

    waveform_decoder = HifiganGenerator(
        100,
        1,
        v_args.resblock_type_decoder,
        v_args.resblock_dilation_sizes_decoder,
        v_args.resblock_kernel_sizes_decoder,
        v_args.upsample_kernel_sizes_decoder,
        v_args.upsample_initial_channel_decoder,
        v_args.upsample_rates_decoder,
        inference_padding=0,
        cond_channels=0,
        conv_pre_weight_norm=False,
        conv_post_weight_norm=False,
        conv_post_bias=False,
    ).to(device)
    waveform_decoder = waveform_decoder.train()

    vocoder_input_slices, slice_ids = rand_segments(  # pylint: disable=unused-variable
        x=dummy_input.transpose(1, 2),
        x_lengths=dummy_spec_lens,
        segment_size=32,
        let_short_samples=True,
        pad_short=True,
    )

    outputs = waveform_decoder(x=vocoder_input_slices.detach())
    assert list(outputs.shape) == [1, 1, 8192]
    # outputs.sum().backward()