|
""" |
|
Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" |
|
Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py |
|
""" |
|
|
|
import torch |
|
from torch.nn import Module, Dropout |
|
|
|
|
|
def elu_feature_map(x): |
|
return torch.nn.functional.elu(x) + 1 |
|
|
|
|
|
class LinearAttention(Module): |
|
def __init__(self, eps=1e-6): |
|
super().__init__() |
|
self.feature_map = elu_feature_map |
|
self.eps = eps |
|
|
|
def forward(self, queries, keys, values, q_mask=None, kv_mask=None): |
|
""" Multi-Head linear attention proposed in "Transformers are RNNs" |
|
Args: |
|
queries: [N, L, H, D] |
|
keys: [N, S, H, D] |
|
values: [N, S, H, D] |
|
q_mask: [N, L] |
|
kv_mask: [N, S] |
|
Returns: |
|
queried_values: (N, L, H, D) |
|
""" |
|
Q = self.feature_map(queries) |
|
K = self.feature_map(keys) |
|
|
|
|
|
if q_mask is not None: |
|
Q = Q * q_mask[:, :, None, None] |
|
if kv_mask is not None: |
|
K = K * kv_mask[:, :, None, None] |
|
values = values * kv_mask[:, :, None, None] |
|
|
|
v_length = values.size(1) |
|
values = values / v_length |
|
KV = torch.einsum("nshd,nshv->nhdv", K, values) |
|
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) |
|
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length |
|
|
|
return queried_values.contiguous() |
|
|
|
|
|
class FullAttention(Module): |
|
def __init__(self, use_dropout=False, attention_dropout=0.1): |
|
super().__init__() |
|
self.use_dropout = use_dropout |
|
self.dropout = Dropout(attention_dropout) |
|
|
|
def forward(self, queries, keys, values, q_mask=None, kv_mask=None): |
|
""" Multi-head scaled dot-product attention, a.k.a full attention. |
|
Args: |
|
queries: [N, L, H, D] |
|
keys: [N, S, H, D] |
|
values: [N, S, H, D] |
|
q_mask: [N, L] |
|
kv_mask: [N, S] |
|
Returns: |
|
queried_values: (N, L, H, D) |
|
""" |
|
|
|
|
|
QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) |
|
if kv_mask is not None: |
|
QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]).bool(), -1e9) |
|
|
|
|
|
softmax_temp = 1. / queries.size(3)**.5 |
|
A = torch.softmax(softmax_temp * QK, dim=2) |
|
if self.use_dropout: |
|
A = self.dropout(A) |
|
|
|
queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) |
|
|
|
return queried_values.contiguous() |
|
|