import streamlit as st import os import pathlib import pandas as pd from collections import defaultdict import json import copy import re import tqdm import plotly.express as px from find_splitting_words import find_dividing_words from dataset_loading import load_local_qrels, load_local_corpus, load_local_queries os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" st.set_page_config(layout="wide") current_checkboxes = [] query_input = None @st.cache_data def convert_df(df): # IMPORTANT: Cache the conversion to prevent computation on every rerun return df.to_csv(path_or_buf=None, index=False, quotechar='"').encode('utf-8') def create_histogram_relevant_docs(relevant_df): # turn results into a dataframe and then plot fig = px.histogram(relevant_df, x="relevant_docs") # make it fit in one column fig.update_layout( height=400, width=250 ) return fig def get_current_data(): cur_query_data = [] cur_query = query_input.replace("\n", "\\n") for doc_id, checkbox in current_checkboxes: if checkbox: cur_query_data.append({ "new_narrative": cur_query, "qid": st.session_state.selectbox_instance, "doc_id": doc_id, "is_relevant": 0 }) # return the data as a CSV pandas return convert_df(pd.DataFrame(cur_query_data)) @st.cache_data def escape_markdown(text): # List of characters to escape # Adding backslash to the list of special characters to escape itself as well text = text.replace("``", "\"") special_chars = ['\\', '`', '*', '_', '{', '}', '[', ']', '(', ')', '#', '+', '-', '.', '!', '|', "$"] # Escaping each special character escaped_text = "".join(f"\\{char}" if char in special_chars else char for char in text) return escaped_text @st.cache_data def highlight_text(text, splitting_words): # remove anything that will mess up markdown text = escape_markdown(text) changed = False if not len(splitting_words): return text, changed def replace_function(match): return f'{match.group(0)}' # Compile a single regular expression pattern for all splitting words pattern = '|'.join([re.escape(word) for word in splitting_words]) # Perform case-insensitive replacement new_text, num_subs = re.subn(pattern, replace_function, text, flags=re.IGNORECASE) if num_subs > 0: changed = True return new_text, changed if 'cur_instance_num' not in st.session_state: st.session_state.cur_instance_num = -1 def validate(config_option, file_loaded): if config_option != "None" and file_loaded is None: st.error("Please upload a file for " + config_option) st.stop() with st.sidebar: st.title("Options") st.header("Upload corpus") corpus_file = st.file_uploader("Choose a file", key="corpus") corpus = load_local_corpus(corpus_file) st.header("Upload queries") queries_file = st.file_uploader("Choose a file", key="queries") queries = load_local_queries(queries_file) st.header("Upload qrels") qrels_file = st.file_uploader("Choose a file", key="qrels") qrels = load_local_qrels(qrels_file) ## make sure all qids in qrels are in queries and write out a warning if not if queries is not None and qrels is not None: missing_qids = set(qrels.keys()) - set(queries.keys()) | set(queries.keys()) - set(qrels.keys()) if len(missing_qids) > 0: st.warning(f"The following qids in qrels are not in queries and will be deleted: {missing_qids}") # remove them from qrels and queries for qid in missing_qids: if qid in qrels: del qrels[qid] if qid in queries: del queries[qid] data = [] for key, value in qrels.items(): data.append({"relevant_docs": len(value), "qid": key}) relevant_df = pd.DataFrame(data) z = st.header("Analysis Options") # sliderbar of how many Top N to choose n_relevant_docs = st.slider("Number of relevant docs", 1, 999, 300) col1, col2 = st.columns([1, 3], gap="large") if corpus is not None and queries is not None and qrels is not None: with st.sidebar: st.success("All files uploaded") with col1: # breakpoint() set_of_cols = set(qrels.keys()) container_for_nav = st.container() name_of_columns = sorted([item for item in set_of_cols]) instances_to_use = name_of_columns st.title("Instances") def sync_from_drop(): if st.session_state.selectbox_instance == "Overview": st.session_state.number_of_col = -1 st.session_state.cur_instance_num = -1 else: index_of_obj = name_of_columns.index(st.session_state.selectbox_instance) # print("Index of obj: ", index_of_obj, type(index_of_obj)) st.session_state.number_of_col = index_of_obj st.session_state.cur_instance_num = index_of_obj def sync_from_number(): st.session_state.cur_instance_num = st.session_state.number_of_col # print("Session state number of col: ", st.session_state.number_of_col, type(st.session_state.number_of_col)) if st.session_state.number_of_col == -1: st.session_state.selectbox_instance = "Overview" else: st.session_state.selectbox_instance = name_of_columns[st.session_state.number_of_col] number_of_col = container_for_nav.number_input(min_value=-1, step=1, max_value=len(instances_to_use) - 1, on_change=sync_from_number, label=f"Select instance by index (up to **{len(instances_to_use) - 1}**)", key="number_of_col") selectbox_instance = container_for_nav.selectbox("Select instance by ID", ["Overview"] + name_of_columns, on_change=sync_from_drop, key="selectbox_instance") st.divider() # make pie plot showing how many relevant docs there are per query histogram st.header("Relevant Docs Per Query") plotly_chart = create_histogram_relevant_docs(relevant_df) st.plotly_chart(plotly_chart) st.divider() # now show the number with relevant docs less than `n_relevant_docs` st.header("Relevant Docs Less Than {}:".format(n_relevant_docs)) st.subheader(f'{relevant_df[relevant_df["relevant_docs"] < n_relevant_docs].shape[0]} Queries') st.text_area(",".join(relevant_df[relevant_df["relevant_docs"] < n_relevant_docs].qid.tolist())) with col2: # get instance number inst_index = number_of_col if inst_index >= 0: inst_num = instances_to_use[inst_index] st.markdown("