""" Build an editable user profile based recommender. - Read the users json and read their paper reps and keyphrases into memory. - Read the candidates document (first stage retrieval) and sentence embeddings into memory (second stage retrieval). - Display the keyphrases to users and ask them to check it. - Use the keyphrases and sentence embeddings to compute keyphrase values. - Display the keyphrase selection box to users for retrieval. - Use the selected keyphrases for performing retrieval. """ import copy import json import pickle import joblib import os import collections import streamlit as st import numpy as np from scipy.spatial import distance from scipy import special from sklearn.neighbors import NearestNeighbors from sentence_transformers import SentenceTransformer, models import torch import ot # import seaborn as sns import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt plt.rcParams['figure.dpi'] = 400 plt.rcParams.update({'axes.labelsize': 'small'}) in_path = './data' ######################################## # BACKEND CODE # ######################################## def read_user(seed_json): """ Given the seed json for the user read the embedded documents for the user. :param seed_json: :return: """ if 'doc_vectors_user' not in st.session_state: uname = seed_json['username'] user_kps = seed_json['user_kps'] # Read document vectors. doc_vectors_user = np.load(os.path.join(in_path, 'users', uname, f'embeds-{uname}-doc.npy')) with open(os.path.join(in_path, 'users', uname, f'pid2idx-{uname}-doc.json'), 'r') as fp: pid2idx_user = json.load(fp) # Read sentence vectors. pid2sent_vectors = joblib.load(os.path.join(in_path, 'users', uname, f'embeds-{uname}-sent.pickle')) pid2sent_vectors_user = collections.OrderedDict() for pid in sorted(pid2sent_vectors): pid2sent_vectors_user[pid] = pid2sent_vectors[pid] st.session_state['doc_vectors_user'] = doc_vectors_user st.session_state['pid2idx_user'] = pid2idx_user st.session_state['pid2sent_vectors_user'] = pid2sent_vectors_user st.session_state['user_kps'] = user_kps return doc_vectors_user, pid2idx_user, pid2sent_vectors, user_kps else: return st.session_state.doc_vectors_user, st.session_state.pid2idx_user, \ st.session_state.pid2sent_vectors_user, st.session_state.user_kps def first_stage_ranked_docs(user_doc_queries, per_doc_to_rank, total_to_rank=2000): """ Return a list of ranked documents given a set of queries. :param user_doc_queries: read the cached query embeddings :return: """ if 'first_stage_ret_pids' not in st.session_state: # read the document vectors doc_vectors = np.load(os.path.join(in_path, 'cands', 'embeds-s2orccompsci-100k.npy')) with open(os.path.join(in_path, 'cands', 'pid2idx-s2orccompsci-100k.pickle'), 'rb') as fp: pid2idx_cands = pickle.load(fp) idx2pid_cands = dict([(v, k) for k, v in pid2idx_cands.items()]) # index the vectors into a nearest neighbors structure neighbors = NearestNeighbors(n_neighbors=per_doc_to_rank) neighbors.fit(doc_vectors) st.session_state['neighbors'] = neighbors st.session_state['idx2pid_cands'] = idx2pid_cands # Get the dists for all the query docs. nearest_dists, nearest_idxs = neighbors.kneighbors(user_doc_queries, return_distance=True) # Get the docs top_pids = [] uniq_top = set() for ranki in range(per_doc_to_rank): # Save papers by rank position for debugging. for qi in range(user_doc_queries.shape[0]): idx = nearest_idxs[qi, ranki] pid = idx2pid_cands[idx] if pid not in uniq_top: # Only save the unique papers. (ignore multiple retrievals of the same paper) top_pids.append(pid) uniq_top.add(pid) top_pids = top_pids[:total_to_rank] st.session_state['first_stage_ret_pids'] = top_pids return top_pids else: return st.session_state.first_stage_ret_pids def read_kp_encoder(in_path): """ Read the kp encoder model from disk. :param in_path: string; :return: """ if 'kp_enc_model' not in st.session_state: word_embedding_model = models.Transformer(os.path.join(in_path, 'models', 'scibert_scivocab_uncased'), max_seq_length=512) trained_model_fname = os.path.join(in_path, 'models', 'kp_encoder_cur_best.pt') if torch.cuda.is_available(): saved_model = torch.load(trained_model_fname) else: saved_model = torch.load(trained_model_fname, map_location=torch.device('cpu')) word_embedding_model.auto_model.load_state_dict(saved_model) pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode='mean') kp_enc_model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) st.session_state['kp_enc_model'] = kp_enc_model else: return st.session_state.kp_enc_model def read_candidates(in_path): """ Read candidate papers into pandas dataframe. :param in_path: :return: """ if 'pid2abstract' not in st.session_state: with open(os.path.join(in_path, 'cands', 'abstracts-s2orccompsci-100k.pickle'), 'rb') as fp: pid2abstract = pickle.load(fp) # read the sentence vectors pid2sent_vectors = joblib.load(os.path.join(in_path, 'cands', f'embeds-sent-s2orccompsci-100k.pickle')) st.session_state['pid2sent_vectors_cands'] = pid2sent_vectors st.session_state['pid2abstract'] = pid2abstract return pid2abstract, pid2sent_vectors else: return st.session_state.pid2abstract, st.session_state.pid2sent_vectors_cands def get_kp_embeddings(profile_keyphrases): """ Embed the passed profike keyphrases :param profile_keyphrases: list(string) :return: """ kp_enc_model = st.session_state['kp_enc_model'] if 'kp_vectors_user' not in st.session_state: kp_embeddings = kp_enc_model.encode(profile_keyphrases) kp_vectors_user = collections.OrderedDict() for i, kp in enumerate(profile_keyphrases): kp_vectors_user[kp] = kp_embeddings[i, :] st.session_state['kp_vectors_user'] = kp_vectors_user return kp_vectors_user else: uncached_kps = [kp for kp in profile_keyphrases if kp not in st.session_state.kp_vectors_user] kp_embeddings = kp_enc_model.encode(uncached_kps) for i, kp in enumerate(uncached_kps): st.session_state.kp_vectors_user[kp] = kp_embeddings[i, :] return st.session_state.kp_vectors_user def generate_profile_values(profile_keyphrases): """ - Read sentence embeddings - Read profile keyphrase embeddings - Compute alignment from sentences to keyphrases - Barycenter project the keyphrases to sentences to get kp values - Return the kp values :param profile_keyphrases: list(string) :return: """ kp_embeddings = get_kp_embeddings(profile_keyphrases) # Read sentence embeddings. user_seed_sentembeds = np.vstack(list(st.session_state.pid2sent_vectors_user.values())) # Read keyphrase embeddings. kps_embeds_flat = [] for kp in profile_keyphrases: kps_embeds_flat.append(kp_embeddings[kp]) kps_embeds_flat = np.vstack(kps_embeds_flat) # Compute transport plan from sentence to keyphrases. pair_dists = distance.cdist(user_seed_sentembeds, kps_embeds_flat, 'euclidean') a_distr = [1 / user_seed_sentembeds.shape[0]] * user_seed_sentembeds.shape[0] b_distr = [1 / kps_embeds_flat.shape[0]] * kps_embeds_flat.shape[0] # tplan = ot.bregman.sinkhorn_epsilon_scaling(a_distr, b_distr, pair_dists, 0.05, numItermax=2000) tplan = ot.partial.entropic_partial_wasserstein(a_distr, b_distr, pair_dists, 0.05, m=0.8) # Barycenter project the keyphrases to the sentences: len(profile_keyphraases) x embedding_dim proj_kp_vectors = np.matmul(user_seed_sentembeds.T, tplan).T norm = np.sum(tplan, axis=0) kp_value_vectors = proj_kp_vectors/norm[:, np.newaxis] # Return as a dict. kp2valvectors = {} for i, kp in enumerate(profile_keyphrases): kp2valvectors[kp] = kp_value_vectors[i, :] return kp2valvectors, tplan def second_stage_ranked_docs(selected_query_kps, first_stage_pids, pid2abstract, pid2sent_reps_cand, to_rank=30): """ Return a list of ranked documents given a set of queries. :param first_stage_pids: list(string) :param pid2abstract: dict(pid: paperd) :param query_paper_idxs: list(int); :return: """ if len(selected_query_kps) < 3: topk = len(selected_query_kps) else: # Use 20% of keyphrases for scoring or 3 whichever is larger topk = max(int(len(st.session_state.kp2val_vectors)*0.2), 3) query_kp_values = np.vstack([st.session_state.kp2val_vectors[kp] for kp in selected_query_kps]) pid2topkdist = dict() pid2kp_expls = collections.defaultdict(list) for i, pid in enumerate(first_stage_pids): sent_reps = pid2sent_reps_cand[pid] pair_dists = distance.cdist(query_kp_values, sent_reps) # Pick the topk unique profile concepts. kp_ind = np.argsort(pair_dists.min(axis=1))[:topk] sub_pair_dists = pair_dists[kp_ind, :] # sub_kp_reps = query_kp_values[kp_ind, :] a_distr = special.softmax(-1*np.min(sub_pair_dists, axis=1)) b_distr = [1 / sent_reps.shape[0]] * sent_reps.shape[0] tplan = ot.bregman.sinkhorn_epsilon_scaling(a_distr, b_distr, sub_pair_dists, 0.05) wd = np.sum(sub_pair_dists * tplan) # topk_dist = 0 # for k in range(topk): # topk_dist += pair_dists[kp_ind[k], sent_ind[k]] # pid2kp_expls[pid].append(selected_query_kps[kp_ind[k]]) # pid2topkdist[pid] = topk_dist pid2topkdist[pid] = wd top_pids = sorted(pid2topkdist, key=pid2topkdist.get) # Get the docs retrieved_papers = collections.OrderedDict() for pid in top_pids: retrieved_papers[pid2abstract[pid]['title']] = { 'title': pid2abstract[pid]['title'], 'kp_explanations': pid2kp_expls[pid], 'abstract': pid2abstract[pid]['abstract'] } if len(retrieved_papers) == to_rank: break return retrieved_papers ######################################## # HELPER CODE # ######################################## def parse_input_kps(unparsed_kps, initial_user_kps): """ Function to parse the input keyphrase string. :return: """ if unparsed_kps.strip(): kps = unparsed_kps.split(',') parsed_user_kps = [] uniq_kps = set() for kp in kps: kp = kp.strip() if kp not in uniq_kps: parsed_user_kps.append(kp) uniq_kps.add(kp) else: # If its an empty string use the initial kps parsed_user_kps = copy.copy(initial_user_kps) return parsed_user_kps # def plot_sent_kp_alignment(tplan, kp_labels, sent_labels): # """ # Plot the sentence keyphrase alignment. # :return: # """ # fig, ax = plt.subplots() # h = sns.heatmap(tplan.T, linewidths=.3, xticklabels=sent_labels, # yticklabels=kp_labels, cmap='Blues') # h.tick_params('y', labelsize=5) # h.tick_params('x', labelsize=2) # plt.tight_layout() # return fig def multiselect_title_formatter(title): """ Format the multi-select titles. :param title: string :return: string: formatted title """ ftitle = title.split()[:5] return ' '.join(ftitle) + '...' def format_abstract(paperd, to_display=3, markdown=True): """ Given a dict with title and abstract return a formatted text for rendering with markdown. :param paperd: :param to_display: :return: """ if len(paperd['abstract']) < to_display: sents = ' '.join(paperd['abstract']) else: sents = ' '.join(paperd['abstract'][:to_display]) + '...' try: kp_expl = ', '.join(paperd['kp_explanations']) except KeyError: kp_expl = '' if markdown: par = '
Title: {:s}
Abstract: {:s}
{:s}
- {:}
'.format(papert) st.markdown('{:}'.format(fpapert), unsafe_allow_html=True) if st.session_state.tuning_i > 0: st.download_button('Download papers', perp_result_json(), mime='json', help='Download the papers saved in the session.') with st.expander("Copy saved papers to clipboard"): st.write(json.loads(perp_result_json()))