import streamlit as st import pandas as pd import sys import os from datasets import load_from_disk, load_dataset from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode from sklearn.metrics.pairwise import cosine_similarity import numpy as np import time from annotated_text import annotated_text ORG_ID = "cornell-authorship" @st.cache_data def preprocess_text(s): return list(filter(lambda x: x!= '', (''.join(c if c.isalnum() or c == ' ' else ' ' for c in s)).split(' '))) @st.cache_data def get_pairwise_distances(model): dataset = load_dataset(f"{ORG_ID}/{model}_distance")["train"] df = pd.DataFrame(dataset).set_index('index') return df @st.cache_data def get_pairwise_distances_chunked(model, chunk): # for df in pd.read_csv(f"{ASSETS_PATH}/{model}/pairwise_distances.csv", chunksize = 16): # print(df.iloc[0]['queries']) # if chunk == int(df.iloc[0]['queries']): # return df return get_pairwise_distances(model) @st.cache_data def get_query_strings(): # df = pd.read_json(hf_hub_download(repo_id=repo_id, filename="IUR_Reddit_test_queries_english.jsonl"), lines = True) dataset = load_dataset(f"{ORG_ID}/IUR_Reddit_test_queries_english")["train"] df = pd.DataFrame(dataset) df['index'] = df.reset_index().index return df # df['partition'] = df['index']%100 # df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", index = 'index', partition_cols = 'partition') # return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", columns=['fullText', 'index', 'authorIDs']) @st.cache_data def get_candidate_strings(): # df = pd.read_json(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.jsonl", lines = True) dataset = load_dataset(f"{ORG_ID}/IUR_Reddit_test_candidates_english")["train"] df = pd.DataFrame(dataset) df['index'] = df.reset_index().index return df # df['partition'] = df['index']%100 # df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", index = 'index', partition_cols = 'partition') # return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", columns=['fullText', 'index', 'authorIDs']) @st.cache_data def get_embedding_dataset(model): # data = load_from_disk(f"{ASSETS_PATH}/{model}/embedding") data = load_dataset(f"{ORG_ID}/{model}_embedding") return data @st.cache_data def get_bad_queries(model): df = get_query_strings().iloc[list(get_pairwise_distances(model)['queries'].unique())][['fullText', 'index', 'authorIDs']] return df @st.cache_data def get_gt_candidates(model, author): gt_candidates = get_candidate_strings() df = gt_candidates[gt_candidates['authorIDs'].apply(lambda x: x[0]) == author] return df @st.cache_data def get_candidate_text(l): return get_candidate_strings().at[l,'fullText'] @st.cache_data def get_annotated_text(text, word, pos): # print("here", word, pos) start= text.index(word, pos) end = start+len(word) return (text[:start], (text[start:end ], 'SELECTED'), text[end:]), end class AgGridBuilder: __static_key = 0 def build_ag_grid(table, display_columns): AgGridBuilder.__static_key += 1 options_builder = GridOptionsBuilder.from_dataframe(table[display_columns]) options_builder.configure_pagination(paginationAutoPageSize=False, paginationPageSize=10) options_builder.configure_selection(selection_mode= 'single', pre_selected_rows = [0]) options = options_builder.build() return AgGrid(table, gridOptions = options, fit_columns_on_grid_load=True, key = AgGridBuilder.__static_key, reload_data = True, update_mode = GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED) if __name__ == '__main__': st.set_page_config(layout="wide") # models = filter(lambda file_name: os.path.isdir(f"{ASSETS_PATH}/{file_name}") and not file_name.endswith(".parquet"), os.listdir(ASSETS_PATH)) models = ['luar_clone2_top_100'] with st.sidebar: current_model = st.selectbox( "Select Model to analyze", models ) pairwise_distances = get_pairwise_distances(current_model) embedding_dataset = get_embedding_dataset(current_model) candidate_string_grid = None gt_candidate_string_grid = None with st.container(): t1 = time.time() st.title("Full Text") col1, col2 = st.columns([14, 2]) t2 = time.time() query_table = get_bad_queries(current_model) t3 = time.time() # print(query_table) with col2: index = st.number_input('Enter Query number to inspect', min_value = 0, max_value = query_table.shape[0], step = 1) query_text = query_table.loc[index]['fullText'] preprocessed_query_text = preprocess_text(query_text) text_highlight_index = st.number_input('Enter word #', min_value = 0, max_value = len(preprocessed_query_text), step = 1) query_index = int(query_table.iloc[index]['index']) with col1: if 'pos_highlight' not in st.session_state or text_highlight_index == 0: st.session_state['pos_highlight'] = text_highlight_index st.session_state['pos_history'] = [0] if st.session_state['pos_highlight'] > text_highlight_index: st.session_state['pos_history'] = st.session_state['pos_history'][:-2] if len(st.session_state['pos_history']) == 0: st.session_state['pos_history'] = [0] # print("pos", st.session_state['pos_history'], st.session_state['pos_highlight'], text_highlight_index) anotated_text_, pos = get_annotated_text(query_text, preprocessed_query_text[text_highlight_index-1], st.session_state['pos_history'][-1]) if text_highlight_index >= 1 else ((query_text), 0) if st.session_state['pos_highlight'] < text_highlight_index: st.session_state['pos_history'].append(pos) st.session_state['pos_highlight'] = text_highlight_index annotated_text(*anotated_text_) # annotated_text("Lol, this" , ('guy', 'SELECTED') , "is such a PR chameleon. \n\n In the Chan Zuckerberg Initiative announcement, he made it sound like he was giving away all his money to charity or . http://www.businessinsider.in/Mark-Zuckerberg-says-hes-giving-99-of-his-Facebook-shares-45-billion-to-charity/articleshow/50005321.cms Apparently, its just a VC fund. And there are still people out there who believe Facebook.org was an initiative to bring Internet to the poor.") t4 = time.time() # print(f"query time query text: {t3-t2}, total time: {t4-t1}") with st.container(): st.title("Top 16 Recommended Candidates") col1, col2, col3 = st.columns([10, 4, 2]) rec_candidates = pairwise_distances[pairwise_distances["queries"]==query_index]['candidates'] # print(rec_candidates) l = list(rec_candidates) with col3: candidate_rec_index = st.number_input('Enter recommended candidate number to inspect', min_value = 0, max_value = len(l), step = 1) print("l:",l, query_index) pairwise_candidate_index = int(l[candidate_rec_index]) with col1: st.header("Text") t1 = time.time() st.write(get_candidate_text(pairwise_candidate_index)) t2 = time.time() with col2: st.header("Cosine Distance") st.write(float(pairwise_distances[\ ( pairwise_distances['queries'] == query_index ) \ & ( pairwise_distances['candidates'] == pairwise_candidate_index)]['distances'])) print(f"candidate string retreival: {t2-t1}") with st.container(): t1 = time.time() st.title("Candidates With Same Authors As Query") col1, col2, col3 = st.columns([10, 4, 2]) t2 = time.time() gt_candidates = get_gt_candidates(current_model, query_table.iloc[query_index]['authorIDs'][0]) t3 = time.time() with col3: candidate_index = st.number_input('Enter ground truthnumber to inspect', min_value = 0, max_value = gt_candidates.shape[0], step = 1) gt_candidate_index = int(gt_candidates.iloc[candidate_index]['index']) with col1: st.header("Text") st.write(gt_candidates.iloc[candidate_index]['fullText']) with col2: t4 = time.time() st.header("Cosine Distance") st.write(1-cosine_similarity(np.array([embedding_dataset['queries'][query_index]['embedding']]), np.array([embedding_dataset['candidates'][gt_candidate_index]['embedding']]))[0,0]) t5 = time.time() print(f"find gt candidates: {t3-t2}, find cosine: {t5-t4}, total: {t5-t1}")