istassiy's picture
commit from
f8aac2d
import streamlit as st
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, DistilBertForSequenceClassification
my_model_name = "istassiy/ysda_2022_ml2_hw3_distilbert_base_uncased"
arxiv_code_to_topic = {
'cs' : 'computer science',
'q-bio' : 'biology',
'q-fin' : 'finance',
'astro-ph' : 'physics',
'cond-mat' : 'physics',
'gr-qc' : 'physics',
'hep-ex' : 'physics',
'hep-lat' : 'physics',
'hep-ph' : 'physics',
'hep-th' : 'physics',
'math-ph' : 'physics',
'nlin' : 'physics',
'nucl-ex' : 'physics',
'nucl-th' : 'physics',
'quant-ph' : 'physics',
'physics' : 'physics',
'eess' : 'electrical engineering',
'econ' : 'economics',
'math' : 'mathematics',
'stat' : 'statistics',
}
sorted_arxiv_topics = sorted(set(arxiv_code_to_topic.values()))
NUM_LABELS = len(sorted_arxiv_topics)
@st.cache(allow_output_mutation=True)
def load_tokenizer():
tokenizer = AutoTokenizer.from_pretrained(my_model_name)
return tokenizer
@st.cache(allow_output_mutation=True)
def load_model():
model = DistilBertForSequenceClassification.from_pretrained(my_model_name)
return model
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def get_top_predictions(predictions):
probs = sigmoid(predictions)
probs = probs / np.sum(probs)
res = {}
total_prob = 0
for topic, prob in sorted(zip(sorted_arxiv_topics, probs), key=lambda item: item[1], reverse=True):
total_prob += prob
res[topic] = prob
if total_prob > 0.95:
break
return res
tokenizer = load_tokenizer()
model = load_model()
st.markdown("# Scientific paper classificator")
st.markdown(
"Fill in paper summary and / or title below and then press open area on the page to submit inputs:",
unsafe_allow_html=False
)
paper_title = st.text_area("Paper title")
paper_summary = st.text_area("Paper abstract")
if not paper_title and not paper_summary:
st.markdown(f"Must have non-empty title or summary")
else:
with torch.no_grad():
raw_predictions = model(
**tokenizer(
[paper_title + "." + paper_summary],
padding=True, truncation=True, return_tensors="pt"
)
)
results = get_top_predictions(raw_predictions[0][0].numpy())
st.markdown("The following are probabilities for paper topics:")
for topic, prob in sorted(results.items(), key=lambda item: item[1], reverse=True):
st.markdown(f"{topic}: {prob}")