File size: 2,290 Bytes
78f91a7 0d3225e ecac71c 78f91a7 049c790 51f97f8 3de3a79 5ca6967 51f97f8 3de3a79 02af193 3de3a79 fc3e0b1 3de3a79 795b9e0 3de3a79 fc3e0b1 c2990c1 fc3e0b1 3de3a79 bb61482 37e5cd3 ba3edc4 3de3a79 02af193 78f91a7 8cb1d65 78f91a7 3de3a79 2460eee 42fda87 2460eee 78f91a7 7d57f72 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import streamlit as st
import torch
import numpy as np
@st.cache(allow_output_mutation=True)
def Model():
from transformers import DebertaTokenizer, DebertaForSequenceClassification
tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
model = DebertaForSequenceClassification.from_pretrained("microsoft/deberta-base", num_labels=8)
bn_state_dict = torch.load('model_weights.pt', map_location=torch.device('cpu'))
model.load_state_dict(bn_state_dict)
return model, tokenizer
def Predict(model, tokenizer, text):
res = tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=512)
res = model(**res)
logits = res.logits.softmax(dim=1)
logits = logits.detach().numpy()[0]#.cpu().detach().numpy()[0]
return logits
def Print(logits, dictionary):
z = zip(logits, np.arange(0, 8))
z = sorted(z, key=lambda x: x[0], reverse=True)
summ, idx = 0, 0
while summ < 0.95:
string = dictionary[z[idx][1]]
st.markdown(f"{idx + 1}. {string}")
summ += z[idx][0]
idx += 1
def filter(title, abstract):
if len(title) < 10 or (len(abstract) > 0 and len(abstract) < 20):
st.markdown("Хммм... Вы точно не ошиблись? Слишком маленькое название или описание.")
return False
return True
st.title('Классификация статьи по названию и описанию')
# ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
title = st.text_area("Введите название статьи:")
abstract = st.text_area("Введите описание статьи:")
# ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
text = title + '. ' + abstract
dictionary = ['computer science', 'economics', 'Electrical Engineering and Systems Science',
'math', 'physics', 'quantitative biology', 'quantitative finance',
'statistics']
if filter(title, abstract):
model, tokenizer = Model()
logits = Predict(model, tokenizer, text)
st.header("Топ 95%:")
Print(logits, dictionary)
|