File size: 2,290 Bytes
78f91a7
0d3225e
ecac71c
78f91a7
049c790
51f97f8
3de3a79
 
 
5ca6967
51f97f8
3de3a79
02af193
3de3a79
fc3e0b1
3de3a79
 
795b9e0
3de3a79
 
 
 
 
fc3e0b1
 
c2990c1
 
fc3e0b1
3de3a79
 
 
bb61482
37e5cd3
ba3edc4
3de3a79
 
02af193
78f91a7
 
8cb1d65
 
 
78f91a7
 
3de3a79
 
 
 
 
 
2460eee
42fda87
2460eee
78f91a7
7d57f72
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import streamlit as st
import torch
import numpy as np

@st.cache(allow_output_mutation=True)
def Model():
  from transformers import DebertaTokenizer, DebertaForSequenceClassification
  tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
  model = DebertaForSequenceClassification.from_pretrained("microsoft/deberta-base", num_labels=8)
  bn_state_dict = torch.load('model_weights.pt', map_location=torch.device('cpu'))
  model.load_state_dict(bn_state_dict)
  return model, tokenizer

def Predict(model, tokenizer, text):
  res = tokenizer(text, padding=True, truncation=True, return_tensors="pt", max_length=512)
  res = model(**res)
  logits = res.logits.softmax(dim=1)
  logits = logits.detach().numpy()[0]#.cpu().detach().numpy()[0]
  return logits
  
def Print(logits, dictionary):
  z = zip(logits, np.arange(0, 8))
  z = sorted(z, key=lambda x: x[0], reverse=True)
  summ, idx = 0, 0
  while summ < 0.95:
    string = dictionary[z[idx][1]]
    st.markdown(f"{idx + 1}. {string}")
    summ += z[idx][0]
    idx += 1

def filter(title, abstract):
  if len(title) < 10 or (len(abstract) > 0 and len(abstract) < 20):
    st.markdown("Хммм... Вы точно не ошиблись? Слишком маленькое название или описание.")
    return False
  return True
 
st.title('Классификация статьи по названию и описанию')
# ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter

title = st.text_area("Введите название статьи:")

abstract = st.text_area("Введите описание статьи:")
# ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент

text = title + '. ' + abstract
dictionary = ['computer science', 'economics', 'Electrical Engineering and Systems Science', 
              'math', 'physics', 'quantitative biology', 'quantitative finance',
              'statistics']
if filter(title, abstract):
  model, tokenizer = Model()
  logits = Predict(model, tokenizer, text)
  st.header("Топ 95%:")
  Print(logits, dictionary)