tts-rvc-autopst / onmt_modules /multi_headed_attn.py
jonathanjordan21's picture
67809715652a92b22870c50ad30f6ff38e292006aedc75ddbdc828aa856ef68f
c021d8e verified
raw
history blame
8.12 kB
""" Multi-Head Attention module """
import math
import torch
import torch.nn as nn
from .misc import generate_relative_positions_matrix,\
relative_matmul
# from onmt.utils.misc import aeq
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention module from "Attention is All You Need"
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`.
Similar to standard `dot` attention but uses
multiple attention distributions simulataneously
to select relevant items.
.. mermaid::
graph BT
A[key]
B[value]
C[query]
O[output]
subgraph Attn
D[Attn 1]
E[Attn 2]
F[Attn N]
end
A --> D
C --> D
A --> E
C --> E
A --> F
C --> F
D --> O
E --> O
F --> O
B --> O
Also includes several additional tricks.
Args:
head_count (int): number of parallel heads
model_dim (int): the dimension of keys/values/queries,
must be divisible by head_count
dropout (float): dropout parameter
"""
def __init__(self, head_count, model_dim, dropout=0.1,
max_relative_positions=0):
assert model_dim % head_count == 0
self.dim_per_head = model_dim // head_count
self.model_dim = model_dim
super(MultiHeadedAttention, self).__init__()
self.head_count = head_count
self.linear_keys = nn.Linear(model_dim,
head_count * self.dim_per_head)
self.linear_values = nn.Linear(model_dim,
head_count * self.dim_per_head)
self.linear_query = nn.Linear(model_dim,
head_count * self.dim_per_head)
self.softmax = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.final_linear = nn.Linear(model_dim, model_dim)
self.max_relative_positions = max_relative_positions
if max_relative_positions > 0:
vocab_size = max_relative_positions * 2 + 1
self.relative_positions_embeddings = nn.Embedding(
vocab_size, self.dim_per_head)
def forward(self, key, value, query, mask=None,
layer_cache=None, attn_type=None):
"""
Compute the context vector and the attention vectors.
Args:
key (FloatTensor): set of `key_len`
key vectors ``(batch, key_len, dim)``
value (FloatTensor): set of `key_len`
value vectors ``(batch, key_len, dim)``
query (FloatTensor): set of `query_len`
query vectors ``(batch, query_len, dim)``
mask: binary mask 1/0 indicating which keys have
zero / non-zero attention ``(batch, query_len, key_len)``
Returns:
(FloatTensor, FloatTensor):
* output context vectors ``(batch, query_len, dim)``
* Attention vector in heads ``(batch, head, query_len, key_len)``.
"""
# CHECKS
# batch, k_len, d = key.size()
# batch_, k_len_, d_ = value.size()
# aeq(batch, batch_)
# aeq(k_len, k_len_)
# aeq(d, d_)
# batch_, q_len, d_ = query.size()
# aeq(batch, batch_)
# aeq(d, d_)
# aeq(self.model_dim % 8, 0)
# if mask is not None:
# batch_, q_len_, k_len_ = mask.size()
# aeq(batch_, batch)
# aeq(k_len_, k_len)
# aeq(q_len_ == q_len)
# END CHECKS
batch_size = key.size(0)
dim_per_head = self.dim_per_head
head_count = self.head_count
key_len = key.size(1)
query_len = query.size(1)
def shape(x):
"""Projection."""
return x.view(batch_size, -1, head_count, dim_per_head) \
.transpose(1, 2)
def unshape(x):
"""Compute context."""
return x.transpose(1, 2).contiguous() \
.view(batch_size, -1, head_count * dim_per_head)
# 1) Project key, value, and query.
if layer_cache is not None:
if attn_type == "self":
query, key, value = self.linear_query(query),\
self.linear_keys(query),\
self.linear_values(query)
key = shape(key)
value = shape(value)
if layer_cache["self_keys"] is not None:
key = torch.cat(
(layer_cache["self_keys"], key),
dim=2)
if layer_cache["self_values"] is not None:
value = torch.cat(
(layer_cache["self_values"], value),
dim=2)
layer_cache["self_keys"] = key
layer_cache["self_values"] = value
elif attn_type == "context":
query = self.linear_query(query)
if layer_cache["memory_keys"] is None:
key, value = self.linear_keys(key),\
self.linear_values(value)
key = shape(key)
value = shape(value)
else:
key, value = layer_cache["memory_keys"],\
layer_cache["memory_values"]
layer_cache["memory_keys"] = key
layer_cache["memory_values"] = value
else:
key = self.linear_keys(key)
value = self.linear_values(value)
query = self.linear_query(query)
key = shape(key)
value = shape(value)
if self.max_relative_positions > 0 and attn_type == "self":
key_len = key.size(2)
# 1 or key_len x key_len
relative_positions_matrix = generate_relative_positions_matrix(
key_len, self.max_relative_positions,
cache=True if layer_cache is not None else False)
# 1 or key_len x key_len x dim_per_head
relations_keys = self.relative_positions_embeddings(
relative_positions_matrix.to(key.device))
# 1 or key_len x key_len x dim_per_head
relations_values = self.relative_positions_embeddings(
relative_positions_matrix.to(key.device))
query = shape(query)
key_len = key.size(2)
query_len = query.size(2)
# 2) Calculate and scale scores.
query = query / math.sqrt(dim_per_head)
# batch x num_heads x query_len x key_len
query_key = torch.matmul(query, key.transpose(2, 3))
if self.max_relative_positions > 0 and attn_type == "self":
scores = query_key + relative_matmul(query, relations_keys, True)
else:
scores = query_key
scores = scores.float()
if mask is not None:
mask = mask.unsqueeze(1) # [B, 1, 1, T_values]
scores = scores.masked_fill(mask, -1e18)
# 3) Apply attention dropout and compute context vectors.
attn = self.softmax(scores).to(query.dtype)
drop_attn = self.dropout(attn)
context_original = torch.matmul(drop_attn, value)
if self.max_relative_positions > 0 and attn_type == "self":
context = unshape(context_original
+ relative_matmul(drop_attn,
relations_values,
False))
else:
context = unshape(context_original)
output = self.final_linear(context)
# CHECK
# batch_, q_len_, d_ = output.size()
# aeq(q_len, q_len_)
# aeq(batch, batch_)
# aeq(d, d_)
# Return multi-head attn
attns = attn \
.view(batch_size, head_count,
query_len, key_len)
return output, attns
def update_dropout(self, dropout):
self.dropout.p = dropout