|
from typing import * |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from attention import Attention |
|
from attention import ConcatScore |
|
|
|
Tensor = torch.Tensor |
|
|
|
|
|
class Encoder(nn.Module): |
|
"""Single layer recurrent bidirectional encoder.""""" |
|
|
|
def __init__(self, vocab_size: int, hidden_dim: int, pad_idx: int): |
|
super().__init__() |
|
self.embedding = nn.Sequential( |
|
OrderedDict( |
|
embedding=nn.Embedding(vocab_size, hidden_dim, padding_idx=pad_idx), |
|
dropout=nn.Dropout(p=0.33), |
|
) |
|
) |
|
self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True, bidirectional=True) |
|
self.fc = nn.Linear(2*hidden_dim, hidden_dim) |
|
self.initialize_parameters() |
|
|
|
def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]: |
|
"""Encode a sequence of tokens as a sequence of hidden states.""""" |
|
B, T = input.shape |
|
embedded = self.embedding(input) |
|
output, hidden = self.gru(embedded) |
|
hidden = torch.cat((hidden[0], hidden[1]), dim=-1) |
|
hidden = torch.tanh(self.fc(hidden)) |
|
return output, hidden.unsqueeze(0) |
|
|
|
@torch.no_grad() |
|
def initialize_parameters(self): |
|
"""Initialize linear weights uniformly, recurrent weights orthogonally, and bias to zero.""""" |
|
for name, parameters in self.named_parameters(): |
|
if "embedding" in name: |
|
nn.init.xavier_uniform_(parameters) |
|
elif "weight_ih" in name: |
|
w_ir, w_iz, w_in = torch.chunk(parameters, chunks=3, dim=0) |
|
nn.init.xavier_uniform_(w_ir) |
|
nn.init.xavier_uniform_(w_iz) |
|
nn.init.xavier_uniform_(w_in) |
|
elif "weight_hh" in name: |
|
w_hr, w_hz, w_hn = torch.chunk(parameters, chunks=3, dim=0) |
|
nn.init.orthogonal_(w_hr) |
|
nn.init.orthogonal_(w_hz) |
|
nn.init.orthogonal_(w_hn) |
|
elif "weight" in name: |
|
nn.init.xavier_uniform_(parameters) |
|
elif "bias" in name: |
|
nn.init.zeros_(parameters) |
|
|
|
|
|
class Decoder(nn.Module): |
|
"""Single layer recurrent decoder.""""" |
|
|
|
def __init__(self, vocab_size: int, hidden_dim: int, pad_idx: int, temperature: float = 1.0): |
|
super().__init__() |
|
self.embedding = nn.Sequential( |
|
OrderedDict( |
|
embedding=nn.Embedding(vocab_size, hidden_dim, padding_idx=pad_idx), |
|
dropout=nn.Dropout(p=0.33), |
|
) |
|
) |
|
self.attention = Attention(ConcatScore(hidden_dim), nn.Dropout(p=0.1)) |
|
self.gru = nn.GRU(3*hidden_dim, hidden_dim, batch_first=True) |
|
self.fc = nn.Sequential( |
|
OrderedDict( |
|
fc1=nn.Linear(4*hidden_dim, hidden_dim), |
|
layer_norm=nn.LayerNorm(hidden_dim), |
|
gelu=nn.GELU(), |
|
fc2=nn.Linear(hidden_dim, vocab_size, bias=False), |
|
) |
|
) |
|
self.fc.fc2.weight = self.embedding.embedding.weight |
|
self.temperature = temperature |
|
self.initialize_parameters() |
|
|
|
def forward(self, input: Tensor, hidden: Tensor, encoder_output: Tensor, source_mask: Tensor = None) -> Tuple[Tensor, Tensor, Tensor]: |
|
"""Predict the next token given an input token. Returns unnormalized predictions over the vocabulary.""""" |
|
B, = input.shape |
|
embedded = self.embedding(input.view(B, 1)) |
|
context, weights = self.attention(hidden.view(B, 1, -1), encoder_output, source_mask) |
|
output, hidden = self.gru(torch.cat((embedded, context), dim=-1), hidden) |
|
predictions = self.fc(torch.cat((embedded, context, output), dim=-1)) / self.temperature |
|
return predictions.view(B, -1), hidden, weights.view(B, -1) |
|
|
|
|
|
@torch.no_grad() |
|
def initialize_parameters(self): |
|
"""Initialize linear weights uniformly, recurrent weights orthogonally, and bias to zero.""""" |
|
for name, parameters in self.named_parameters(): |
|
if "norm" in name: |
|
continue |
|
elif "embedding" in name: |
|
nn.init.xavier_uniform_(parameters) |
|
elif "weight_ih" in name: |
|
w_ir, w_iz, w_in = torch.chunk(parameters, chunks=3, dim=0) |
|
nn.init.xavier_uniform_(w_ir) |
|
nn.init.xavier_uniform_(w_iz) |
|
nn.init.xavier_uniform_(w_in) |
|
elif "weight_hh" in name: |
|
w_hr, w_hz, w_hn = torch.chunk(parameters, chunks=3, dim=0) |
|
nn.init.orthogonal_(w_hr) |
|
nn.init.orthogonal_(w_hz) |
|
nn.init.orthogonal_(w_hn) |
|
elif "weight" in name: |
|
nn.init.xavier_uniform_(parameters) |
|
elif "bias" in name: |
|
nn.init.zeros_(parameters) |
|
|
|
|
|
class Seq2Seq(nn.Module): |
|
"""Seq2seq with attention.""""" |
|
|
|
def __init__(self, vocab_size: int, hidden_dim: int, bos_idx: int, eos_idx: int, pad_idx: int, teacher_forcing: float = 0.5, temperature: float = 1.0): |
|
super().__init__() |
|
self.encoder = Encoder(vocab_size, hidden_dim, pad_idx) |
|
self.decoder = Decoder(vocab_size, hidden_dim, pad_idx, temperature=temperature) |
|
self.bos_idx = bos_idx |
|
self.eos_idx = eos_idx |
|
self.pad_idx = pad_idx |
|
self.teacher_forcing = teacher_forcing |
|
|
|
def forward(self, source: Tensor, target: Tensor) -> Tensor: |
|
"""Forward pass at training time. Returns unnormalized predictions over the vocabulary.""""" |
|
(B, T), (B, L) = source.shape, target.shape |
|
encoder_output, hidden = self.encoder(source) |
|
decoder_input = torch.full((B,), self.bos_idx, device=source.device) |
|
source_mask = source == self.pad_idx |
|
|
|
output = [] |
|
for i in range(L): |
|
predictions, hidden, _ = self.decoder(decoder_input, hidden, encoder_output, source_mask) |
|
output.append(predictions) |
|
if self.training and random.random() < self.teacher_forcing: |
|
decoder_input = target[:,i] |
|
else: |
|
decoder_input = predictions.argmax(dim=1) |
|
return torch.stack(output, dim=1) |
|
|
|
@torch.inference_mode() |
|
def decode(self, source: Tensor, max_decode_length: int) -> Tuple[Tensor, Tensor]: |
|
"""Decode a single sequence at inference time. Returns output sequence and attention weights.""""" |
|
B, (T,) = 1, source.shape |
|
encoder_output, hidden = self.encoder(source.view(B, T)) |
|
decoder_input = torch.full((B,), self.bos_idx, device=source.device) |
|
|
|
output, attention = [], [] |
|
for i in range(max_decode_length): |
|
predictions, hidden, weights = self.decoder(decoder_input, hidden, encoder_output) |
|
output.append(predictions.argmax(dim=-1)) |
|
attention.append(weights) |
|
if output[i] == self.eos_idx: |
|
break |
|
else: |
|
decoder_input = output[i] |
|
return torch.cat(output, dim=0), torch.cat(attention, dim=0) |
|
|