tts-rvc-autopst / override_decoder.py
jonathanjordan21's picture
67809715652a92b22870c50ad30f6ff38e292006aedc75ddbdc828aa856ef68f
c021d8e verified
raw
history blame
2.34 kB
from onmt_modules.decoder_transformer import TransformerDecoder
from onmt_modules.misc import sequence_mask
class OnmtDecoder_1(TransformerDecoder):
# overide forward
# without teacher forcing for stop
def forward(self, tgt, memory_bank, step=None, **kwargs):
"""Decode, possibly stepwise."""
if step == 0:
self._init_cache(memory_bank)
if step is None:
tgt_lens = kwargs["tgt_lengths"]
else:
tgt_words = kwargs["tgt_words"]
emb = self.embeddings(tgt, step=step)
assert emb.dim() == 3 # len x batch x embedding_dim
output = emb.transpose(0, 1).contiguous()
src_memory_bank = memory_bank.transpose(0, 1).contiguous()
pad_idx = self.embeddings.word_padding_idx
src_lens = kwargs["memory_lengths"]
src_max_len = self.state["src"].shape[0]
src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1)
if step is None:
tgt_max_len = tgt_lens.max()
tgt_pad_mask = ~sequence_mask(tgt_lens, tgt_max_len).unsqueeze(1)
else:
tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1)
with_align = kwargs.pop('with_align', False)
attn_aligns = []
for i, layer in enumerate(self.transformer_layers):
layer_cache = self.state["cache"]["layer_{}".format(i)] \
if step is not None else None
output, attn, attn_align = layer(
output,
src_memory_bank,
src_pad_mask,
tgt_pad_mask,
layer_cache=layer_cache,
step=step,
with_align=with_align)
if attn_align is not None:
attn_aligns.append(attn_align)
output = self.layer_norm(output)
dec_outs = output.transpose(0, 1).contiguous()
attn = attn.transpose(0, 1).contiguous()
attns = {"std": attn}
if self._copy:
attns["copy"] = attn
if with_align:
attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)`
# attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg
# TODO change the way attns is returned dict => list or tuple (onnx)
return dec_outs, attns