|
import streamlit as st |
|
import torch |
|
|
|
@st.cache |
|
def Model(): |
|
from transformers import DebertaTokenizer, DebertaForSequenceClassification |
|
tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base") |
|
model = DebertaForSequenceClassification.from_pretrained("microsoft/deberta-base", num_labels=8) |
|
bn_state_dict = torch.load('model_weights.pt') |
|
model.load_state_dict(bn_state_dict) |
|
return model, tokenizer |
|
|
|
def Predict(model, tokenizer, text): |
|
res = tokenizer(s, padding=True, truncation=True, return_tensors="pt", max_length=512) |
|
|
|
res = model(**res) |
|
logits = res.logits.softmax(dim=1) |
|
logits = logits.numpy()[0] |
|
return logits |
|
|
|
def Print(logits, dictionary): |
|
z = zip(logits, np.arange(0, 8)) |
|
z = sorted(z, key=lambda x: x[0], reverse=True) |
|
sum, idx = 0, 0 |
|
while sum < 0.95: |
|
st.markdown(f"{idx + 1}. ", dictionary[z[idx][1]]) |
|
sum += z[idx][0] |
|
idx += 1 |
|
|
|
def filter(title, abstract): |
|
return True |
|
|
|
st.title('Классификация статьи по названию и описанию') |
|
|
|
|
|
title = st.text_area("Введите название статьи:") |
|
|
|
abstract = st.text_area("Введите описание статьи:") |
|
|
|
|
|
text = title + '. ' + abstract |
|
dictionary = ['computer science', 'economics', 'Electrical Engineering and Systems Science', |
|
'math', 'physics', 'quantitative biology', 'quantitative finance', |
|
'statistics'] |
|
if filter(title, abstract): |
|
model, tokenizer = Model() |
|
logits = Predict(model, tokenizer, text) |
|
Print(logits, dictionary) |
|
|
|
|
|
|