|
import streamlit as st |
|
import torch |
|
import torch.nn as nn |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
class Net(nn.Module): |
|
def __init__(self): |
|
super(Net,self).__init__() |
|
self.layer = nn.Sequential( |
|
nn.Linear(768, 512), |
|
nn.ReLU(), |
|
nn.Linear(512, 256), |
|
nn.ReLU(), |
|
nn.Linear(256, 128), |
|
nn.ReLU(), |
|
nn.Linear(128, 8), |
|
) |
|
|
|
def forward(self,x): |
|
return self.layer(x) |
|
|
|
@st.cache |
|
def GetModel(): |
|
model = Net() |
|
model.load_state_dict(torch.load('model.dat', map_location=torch.device('cpu'))) |
|
return model |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def GetModelAndTokenizer(): |
|
model = GetModel() |
|
tokenizer = AutoTokenizer.from_pretrained("Callidior/bert2bert-base-arxiv-titlegen") |
|
model_emb = AutoModelForSeq2SeqLM.from_pretrained("Callidior/bert2bert-base-arxiv-titlegen") |
|
return model, tokenizer, model_emb |
|
|
|
def BuildAnswer(txt): |
|
def get_hidden_states(encoded, model): |
|
with torch.no_grad(): |
|
output = model(decoder_input_ids=encoded['input_ids'], output_hidden_states=True, **encoded) |
|
|
|
layers = [-4, -3, -2, -1] |
|
states = output['decoder_hidden_states'] |
|
output = torch.stack([states[i] for i in layers]).sum(0).squeeze() |
|
|
|
return output.mean(dim=0) |
|
|
|
def get_word_vector(sent, tokenizer, model): |
|
encoded = tokenizer.encode_plus(sent, return_tensors="pt", truncation=True) |
|
return get_hidden_states(encoded, model) |
|
|
|
labels_articles = { |
|
1: 'Computer Science', |
|
2: 'Economics', |
|
3: "Electrical Engineering And Systems Science", |
|
4: "Mathematics", |
|
5: "Physics", |
|
6: "Quantitative Biology", |
|
7: "Quantitative Finance", |
|
8: "Statistics" |
|
} |
|
|
|
txt = txt.strip() |
|
if txt == '': |
|
return '' |
|
|
|
model, tokenizer, model_emb = GetModelAndTokenizer() |
|
|
|
embed = get_word_vector(txt, tokenizer, model_emb) |
|
logits = torch.nn.functional.softmax(model(embed), dim=0) |
|
best_tags = torch.argsort(logits, descending=True) |
|
|
|
sum = 0 |
|
result = [] |
|
for tag in best_tags: |
|
if sum > 0.95: |
|
break |
|
sum += logits[tag.item()] |
|
res = round(float(logits[tag.item()].cpu()) * 100) |
|
label = labels_articles[tag.item() + 1] |
|
result.append(f'{res:3d}% - {label}') |
|
return result |
|
|
|
|
|
|
|
st.markdown("### Hello, world!") |
|
st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True) |
|
|
|
|
|
title = st.text_area("Title:") |
|
abstract = st.text_area("Abstract:", height=400) |
|
|
|
|
|
|
|
|
|
|
|
result = BuildAnswer(title + ' ' + abstract) |
|
|
|
for res in result: |
|
st.markdown(f"{res}") |