File size: 1,010 Bytes
7694c84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch.nn as nn
import torch.nn.functional as F


class Tacotron2Loss(nn.Module):
    def __init__(self, 
                 mel_loss_scale=1.0):
        super().__init__()
        self.mel_loss_scale = mel_loss_scale

    def forward(self, 
                mel_out, 
                mel_out_postnet, 
                mel_padded,
                gate_out,
                gate_padded):
        
        rnn_mel_loss = F.mse_loss(mel_out, mel_padded)
        postnet_mel_loss = F.mse_loss(mel_out_postnet, mel_padded)            
        gate_loss = F.binary_cross_entropy_with_logits(
            gate_out, gate_padded)
        
        meta = {
            'mel_loss_rnn': rnn_mel_loss.clone().detach(),
            'mel_loss_postnet': postnet_mel_loss.clone().detach(),
            'gate_loss': gate_loss.clone().detach(),           
        }    

        loss = self.mel_loss_scale * rnn_mel_loss \
        + self.mel_loss_scale * postnet_mel_loss \
        + gate_loss 

        return loss, meta