Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
from torch import nn | |
from models.preprocess_stage.preprocess_lstm import preprocess_lstm | |
EMBEDDING_DIM = 128 | |
HIDDEN_SIZE = 16 | |
MAX_LEN = 125 | |
# DEVICE='cpu' | |
embedding_matrix = np.load('models/datasets/embedding_matrix.npy') | |
embedding_layer = nn.Embedding.from_pretrained(torch.FloatTensor(embedding_matrix)) | |
class AtenttionTest(nn.Module): | |
def __init__(self, hidden_size=HIDDEN_SIZE): | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.fc1 = nn.Linear(self.hidden_size, self.hidden_size) | |
self.fc2 = nn.Linear(self.hidden_size, self.hidden_size) | |
self.tahn = nn.Tanh() | |
self.fc3 = nn.Linear(self.hidden_size, 1) | |
def forward(self, outputs_lmst, h_n): | |
output_fc1 = self.fc1(outputs_lmst) | |
output_fc2 = self.fc2(h_n.squeeze(0)) | |
fc1_fc2_cat = output_fc1 + output_fc2.unsqueeze(1) | |
output_tahn = self.tahn(fc1_fc2_cat) | |
attention_weights = torch.softmax(self.fc3(output_tahn).squeeze(2), dim=1) | |
output_finished = torch.bmm(output_fc1.transpose(1, 2), attention_weights.unsqueeze(2)) | |
return output_finished, attention_weights | |
class LSTMnn(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.embedding = embedding_layer | |
self.lstm = nn.LSTM( | |
input_size=EMBEDDING_DIM, | |
hidden_size=HIDDEN_SIZE, | |
num_layers=1, | |
batch_first=True | |
) | |
self.attention = AtenttionTest(hidden_size=HIDDEN_SIZE) | |
self.fc_out = nn.Sequential( | |
nn.Linear(HIDDEN_SIZE, 128), | |
nn.Dropout(), | |
nn.Tanh(), | |
nn.Linear(128, 1) | |
) | |
def forward(self, x): | |
embedding = self.embedding(x) | |
output_lstm, (h_n, _) = self.lstm(embedding) | |
output_attention, attention_weights = self.attention(output_lstm, h_n) | |
output_finished = self.fc_out(output_attention.squeeze(2)) | |
return torch.sigmoid(output_finished), attention_weights | |
model = LSTMnn() | |
model.load_state_dict(torch.load('models/weights/LSTMBestWeights.pt', map_location=torch.device('cpu'))) | |
def predict_3(text): | |
preprocessed_text = preprocess_lstm(text, MAX_LEN=MAX_LEN) | |
# model.to(DEVICE) | |
model.eval() | |
predict, attention = model(torch.tensor(preprocessed_text).unsqueeze(0)) | |
predict = round(predict.item()) | |
return predict |