Spaces:
Runtime error
Runtime error
File size: 2,027 Bytes
b8769be 491d5a1 2c5279b b8769be 2c5279b f05ebed 2c5279b 26fd4ec d52f486 2c5279b 5e4fa04 f6a020d 5e4fa04 2c5279b 5e4fa04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import streamlit as st
import torch
import transformers
@st.cache(suppress_st_warning=True)
def load_model():
from transformers import AutoTokenizer, AutoModelForSequenceClassification
model_name = 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=8)
model.load_state_dict(torch.load('model_weights.pt', map_location=torch.device('cpu')))
model.eval()
return tokenizer, model
@st.cache(suppress_st_warning=True)
def predict(title, summary, tokenizer, model):
text = title + "\n" + summary
tokens = tokenizer.encode(text)
with torch.no_grad():
logits = model(torch.as_tensor([tokens], device=device))[0]
probs = torch.softmax(logits[-1, :], dim=-1).data.cpu().numpy()
classes = np.flip(np.argsort(probs))
sum_probs = 0
ind = 0
prediction = []
prediction_probs = []
while sum_probs < 0.95:
prediction.append(label_to_theme[classes[ind]])
prediction_probs.append(probs[classes[ind]])
sum_probs += probs[classes[ind]]
ind += 1
return prediction, prediction_probs
@st.cache(suppress_st_warning=True)
def get_results(prediction, prediction_probs):
ans = "Topic:\t\tConfidence:\n"
for (topic, prob) in zip(prediction, prediction_probs):
ans += topic + "\t\t" + str(prob) + "\n"
return ans
st.title("Arxiv articles classification")
st.markdown("This is an interface that can determine the article's topic based on its title and summary. Though it can work with title only, it is recommended that you provide summary if possible - this will result in a better prediction quality.")
tokenizer, model = load_model()
title = st.text_area(label='Title', height=100)
summary = st.text_area(label='Summary (optional)', height=250)
prediction, prediction_probs = predict(title, summary, tokenizer, model)
ans = get_results(prediction, prediction_probs)
st.markdown(text) |