File size: 3,138 Bytes
46a75d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch as T

from TTS.tts.utils.helpers import average_over_durations, generate_path, rand_segments, segment, sequence_mask


def average_over_durations_test():  # pylint: disable=no-self-use
    pitch = T.rand(1, 1, 128)

    durations = T.randint(1, 5, (1, 21))
    coeff = 128.0 / durations.sum()
    durations = T.floor(durations * coeff)
    diff = 128.0 - durations.sum()
    durations[0, -1] += diff
    durations = durations.long()

    pitch_avg = average_over_durations(pitch, durations)

    index = 0
    for idx, dur in enumerate(durations[0]):
        assert abs(pitch_avg[0, 0, idx] - pitch[0, 0, index : index + dur.item()].mean()) < 1e-5
        index += dur


def seqeunce_mask_test():
    lengths = T.randint(10, 15, (8,))
    mask = sequence_mask(lengths)
    for i in range(8):
        l = lengths[i].item()
        assert mask[i, :l].sum() == l
        assert mask[i, l:].sum() == 0


def segment_test():
    x = T.range(0, 11)
    x = x.repeat(8, 1).unsqueeze(1)
    segment_ids = T.randint(0, 7, (8,))

    segments = segment(x, segment_ids, segment_size=4)
    for idx, start_indx in enumerate(segment_ids):
        assert x[idx, :, start_indx : start_indx + 4].sum() == segments[idx, :, :].sum()

    try:
        segments = segment(x, segment_ids, segment_size=10)
        raise Exception("Should have failed")
    except:
        pass

    segments = segment(x, segment_ids, segment_size=10, pad_short=True)
    for idx, start_indx in enumerate(segment_ids):
        assert x[idx, :, start_indx : start_indx + 10].sum() == segments[idx, :, :].sum()


def rand_segments_test():
    x = T.rand(2, 3, 4)
    x_lens = T.randint(3, 4, (2,))
    segments, seg_idxs = rand_segments(x, x_lens, segment_size=3)
    assert segments.shape == (2, 3, 3)
    assert all(seg_idxs >= 0), seg_idxs
    try:
        segments, _ = rand_segments(x, x_lens, segment_size=5)
        raise Exception("Should have failed")
    except:
        pass
    x_lens_back = x_lens.clone()
    segments, seg_idxs = rand_segments(x, x_lens.clone(), segment_size=5, pad_short=True, let_short_samples=True)
    assert segments.shape == (2, 3, 5)
    assert all(seg_idxs >= 0), seg_idxs
    assert all(x_lens_back == x_lens)


def generate_path_test():
    durations = T.randint(1, 4, (10, 21))
    x_length = T.randint(18, 22, (10,))
    x_mask = sequence_mask(x_length).unsqueeze(1).long()
    durations = durations * x_mask.squeeze(1)
    y_length = durations.sum(1)
    y_mask = sequence_mask(y_length).unsqueeze(1).long()
    attn_mask = (T.unsqueeze(x_mask, -1) * T.unsqueeze(y_mask, 2)).squeeze(1).long()
    print(attn_mask.shape)
    path = generate_path(durations, attn_mask)
    assert path.shape == (10, 21, durations.sum(1).max().item())
    for b in range(durations.shape[0]):
        current_idx = 0
        for t in range(durations.shape[1]):
            assert all(path[b, t, current_idx : current_idx + durations[b, t].item()] == 1.0)
            assert all(path[b, t, :current_idx] == 0.0)
            assert all(path[b, t, current_idx + durations[b, t].item() :] == 0.0)
            current_idx += durations[b, t].item()