from typing import * import torch import torch.nn as nn from attention import Attention from attention import ConcatScore Tensor = torch.Tensor class Encoder(nn.Module): """Single layer recurrent bidirectional encoder.""""" def __init__(self, vocab_size: int, hidden_dim: int, pad_idx: int): super().__init__() self.embedding = nn.Sequential( OrderedDict( embedding=nn.Embedding(vocab_size, hidden_dim, padding_idx=pad_idx), dropout=nn.Dropout(p=0.33), ) ) self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True, bidirectional=True) self.fc = nn.Linear(2*hidden_dim, hidden_dim) self.initialize_parameters() def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]: """Encode a sequence of tokens as a sequence of hidden states.""""" B, T = input.shape embedded = self.embedding(input) # (B, T, D) output, hidden = self.gru(embedded) # (B, T, 2*D), (2, B, D) hidden = torch.cat((hidden[0], hidden[1]), dim=-1) # (B, 2*D) hidden = torch.tanh(self.fc(hidden)) # (B, D) return output, hidden.unsqueeze(0) # (B, T, 2*D), (1, B, D) @torch.no_grad() def initialize_parameters(self): """Initialize linear weights uniformly, recurrent weights orthogonally, and bias to zero.""""" for name, parameters in self.named_parameters(): if "embedding" in name: nn.init.xavier_uniform_(parameters) elif "weight_ih" in name: w_ir, w_iz, w_in = torch.chunk(parameters, chunks=3, dim=0) nn.init.xavier_uniform_(w_ir) nn.init.xavier_uniform_(w_iz) nn.init.xavier_uniform_(w_in) elif "weight_hh" in name: w_hr, w_hz, w_hn = torch.chunk(parameters, chunks=3, dim=0) nn.init.orthogonal_(w_hr) nn.init.orthogonal_(w_hz) nn.init.orthogonal_(w_hn) elif "weight" in name: nn.init.xavier_uniform_(parameters) elif "bias" in name: nn.init.zeros_(parameters) class Decoder(nn.Module): """Single layer recurrent decoder.""""" def __init__(self, vocab_size: int, hidden_dim: int, pad_idx: int, temperature: float = 1.0): super().__init__() self.embedding = nn.Sequential( OrderedDict( embedding=nn.Embedding(vocab_size, hidden_dim, padding_idx=pad_idx), dropout=nn.Dropout(p=0.33), ) ) self.attention = Attention(ConcatScore(hidden_dim), nn.Dropout(p=0.1)) self.gru = nn.GRU(3*hidden_dim, hidden_dim, batch_first=True) self.fc = nn.Sequential( OrderedDict( fc1=nn.Linear(4*hidden_dim, hidden_dim), layer_norm=nn.LayerNorm(hidden_dim), gelu=nn.GELU(), fc2=nn.Linear(hidden_dim, vocab_size, bias=False), ) ) self.fc.fc2.weight = self.embedding.embedding.weight self.temperature = temperature self.initialize_parameters() def forward(self, input: Tensor, hidden: Tensor, encoder_output: Tensor, source_mask: Tensor = None) -> Tuple[Tensor, Tensor, Tensor]: """Predict the next token given an input token. Returns unnormalized predictions over the vocabulary.""""" B, = input.shape # L=1 embedded = self.embedding(input.view(B, 1)) # (B, 1, D) context, weights = self.attention(hidden.view(B, 1, -1), encoder_output, source_mask) # (B, 1, 2*D), (B, 1, T) output, hidden = self.gru(torch.cat((embedded, context), dim=-1), hidden) # (B, 1, D), (1, B, D) predictions = self.fc(torch.cat((embedded, context, output), dim=-1)) / self.temperature # (B, 1, V) return predictions.view(B, -1), hidden, weights.view(B, -1) # (B, V), (1, B, D), (B, T) @torch.no_grad() def initialize_parameters(self): """Initialize linear weights uniformly, recurrent weights orthogonally, and bias to zero.""""" for name, parameters in self.named_parameters(): if "norm" in name: continue elif "embedding" in name: nn.init.xavier_uniform_(parameters) elif "weight_ih" in name: w_ir, w_iz, w_in = torch.chunk(parameters, chunks=3, dim=0) nn.init.xavier_uniform_(w_ir) nn.init.xavier_uniform_(w_iz) nn.init.xavier_uniform_(w_in) elif "weight_hh" in name: w_hr, w_hz, w_hn = torch.chunk(parameters, chunks=3, dim=0) nn.init.orthogonal_(w_hr) nn.init.orthogonal_(w_hz) nn.init.orthogonal_(w_hn) elif "weight" in name: nn.init.xavier_uniform_(parameters) elif "bias" in name: nn.init.zeros_(parameters) class Seq2Seq(nn.Module): """Seq2seq with attention.""""" def __init__(self, vocab_size: int, hidden_dim: int, bos_idx: int, eos_idx: int, pad_idx: int, teacher_forcing: float = 0.5, temperature: float = 1.0): super().__init__() self.encoder = Encoder(vocab_size, hidden_dim, pad_idx) self.decoder = Decoder(vocab_size, hidden_dim, pad_idx, temperature=temperature) self.bos_idx = bos_idx self.eos_idx = eos_idx self.pad_idx = pad_idx self.teacher_forcing = teacher_forcing def forward(self, source: Tensor, target: Tensor) -> Tensor: """Forward pass at training time. Returns unnormalized predictions over the vocabulary.""""" (B, T), (B, L) = source.shape, target.shape encoder_output, hidden = self.encoder(source) # (B, T, D), (1, B, D) decoder_input = torch.full((B,), self.bos_idx, device=source.device) # (B,) source_mask = source == self.pad_idx # (B, 1, T) output = [] for i in range(L): predictions, hidden, _ = self.decoder(decoder_input, hidden, encoder_output, source_mask) # (B, V), (1, B, D) output.append(predictions) if self.training and random.random() < self.teacher_forcing: decoder_input = target[:,i] # (B,) else: decoder_input = predictions.argmax(dim=1) # (B,) return torch.stack(output, dim=1) # (B, L, V) @torch.inference_mode() def decode(self, source: Tensor, max_decode_length: int) -> Tuple[Tensor, Tensor]: """Decode a single sequence at inference time. Returns output sequence and attention weights.""""" B, (T,) = 1, source.shape encoder_output, hidden = self.encoder(source.view(B, T)) # (B, T, D), (B, 1, D) decoder_input = torch.full((B,), self.bos_idx, device=source.device) # (B,) output, attention = [], [] for i in range(max_decode_length): predictions, hidden, weights = self.decoder(decoder_input, hidden, encoder_output) # (B, V), (1, B, D), (B, T) output.append(predictions.argmax(dim=-1)) # (B,) attention.append(weights) # (B, T) if output[i] == self.eos_idx: break else: decoder_input = output[i] # (B,) return torch.cat(output, dim=0), torch.cat(attention, dim=0) # (L,), (L, T)