from transformers import PreTrainedModel from torch import nn import torch class BiLSTM(PreTrainedModel): def __init__(self, config): super().__init__(config) self.hidden_dim = config.hidden_dim self.predict_output = config.predict_output self.embed_layer = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=0) self.biLSTM = nn.LSTM(input_size=config.embed_dim, hidden_size=config.hidden_dim // 2, # BiLSTM will concatenate the 2 directional LSTMs num_layers=config.num_layers, bidirectional=True, batch_first=True) self.linear = nn.Linear(config.hidden_dim, config.output_dim) self.dropout = nn.Dropout(config.dropout) self.elu = nn.ELU() self.fc = nn.Linear(config.output_dim, config.predict_output) self.device_ = config.device def forward(self, input): # input is a list of indices, shape batch_size, seq_len x = self.embed_layer(input) # batch_size, seq_len, 100 (This is only when batch_first=True!!!!) batch_size = x.size(0) hidden, cell = self.init_hidden(batch_size) out, hidden = self.biLSTM(x, (hidden, cell)) # seq_len, batch_size, (hidden_dim//2) * 2 out = self.dropout(out) out = self.elu(self.linear(out)) # self.linear(out): batch_size, seq_len, output_dim out = self.fc(out) return out, hidden def init_hidden(self, batch_size): hidden = torch.zeros(2, batch_size, self.hidden_dim//2, device=self.device_) cell = torch.zeros(2, batch_size, self.hidden_dim//2, device=self.device_) return hidden, cell