multi30k / models.py
msarmi9's picture
initial commit
8c7a320
from typing import *
import torch
import torch.nn as nn
from attention import Attention
from attention import ConcatScore
Tensor = torch.Tensor
class Encoder(nn.Module):
"""Single layer recurrent bidirectional encoder."""""
def __init__(self, vocab_size: int, hidden_dim: int, pad_idx: int):
super().__init__()
self.embedding = nn.Sequential(
OrderedDict(
embedding=nn.Embedding(vocab_size, hidden_dim, padding_idx=pad_idx),
dropout=nn.Dropout(p=0.33),
)
)
self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True, bidirectional=True)
self.fc = nn.Linear(2*hidden_dim, hidden_dim)
self.initialize_parameters()
def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
"""Encode a sequence of tokens as a sequence of hidden states."""""
B, T = input.shape
embedded = self.embedding(input) # (B, T, D)
output, hidden = self.gru(embedded) # (B, T, 2*D), (2, B, D)
hidden = torch.cat((hidden[0], hidden[1]), dim=-1) # (B, 2*D)
hidden = torch.tanh(self.fc(hidden)) # (B, D)
return output, hidden.unsqueeze(0) # (B, T, 2*D), (1, B, D)
@torch.no_grad()
def initialize_parameters(self):
"""Initialize linear weights uniformly, recurrent weights orthogonally, and bias to zero."""""
for name, parameters in self.named_parameters():
if "embedding" in name:
nn.init.xavier_uniform_(parameters)
elif "weight_ih" in name:
w_ir, w_iz, w_in = torch.chunk(parameters, chunks=3, dim=0)
nn.init.xavier_uniform_(w_ir)
nn.init.xavier_uniform_(w_iz)
nn.init.xavier_uniform_(w_in)
elif "weight_hh" in name:
w_hr, w_hz, w_hn = torch.chunk(parameters, chunks=3, dim=0)
nn.init.orthogonal_(w_hr)
nn.init.orthogonal_(w_hz)
nn.init.orthogonal_(w_hn)
elif "weight" in name:
nn.init.xavier_uniform_(parameters)
elif "bias" in name:
nn.init.zeros_(parameters)
class Decoder(nn.Module):
"""Single layer recurrent decoder."""""
def __init__(self, vocab_size: int, hidden_dim: int, pad_idx: int, temperature: float = 1.0):
super().__init__()
self.embedding = nn.Sequential(
OrderedDict(
embedding=nn.Embedding(vocab_size, hidden_dim, padding_idx=pad_idx),
dropout=nn.Dropout(p=0.33),
)
)
self.attention = Attention(ConcatScore(hidden_dim), nn.Dropout(p=0.1))
self.gru = nn.GRU(3*hidden_dim, hidden_dim, batch_first=True)
self.fc = nn.Sequential(
OrderedDict(
fc1=nn.Linear(4*hidden_dim, hidden_dim),
layer_norm=nn.LayerNorm(hidden_dim),
gelu=nn.GELU(),
fc2=nn.Linear(hidden_dim, vocab_size, bias=False),
)
)
self.fc.fc2.weight = self.embedding.embedding.weight
self.temperature = temperature
self.initialize_parameters()
def forward(self, input: Tensor, hidden: Tensor, encoder_output: Tensor, source_mask: Tensor = None) -> Tuple[Tensor, Tensor, Tensor]:
"""Predict the next token given an input token. Returns unnormalized predictions over the vocabulary."""""
B, = input.shape # L=1
embedded = self.embedding(input.view(B, 1)) # (B, 1, D)
context, weights = self.attention(hidden.view(B, 1, -1), encoder_output, source_mask) # (B, 1, 2*D), (B, 1, T)
output, hidden = self.gru(torch.cat((embedded, context), dim=-1), hidden) # (B, 1, D), (1, B, D)
predictions = self.fc(torch.cat((embedded, context, output), dim=-1)) / self.temperature # (B, 1, V)
return predictions.view(B, -1), hidden, weights.view(B, -1) # (B, V), (1, B, D), (B, T)
@torch.no_grad()
def initialize_parameters(self):
"""Initialize linear weights uniformly, recurrent weights orthogonally, and bias to zero."""""
for name, parameters in self.named_parameters():
if "norm" in name:
continue
elif "embedding" in name:
nn.init.xavier_uniform_(parameters)
elif "weight_ih" in name:
w_ir, w_iz, w_in = torch.chunk(parameters, chunks=3, dim=0)
nn.init.xavier_uniform_(w_ir)
nn.init.xavier_uniform_(w_iz)
nn.init.xavier_uniform_(w_in)
elif "weight_hh" in name:
w_hr, w_hz, w_hn = torch.chunk(parameters, chunks=3, dim=0)
nn.init.orthogonal_(w_hr)
nn.init.orthogonal_(w_hz)
nn.init.orthogonal_(w_hn)
elif "weight" in name:
nn.init.xavier_uniform_(parameters)
elif "bias" in name:
nn.init.zeros_(parameters)
class Seq2Seq(nn.Module):
"""Seq2seq with attention."""""
def __init__(self, vocab_size: int, hidden_dim: int, bos_idx: int, eos_idx: int, pad_idx: int, teacher_forcing: float = 0.5, temperature: float = 1.0):
super().__init__()
self.encoder = Encoder(vocab_size, hidden_dim, pad_idx)
self.decoder = Decoder(vocab_size, hidden_dim, pad_idx, temperature=temperature)
self.bos_idx = bos_idx
self.eos_idx = eos_idx
self.pad_idx = pad_idx
self.teacher_forcing = teacher_forcing
def forward(self, source: Tensor, target: Tensor) -> Tensor:
"""Forward pass at training time. Returns unnormalized predictions over the vocabulary."""""
(B, T), (B, L) = source.shape, target.shape
encoder_output, hidden = self.encoder(source) # (B, T, D), (1, B, D)
decoder_input = torch.full((B,), self.bos_idx, device=source.device) # (B,)
source_mask = source == self.pad_idx # (B, 1, T)
output = []
for i in range(L):
predictions, hidden, _ = self.decoder(decoder_input, hidden, encoder_output, source_mask) # (B, V), (1, B, D)
output.append(predictions)
if self.training and random.random() < self.teacher_forcing:
decoder_input = target[:,i] # (B,)
else:
decoder_input = predictions.argmax(dim=1) # (B,)
return torch.stack(output, dim=1) # (B, L, V)
@torch.inference_mode()
def decode(self, source: Tensor, max_decode_length: int) -> Tuple[Tensor, Tensor]:
"""Decode a single sequence at inference time. Returns output sequence and attention weights."""""
B, (T,) = 1, source.shape
encoder_output, hidden = self.encoder(source.view(B, T)) # (B, T, D), (B, 1, D)
decoder_input = torch.full((B,), self.bos_idx, device=source.device) # (B,)
output, attention = [], []
for i in range(max_decode_length):
predictions, hidden, weights = self.decoder(decoder_input, hidden, encoder_output) # (B, V), (1, B, D), (B, T)
output.append(predictions.argmax(dim=-1)) # (B,)
attention.append(weights) # (B, T)
if output[i] == self.eos_idx:
break
else:
decoder_input = output[i] # (B,)
return torch.cat(output, dim=0), torch.cat(attention, dim=0) # (L,), (L, T)