import torch import torch.nn as nn import json def attention(Q, K, V): d = K.shape[-1] QK = Q @ K.transpose(-2, -1) QK_d = QK / (d ** 0.5) weights = torch.softmax(QK_d, axis=-1) outputs = weights @ V return outputs class Attention(torch.nn.Module): def __init__(self, emb_dim, n_heads): super(Attention, self).__init__() self.emb_dim = emb_dim self.n_heads = n_heads def forward(self, X): batch_size, seq_len, emb_dim = X.size() # (batch_size, seq_len, emb_dim) n_heads = self.n_heads emb_dim_per_head = emb_dim // n_heads assert emb_dim == self.emb_dim assert emb_dim_per_head * n_heads == emb_dim X = X.transpose(1, 2) output = attention(X, X, X) # (batch_size, n_heads, seq_len, emb_dim_per_head) output = output.transpose(1, 2) # (batch_size, seq_len, n_heads, emb_dim_per_head) output = output.contiguous().view(batch_size, seq_len, emb_dim) # (batch_size, seq_len, emb_dim) return output class ClassifierAttention(nn.Module): def __init__(self, vocab_size, emb_dim, padding_idx, hidden_size, n_layers, attention_heads, hidden_layer_units, dropout): super(ClassifierAttention, self).__init__() self.embedding = nn.Embedding( num_embeddings = vocab_size, embedding_dim = emb_dim, padding_idx = padding_idx ) self.rnn_1 = nn.LSTM( emb_dim, hidden_size, n_layers, bidirectional = False, batch_first = True, ) self.attention = Attention(hidden_size, attention_heads) self.rnn_2 = nn.LSTM( hidden_size, hidden_size, n_layers, bidirectional = False, batch_first = True, ) self.dropout = nn.Dropout(dropout) hidden_layer_units = [hidden_size, *hidden_layer_units] self.hidden_layers = nn.ModuleList([]) for in_unit, out_unit in zip(hidden_layer_units[:-1], hidden_layer_units[1:]): self.hidden_layers.append(nn.Linear(in_unit, out_unit)) self.hidden_layers.append(nn.ReLU()) self.hidden_layers.append(self.dropout) self.hidden_layers.append(nn.Linear(hidden_layer_units[-1], 1)) self.sigmoid = nn.Sigmoid() def forward(self, x): # x: (batch_size, seq_len) out = self.embedding(x) # (batch_size, seq_len, emb_dim) out, (hidden_state, cell_state) = self.rnn_1(out) out = self.attention(out) # (batch_size, seq_len, emb_dim) out = self.dropout(out) output, (hidden_state, cell_state) = self.rnn_2(out) out = hidden_state[-1] # (batch_size, hidden_size) out = self.dropout(out) # (batch_size, seq_len, hidden_dim) # (n_layers*n_direction, batch_size, hidden_size) # (n_layers*n_direction, batch_size, hidden_size) for layer in self.hidden_layers: out = layer(out) out = self.sigmoid(out) # (batch_size, 1) out = out.squeeze(-1) # (batch_size) return out def get_model(model_path, params_path): with open(params_path, 'rb') as f: params = json.load(f) model = ClassifierAttention(*params) model.load_state_dict(torch.load(model_path)) model.eval() return model