|
from typing import Optional, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
|
|
class ScaledDotProductAttention(nn.Module): |
|
""" |
|
Scaled Dot-Product Attention (section 3.2.1) |
|
|
|
Args: |
|
- dim (int): dimension of d_k or d_head |
|
- dropout_p (float): probability of dropout |
|
|
|
Input: |
|
- query (batch, num_heads, seq_len, d_head) |
|
- key (batch, num_heads, seq_len, d_head) |
|
- value (batch, num_heads, seq_len, d_head) |
|
- mask () |
|
|
|
Output: |
|
- context (batch, num_head, seq_len, d_head): Context matrix. |
|
- attn (batch, num_head, seq_len, seq_len): Attention matrix for visualization. |
|
""" |
|
def __init__(self, dim: int, dropout_p: float) -> None: |
|
super(ScaledDotProductAttention, self).__init__() |
|
self.sqrt_dim = np.sqrt(dim) |
|
self.dropout = nn.Dropout(p = dropout_p) |
|
|
|
def forward( |
|
self, |
|
query: Tensor, |
|
key: Tensor, |
|
value: Tensor, |
|
mask: Optional[Tensor] = None, |
|
) -> Tuple[Tensor, Tensor]: |
|
|
|
|
|
|
|
score = torch.matmul(query, key.transpose(-2, -1)) / self.sqrt_dim |
|
|
|
if mask is not None: |
|
score.masked_fill_(mask == 0, -1e4) |
|
|
|
attn = F.softmax(score, -1) |
|
|
|
attn = self.dropout(attn) |
|
|
|
context = torch.matmul(attn, value) |
|
|
|
|
|
return context, attn |
|
|
|
|