|
from typing import * |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
Tensor = torch.Tensor |
|
|
|
|
|
class Attention(nn.Module): |
|
"""Container for applying an attention scoring function.""""" |
|
|
|
def __init__(self, score: nn.Module, dropout: nn.Module = None): |
|
super().__init__() |
|
self.score = score |
|
self.dropout = dropout |
|
|
|
def forward(self, decoder_state: Tensor, encoder_state: Tensor, source_mask: Tensor = None) -> Tuple[Tensor, Tensor]: |
|
"""Return context and attention weights. Accepts a boolean mask indicating padding in the source sequence.""""" |
|
(B, L, D), (B, T, _) = decoder_state.shape, encoder_state.shape |
|
scores = self.score(decoder_state, encoder_state) |
|
if source_mask is not None: |
|
scores.masked_fill_(source_mask.view(B, 1, T), -1e4) |
|
weights = F.softmax(scores, dim=-1) |
|
if self.dropout is not None: |
|
weights = self.dropout(weights) |
|
context = weights @ encoder_state |
|
return context, weights |
|
|
|
|
|
class ConcatScore(nn.Module): |
|
"""A two layer network as an attention scoring function. Expects bidirectional encoder.""""" |
|
|
|
def __init__(self, d: int): |
|
super().__init__() |
|
self.w = nn.Linear(3*d, d) |
|
self.v = nn.Linear(d, 1, bias=False) |
|
self.initialize_parameters() |
|
|
|
def forward(self, decoder_state: Tensor, encoder_state: Tensor) -> Tensor: |
|
"""Return attention scores.""""" |
|
(B, L, D), (B, T, _) = decoder_state.shape, encoder_state.shape |
|
decoder_state = decoder_state.repeat_interleave(T, dim=1) |
|
encoder_state = encoder_state.repeat(1, L, 1) |
|
concatenated = torch.cat((decoder_state, encoder_state), dim=-1) |
|
scores = self.v(torch.tanh(self.w(concatenated))) |
|
return scores.view(B, L, T) |
|
|
|
@torch.no_grad() |
|
def initialize_parameters(self): |
|
nn.init.xavier_uniform_(self.w.weight) |
|
nn.init.xavier_uniform_(self.v.weight, gain=nn.init.calculate_gain("tanh")) |
|
nn.init.zeros_(self.w.bias) |
|
|