Spaces:
Runtime error
Runtime error
File size: 2,332 Bytes
4355387 de1989c e1fd88d 4355387 675d241 c31188d 4355387 c31188d 675d241 c31188d e1fd88d c31188d 4355387 4915257 4355387 5409e3f 4355387 c31188d 4355387 5409e3f de1989c 5409e3f d276604 40d8d53 7d6731a 5409e3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import streamlit as st
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, DistilBertForSequenceClassification
my_model_name = "istassiy/ysda_2022_ml2_hw3_distilbert_base_uncased"
arxiv_code_to_topic = {
'cs' : 'computer science',
'q-bio' : 'biology',
'q-fin' : 'finance',
'astro-ph' : 'physics',
'cond-mat' : 'physics',
'gr-qc' : 'physics',
'hep-ex' : 'physics',
'hep-lat' : 'physics',
'hep-ph' : 'physics',
'hep-th' : 'physics',
'math-ph' : 'physics',
'nlin' : 'physics',
'nucl-ex' : 'physics',
'nucl-th' : 'physics',
'quant-ph' : 'physics',
'physics' : 'physics',
'eess' : 'electrical engineering',
'econ' : 'economics',
'math' : 'mathematics',
'stat' : 'statistics',
}
sorted_arxiv_topics = sorted(set(arxiv_code_to_topic.values()))
NUM_LABELS = len(sorted_arxiv_topics)
@st.cache(allow_output_mutation=True)
def load_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(my_model_name)
return tokenizer
@st.cache(allow_output_mutation=True)
def load_model():
model = DistilBertForSequenceClassification.from_pretrained(my_model_name)
return model
def sigmoid(x):
return 1/(1 + np.exp(-x))
def get_top_predictions(predictions):
probs = (sigmoid(predictions) > 0.5).astype(float)
probs = probs / np.sum(probs)
res = {}
total_prob = 0
for topic, prob in zip(sorted_arxiv_topics, probs):
total_prob += prob
res[topic] = prob
if total_prob > 0.95:
break
return res
tokenizer = load_tokenizer()
model = load_model()
st.markdown("# Scientific paper classificator")
st.markdown(
"Fill in paper summary and / or title below:",
unsafe_allow_html=False
)
paper_title = st.text_area("Paper title")
paper_summary = st.text_area("Paper abstract")
if not paper_title and not paper_summary:
st.markdown(f"Must have non-empty title or summary")
else:
with torch.no_grad():
raw_predictions = model(
**tokenizer(
[paper_title + "." + paper_summary],
padding=True, truncation=True, return_tensors="pt"
)
)
results = get_top_predictions(raw_predictions[0][0].numpy())
st.markdown("The following are probabilities for paper topics:")
for topic, prob in sorted(results.items(), key=lambda item: item[1], reverse=True):
st.markdown(f"{topic}: {prob}") |