arkmartov commited on
Commit
3de3a79
1 Parent(s): 5ea5531

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -14
app.py CHANGED
@@ -3,14 +3,33 @@ import torch
3
 
4
  @st.cache
5
  def Model():
6
- from transformers import BertTokenizer, BertForSequenceClassification
7
- model_name = "google/bert_uncased_L-4_H-256_A-4"
8
- tokenizer = BertTokenizer.from_pretrained(model_name)
9
- model = BertForSequenceClassification.from_pretrained(model_name, num_labels=8)
10
- bn_state_dict = torch.load('model_w.pt')
11
  model.load_state_dict(bn_state_dict)
12
- return model
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  st.title('Классификация статьи по названию и описанию')
15
  # ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
16
 
@@ -19,13 +38,13 @@ title = st.text_area("Введите название статьи:")
19
  abstract = st.text_area("Введите описание статьи:")
20
  # ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
21
 
22
- from transformers import pipeline
23
- pipe = pipeline("ner", "Davlan/distilbert-base-multilingual-cased-ner-hrl")
24
- raw_predictions = pipe(title)
25
- # тут уже знакомый вам код с huggingface.transformers -- его можно заменить на что угодно от fairseq до catboost
 
 
 
 
26
 
27
- st.markdown(f"{raw_predictions}")
28
 
29
- model = Model()
30
- st.markdown(f"{model}")
31
- # выводим результаты модели в текстовое поле, на потеху пользователю
 
3
 
4
  @st.cache
5
  def Model():
6
+ from transformers import DebertaTokenizer, DebertaForSequenceClassification
7
+ tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")
8
+ model = DebertaForSequenceClassification.from_pretrained("microsoft/deberta-base", num_labels=8)
9
+ bn_state_dict = torch.load('model_weights.pt')
 
10
  model.load_state_dict(bn_state_dict)
11
+ return model, tokenizer
12
 
13
+ def Predict(model, tokenizer, text):
14
+ res = tokenizer(s, padding=True, truncation=True, return_tensors="pt", max_length=512)
15
+ #var.to("cuda:0")
16
+ res = model(**res)
17
+ logits = res.logits.softmax(dim=1)
18
+ logits = logits.numpy()[0]#logits.cpu().detach().numpy()[0]
19
+ return logits
20
+
21
+ def Print(logits, dictionary):
22
+ z = zip(logits, np.arange(0, 8))
23
+ z = sorted(z, key=lambda x: x[0], reverse=True)
24
+ sum, idx = 0, 0
25
+ while sum < 0.95:
26
+ st.markdown(f"{idx + 1}. ", dictionary[z[idx][1]])
27
+ sum += z[idx][0]
28
+ idx += 1
29
+
30
+ def filter(title, abstract):
31
+ return True
32
+
33
  st.title('Классификация статьи по названию и описанию')
34
  # ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
35
 
 
38
  abstract = st.text_area("Введите описание статьи:")
39
  # ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
40
 
41
+ text = title + '. ' + abstract
42
+ dictionary = ['computer science', 'economics', 'Electrical Engineering and Systems Science',
43
+ 'math', 'physics', 'quantitative biology', 'quantitative finance',
44
+ 'statistics']
45
+ if filter(title, abstract):
46
+ model, tokenizer = Model()
47
+ logits = Predict(model, tokenizer, text)
48
+ Print(logits, dictionary)
49
 
 
50