tts-rvc-autopst / fast_decoders.py
jonathanjordan21's picture
67809715652a92b22870c50ad30f6ff38e292006aedc75ddbdc828aa856ef68f
c021d8e verified
raw
history blame
2.82 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from onmt_modules.misc import sequence_mask
class DecodeFunc_Sp(object):
"""
Decoding functions
"""
def __init__(self, hparams, type_out):
if type_out == 'Sp':
self.dim_freq = hparams.dim_freq
self.max_decoder_steps = hparams.dec_steps_sp
elif type_out == 'Tx':
self.dim_freq = hparams.dim_code
self.max_decoder_steps = hparams.dec_steps_tx
else:
raise ValueError
self.gate_threshold = hparams.gate_threshold
self.type_out = type_out
def __call__(self, tgt, memory_bank, memory_lengths, decoder, postnet):
dec_outs, attns = decoder(tgt, memory_bank, step=None,
memory_lengths=memory_lengths)
spect_gate = postnet(dec_outs)
spect, gate = spect_gate[:, :, 1:], spect_gate[:, :, :1]
return spect, gate
def infer(self, tgt_real, memory_bank, memory_lengths, decoder, postnet):
B = memory_bank.size(1)
device = memory_bank.device
spect_outputs = torch.zeros((self.max_decoder_steps, B, self.dim_freq),
dtype=torch.float, device=device)
gate_outputs = torch.zeros((self.max_decoder_steps, B, 1),
dtype=torch.float, device=device)
tgt_words = torch.zeros([B, 1],
dtype=torch.float, device=device)
current_pred = torch.zeros([1, B, self.dim_freq],
dtype=torch.float, device=device)
for t in range(self.max_decoder_steps):
dec_outs, _ = decoder(current_pred,
memory_bank, t,
memory_lengths=memory_lengths,
tgt_words=tgt_words)
spect_gate = postnet(dec_outs)
spect, gate = spect_gate[:, :, 1:], spect_gate[:, :, :1]
spect_outputs[t:t+1] = spect
gate_outputs[t:t+1] = gate
stop = (torch.sigmoid(gate) - self.gate_threshold + 0.5).round()
current_pred = spect.data
tgt_words = stop.squeeze(-1).t()
if t == self.max_decoder_steps - 1:
print(f"Warning! {self.type_out} reached max decoder steps")
if (stop == 1).all():
break
stop_quant = (torch.sigmoid(gate_outputs.data) - self.gate_threshold + 0.5).round().squeeze(-1)
len_spect = (stop_quant.cumsum(dim=0)==0).sum(dim=0)
return spect_outputs, len_spect, gate_outputs