ki33elev's picture
Update app.py
dee08fe
raw
history blame
2.49 kB
import streamlit as st
import numpy as np
import pandas as pd
import torch
import transformers
import tokenizers
@st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None})
def load_model():
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model_name = 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=8)
model.load_state_dict(torch.load('model_weights.pt', map_location=torch.device('cpu')))
model.eval()
return tokenizer, model
@st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None})
def predict(title, summary, tokenizer, model):
text = title + "\n" + summary
tokens = tokenizer.encode(text)
with torch.no_grad():
logits = model(torch.as_tensor([tokens]))[0]
probs = torch.softmax(logits[-1, :], dim=-1).data.cpu().numpy()
classes = np.flip(np.argsort(probs))
sum_probs = 0
ind = 0
prediction = []
prediction_probs = []
while sum_probs < 0.95:
prediction.append(label_to_theme[classes[ind]])
prediction_probs.append(probs[classes[ind]])
sum_probs += probs[classes[ind]]
ind += 1
return prediction, prediction_probs
@st.cache(suppress_st_warning=True)
def get_results(prediction, prediction_probs):
for prob in prediction_probs:
prob = str("{:.2f}".format(100 * prob)) + "%"
return pd.DataFrame({
'Topic': prediction,
'Confidence': prediction_probs,
})
label_to_theme = {0: 'Computer science', 1: 'Economics', 2: 'Electrical Engineering and Systems Science', 3: 'Math',
4: 'Quantitative biology', 5: 'Quantitative Finance', 6: 'Statistics', 7: 'Physics'}
st.title("Arxiv articles classification")
st.markdown("This is an interface that can determine the article's topic based on its title and summary. Though it can work with title only, it is recommended that you provide summary if possible - this will result in a better prediction quality.")
tokenizer, model = load_model()
title = st.text_area(label='Title', height=100)
summary = st.text_area(label='Summary (optional)', height=250)
button = st.button('Run')
if button:
prediction, prediction_probs = predict(title, summary, tokenizer, model)
ans = get_results(prediction, prediction_probs)
st.write('Results: ', ans)