File size: 7,905 Bytes
8c7a320 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
from typing import *
import torch
import torch.nn as nn
from attention import Attention
from attention import ConcatScore
Tensor = torch.Tensor
class Encoder(nn.Module):
"""Single layer recurrent bidirectional encoder."""""
def __init__(self, vocab_size: int, hidden_dim: int, pad_idx: int):
super().__init__()
self.embedding = nn.Sequential(
OrderedDict(
embedding=nn.Embedding(vocab_size, hidden_dim, padding_idx=pad_idx),
dropout=nn.Dropout(p=0.33),
)
)
self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True, bidirectional=True)
self.fc = nn.Linear(2*hidden_dim, hidden_dim)
self.initialize_parameters()
def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
"""Encode a sequence of tokens as a sequence of hidden states."""""
B, T = input.shape
embedded = self.embedding(input) # (B, T, D)
output, hidden = self.gru(embedded) # (B, T, 2*D), (2, B, D)
hidden = torch.cat((hidden[0], hidden[1]), dim=-1) # (B, 2*D)
hidden = torch.tanh(self.fc(hidden)) # (B, D)
return output, hidden.unsqueeze(0) # (B, T, 2*D), (1, B, D)
@torch.no_grad()
def initialize_parameters(self):
"""Initialize linear weights uniformly, recurrent weights orthogonally, and bias to zero."""""
for name, parameters in self.named_parameters():
if "embedding" in name:
nn.init.xavier_uniform_(parameters)
elif "weight_ih" in name:
w_ir, w_iz, w_in = torch.chunk(parameters, chunks=3, dim=0)
nn.init.xavier_uniform_(w_ir)
nn.init.xavier_uniform_(w_iz)
nn.init.xavier_uniform_(w_in)
elif "weight_hh" in name:
w_hr, w_hz, w_hn = torch.chunk(parameters, chunks=3, dim=0)
nn.init.orthogonal_(w_hr)
nn.init.orthogonal_(w_hz)
nn.init.orthogonal_(w_hn)
elif "weight" in name:
nn.init.xavier_uniform_(parameters)
elif "bias" in name:
nn.init.zeros_(parameters)
class Decoder(nn.Module):
"""Single layer recurrent decoder."""""
def __init__(self, vocab_size: int, hidden_dim: int, pad_idx: int, temperature: float = 1.0):
super().__init__()
self.embedding = nn.Sequential(
OrderedDict(
embedding=nn.Embedding(vocab_size, hidden_dim, padding_idx=pad_idx),
dropout=nn.Dropout(p=0.33),
)
)
self.attention = Attention(ConcatScore(hidden_dim), nn.Dropout(p=0.1))
self.gru = nn.GRU(3*hidden_dim, hidden_dim, batch_first=True)
self.fc = nn.Sequential(
OrderedDict(
fc1=nn.Linear(4*hidden_dim, hidden_dim),
layer_norm=nn.LayerNorm(hidden_dim),
gelu=nn.GELU(),
fc2=nn.Linear(hidden_dim, vocab_size, bias=False),
)
)
self.fc.fc2.weight = self.embedding.embedding.weight
self.temperature = temperature
self.initialize_parameters()
def forward(self, input: Tensor, hidden: Tensor, encoder_output: Tensor, source_mask: Tensor = None) -> Tuple[Tensor, Tensor, Tensor]:
"""Predict the next token given an input token. Returns unnormalized predictions over the vocabulary."""""
B, = input.shape # L=1
embedded = self.embedding(input.view(B, 1)) # (B, 1, D)
context, weights = self.attention(hidden.view(B, 1, -1), encoder_output, source_mask) # (B, 1, 2*D), (B, 1, T)
output, hidden = self.gru(torch.cat((embedded, context), dim=-1), hidden) # (B, 1, D), (1, B, D)
predictions = self.fc(torch.cat((embedded, context, output), dim=-1)) / self.temperature # (B, 1, V)
return predictions.view(B, -1), hidden, weights.view(B, -1) # (B, V), (1, B, D), (B, T)
@torch.no_grad()
def initialize_parameters(self):
"""Initialize linear weights uniformly, recurrent weights orthogonally, and bias to zero."""""
for name, parameters in self.named_parameters():
if "norm" in name:
continue
elif "embedding" in name:
nn.init.xavier_uniform_(parameters)
elif "weight_ih" in name:
w_ir, w_iz, w_in = torch.chunk(parameters, chunks=3, dim=0)
nn.init.xavier_uniform_(w_ir)
nn.init.xavier_uniform_(w_iz)
nn.init.xavier_uniform_(w_in)
elif "weight_hh" in name:
w_hr, w_hz, w_hn = torch.chunk(parameters, chunks=3, dim=0)
nn.init.orthogonal_(w_hr)
nn.init.orthogonal_(w_hz)
nn.init.orthogonal_(w_hn)
elif "weight" in name:
nn.init.xavier_uniform_(parameters)
elif "bias" in name:
nn.init.zeros_(parameters)
class Seq2Seq(nn.Module):
"""Seq2seq with attention."""""
def __init__(self, vocab_size: int, hidden_dim: int, bos_idx: int, eos_idx: int, pad_idx: int, teacher_forcing: float = 0.5, temperature: float = 1.0):
super().__init__()
self.encoder = Encoder(vocab_size, hidden_dim, pad_idx)
self.decoder = Decoder(vocab_size, hidden_dim, pad_idx, temperature=temperature)
self.bos_idx = bos_idx
self.eos_idx = eos_idx
self.pad_idx = pad_idx
self.teacher_forcing = teacher_forcing
def forward(self, source: Tensor, target: Tensor) -> Tensor:
"""Forward pass at training time. Returns unnormalized predictions over the vocabulary."""""
(B, T), (B, L) = source.shape, target.shape
encoder_output, hidden = self.encoder(source) # (B, T, D), (1, B, D)
decoder_input = torch.full((B,), self.bos_idx, device=source.device) # (B,)
source_mask = source == self.pad_idx # (B, 1, T)
output = []
for i in range(L):
predictions, hidden, _ = self.decoder(decoder_input, hidden, encoder_output, source_mask) # (B, V), (1, B, D)
output.append(predictions)
if self.training and random.random() < self.teacher_forcing:
decoder_input = target[:,i] # (B,)
else:
decoder_input = predictions.argmax(dim=1) # (B,)
return torch.stack(output, dim=1) # (B, L, V)
@torch.inference_mode()
def decode(self, source: Tensor, max_decode_length: int) -> Tuple[Tensor, Tensor]:
"""Decode a single sequence at inference time. Returns output sequence and attention weights."""""
B, (T,) = 1, source.shape
encoder_output, hidden = self.encoder(source.view(B, T)) # (B, T, D), (B, 1, D)
decoder_input = torch.full((B,), self.bos_idx, device=source.device) # (B,)
output, attention = [], []
for i in range(max_decode_length):
predictions, hidden, weights = self.decoder(decoder_input, hidden, encoder_output) # (B, V), (1, B, D), (B, T)
output.append(predictions.argmax(dim=-1)) # (B,)
attention.append(weights) # (B, T)
if output[i] == self.eos_idx:
break
else:
decoder_input = output[i] # (B,)
return torch.cat(output, dim=0), torch.cat(attention, dim=0) # (L,), (L, T)
|