Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import pathlib | |
import pandas as pd | |
from collections import defaultdict | |
import json | |
import copy | |
import re | |
import tqdm | |
import plotly.express as px | |
import pandas as pd | |
from nltk.corpus import stopwords | |
from nltk.stem import PorterStemmer | |
from nltk.tokenize import word_tokenize | |
from collections import Counter | |
import string | |
import os | |
import streamlit as st | |
# Ensure you've downloaded the set of stop words the first time you run this | |
import nltk | |
# only download if they don't exist | |
# if not os.path.exists(os.path.join(nltk.data.find('corpora'), 'stopwords')): | |
nltk.download('punkt') | |
nltk.download('stopwords') | |
from dataset_loading import load_local_qrels, load_local_corpus, load_local_queries | |
def preprocess_document(doc): | |
""" | |
Tokenizes, removes punctuation, stopwords, and stems words in a single document. | |
""" | |
# Lowercase | |
doc = doc.lower() | |
# Remove punctuation | |
doc = doc.translate(str.maketrans('', '', string.punctuation)) | |
# Tokenize | |
tokens = word_tokenize(doc) | |
# Remove stop words | |
stop_words = set(stopwords.words('english')) | |
filtered_tokens = [word for word in tokens if word not in stop_words] | |
# Stemming | |
stemmer = PorterStemmer() | |
stemmed_tokens = [stemmer.stem(word) for word in filtered_tokens] | |
return stemmed_tokens | |
def find_dividing_words(documents): | |
""" | |
Identifies candidate words that might split the set of documents into two groups. | |
""" | |
all_words = [] | |
per_doc_word_counts = [] | |
i = 0 | |
for doc in documents: | |
print(i) | |
preprocessed_doc = preprocess_document(doc) | |
all_words.extend(preprocessed_doc) | |
per_doc_word_counts.append(Counter(preprocessed_doc)) | |
i += 1 | |
# Overall word frequency | |
overall_word_counts = Counter(all_words) | |
# Find words that appear in roughly half the documents | |
num_docs = len(documents) | |
candidate_words = [] | |
for word, count in overall_word_counts.items(): | |
doc_frequency = sum(1 for doc_count in per_doc_word_counts if doc_count[word] > 0) | |
if 0.35 * num_docs <= doc_frequency <= 0.75 * num_docs: | |
candidate_words.append(word) | |
print("Done with dividing words") | |
return candidate_words | |
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" | |
st.set_page_config(layout="wide") | |
current_checkboxes = [] | |
query_input = None | |
def convert_df(df): | |
# IMPORTANT: Cache the conversion to prevent computation on every rerun | |
return df.to_csv(path_or_buf=None, index=False, quotechar='"').encode('utf-8') | |
def create_histogram_relevant_docs(relevant_df): | |
# turn results into a dataframe and then plot | |
fig = px.histogram(relevant_df, x="relevant_docs") | |
# make it fit in one column | |
fig.update_layout( | |
height=400, | |
width=250 | |
) | |
return fig | |
def get_current_data(): | |
cur_query_data = [] | |
cur_query = query_input.replace("\n", "\\n") | |
for doc_id, checkbox in current_checkboxes: | |
if checkbox: | |
cur_query_data.append({ | |
"new_narrative": cur_query, | |
"qid": st.session_state.selectbox_instance, | |
"doc_id": doc_id, | |
"is_relevant": 0 | |
}) | |
# return the data as a CSV pandas | |
return convert_df(pd.DataFrame(cur_query_data)) | |
def escape_markdown(text): | |
# List of characters to escape | |
# Adding backslash to the list of special characters to escape itself as well | |
text = text.replace("``", "\"") | |
text = text.replace("$", "\$") | |
special_chars = ['\\', '`', '*', '_', '{', '}', '[', ']', '(', ')', '#', '+', '-', '.', '!', '|', "$"] | |
# Escaping each special character | |
escaped_text = "".join(f"\\{char}" if char in special_chars else char for char in text) | |
return escaped_text | |
def highlight_text(text, splitting_words): | |
# remove anything that will mess up markdown | |
text = escape_markdown(text) | |
changed = False | |
if not len(splitting_words): | |
return text, changed | |
def replace_function(match): | |
return f'<span style="background-color: #FFFF00">{match.group(0)}</span>' | |
# Compile a single regular expression pattern for all splitting words | |
pattern = '|'.join([re.escape(word) for word in splitting_words]) | |
# Perform case-insensitive replacement | |
new_text, num_subs = re.subn(pattern, replace_function, text, flags=re.IGNORECASE) | |
if num_subs > 0: | |
changed = True | |
return new_text, changed | |
if 'cur_instance_num' not in st.session_state: | |
st.session_state.cur_instance_num = -1 | |
def validate(config_option, file_loaded): | |
if config_option != "None" and file_loaded is None: | |
st.error("Please upload a file for " + config_option) | |
st.stop() | |
with st.sidebar: | |
st.title("Options") | |
st.header("Upload corpus") | |
corpus_file = st.file_uploader("Choose a file", key="corpus") | |
corpus = load_local_corpus(corpus_file) | |
st.header("Upload queries") | |
queries_file = st.file_uploader("Choose a file", key="queries") | |
queries = load_local_queries(queries_file) | |
st.header("Upload qrels") | |
qrels_file = st.file_uploader("Choose a file", key="qrels") | |
qrels = load_local_qrels(qrels_file) | |
## make sure all qids in qrels are in queries and write out a warning if not | |
if queries is not None and qrels is not None: | |
missing_qids = set(qrels.keys()) - set(queries.keys()) | set(queries.keys()) - set(qrels.keys()) | |
if len(missing_qids) > 0: | |
st.warning(f"The following qids in qrels are not in queries and will be deleted: {missing_qids}") | |
# remove them from qrels and queries | |
for qid in missing_qids: | |
if qid in qrels: | |
del qrels[qid] | |
if qid in queries: | |
del queries[qid] | |
data = [] | |
for key, value in qrels.items(): | |
data.append({"relevant_docs": len(value), "qid": key}) | |
relevant_df = pd.DataFrame(data) | |
z = st.header("Analysis Options") | |
# sliderbar of how many Top N to choose | |
n_relevant_docs = st.slider("Number of relevant docs", 1, 999, 100) | |
col1, col2 = st.columns([1, 3], gap="large") | |
if corpus is not None and queries is not None and qrels is not None: | |
with st.sidebar: | |
st.success("All files uploaded") | |
with col1: | |
# breakpoint() | |
qids_with_less = relevant_df[relevant_df["relevant_docs"] < n_relevant_docs].qid.tolist() | |
set_of_cols = set(qrels.keys()).intersection(set(qids_with_less)) | |
container_for_nav = st.container() | |
name_of_columns = sorted([item for item in set_of_cols]) | |
instances_to_use = name_of_columns | |
st.title("Instances") | |
def sync_from_drop(): | |
if st.session_state.selectbox_instance == "Overview": | |
st.session_state.number_of_col = -1 | |
st.session_state.cur_instance_num = -1 | |
else: | |
index_of_obj = name_of_columns.index(st.session_state.selectbox_instance) | |
# print("Index of obj: ", index_of_obj, type(index_of_obj)) | |
st.session_state.number_of_col = index_of_obj | |
st.session_state.cur_instance_num = index_of_obj | |
def sync_from_number(): | |
st.session_state.cur_instance_num = st.session_state.number_of_col | |
# print("Session state number of col: ", st.session_state.number_of_col, type(st.session_state.number_of_col)) | |
if st.session_state.number_of_col == -1: | |
st.session_state.selectbox_instance = "Overview" | |
else: | |
st.session_state.selectbox_instance = name_of_columns[st.session_state.number_of_col] | |
number_of_col = container_for_nav.number_input(min_value=-1, step=1, max_value=len(instances_to_use) - 1, on_change=sync_from_number, label=f"Select instance by index (up to **{len(instances_to_use) - 1}**)", key="number_of_col") | |
selectbox_instance = container_for_nav.selectbox("Select instance by ID", ["Overview"] + name_of_columns, on_change=sync_from_drop, key="selectbox_instance") | |
st.divider() | |
# make pie plot showing how many relevant docs there are per query histogram | |
st.header("Relevant Docs Per Query") | |
plotly_chart = create_histogram_relevant_docs(relevant_df) | |
st.plotly_chart(plotly_chart) | |
st.divider() | |
# now show the number with relevant docs less than `n_relevant_docs` | |
st.header("Relevant Docs Less Than {}:".format(n_relevant_docs)) | |
st.subheader(f'{relevant_df[relevant_df["relevant_docs"] < n_relevant_docs].shape[0]} Queries') | |
st.markdown(",".join(relevant_df[relevant_df["relevant_docs"] < n_relevant_docs].qid.tolist())) | |
with col2: | |
# get instance number | |
inst_index = number_of_col | |
if inst_index >= 0: | |
inst_num = instances_to_use[inst_index] | |
st.markdown("<h1 style='text-align: center; color: black;text-decoration: underline;'>Editor</h1>", unsafe_allow_html=True) | |
container = st.container() | |
container.divider() | |
container.subheader(f"Query") | |
query_text = queries[str(inst_num)].strip() | |
query_input = container.text_area(f"QID: {inst_num}", query_text) | |
container.divider() | |
## Documents | |
# relevant | |
relevant_docs = list(qrels[str(inst_num)].keys())[:n_relevant_docs] | |
doc_texts = [(doc_id, corpus[doc_id]["title"] if "title" in corpus[doc_id] else "", corpus[doc_id]["text"]) for doc_id in relevant_docs] | |
splitting_words = find_dividing_words([item[1] + " " + item[2] for item in doc_texts]) | |
# make a selectbox of these splitting words (allow multiple) | |
container.subheader("Splitting Words") | |
container.text("Select words that are relevant to the query") | |
splitting_word_select = container.multiselect("Splitting Words", splitting_words, key="splitting_words") | |
container.divider() | |
current_checkboxes = [] | |
total_changed = 0 | |
highlighted_texts = [] | |
highlighted_titles = [] | |
for (docid, title, text) in tqdm.tqdm(doc_texts): | |
if not len(splitting_word_select): | |
highlighted_texts.append(text) | |
highlighted_titles.append(title) | |
continue | |
highlighted_text, changed_text = highlight_text(text, splitting_word_select) | |
highlighted_title, changed_title = highlight_text(title, splitting_word_select) | |
highlighted_titles.append(highlighted_title) | |
highlighted_texts.append(highlighted_text) | |
total_changed += int(int(changed_text) or int(changed_title)) | |
container.subheader(f"Relevant Documents ({len(list(qrels[str(inst_num)].keys()))})") | |
container.subheader(f"Total have these words: {total_changed}") | |
container.divider() | |
for i, (docid, title, text) in enumerate(doc_texts): | |
container.markdown(f"## {docid}") | |
container.markdown(f"#### {highlighted_titles[i]}", True) | |
container.markdown(f"\n{highlighted_texts[i]}", True) | |
current_checkboxes.append((docid, container.checkbox(f'{docid} is Non-Relevant', key=docid))) | |
container.divider() | |
if st.checkbox("Download data as CSV"): | |
st.download_button( | |
label="Download data as CSV", | |
data=get_current_data(), | |
file_name=f'annotation_query_{inst_num}.csv', | |
mime='text/csv', | |
) | |
# none checked | |
elif inst_index < 0: | |
st.title("Overview") | |
else: | |
st.warning("Please choose a dataset and upload a run file. If you chose \"custom\" be sure that you uploaded all files (queries, corpus, qrels)") |