File size: 2,332 Bytes
4355387
 
de1989c
e1fd88d
4355387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675d241
c31188d
4355387
c31188d
 
675d241
c31188d
e1fd88d
c31188d
4355387
4915257
 
 
4355387
5409e3f
4355387
 
 
 
 
 
 
 
 
 
 
c31188d
 
4355387
 
 
 
 
 
 
 
 
 
5409e3f
 
 
 
de1989c
 
 
 
 
5409e3f
d276604
40d8d53
7d6731a
5409e3f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import streamlit as st
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, DistilBertForSequenceClassification

my_model_name = "istassiy/ysda_2022_ml2_hw3_distilbert_base_uncased"

arxiv_code_to_topic = {
  'cs' : 'computer science',

  'q-bio' : 'biology',

  'q-fin' : 'finance',

  'astro-ph' : 'physics',
  'cond-mat' : 'physics',
  'gr-qc' : 'physics',
  'hep-ex' : 'physics',
  'hep-lat' : 'physics',
  'hep-ph' : 'physics',
  'hep-th' : 'physics',
  'math-ph' : 'physics',
  'nlin' : 'physics',
  'nucl-ex' : 'physics',
  'nucl-th' : 'physics',
  'quant-ph' : 'physics',
  'physics' : 'physics',

  'eess' : 'electrical engineering',

  'econ' : 'economics',

  'math' : 'mathematics',

  'stat' : 'statistics',
}

sorted_arxiv_topics = sorted(set(arxiv_code_to_topic.values()))

NUM_LABELS = len(sorted_arxiv_topics)

@st.cache(allow_output_mutation=True)
def load_tokenizer():
  tokenizer = AutoTokenizer.from_pretrained(my_model_name)
  return tokenizer

@st.cache(allow_output_mutation=True)
def load_model():
  model = DistilBertForSequenceClassification.from_pretrained(my_model_name)
  return model

def sigmoid(x):
  return 1/(1 + np.exp(-x))

def get_top_predictions(predictions):
  probs = (sigmoid(predictions) > 0.5).astype(float)
  probs = probs / np.sum(probs)

  res = {}
  total_prob = 0
  for topic, prob in zip(sorted_arxiv_topics, probs):
    total_prob += prob
    res[topic] = prob
    if total_prob > 0.95:
      break
  return res

tokenizer = load_tokenizer()
model = load_model()

st.markdown("# Scientific paper classificator")
st.markdown(
    "Fill in paper summary and / or title below:",
    unsafe_allow_html=False
)

paper_title = st.text_area("Paper title")
paper_summary = st.text_area("Paper abstract")

if not paper_title and not paper_summary:
  st.markdown(f"Must have non-empty title or summary")
else:
  with torch.no_grad():
    raw_predictions = model(
      **tokenizer(
        [paper_title + "." + paper_summary],
        padding=True, truncation=True, return_tensors="pt"
      )
    )  
    results = get_top_predictions(raw_predictions[0][0].numpy())
    st.markdown("The following are probabilities for paper topics:")
    for topic, prob in sorted(results.items(), key=lambda item: item[1], reverse=True):
      st.markdown(f"{topic}: {prob}")