Spaces:
Runtime error
Runtime error
File size: 2,820 Bytes
c021d8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from onmt_modules.misc import sequence_mask
class DecodeFunc_Sp(object):
"""
Decoding functions
"""
def __init__(self, hparams, type_out):
if type_out == 'Sp':
self.dim_freq = hparams.dim_freq
self.max_decoder_steps = hparams.dec_steps_sp
elif type_out == 'Tx':
self.dim_freq = hparams.dim_code
self.max_decoder_steps = hparams.dec_steps_tx
else:
raise ValueError
self.gate_threshold = hparams.gate_threshold
self.type_out = type_out
def __call__(self, tgt, memory_bank, memory_lengths, decoder, postnet):
dec_outs, attns = decoder(tgt, memory_bank, step=None,
memory_lengths=memory_lengths)
spect_gate = postnet(dec_outs)
spect, gate = spect_gate[:, :, 1:], spect_gate[:, :, :1]
return spect, gate
def infer(self, tgt_real, memory_bank, memory_lengths, decoder, postnet):
B = memory_bank.size(1)
device = memory_bank.device
spect_outputs = torch.zeros((self.max_decoder_steps, B, self.dim_freq),
dtype=torch.float, device=device)
gate_outputs = torch.zeros((self.max_decoder_steps, B, 1),
dtype=torch.float, device=device)
tgt_words = torch.zeros([B, 1],
dtype=torch.float, device=device)
current_pred = torch.zeros([1, B, self.dim_freq],
dtype=torch.float, device=device)
for t in range(self.max_decoder_steps):
dec_outs, _ = decoder(current_pred,
memory_bank, t,
memory_lengths=memory_lengths,
tgt_words=tgt_words)
spect_gate = postnet(dec_outs)
spect, gate = spect_gate[:, :, 1:], spect_gate[:, :, :1]
spect_outputs[t:t+1] = spect
gate_outputs[t:t+1] = gate
stop = (torch.sigmoid(gate) - self.gate_threshold + 0.5).round()
current_pred = spect.data
tgt_words = stop.squeeze(-1).t()
if t == self.max_decoder_steps - 1:
print(f"Warning! {self.type_out} reached max decoder steps")
if (stop == 1).all():
break
stop_quant = (torch.sigmoid(gate_outputs.data) - self.gate_threshold + 0.5).round().squeeze(-1)
len_spect = (stop_quant.cumsum(dim=0)==0).sum(dim=0)
return spect_outputs, len_spect, gate_outputs |