File size: 2,789 Bytes
7694c84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import random
import torch
import torch.nn.functional as F


def save_states(fname, 
                model, 
                optimizer,
                n_iter, epoch, 
                net_config, config):
    torch.save({'model': model.state_dict(),
                'optim': optimizer.state_dict(),
                'epoch': epoch,
                'iter': n_iter,
                'config': net_config,
                },
               f'{config.checkpoint_dir}/{fname}')

def save_states_gan(fname, 
                    model, model_d, 
                    optimizer, optimizer_d, 
                    n_iter, epoch, 
                    net_config, config):
    torch.save({'model': model.state_dict(),
                'model_d': model_d.state_dict(),                
                'optim': optimizer.state_dict(),
                'optim_d': optimizer_d.state_dict(),
                'epoch': epoch, 'iter': n_iter,
                'config': net_config,
                },
               f'{config.checkpoint_dir}/{fname}')
    

def batch_to_device(batch, device):
    text_padded, input_lengths, mel_padded, gate_padded, \
        output_lengths = batch

    text_padded = text_padded.to(device, non_blocking=True)
    input_lengths = input_lengths.to(device, non_blocking=True)
    mel_padded = mel_padded.to(device, non_blocking=True)
    gate_padded = gate_padded.to(device, non_blocking=True)
    output_lengths = output_lengths.to(device, non_blocking=True)

    return (text_padded, input_lengths, mel_padded, gate_padded,
            output_lengths)


@torch.inference_mode()
def validate(model, test_loader, writer, device, n_iter):
    loss_sum = 0
    n_test_sum = 0

    model.eval()

    for batch in test_loader:
        text_padded, input_lengths, mel_padded, gate_padded, \
            output_lengths = batch_to_device(batch, device)

        y_pred = model(text_padded, input_lengths,
                       mel_padded, output_lengths)
        mel_out, mel_out_postnet, gate_pred, alignments = y_pred

        mel_loss = F.mse_loss(mel_out, mel_padded) + \
            F.mse_loss(mel_out_postnet, mel_padded)
        gate_loss = F.binary_cross_entropy_with_logits(gate_pred, gate_padded)
        loss = mel_loss + gate_loss

        loss_sum += mel_padded.size(0)*loss.item()
        n_test_sum += mel_padded.size(0)

    val_loss = loss_sum / n_test_sum

    idx = random.randint(0, mel_padded.size(0) - 1)
    mel_infer, *_ = model.infer(
        text_padded[idx:idx+1], input_lengths[idx:idx+1])

    writer.add_sample(
        alignments[idx, :, :input_lengths[idx].item()],
        mel_out[idx], mel_padded[idx], mel_infer[0],
        output_lengths[idx], n_iter)

    writer.add_scalar('loss/val_loss', val_loss, n_iter)

    model.train()

    return val_loss