friendshipkim commited on
Commit
e899844
1 Parent(s): 1ece544

change to hf datasets

Browse files
Files changed (1) hide show
  1. app.py +140 -142
app.py CHANGED
@@ -9,179 +9,177 @@ import numpy as np
9
  import time
10
  from annotated_text import annotated_text
11
 
12
- from huggingface_hub import hf_hub_download
13
- repo_id = "friendshipkim/IUR_Reddit"
14
 
15
- # ABSOLUTE_PATH = os.path.dirname(__file__)
16
- # ASSETS_PATH = os.path.join(ABSOLUTE_PATH, 'model_assets')
17
-
18
- @st.cache
19
  def preprocess_text(s):
20
- return list(filter(lambda x: x!= '', (''.join(c if c.isalnum() or c == ' ' else ' ' for c in s)).split(' ')))
21
 
22
- @st.cache
23
  def get_pairwise_distances(model):
24
- # df = pd.read_csv(f"{ASSETS_PATH}/{model}/pairwise_distances.csv").set_index('index')
25
- df = pd.read_csv(hf_hub_download(repo_id=repo_id, filename="pairwise_distances.csv")).set_index('index')
26
  return df
27
 
28
- @st.cache
29
  def get_pairwise_distances_chunked(model, chunk):
30
- # for df in pd.read_csv(f"{ASSETS_PATH}/{model}/pairwise_distances.csv", chunksize = 16):
31
- # print(df.iloc[0]['queries'])
32
- # if chunk == int(df.iloc[0]['queries']):
33
- # return df
34
- return get_pairwise_distances(model)
35
 
36
- @st.cache
37
  def get_query_strings():
38
- # df = pd.read_json(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.jsonl", lines = True)
39
- df = pd.read_json(hf_hub_download(repo_id=repo_id, filename="IUR_Reddit_test_queries_english.jsonl"), lines = True)
 
40
  df['index'] = df.reset_index().index
41
  return df
42
- # df['partition'] = df['index']%100
43
- # df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", index = 'index', partition_cols = 'partition')
44
-
45
- # return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", columns=['fullText', 'index', 'authorIDs'])
46
 
47
- @st.cache
48
  def get_candidate_strings():
49
  # df = pd.read_json(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.jsonl", lines = True)
50
- df = pd.read_json(hf_hub_download(repo_id=repo_id, filename="IUR_Reddit_test_candidates_english.jsonl"), lines = True)
 
51
  df['index'] = df.reset_index().index
52
  return df
53
- # df['partition'] = df['index']%100
54
- # df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", index = 'index', partition_cols = 'partition')
55
- # return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", columns=['fullText', 'index', 'authorIDs'])
56
 
57
- @st.cache
58
  def get_embedding_dataset(model):
59
  # data = load_from_disk(f"{ASSETS_PATH}/{model}/embedding")
60
- data = load_dataset("friendshipkim/luar_clone2_top_100_embedding")
61
  return data
62
 
63
- @st.cache
64
  def get_bad_queries(model):
65
- df = get_query_strings().iloc[list(get_pairwise_distances(model)['queries'].unique())][['fullText', 'index', 'authorIDs']]
66
- return df
67
 
68
- @st.cache
69
  def get_gt_candidates(model, author):
70
- gt_candidates = get_candidate_strings()
71
- df = gt_candidates[gt_candidates['authorIDs'].apply(lambda x: x[0]) == author]
72
- return df
73
 
74
- @st.cache
75
  def get_candidate_text(l):
76
- return get_candidate_strings().at[l,'fullText']
77
 
78
- @st.cache
79
  def get_annotated_text(text, word, pos):
80
- # print("here", word, pos)
81
- start= text.index(word, pos)
82
- end = start+len(word)
83
- return (text[:start], (text[start:end ], 'SELECTED'), text[end:]), end
84
 
85
  class AgGridBuilder:
86
- __static_key = 0
87
- def build_ag_grid(table, display_columns):
88
- AgGridBuilder.__static_key += 1
89
- options_builder = GridOptionsBuilder.from_dataframe(table[display_columns])
90
- options_builder.configure_pagination(paginationAutoPageSize=False, paginationPageSize=10)
91
- options_builder.configure_selection(selection_mode= 'single', pre_selected_rows = [0])
92
- options = options_builder.build()
93
- return AgGrid(table, gridOptions = options, fit_columns_on_grid_load=True, key = AgGridBuilder.__static_key, reload_data = True, update_mode = GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED)
94
 
95
  if __name__ == '__main__':
96
- st.set_page_config(layout="wide")
97
 
98
- # models = filter(lambda file_name: os.path.isdir(f"{ASSETS_PATH}/{file_name}") and not file_name.endswith(".parquet"), os.listdir(ASSETS_PATH))
99
- models = ['luar_clone2_top_100']
100
 
101
- with st.sidebar:
102
- current_model = st.selectbox(
103
- "Select Model to analyze",
104
- models
105
- )
106
-
107
- pairwise_distances = get_pairwise_distances(current_model)
108
- embedding_dataset = get_embedding_dataset(current_model)
109
-
110
- candidate_string_grid = None
111
- gt_candidate_string_grid = None
112
- with st.container():
113
- t1 = time.time()
114
- st.title("Full Text")
115
- col1, col2 = st.columns([14, 2])
116
- t2 = time.time()
117
- query_table = get_bad_queries(current_model)
118
- t3 = time.time()
119
- # print(query_table)
120
- with col2:
121
- index = st.number_input('Enter Query number to inspect', min_value = 0, max_value = query_table.shape[0], step = 1)
122
- query_text = query_table.loc[index]['fullText']
123
- preprocessed_query_text = preprocess_text(query_text)
124
- text_highlight_index = st.number_input('Enter word #', min_value = 0, max_value = len(preprocessed_query_text), step = 1)
125
- query_index = int(query_table.iloc[index]['index'])
126
-
127
- with col1:
128
- if 'pos_highlight' not in st.session_state or text_highlight_index == 0:
129
- st.session_state['pos_highlight'] = text_highlight_index
130
- st.session_state['pos_history'] = [0]
131
-
132
- if st.session_state['pos_highlight'] > text_highlight_index:
133
- st.session_state['pos_history'] = st.session_state['pos_history'][:-2]
134
- if len(st.session_state['pos_history']) == 0:
135
- st.session_state['pos_history'] = [0]
136
- # print("pos", st.session_state['pos_history'], st.session_state['pos_highlight'], text_highlight_index)
137
- anotated_text_, pos = get_annotated_text(query_text, preprocessed_query_text[text_highlight_index-1], st.session_state['pos_history'][-1]) if text_highlight_index >= 1 else ((query_text), 0)
138
- if st.session_state['pos_highlight'] < text_highlight_index:
139
- st.session_state['pos_history'].append(pos)
140
- st.session_state['pos_highlight'] = text_highlight_index
141
- annotated_text(*anotated_text_)
142
- # annotated_text("Lol, this" , ('guy', 'SELECTED') , "is such a PR chameleon. \n\n In the Chan Zuckerberg Initiative announcement, he made it sound like he was giving away all his money to charity <PERSON> or <PERSON>. http://www.businessinsider.in/Mark-Zuckerberg-says-hes-giving-99-of-his-Facebook-shares-45-billion-to-charity/articleshow/50005321.cms Apparently, its just a VC fund. And there are still people out there who believe Facebook.org was an initiative to bring Internet to the poor.")
143
- t4 = time.time()
144
-
145
- # print(f"query time query text: {t3-t2}, total time: {t4-t1}")
146
- with st.container():
147
- st.title("Top 16 Recommended Candidates")
148
- col1, col2, col3 = st.columns([10, 4, 2])
149
- rec_candidates = pairwise_distances[pairwise_distances["queries"]==query_index]['candidates']
150
- # print(rec_candidates)
151
- l = list(rec_candidates)
152
- with col3:
153
- candidate_rec_index = st.number_input('Enter recommended candidate number to inspect', min_value = 0, max_value = len(l), step = 1)
154
- print("l:",l, query_index)
155
- pairwise_candidate_index = int(l[candidate_rec_index])
156
- with col1:
157
- st.header("Text")
158
- t1 = time.time()
159
- st.write(get_candidate_text(pairwise_candidate_index))
160
- t2 = time.time()
161
- with col2:
162
- st.header("Cosine Distance")
163
- st.write(float(pairwise_distances[\
164
- ( pairwise_distances['queries'] == query_index ) \
165
- &
166
- ( pairwise_distances['candidates'] == pairwise_candidate_index)]['distances']))
167
- print(f"candidate string retreival: {t2-t1}")
168
- with st.container():
169
- t1 = time.time()
170
- st.title("Candidates With Same Authors As Query")
171
- col1, col2, col3 = st.columns([10, 4, 2])
172
- t2 = time.time()
173
- gt_candidates = get_gt_candidates(current_model, query_table.iloc[query_index]['authorIDs'][0])
174
- t3 = time.time()
175
-
176
- with col3:
177
- candidate_index = st.number_input('Enter ground truthnumber to inspect', min_value = 0, max_value = gt_candidates.shape[0], step = 1)
178
- gt_candidate_index = int(gt_candidates.iloc[candidate_index]['index'])
179
- with col1:
180
- st.header("Text")
181
- st.write(gt_candidates.iloc[candidate_index]['fullText'])
182
- with col2:
183
- t4 = time.time()
184
- st.header("Cosine Distance")
185
- st.write(1-cosine_similarity(np.array([embedding_dataset['queries'][query_index]['embedding']]), np.array([embedding_dataset['candidates'][gt_candidate_index]['embedding']]))[0,0])
186
- t5 = time.time()
187
- print(f"find gt candidates: {t3-t2}, find cosine: {t5-t4}, total: {t5-t1}")
 
9
  import time
10
  from annotated_text import annotated_text
11
 
12
+ ORG_ID = "cornell-authorship"
 
13
 
14
+ @st.cache_data
 
 
 
15
  def preprocess_text(s):
16
+ return list(filter(lambda x: x!= '', (''.join(c if c.isalnum() or c == ' ' else ' ' for c in s)).split(' ')))
17
 
18
+ @st.cache_data
19
  def get_pairwise_distances(model):
20
+ dataset = load_dataset(f"{ORG_ID}/{model}_distance")["train"]
21
+ df = pd.DataFrame(dataset).set_index('index')
22
  return df
23
 
24
+ @st.cache_data
25
  def get_pairwise_distances_chunked(model, chunk):
26
+ # for df in pd.read_csv(f"{ASSETS_PATH}/{model}/pairwise_distances.csv", chunksize = 16):
27
+ # print(df.iloc[0]['queries'])
28
+ # if chunk == int(df.iloc[0]['queries']):
29
+ # return df
30
+ return get_pairwise_distances(model)
31
 
32
+ @st.cache_data
33
  def get_query_strings():
34
+ # df = pd.read_json(hf_hub_download(repo_id=repo_id, filename="IUR_Reddit_test_queries_english.jsonl"), lines = True)
35
+ dataset = load_dataset(f"{ORG_ID}/IUR_Reddit_test_queries_english")["train"]
36
+ df = pd.DataFrame(dataset)
37
  df['index'] = df.reset_index().index
38
  return df
39
+ # df['partition'] = df['index']%100
40
+ # df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", index = 'index', partition_cols = 'partition')
41
+
42
+ # return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_queries_english.parquet", columns=['fullText', 'index', 'authorIDs'])
43
 
44
+ @st.cache_data
45
  def get_candidate_strings():
46
  # df = pd.read_json(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.jsonl", lines = True)
47
+ dataset = load_dataset(f"{ORG_ID}/IUR_Reddit_test_candidates_english")["train"]
48
+ df = pd.DataFrame(dataset)
49
  df['index'] = df.reset_index().index
50
  return df
51
+ # df['partition'] = df['index']%100
52
+ # df.to_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", index = 'index', partition_cols = 'partition')
53
+ # return pd.read_parquet(f"{ASSETS_PATH}/IUR_Reddit_test_candidates_english.parquet", columns=['fullText', 'index', 'authorIDs'])
54
 
55
+ @st.cache_data
56
  def get_embedding_dataset(model):
57
  # data = load_from_disk(f"{ASSETS_PATH}/{model}/embedding")
58
+ data = load_dataset(f"{ORG_ID}/{model}_embedding")
59
  return data
60
 
61
+ @st.cache_data
62
  def get_bad_queries(model):
63
+ df = get_query_strings().iloc[list(get_pairwise_distances(model)['queries'].unique())][['fullText', 'index', 'authorIDs']]
64
+ return df
65
 
66
+ @st.cache_data
67
  def get_gt_candidates(model, author):
68
+ gt_candidates = get_candidate_strings()
69
+ df = gt_candidates[gt_candidates['authorIDs'].apply(lambda x: x[0]) == author]
70
+ return df
71
 
72
+ @st.cache_data
73
  def get_candidate_text(l):
74
+ return get_candidate_strings().at[l,'fullText']
75
 
76
+ @st.cache_data
77
  def get_annotated_text(text, word, pos):
78
+ # print("here", word, pos)
79
+ start= text.index(word, pos)
80
+ end = start+len(word)
81
+ return (text[:start], (text[start:end ], 'SELECTED'), text[end:]), end
82
 
83
  class AgGridBuilder:
84
+ __static_key = 0
85
+ def build_ag_grid(table, display_columns):
86
+ AgGridBuilder.__static_key += 1
87
+ options_builder = GridOptionsBuilder.from_dataframe(table[display_columns])
88
+ options_builder.configure_pagination(paginationAutoPageSize=False, paginationPageSize=10)
89
+ options_builder.configure_selection(selection_mode= 'single', pre_selected_rows = [0])
90
+ options = options_builder.build()
91
+ return AgGrid(table, gridOptions = options, fit_columns_on_grid_load=True, key = AgGridBuilder.__static_key, reload_data = True, update_mode = GridUpdateMode.SELECTION_CHANGED | GridUpdateMode.VALUE_CHANGED)
92
 
93
  if __name__ == '__main__':
94
+ st.set_page_config(layout="wide")
95
 
96
+ # models = filter(lambda file_name: os.path.isdir(f"{ASSETS_PATH}/{file_name}") and not file_name.endswith(".parquet"), os.listdir(ASSETS_PATH))
97
+ models = ['luar_clone2_top_100']
98
 
99
+ with st.sidebar:
100
+ current_model = st.selectbox(
101
+ "Select Model to analyze",
102
+ models
103
+ )
104
+
105
+ pairwise_distances = get_pairwise_distances(current_model)
106
+ embedding_dataset = get_embedding_dataset(current_model)
107
+
108
+ candidate_string_grid = None
109
+ gt_candidate_string_grid = None
110
+ with st.container():
111
+ t1 = time.time()
112
+ st.title("Full Text")
113
+ col1, col2 = st.columns([14, 2])
114
+ t2 = time.time()
115
+ query_table = get_bad_queries(current_model)
116
+ t3 = time.time()
117
+ # print(query_table)
118
+ with col2:
119
+ index = st.number_input('Enter Query number to inspect', min_value = 0, max_value = query_table.shape[0], step = 1)
120
+ query_text = query_table.loc[index]['fullText']
121
+ preprocessed_query_text = preprocess_text(query_text)
122
+ text_highlight_index = st.number_input('Enter word #', min_value = 0, max_value = len(preprocessed_query_text), step = 1)
123
+ query_index = int(query_table.iloc[index]['index'])
124
+
125
+ with col1:
126
+ if 'pos_highlight' not in st.session_state or text_highlight_index == 0:
127
+ st.session_state['pos_highlight'] = text_highlight_index
128
+ st.session_state['pos_history'] = [0]
129
+
130
+ if st.session_state['pos_highlight'] > text_highlight_index:
131
+ st.session_state['pos_history'] = st.session_state['pos_history'][:-2]
132
+ if len(st.session_state['pos_history']) == 0:
133
+ st.session_state['pos_history'] = [0]
134
+ # print("pos", st.session_state['pos_history'], st.session_state['pos_highlight'], text_highlight_index)
135
+ anotated_text_, pos = get_annotated_text(query_text, preprocessed_query_text[text_highlight_index-1], st.session_state['pos_history'][-1]) if text_highlight_index >= 1 else ((query_text), 0)
136
+ if st.session_state['pos_highlight'] < text_highlight_index:
137
+ st.session_state['pos_history'].append(pos)
138
+ st.session_state['pos_highlight'] = text_highlight_index
139
+ annotated_text(*anotated_text_)
140
+ # annotated_text("Lol, this" , ('guy', 'SELECTED') , "is such a PR chameleon. \n\n In the Chan Zuckerberg Initiative announcement, he made it sound like he was giving away all his money to charity <PERSON> or <PERSON>. http://www.businessinsider.in/Mark-Zuckerberg-says-hes-giving-99-of-his-Facebook-shares-45-billion-to-charity/articleshow/50005321.cms Apparently, its just a VC fund. And there are still people out there who believe Facebook.org was an initiative to bring Internet to the poor.")
141
+ t4 = time.time()
142
+
143
+ # print(f"query time query text: {t3-t2}, total time: {t4-t1}")
144
+ with st.container():
145
+ st.title("Top 16 Recommended Candidates")
146
+ col1, col2, col3 = st.columns([10, 4, 2])
147
+ rec_candidates = pairwise_distances[pairwise_distances["queries"]==query_index]['candidates']
148
+ # print(rec_candidates)
149
+ l = list(rec_candidates)
150
+ with col3:
151
+ candidate_rec_index = st.number_input('Enter recommended candidate number to inspect', min_value = 0, max_value = len(l), step = 1)
152
+ print("l:",l, query_index)
153
+ pairwise_candidate_index = int(l[candidate_rec_index])
154
+ with col1:
155
+ st.header("Text")
156
+ t1 = time.time()
157
+ st.write(get_candidate_text(pairwise_candidate_index))
158
+ t2 = time.time()
159
+ with col2:
160
+ st.header("Cosine Distance")
161
+ st.write(float(pairwise_distances[\
162
+ ( pairwise_distances['queries'] == query_index ) \
163
+ &
164
+ ( pairwise_distances['candidates'] == pairwise_candidate_index)]['distances']))
165
+ print(f"candidate string retreival: {t2-t1}")
166
+ with st.container():
167
+ t1 = time.time()
168
+ st.title("Candidates With Same Authors As Query")
169
+ col1, col2, col3 = st.columns([10, 4, 2])
170
+ t2 = time.time()
171
+ gt_candidates = get_gt_candidates(current_model, query_table.iloc[query_index]['authorIDs'][0])
172
+ t3 = time.time()
173
+
174
+ with col3:
175
+ candidate_index = st.number_input('Enter ground truthnumber to inspect', min_value = 0, max_value = gt_candidates.shape[0], step = 1)
176
+ gt_candidate_index = int(gt_candidates.iloc[candidate_index]['index'])
177
+ with col1:
178
+ st.header("Text")
179
+ st.write(gt_candidates.iloc[candidate_index]['fullText'])
180
+ with col2:
181
+ t4 = time.time()
182
+ st.header("Cosine Distance")
183
+ st.write(1-cosine_similarity(np.array([embedding_dataset['queries'][query_index]['embedding']]), np.array([embedding_dataset['candidates'][gt_candidate_index]['embedding']]))[0,0])
184
+ t5 = time.time()
185
+ print(f"find gt candidates: {t3-t2}, find cosine: {t5-t4}, total: {t5-t1}")