Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -18,6 +18,8 @@ def load_model():
|
|
18 |
@st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None})
|
19 |
def predict(title, summary, tokenizer, model):
|
20 |
text = title + "\n" + summary
|
|
|
|
|
21 |
tokens = tokenizer.encode(text)
|
22 |
with torch.no_grad():
|
23 |
logits = model(torch.as_tensor([tokens]))[0]
|
@@ -38,7 +40,7 @@ def predict(title, summary, tokenizer, model):
|
|
38 |
|
39 |
@st.cache(suppress_st_warning=True)
|
40 |
def get_results(prediction, prediction_probs):
|
41 |
-
frame = pd.DataFrame({'
|
42 |
frame.index = np.arange(1, len(frame) + 1)
|
43 |
return frame
|
44 |
|
@@ -47,7 +49,7 @@ label_to_theme = {0: 'Computer science', 1: 'Economics', 2: 'Electrical Engineer
|
|
47 |
|
48 |
st.title("Arxiv articles classification")
|
49 |
st.markdown("<h1 style='text-align: center;'><img width=300px src='https://media.wired.com/photos/592700e3cfe0d93c474320f1/191:100/w_1200,h_630,c_limit/faces-icon.jpg'>", unsafe_allow_html=True)
|
50 |
-
st.markdown("This is an interface that can determine the article's
|
51 |
|
52 |
tokenizer, model = load_model()
|
53 |
|
@@ -58,5 +60,8 @@ button = st.button('Run')
|
|
58 |
if button:
|
59 |
prediction, prediction_probs = predict(title, summary, tokenizer, model)
|
60 |
ans = get_results(prediction, prediction_probs)
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
18 |
@st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None})
|
19 |
def predict(title, summary, tokenizer, model):
|
20 |
text = title + "\n" + summary
|
21 |
+
if len(text) < 20:
|
22 |
+
return 'error'
|
23 |
tokens = tokenizer.encode(text)
|
24 |
with torch.no_grad():
|
25 |
logits = model(torch.as_tensor([tokens]))[0]
|
|
|
40 |
|
41 |
@st.cache(suppress_st_warning=True)
|
42 |
def get_results(prediction, prediction_probs):
|
43 |
+
frame = pd.DataFrame({'Category': prediction, 'Confidence': prediction_probs})
|
44 |
frame.index = np.arange(1, len(frame) + 1)
|
45 |
return frame
|
46 |
|
|
|
49 |
|
50 |
st.title("Arxiv articles classification")
|
51 |
st.markdown("<h1 style='text-align: center;'><img width=300px src='https://media.wired.com/photos/592700e3cfe0d93c474320f1/191:100/w_1200,h_630,c_limit/faces-icon.jpg'>", unsafe_allow_html=True)
|
52 |
+
st.markdown("This is an interface that can determine the article's category based on its title and summary. Though it can work with title only, it is recommended that you provide summary if possible - this will result in a better prediction quality.")
|
53 |
|
54 |
tokenizer, model = load_model()
|
55 |
|
|
|
60 |
if button:
|
61 |
prediction, prediction_probs = predict(title, summary, tokenizer, model)
|
62 |
ans = get_results(prediction, prediction_probs)
|
63 |
+
if ans == 'error':
|
64 |
+
st.error("Your input is too short. It is probably not a real article, please try again.")
|
65 |
+
else:
|
66 |
+
st.subheader('Results:')
|
67 |
+
st.write(ans)
|