Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
import transformers | |
import tokenizers | |
def load_model(): | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
model_name = 'distilbert-base-cased' | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=8) | |
model.load_state_dict(torch.load('model_weights.pt', map_location=torch.device('cpu'))) | |
model.eval() | |
return tokenizer, model | |
def predict(title, summary, tokenizer, model): | |
text = title + "\n" + summary | |
tokens = tokenizer.encode(text) | |
with torch.no_grad(): | |
logits = model(torch.as_tensor([tokens], device=device))[0] | |
probs = torch.softmax(logits[-1, :], dim=-1).data.cpu().numpy() | |
classes = np.flip(np.argsort(probs)) | |
sum_probs = 0 | |
ind = 0 | |
prediction = [] | |
prediction_probs = [] | |
while sum_probs < 0.95: | |
prediction.append(label_to_theme[classes[ind]]) | |
prediction_probs.append(probs[classes[ind]]) | |
sum_probs += probs[classes[ind]] | |
ind += 1 | |
return prediction, prediction_probs | |
def get_results(prediction, prediction_probs): | |
ans = "Topic:\t\tConfidence:\n" | |
for (topic, prob) in zip(prediction, prediction_probs): | |
ans += topic + "\t\t" + str(prob) + "\n" | |
return ans | |
st.title("Arxiv articles classification") | |
st.markdown("This is an interface that can determine the article's topic based on its title and summary. Though it can work with title only, it is recommended that you provide summary if possible - this will result in a better prediction quality.") | |
tokenizer, model = load_model() | |
title = st.text_area(label='Title', height=100) | |
summary = st.text_area(label='Summary (optional)', height=250) | |
prediction, prediction_probs = predict(title, summary, tokenizer, model) | |
ans = get_results(prediction, prediction_probs) | |
st.markdown(text) |