import streamlit as st import os import pathlib import pandas as pd from collections import defaultdict import json import copy import re import tqdm import plotly.express as px from find_splitting_words import find_dividing_words from dataset_loading import load_local_qrels, load_local_corpus, load_local_queries os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" st.set_page_config(layout="wide") current_checkboxes = [] query_input = None @st.cache_data def convert_df(df): # IMPORTANT: Cache the conversion to prevent computation on every rerun return df.to_csv(path_or_buf=None, index=False, quotechar='"').encode('utf-8') def create_histogram_relevant_docs(relevant_df): # turn results into a dataframe and then plot fig = px.histogram(relevant_df, x="relevant_docs") # make it fit in one column fig.update_layout( height=400, width=250 ) return fig def get_current_data(): cur_query_data = [] cur_query = query_input.replace("\n", "\\n") for doc_id, checkbox in current_checkboxes: if checkbox: cur_query_data.append({ "new_narrative": cur_query, "qid": st.session_state.selectbox_instance, "doc_id": doc_id, "is_relevant": 0 }) # return the data as a CSV pandas return convert_df(pd.DataFrame(cur_query_data)) @st.cache_data def escape_markdown(text): # List of characters to escape # Adding backslash to the list of special characters to escape itself as well text = text.replace("``", "\"") special_chars = ['\\', '`', '*', '_', '{', '}', '[', ']', '(', ')', '#', '+', '-', '.', '!', '|', "$"] # Escaping each special character escaped_text = "".join(f"\\{char}" if char in special_chars else char for char in text) return escaped_text @st.cache_data def highlight_text(text, splitting_words): # remove anything that will mess up markdown text = escape_markdown(text) changed = False if not len(splitting_words): return text, changed def replace_function(match): return f'{match.group(0)}' # Compile a single regular expression pattern for all splitting words pattern = '|'.join([re.escape(word) for word in splitting_words]) # Perform case-insensitive replacement new_text, num_subs = re.subn(pattern, replace_function, text, flags=re.IGNORECASE) if num_subs > 0: changed = True return new_text, changed if 'cur_instance_num' not in st.session_state: st.session_state.cur_instance_num = -1 def validate(config_option, file_loaded): if config_option != "None" and file_loaded is None: st.error("Please upload a file for " + config_option) st.stop() with st.sidebar: st.title("Options") st.header("Upload corpus") corpus_file = st.file_uploader("Choose a file", key="corpus") corpus = load_local_corpus(corpus_file) st.header("Upload queries") queries_file = st.file_uploader("Choose a file", key="queries") queries = load_local_queries(queries_file) st.header("Upload qrels") qrels_file = st.file_uploader("Choose a file", key="qrels") qrels = load_local_qrels(qrels_file) ## make sure all qids in qrels are in queries and write out a warning if not if queries is not None and qrels is not None: missing_qids = set(qrels.keys()) - set(queries.keys()) | set(queries.keys()) - set(qrels.keys()) if len(missing_qids) > 0: st.warning(f"The following qids in qrels are not in queries and will be deleted: {missing_qids}") # remove them from qrels and queries for qid in missing_qids: if qid in qrels: del qrels[qid] if qid in queries: del queries[qid] data = [] for key, value in qrels.items(): data.append({"relevant_docs": len(value), "qid": key}) relevant_df = pd.DataFrame(data) z = st.header("Analysis Options") # sliderbar of how many Top N to choose n_relevant_docs = st.slider("Number of relevant docs", 1, 999, 300) col1, col2 = st.columns([1, 3], gap="large") if corpus is not None and queries is not None and qrels is not None: with st.sidebar: st.success("All files uploaded") with col1: # breakpoint() set_of_cols = set(qrels.keys()) container_for_nav = st.container() name_of_columns = sorted([item for item in set_of_cols]) instances_to_use = name_of_columns st.title("Instances") def sync_from_drop(): if st.session_state.selectbox_instance == "Overview": st.session_state.number_of_col = -1 st.session_state.cur_instance_num = -1 else: index_of_obj = name_of_columns.index(st.session_state.selectbox_instance) # print("Index of obj: ", index_of_obj, type(index_of_obj)) st.session_state.number_of_col = index_of_obj st.session_state.cur_instance_num = index_of_obj def sync_from_number(): st.session_state.cur_instance_num = st.session_state.number_of_col # print("Session state number of col: ", st.session_state.number_of_col, type(st.session_state.number_of_col)) if st.session_state.number_of_col == -1: st.session_state.selectbox_instance = "Overview" else: st.session_state.selectbox_instance = name_of_columns[st.session_state.number_of_col] number_of_col = container_for_nav.number_input(min_value=-1, step=1, max_value=len(instances_to_use) - 1, on_change=sync_from_number, label=f"Select instance by index (up to **{len(instances_to_use) - 1}**)", key="number_of_col") selectbox_instance = container_for_nav.selectbox("Select instance by ID", ["Overview"] + name_of_columns, on_change=sync_from_drop, key="selectbox_instance") st.divider() # make pie plot showing how many relevant docs there are per query histogram st.header("Relevant Docs Per Query") plotly_chart = create_histogram_relevant_docs(relevant_df) st.plotly_chart(plotly_chart) st.divider() # now show the number with relevant docs less than `n_relevant_docs` st.header("Relevant Docs Less Than {}:".format(n_relevant_docs)) st.subheader(f'{relevant_df[relevant_df["relevant_docs"] < n_relevant_docs].shape[0]} Queries') st.text_area(",".join(relevant_df[relevant_df["relevant_docs"] < n_relevant_docs].qid.tolist())) with col2: # get instance number inst_index = number_of_col if inst_index >= 0: inst_num = instances_to_use[inst_index] st.markdown("

Editor

", unsafe_allow_html=True) container = st.container() container.divider() container.subheader(f"Query") query_text = queries[str(inst_num)].strip() query_input = container.text_area(f"QID: {inst_num}", query_text) container.divider() ## Documents # relevant relevant_docs = list(qrels[str(inst_num)].keys())[:n_relevant_docs] doc_texts = [(doc_id, corpus[doc_id]["title"] if "title" in corpus[doc_id] else "", corpus[doc_id]["text"]) for doc_id in relevant_docs] splitting_words = find_dividing_words([item[1] + " " + item[2] for item in doc_texts]) # make a selectbox of these splitting words (allow multiple) container.subheader("Splitting Words") container.text("Select words that are relevant to the query") splitting_word_select = container.multiselect("Splitting Words", splitting_words, key="splitting_words") container.divider() current_checkboxes = [] total_changed = 0 highlighted_texts = [] highlighted_titles = [] for (docid, title, text) in tqdm.tqdm(doc_texts): if not len(splitting_word_select): highlighted_texts.append(text) highlighted_titles.append(title) continue highlighted_text, changed_text = highlight_text(text, splitting_word_select) highlighted_title, changed_title = highlight_text(title, splitting_word_select) highlighted_titles.append(highlighted_title) highlighted_texts.append(highlighted_text) total_changed += int(int(changed_text) or int(changed_title)) container.subheader(f"Relevant Documents ({len(list(qrels[str(inst_num)].keys()))})") container.subheader(f"Total have these words: {total_changed}") container.divider() for i, (docid, title, text) in enumerate(doc_texts): container.markdown(f"## {docid}") container.markdown(f"#### {highlighted_titles[i]}", True) container.markdown(f"\n{highlighted_texts[i]}", True) current_checkboxes.append((docid, container.checkbox(f'{docid} is Non-Relevant', key=docid))) container.divider() if st.checkbox("Download data as CSV"): st.download_button( label="Download data as CSV", data=get_current_data(), file_name=f'annotation_query_{inst_num}.csv', mime='text/csv', ) # none checked elif inst_index < 0: st.title("Overview") else: st.warning("Please choose a dataset and upload a run file. If you chose \"custom\" be sure that you uploaded all files (queries, corpus, qrels)")