File size: 1,942 Bytes
78f91a7
0d3225e
78f91a7
51f97f8
 
3de3a79
 
 
 
51f97f8
3de3a79
02af193
3de3a79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02af193
78f91a7
 
8cb1d65
 
 
78f91a7
 
3de3a79
 
 
 
 
 
 
 
78f91a7
7d57f72
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import streamlit as st
import torch

@st.cache
def Model():
  from transformers import DebertaTokenizer, DebertaForSequenceClassification
  tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
  model = DebertaForSequenceClassification.from_pretrained("microsoft/deberta-base", num_labels=8)
  bn_state_dict = torch.load('model_weights.pt')
  model.load_state_dict(bn_state_dict)
  return model, tokenizer

def Predict(model, tokenizer, text):
  res = tokenizer(s, padding=True, truncation=True, return_tensors="pt", max_length=512)
  #var.to("cuda:0")
  res = model(**res)
  logits = res.logits.softmax(dim=1)
  logits = logits.numpy()[0]#logits.cpu().detach().numpy()[0]
  return logits
  
def Print(logits, dictionary):
  z = zip(logits, np.arange(0, 8))
  z = sorted(z, key=lambda x: x[0], reverse=True)
  sum, idx = 0, 0
  while sum < 0.95:
    st.markdown(f"{idx + 1}. ", dictionary[z[idx][1]])
    sum += z[idx][0]
    idx += 1

def filter(title, abstract):
  return True
 
st.title('Классификация статьи по названию и описанию')
# ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter

title = st.text_area("Введите название статьи:")

abstract = st.text_area("Введите описание статьи:")
# ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент

text = title + '. ' + abstract
dictionary = ['computer science', 'economics', 'Electrical Engineering and Systems Science', 
              'math', 'physics', 'quantitative biology', 'quantitative finance',
              'statistics']
if filter(title, abstract):
  model, tokenizer = Model()
  logits = Predict(model, tokenizer, text)
  Print(logits, dictionary)