import streamlit as st import numpy as np import torch from transformers import AutoTokenizer, AutoModel, DistilBertForSequenceClassification my_model_name = "istassiy/ysda_2022_ml2_hw3_distilbert_base_uncased" arxiv_code_to_topic = { 'cs' : 'computer science', 'q-bio' : 'biology', 'q-fin' : 'finance', 'astro-ph' : 'physics', 'cond-mat' : 'physics', 'gr-qc' : 'physics', 'hep-ex' : 'physics', 'hep-lat' : 'physics', 'hep-ph' : 'physics', 'hep-th' : 'physics', 'math-ph' : 'physics', 'nlin' : 'physics', 'nucl-ex' : 'physics', 'nucl-th' : 'physics', 'quant-ph' : 'physics', 'physics' : 'physics', 'eess' : 'electrical engineering', 'econ' : 'economics', 'math' : 'mathematics', 'stat' : 'statistics', } sorted_arxiv_topics = sorted(set(arxiv_code_to_topic.values())) NUM_LABELS = len(sorted_arxiv_topics) @st.cache(allow_output_mutation=True) def load_tokenizer(): tokenizer = AutoTokenizer.from_pretrained(my_model_name) return tokenizer @st.cache(allow_output_mutation=True) def load_model(): model = DistilBertForSequenceClassification.from_pretrained(my_model_name) return model def sigmoid(x): return 1/(1 + np.exp(-x)) def get_top_predictions(predictions): probs = (sigmoid(predictions) > 0.5).astype(float) probs = probs / np.sum(probs) res = {} total_prob = 0 for topic, prob in sorted(zip(sorted_arxiv_topics, probs), key=lambda item: item[1], reverse=True): total_prob += prob res[topic] = prob if total_prob > 0.95: break return res tokenizer = load_tokenizer() model = load_model() st.markdown("# Scientific paper classificator") st.markdown( "Fill in paper summary and / or title below and then press open area on the page to submit inputs:", unsafe_allow_html=False ) paper_title = st.text_area("Paper title") paper_summary = st.text_area("Paper abstract") if not paper_title and not paper_summary: st.markdown(f"Must have non-empty title or summary") else: with torch.no_grad(): raw_predictions = model( **tokenizer( [paper_title + "." + paper_summary], padding=True, truncation=True, return_tensors="pt" ) ) results = get_top_predictions(raw_predictions[0][0].numpy()) st.markdown("The following are probabilities for paper topics:") for topic, prob in sorted(results.items(), key=lambda item: item[1], reverse=True): st.markdown(f"{topic}: {prob}")