tts-rvc-autopst / onmt_modules /decoder_transformer.py
jonathanjordan21's picture
67809715652a92b22870c50ad30f6ff38e292006aedc75ddbdc828aa856ef68f
c021d8e verified
raw
history blame
14.5 kB
"""
Implementation of "Attention is All You Need"
"""
import torch
import torch.nn as nn
from .decoder import DecoderBase
from .multi_headed_attn import MultiHeadedAttention
from .average_attn import AverageAttention
from .position_ffn import PositionwiseFeedForward
from .misc import sequence_mask
class TransformerDecoderLayer(nn.Module):
"""Transformer Decoder layer block in Pre-Norm style.
Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style,
providing better converge speed and performance. This is also the actual
implementation in tensor2tensor and also avalable in fairseq.
See https://tunz.kr/post/4 and :cite:`DeeperTransformer`.
.. mermaid::
graph LR
%% "*SubLayer" can be self-attn, src-attn or feed forward block
A(input) --> B[Norm]
B --> C["*SubLayer"]
C --> D[Drop]
D --> E((+))
A --> E
E --> F(out)
Args:
d_model (int): the dimension of keys/values/queries in
:class:`MultiHeadedAttention`, also the input size of
the first-layer of the :class:`PositionwiseFeedForward`.
heads (int): the number of heads for MultiHeadedAttention.
d_ff (int): the second-layer of the :class:`PositionwiseFeedForward`.
dropout (float): dropout in residual, self-attn(dot) and feed-forward
attention_dropout (float): dropout in context_attn (and self-attn(avg))
self_attn_type (string): type of self-attention scaled-dot, average
max_relative_positions (int):
Max distance between inputs in relative positions representations
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
full_context_alignment (bool):
whether enable an extra full context decoder forward for alignment
alignment_heads (int):
N. of cross attention heads to use for alignment guiding
"""
def __init__(self, d_model, heads, d_ff, dropout, attention_dropout,
self_attn_type="scaled-dot", max_relative_positions=0,
aan_useffn=False, full_context_alignment=False,
alignment_heads=0):
super(TransformerDecoderLayer, self).__init__()
if self_attn_type == "scaled-dot":
self.self_attn = MultiHeadedAttention(
heads, d_model, dropout=attention_dropout,
max_relative_positions=max_relative_positions)
elif self_attn_type == "average":
self.self_attn = AverageAttention(d_model,
dropout=attention_dropout,
aan_useffn=aan_useffn)
self.context_attn = MultiHeadedAttention(
heads, d_model, dropout=attention_dropout)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
self.drop = nn.Dropout(dropout)
self.full_context_alignment = full_context_alignment
self.alignment_heads = alignment_heads
def forward(self, *args, **kwargs):
""" Extend `_forward` for (possibly) multiple decoder pass:
Always a default (future masked) decoder forward pass,
Possibly a second future aware decoder pass for joint learn
full context alignement, :cite:`garg2019jointly`.
Args:
* All arguments of _forward.
with_align (bool): whether return alignment attention.
Returns:
(FloatTensor, FloatTensor, FloatTensor or None):
* output ``(batch_size, T, model_dim)``
* top_attn ``(batch_size, T, src_len)``
* attn_align ``(batch_size, T, src_len)`` or None
"""
with_align = kwargs.pop('with_align', False)
output, attns = self._forward(*args, **kwargs)
top_attn = attns[:, 0, :, :].contiguous()
attn_align = None
if with_align:
if self.full_context_alignment:
# return _, (B, Q_len, K_len)
_, attns = self._forward(*args, **kwargs, future=True)
if self.alignment_heads > 0:
attns = attns[:, :self.alignment_heads, :, :].contiguous()
# layer average attention across heads, get ``(B, Q, K)``
# Case 1: no full_context, no align heads -> layer avg baseline
# Case 2: no full_context, 1 align heads -> guided align
# Case 3: full_context, 1 align heads -> full cte guided align
attn_align = attns.mean(dim=1)
return output, top_attn, attn_align
def _forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask,
layer_cache=None, step=None, future=False):
""" A naive forward pass for transformer decoder.
# T: could be 1 in the case of stepwise decoding or tgt_len
Args:
inputs (FloatTensor): ``(batch_size, T, model_dim)``
memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)``
src_pad_mask (LongTensor): ``(batch_size, 1, src_len)``
tgt_pad_mask (LongTensor): ``(batch_size, 1, T)``
layer_cache (dict or None): cached layer info when stepwise decode
step (int or None): stepwise decoding counter
future (bool): If set True, do not apply future_mask.
Returns:
(FloatTensor, FloatTensor):
* output ``(batch_size, T, model_dim)``
* attns ``(batch_size, head, T, src_len)``
"""
dec_mask = None
if step is None:
tgt_len = tgt_pad_mask.size(-1)
if not future: # apply future_mask, result mask in (B, T, T)
future_mask = torch.ones(
[tgt_len, tgt_len],
device=tgt_pad_mask.device,
dtype=torch.uint8)
future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len)
# BoolTensor was introduced in pytorch 1.2
try:
future_mask = future_mask.bool()
except AttributeError:
pass
dec_mask = torch.gt(tgt_pad_mask + future_mask, 0)
else: # only mask padding, result mask in (B, 1, T)
dec_mask = tgt_pad_mask
input_norm = self.layer_norm_1(inputs)
if isinstance(self.self_attn, MultiHeadedAttention):
query, _ = self.self_attn(input_norm, input_norm, input_norm,
mask=dec_mask,
layer_cache=layer_cache,
attn_type="self")
elif isinstance(self.self_attn, AverageAttention):
query, _ = self.self_attn(input_norm, mask=dec_mask,
layer_cache=layer_cache, step=step)
query = self.drop(query) + inputs
query_norm = self.layer_norm_2(query)
mid, attns = self.context_attn(memory_bank, memory_bank, query_norm,
mask=src_pad_mask,
layer_cache=layer_cache,
attn_type="context")
output = self.feed_forward(self.drop(mid) + query)
return output, attns
def update_dropout(self, dropout, attention_dropout):
self.self_attn.update_dropout(attention_dropout)
self.context_attn.update_dropout(attention_dropout)
self.feed_forward.update_dropout(dropout)
self.drop.p = dropout
class TransformerDecoder(DecoderBase):
"""The Transformer decoder from "Attention is All You Need".
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
.. mermaid::
graph BT
A[input]
B[multi-head self-attn]
BB[multi-head src-attn]
C[feed forward]
O[output]
A --> B
B --> BB
BB --> C
C --> O
Args:
num_layers (int): number of encoder layers.
d_model (int): size of the model
heads (int): number of heads
d_ff (int): size of the inner FF layer
copy_attn (bool): if using a separate copy attention
self_attn_type (str): type of self-attention scaled-dot, average
dropout (float): dropout in residual, self-attn(dot) and feed-forward
attention_dropout (float): dropout in context_attn (and self-attn(avg))
embeddings (onmt.modules.Embeddings):
embeddings to use, should have positional encodings
max_relative_positions (int):
Max distance between inputs in relative positions representations
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
full_context_alignment (bool):
whether enable an extra full context decoder forward for alignment
alignment_layer (int): N° Layer to supervise with for alignment guiding
alignment_heads (int):
N. of cross attention heads to use for alignment guiding
"""
def __init__(self, num_layers, d_model, heads, d_ff,
copy_attn, self_attn_type, dropout, attention_dropout,
embeddings, max_relative_positions, aan_useffn,
full_context_alignment, alignment_layer,
alignment_heads):
super(TransformerDecoder, self).__init__()
self.embeddings = embeddings
# Decoder State
self.state = {}
self.transformer_layers = nn.ModuleList(
[TransformerDecoderLayer(d_model, heads, d_ff, dropout,
attention_dropout, self_attn_type=self_attn_type,
max_relative_positions=max_relative_positions,
aan_useffn=aan_useffn,
full_context_alignment=full_context_alignment,
alignment_heads=alignment_heads)
for i in range(num_layers)])
# previously, there was a GlobalAttention module here for copy
# attention. But it was never actually used -- the "copy" attention
# just reuses the context attention.
self._copy = copy_attn
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.alignment_layer = alignment_layer
@classmethod
def from_opt(cls, opt, embeddings):
"""Alternate constructor."""
return cls(
opt.dec_layers,
opt.dec_rnn_size,
opt.heads,
opt.transformer_ff,
opt.copy_attn,
opt.self_attn_type,
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
opt.attention_dropout[0] if type(opt.attention_dropout)
is list else opt.dropout,
embeddings,
opt.max_relative_positions,
opt.aan_useffn,
opt.full_context_alignment,
opt.alignment_layer,
alignment_heads=opt.alignment_heads)
def init_state(self, src, memory_bank, enc_hidden):
"""Initialize decoder state."""
self.state["src"] = src
self.state["cache"] = None
def map_state(self, fn):
def _recursive_map(struct, batch_dim=0):
for k, v in struct.items():
if v is not None:
if isinstance(v, dict):
_recursive_map(v)
else:
struct[k] = fn(v, batch_dim)
self.state["src"] = fn(self.state["src"], 1)
if self.state["cache"] is not None:
_recursive_map(self.state["cache"])
def detach_state(self):
self.state["src"] = self.state["src"].detach()
def forward(self, tgt, memory_bank, step=None, **kwargs):
"""Decode, possibly stepwise."""
if step == 0:
self._init_cache(memory_bank)
tgt_words = tgt[:, :, 0].transpose(0, 1)
emb = self.embeddings(tgt, step=step)
assert emb.dim() == 3 # len x batch x embedding_dim
output = emb.transpose(0, 1).contiguous()
src_memory_bank = memory_bank.transpose(0, 1).contiguous()
pad_idx = self.embeddings.word_padding_idx
src_lens = kwargs["memory_lengths"]
src_max_len = self.state["src"].shape[0]
src_pad_mask = ~sequence_mask(src_lens, src_max_len).unsqueeze(1)
tgt_pad_mask = tgt_words.data.eq(pad_idx).unsqueeze(1) # [B, 1, T_tgt]
with_align = kwargs.pop('with_align', False)
attn_aligns = []
for i, layer in enumerate(self.transformer_layers):
layer_cache = self.state["cache"]["layer_{}".format(i)] \
if step is not None else None
output, attn, attn_align = layer(
output,
src_memory_bank,
src_pad_mask,
tgt_pad_mask,
layer_cache=layer_cache,
step=step,
with_align=with_align)
if attn_align is not None:
attn_aligns.append(attn_align)
output = self.layer_norm(output)
dec_outs = output.transpose(0, 1).contiguous()
attn = attn.transpose(0, 1).contiguous()
attns = {"std": attn}
if self._copy:
attns["copy"] = attn
if with_align:
attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)`
# attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg
# TODO change the way attns is returned dict => list or tuple (onnx)
return dec_outs, attns
def _init_cache(self, memory_bank):
self.state["cache"] = {}
batch_size = memory_bank.size(1)
depth = memory_bank.size(-1)
for i, layer in enumerate(self.transformer_layers):
layer_cache = {"memory_keys": None, "memory_values": None}
if isinstance(layer.self_attn, AverageAttention):
layer_cache["prev_g"] = torch.zeros((batch_size, 1, depth),
device=memory_bank.device)
else:
layer_cache["self_keys"] = None
layer_cache["self_values"] = None
self.state["cache"]["layer_{}".format(i)] = layer_cache
def update_dropout(self, dropout, attention_dropout):
self.embeddings.update_dropout(dropout)
for layer in self.transformer_layers:
layer.update_dropout(dropout, attention_dropout)