File size: 2,029 Bytes
b8769be
2c5279b
 
b8769be
 
2c5279b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f05ebed
 
2c5279b
26fd4ec
d52f486
 
 
2c5279b
 
5e4fa04
f6a020d
5e4fa04
2c5279b
 
5e4fa04
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import streamlit as st
import pytorch
import transformers

@st.cache(suppress_st_warning=True)
def load_model():
    from transformers import AutoTokenizer, AutoModelForSequenceClassification
    model_name = 'distilbert-base-cased'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=8)
    model.load_state_dict(torch.load('model_weights.pt', map_location=torch.device('cpu')))
    model.eval()
    return tokenizer, model
    
@st.cache(suppress_st_warning=True)  
def predict(title, summary, tokenizer, model):
    text = title + "\n" + summary
    tokens = tokenizer.encode(text)
    with torch.no_grad():
        logits = model(torch.as_tensor([tokens], device=device))[0]
        probs = torch.softmax(logits[-1, :], dim=-1).data.cpu().numpy()
    
    classes = np.flip(np.argsort(probs))
    sum_probs = 0
    ind = 0
    prediction = []
    prediction_probs = []
    while sum_probs < 0.95:
        prediction.append(label_to_theme[classes[ind]])
        prediction_probs.append(probs[classes[ind]])
        sum_probs += probs[classes[ind]]
        ind += 1
    
    return prediction, prediction_probs
   
@st.cache(suppress_st_warning=True) 
def get_results(prediction, prediction_probs):
    ans = "Topic:\t\tConfidence:\n"
    for (topic, prob) in zip(prediction, prediction_probs):
        ans += topic + "\t\t" + str(prob) + "\n"
    return ans

st.title("Arxiv articles classification")
st.markdown("This is an interface that can determine the article's topic based on its title and summary. Though it can work with title only, it is recommended that you provide summary if possible - this will result in a better prediction quality.")

tokenizer, model = load_model()

title = st.text_area(label='Title', height=100)
summary = st.text_area(label='Summary (optional)', height=250)

prediction, prediction_probs = predict(title, summary, tokenizer, model)
ans = get_results(prediction, prediction_probs)
st.markdown(text)