|
from transformers import PreTrainedModel |
|
|
|
from torch import nn |
|
import torch |
|
|
|
class BiLSTM(PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.hidden_dim = config.hidden_dim |
|
self.predict_output = config.predict_output |
|
|
|
self.embed_layer = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=0) |
|
self.biLSTM = nn.LSTM(input_size=config.embed_dim, |
|
hidden_size=config.hidden_dim // 2, |
|
num_layers=config.num_layers, |
|
bidirectional=True, |
|
batch_first=True) |
|
self.linear = nn.Linear(config.hidden_dim, config.output_dim) |
|
self.dropout = nn.Dropout(config.dropout) |
|
self.elu = nn.ELU() |
|
self.fc = nn.Linear(config.output_dim, config.predict_output) |
|
self.device_ = config.device |
|
|
|
def forward(self, input): |
|
x = self.embed_layer(input) |
|
batch_size = x.size(0) |
|
hidden, cell = self.init_hidden(batch_size) |
|
|
|
out, hidden = self.biLSTM(x, (hidden, cell)) |
|
|
|
out = self.dropout(out) |
|
|
|
out = self.elu(self.linear(out)) |
|
|
|
out = self.fc(out) |
|
|
|
return out, hidden |
|
|
|
def init_hidden(self, batch_size): |
|
hidden = torch.zeros(2, batch_size, self.hidden_dim//2, device=self.device_) |
|
cell = torch.zeros(2, batch_size, self.hidden_dim//2, device=self.device_) |
|
return hidden, cell |