Spaces:
Runtime error
Runtime error
TinySuitStarfish
commited on
Commit
•
a372b52
1
Parent(s):
de4c369
Adding app to demo space HF
Browse files- paper_recommender.py +59 -0
paper_recommender.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import torch
|
5 |
+
import base64
|
6 |
+
import pandas as pd
|
7 |
+
import streamlit as st
|
8 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
9 |
+
from sentence_transformers import SentenceTransformer
|
10 |
+
|
11 |
+
sys.path.insert(0, os.getcwd())
|
12 |
+
|
13 |
+
st.title("ArXiV Paper Recommender")
|
14 |
+
|
15 |
+
def set_background(main_bg):
|
16 |
+
main_bg_ext = "jpg"
|
17 |
+
st.markdown(
|
18 |
+
f"""
|
19 |
+
<style>
|
20 |
+
.stApp {{
|
21 |
+
background: url(data:image/{main_bg_ext};base64,{base64.b64encode(open(main_bg, "rb").read()).decode()});
|
22 |
+
background-size: cover
|
23 |
+
}}
|
24 |
+
</style>
|
25 |
+
""",
|
26 |
+
unsafe_allow_html=True
|
27 |
+
)
|
28 |
+
|
29 |
+
set_background("../images/p1.jpg")
|
30 |
+
|
31 |
+
topic = st.text_input('What kind of paper would you wish to be recommended?', 'I want to read a paper on Bayesian Optimization!')
|
32 |
+
number = st.number_input('Show me these many papers.', min_value=1, max_value=10, value=3, step=1)
|
33 |
+
|
34 |
+
def process_text(text):
|
35 |
+
rep = {"\n": " ", "(": "", ")": "", "!": ""}
|
36 |
+
rep = dict((re.escape(k), v) for k, v in rep.items())
|
37 |
+
pattern = re.compile("|".join(rep.keys()))
|
38 |
+
text = pattern.sub(lambda m: rep[re.escape(m.group(0))], text).lower()
|
39 |
+
return text
|
40 |
+
|
41 |
+
def get_cosine_similarity(feature_vec_1, feature_vec_2):
|
42 |
+
return cosine_similarity(feature_vec_1.reshape(1, -1), feature_vec_2.reshape(1, -1))[0][0]
|
43 |
+
|
44 |
+
def get_model():
|
45 |
+
device = 'cuda' if torch.cuda.is_available() else None
|
46 |
+
model = SentenceTransformer('paraphrase-MiniLM-L6-v2', device=device)
|
47 |
+
return model
|
48 |
+
|
49 |
+
if st.button("GO!"):
|
50 |
+
prompt = process_text(topic)
|
51 |
+
model = get_model()
|
52 |
+
prompt_embedded = model.encode(prompt)
|
53 |
+
df_embed = pd.read_pickle('../data/embeddings_pkl.pkl').drop_duplicates(subset=['titles'])
|
54 |
+
df_embed["similarity_scores"] = df_embed["abstracts_embeddings"].apply(lambda x: get_cosine_similarity(x, prompt_embedded))
|
55 |
+
top_n = df_embed.nlargest(number, 'similarity_scores').head(5)["titles"].to_list()
|
56 |
+
st.text(" ")
|
57 |
+
st.subheader('Have a look at the following: :sunglasses:')
|
58 |
+
for rec_title in top_n:
|
59 |
+
st.markdown("-> " + rec_title)
|