istassiy's picture
commit from
4915257
raw
history blame
1.84 kB
import streamlit as st
import numpy as np
my_model_name = "istassiy/ysda_2022_ml2_hw3_distilbert_base_uncased"
arxiv_code_to_topic = {
'cs' : 'computer science',
'q-bio' : 'biology',
'q-fin' : 'finance',
'astro-ph' : 'physics',
'cond-mat' : 'physics',
'gr-qc' : 'physics',
'hep-ex' : 'physics',
'hep-lat' : 'physics',
'hep-ph' : 'physics',
'hep-th' : 'physics',
'math-ph' : 'physics',
'nlin' : 'physics',
'nucl-ex' : 'physics',
'nucl-th' : 'physics',
'quant-ph' : 'physics',
'physics' : 'physics',
'eess' : 'electrical engineering',
'econ' : 'economics',
'math' : 'mathematics',
'stat' : 'statistics',
}
sorted_arxiv_topics = sorted(set(arxiv_code_to_topic.values()))
NUM_LABELS = len(sorted_arxiv_topics)
@st.cache
def load_model():
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained(my_model_name)
model = AutoModel.from_pretrained(my_model_name)
return tokenizer, model
def sigmoid(x):
return 1/(1 + np.exp(-x))
def get_top_predictions(predictions):
probs = (sigmoid(predictions) > 0).astype(int)
probs = probs / np.sum(probs)
res = {}
total_prob = 0
for topic, prob in zip(sorted_arxiv_topics, probs):
total_prob += prob
res[topic] = prob
if total_prob > 0.95:
break
return res
tokenizer, model = load_model()
st.markdown("# Scientific paper classificator")
st.markdown(
"Fill in paper summary and / or title below:",
unsafe_allow_html=False
)
paper_title = st.text_area("Paper title")
paper_summary = st.text_area("Paper abstract")
if not paper_title and not paper_summary:
st.markdown(f"Must have non-empty title or summary")
else:
raw_predictions = model(**tokenizer(paper_title + "." + paper_summary))
results = get_top_predictions(raw_predictions)
st.markdown(f"{results}")