Spaces:
Sleeping
Sleeping
Orion Weller
commited on
Commit
·
a09b56d
1
Parent(s):
56649db
saliency maps
Browse files- .gitignore +3 -1
- analysis.py +93 -1
- app.py +88 -11
- dataset_loading.py +11 -2
- requirements.txt +3 -1
.gitignore
CHANGED
@@ -1,3 +1,5 @@
|
|
1 |
datasets/
|
2 |
__pycache__/
|
3 |
-
env/
|
|
|
|
|
|
1 |
datasets/
|
2 |
__pycache__/
|
3 |
+
env/
|
4 |
+
.ipynb_checkpoints/
|
5 |
+
*.ipynb
|
analysis.py
CHANGED
@@ -1,8 +1,21 @@
|
|
1 |
import pandas as pd
|
2 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
3 |
import plotly.express as px
|
4 |
import plotly.figure_factory as ff
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def results_to_df(results: dict, metric_name: str):
|
8 |
metric_scores = []
|
@@ -38,4 +51,83 @@ def create_boxplot_diff(results1, results2, metric_name):
|
|
38 |
|
39 |
x_axis = f"Difference in {metric_name} from 1 to 2"
|
40 |
fig = px.histogram(pd.DataFrame({x_axis: diff}), x=x_axis, marginal="box")
|
41 |
-
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
import numpy as np
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
from transformers import pipeline
|
6 |
+
import streamlit as st
|
7 |
+
|
8 |
import plotly.express as px
|
9 |
import plotly.figure_factory as ff
|
10 |
|
11 |
+
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization
|
12 |
+
from captum.attr import visualization as viz
|
13 |
+
from captum import attr
|
14 |
+
from captum.attr._utils.visualization import format_word_importances, format_special_tokens, _get_color
|
15 |
+
|
16 |
+
|
17 |
+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
18 |
+
|
19 |
|
20 |
def results_to_df(results: dict, metric_name: str):
|
21 |
metric_scores = []
|
|
|
51 |
|
52 |
x_axis = f"Difference in {metric_name} from 1 to 2"
|
53 |
fig = px.histogram(pd.DataFrame({x_axis: diff}), x=x_axis, marginal="box")
|
54 |
+
return fig
|
55 |
+
|
56 |
+
|
57 |
+
def summarize_attributions(attributions):
|
58 |
+
attributions = attributions.sum(dim=-1).squeeze(0)
|
59 |
+
attributions = attributions / torch.norm(attributions)
|
60 |
+
return attributions
|
61 |
+
|
62 |
+
|
63 |
+
def get_words(words, importances):
|
64 |
+
words_colored = []
|
65 |
+
for word, importance in zip(words, importances[: len(words)]):
|
66 |
+
word = format_special_tokens(word)
|
67 |
+
color = _get_color(importance)
|
68 |
+
unwrapped_tag = '<span style="background-color: {color}; opacity:1.0; line-height:1.75">{word}</span>'.format(
|
69 |
+
color=color, word=word
|
70 |
+
)
|
71 |
+
words_colored.append(unwrapped_tag)
|
72 |
+
return words_colored
|
73 |
+
|
74 |
+
@st.cache_resource
|
75 |
+
def get_model(model_name: str):
|
76 |
+
if model_name == "MonoT5":
|
77 |
+
pipe = pipeline('text2text-generation',
|
78 |
+
model='castorini/monot5-small-msmarco-10k',
|
79 |
+
tokenizer='castorini/monot5-small-msmarco-10k',
|
80 |
+
device='cpu')
|
81 |
+
def formatter(query, doc):
|
82 |
+
return f"Query: {query} Document: {doc} Relevant:"
|
83 |
+
|
84 |
+
return pipe, formatter
|
85 |
+
|
86 |
+
def prep_func(pipe, formatter):
|
87 |
+
# variables that only need to be run once
|
88 |
+
decoder_input_ids = pipe.tokenizer(["<pad>"], return_tensors="pt", add_special_tokens=False, truncation=True).input_ids.to('cpu')
|
89 |
+
decoder_embedding_layer = pipe.model.base_model.decoder.embed_tokens
|
90 |
+
decoder_inputs_emb = decoder_embedding_layer(decoder_input_ids)
|
91 |
+
|
92 |
+
token_false_id = pipe.tokenizer.get_vocab()['▁false']
|
93 |
+
token_true_id = pipe.tokenizer.get_vocab()["▁true"]
|
94 |
+
|
95 |
+
# this function needs to be run for each combination
|
96 |
+
@st.cache_data
|
97 |
+
def get_saliency(query, doc):
|
98 |
+
input_ids = pipe.tokenizer(
|
99 |
+
[formatter(query, doc)],
|
100 |
+
padding=False,
|
101 |
+
truncation=True,
|
102 |
+
return_tensors="pt",
|
103 |
+
max_length=pipe.tokenizer.model_max_length,
|
104 |
+
)["input_ids"].to('cpu')
|
105 |
+
|
106 |
+
embedding_layer = pipe.model.base_model.encoder.embed_tokens
|
107 |
+
inputs_emb = embedding_layer(input_ids)
|
108 |
+
|
109 |
+
def forward_from_embeddings(inputs_embeds, decoder_inputs_embeds):
|
110 |
+
logits = pipe.model.forward(inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds)['logits'][:, -1, :]
|
111 |
+
batch_scores = logits[:, [token_false_id, token_true_id]]
|
112 |
+
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
113 |
+
scores = batch_scores[:, 1].exp() # relevant token
|
114 |
+
return scores
|
115 |
+
|
116 |
+
lig = attr.Saliency(forward_from_embeddings)
|
117 |
+
attributions_ig, delta = lig.attribute(
|
118 |
+
inputs=(inputs_emb, decoder_inputs_emb)
|
119 |
+
)
|
120 |
+
attributions_normed = summarize_attributions(attributions_ig)
|
121 |
+
return "\n".join(get_words(pipe.tokenizer.convert_ids_to_tokens(input_ids.squeeze(0).tolist()), attributions_normed))
|
122 |
+
|
123 |
+
return get_saliency
|
124 |
+
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
query = "how to add dll to visual studio?"
|
128 |
+
doc = "StackOverflow In the days of 16-bit Windows, a WPARAM was a 16-bit word, while LPARAM was a 32-bit long. These distinctions went away in Win32; they both became 32-bit values. ... WPARAM is defined as UINT_PTR , which in 64-bit Windows is an unsigned, 64-bit value."
|
129 |
+
model, formatter = get_model("MonoT5")
|
130 |
+
get_saliency = prep_func(model, formatter)
|
131 |
+
print(get_saliency(query, doc))
|
132 |
+
|
133 |
+
|
app.py
CHANGED
@@ -13,9 +13,10 @@ import plotly.express as px
|
|
13 |
|
14 |
from constants import ALL_DATASETS, ALL_METRICS
|
15 |
from dataset_loading import get_dataset, load_run, load_local_qrels, load_local_corpus, load_local_queries
|
16 |
-
from analysis import create_boxplot_1df, create_boxplot_2df, create_boxplot_diff
|
17 |
|
18 |
|
|
|
19 |
st.set_page_config(layout="wide")
|
20 |
|
21 |
|
@@ -41,6 +42,7 @@ def check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus)
|
|
41 |
return True
|
42 |
return False
|
43 |
|
|
|
44 |
def validate(config_option, file_loaded):
|
45 |
if config_option != "None" and file_loaded is None:
|
46 |
st.error("Please upload a file for " + config_option)
|
@@ -90,6 +92,14 @@ with st.sidebar:
|
|
90 |
incorrect_only = st.checkbox("Show only incorrect instances", value=False)
|
91 |
one_better_than_two = st.checkbox("Show only instances where run 1 is better than run 2", value=False)
|
92 |
two_better_than_one = st.checkbox("Show only instances where run 2 is better than run 1", value=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
advanced_options1 = st.checkbox("Show advanced options for Run 1", value=False)
|
94 |
doc_expansion1 = doc_expansion2 = None
|
95 |
query_expansion1 = query_expansion2 = None
|
@@ -307,9 +317,16 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
|
|
307 |
if doc_expansion1 is not None and run1_uses_doc_expansion != "None" and not show_orig_rel:
|
308 |
alt_text = doc_expansion1[docid]["text"]
|
309 |
text = combine(text, alt_text, run1_uses_doc_expansion)
|
310 |
-
st.text_area(f"{docid}:", text)
|
311 |
|
312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
|
314 |
pred_doc = run1_pandas[run1_pandas.doc_id.isin(relevant_docs)]
|
315 |
rank_pred = pred_doc[pred_doc.qid == str(inst_num)]["rank"].tolist()
|
@@ -320,6 +337,7 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
|
|
320 |
ranking_str = "--"
|
321 |
rank_col.metric(f"Rank of Relevant Doc(s)", ranking_str)
|
322 |
|
|
|
323 |
st.divider()
|
324 |
|
325 |
# top ranked
|
@@ -336,10 +354,22 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
|
|
336 |
for d_idx, doc in enumerate(run1_top_n_docs):
|
337 |
alt_text = run1_top_n_docs_alt[d_idx]["text"]
|
338 |
doc_text = combine(doc["text"], alt_text, run1_uses_doc_expansion)
|
339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
else:
|
341 |
for d_idx, doc in enumerate(run1_top_n_docs):
|
342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
st.divider()
|
344 |
|
345 |
# none checked
|
@@ -384,20 +414,28 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
|
|
384 |
combined_text2 = combine(query_text_og, alt_text2, run2_uses_query_expansion)
|
385 |
col_run1.markdown(combined_text1)
|
386 |
col_run2.markdown(combined_text2)
|
|
|
|
|
387 |
elif run1_uses_query_expansion != "None":
|
388 |
alt_text = query_expansion1[str(inst_num)]
|
389 |
combined_text1 = combine(query_text_og, alt_text, run1_uses_query_expansion)
|
390 |
col_run1.markdown(combined_text1)
|
391 |
col_run2.markdown(query_text_og)
|
|
|
|
|
392 |
elif run2_uses_query_expansion != "None":
|
393 |
alt_text = query_expansion2[str(inst_num)]
|
394 |
combined_text2 = combine(query_text_og, alt_text, run2_uses_query_expansion)
|
395 |
col_run1.markdown(query_text_og)
|
396 |
col_run2.markdown(combined_text2)
|
|
|
|
|
397 |
else:
|
398 |
query_text = query_text_og
|
399 |
col_run1.markdown(query_text)
|
400 |
col_run2.markdown(query_text)
|
|
|
|
|
401 |
|
402 |
st.divider()
|
403 |
|
@@ -420,13 +458,27 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
|
|
420 |
if doc_expansion1 is not None and run1_uses_doc_expansion != "None" and not show_orig_rel1:
|
421 |
alt_text = doc_expansion1[docid]["text"]
|
422 |
text = combine(text, alt_text, run1_uses_doc_expansion)
|
423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
|
425 |
for (docid, title, text) in doc_texts:
|
426 |
if doc_expansion2 is not None and run2_uses_doc_expansion != "None" and not show_orig_rel2:
|
427 |
alt_text = doc_expansion2[docid]["text"]
|
428 |
text = combine(text, alt_text, run2_uses_doc_expansion)
|
429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
|
431 |
# top ranked
|
432 |
# NOTE: BEIR calls trec_eval which ranks by score, then doc_id for ties
|
@@ -474,10 +526,23 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
|
|
474 |
for d_idx, doc in enumerate(run1_top_n_docs):
|
475 |
alt_text = run1_top_n_docs_alt[d_idx]["text"]
|
476 |
doc_text = combine(doc["text"], alt_text, run1_uses_doc_expansion)
|
477 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
478 |
else:
|
479 |
for d_idx, doc in enumerate(run1_top_n_docs):
|
480 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
|
482 |
if col_run2.checkbox('Show top ranked documents for Run 2', key=f"{inst_index}top-2run"):
|
483 |
col_run2.subheader("Top N Ranked Documents")
|
@@ -492,10 +557,22 @@ if check_valid_args(run1_file, run2_file, dataset_name, qrels, queries, corpus):
|
|
492 |
for d_idx, doc in enumerate(run2_top_n_docs):
|
493 |
alt_text = run2_top_n_docs_alt[d_idx]["text"]
|
494 |
doc_text = combine(doc["text"], alt_text, run2_uses_doc_expansion)
|
495 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
496 |
else:
|
497 |
for d_idx, doc in enumerate(run2_top_n_docs):
|
498 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
|
500 |
st.divider()
|
501 |
|
|
|
13 |
|
14 |
from constants import ALL_DATASETS, ALL_METRICS
|
15 |
from dataset_loading import get_dataset, load_run, load_local_qrels, load_local_corpus, load_local_queries
|
16 |
+
from analysis import create_boxplot_1df, create_boxplot_2df, create_boxplot_diff, get_model, prep_func
|
17 |
|
18 |
|
19 |
+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
20 |
st.set_page_config(layout="wide")
|
21 |
|
22 |
|
|
|
42 |
return True
|
43 |
return False
|
44 |
|
45 |
+
|
46 |
def validate(config_option, file_loaded):
|
47 |
if config_option != "None" and file_loaded is None:
|
48 |
st.error("Please upload a file for " + config_option)
|
|
|
92 |
incorrect_only = st.checkbox("Show only incorrect instances", value=False)
|
93 |
one_better_than_two = st.checkbox("Show only instances where run 1 is better than run 2", value=False)
|
94 |
two_better_than_one = st.checkbox("Show only instances where run 2 is better than run 1", value=False)
|
95 |
+
use_model_saliency = st.checkbox("Use model saliency (slow!)", value=False)
|
96 |
+
if use_model_saliency:
|
97 |
+
# choose from a list of models
|
98 |
+
model_name = st.selectbox("Choose from a list of models", ["MonoT5"])
|
99 |
+
model, formatter = get_model("MonoT5")
|
100 |
+
get_saliency = prep_func(model, formatter)
|
101 |
+
|
102 |
+
|
103 |
advanced_options1 = st.checkbox("Show advanced options for Run 1", value=False)
|
104 |
doc_expansion1 = doc_expansion2 = None
|
105 |
query_expansion1 = query_expansion2 = None
|
|
|
317 |
if doc_expansion1 is not None and run1_uses_doc_expansion != "None" and not show_orig_rel:
|
318 |
alt_text = doc_expansion1[docid]["text"]
|
319 |
text = combine(text, alt_text, run1_uses_doc_expansion)
|
|
|
320 |
|
321 |
+
if use_model_saliency:
|
322 |
+
if st.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency", value=False):
|
323 |
+
st.markdown(get_saliency(query_text, doc_texts),unsafe_allow_html=True)
|
324 |
+
else:
|
325 |
+
st.text_area(f"{docid}:", text)
|
326 |
+
|
327 |
+
else:
|
328 |
+
st.text_area(f"{docid}:", text)
|
329 |
+
|
330 |
|
331 |
pred_doc = run1_pandas[run1_pandas.doc_id.isin(relevant_docs)]
|
332 |
rank_pred = pred_doc[pred_doc.qid == str(inst_num)]["rank"].tolist()
|
|
|
337 |
ranking_str = "--"
|
338 |
rank_col.metric(f"Rank of Relevant Doc(s)", ranking_str)
|
339 |
|
340 |
+
|
341 |
st.divider()
|
342 |
|
343 |
# top ranked
|
|
|
354 |
for d_idx, doc in enumerate(run1_top_n_docs):
|
355 |
alt_text = run1_top_n_docs_alt[d_idx]["text"]
|
356 |
doc_text = combine(doc["text"], alt_text, run1_uses_doc_expansion)
|
357 |
+
if use_model_saliency:
|
358 |
+
if st.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency", value=False):
|
359 |
+
st.markdown(get_saliency(query_text, doc_text),unsafe_allow_html=True)
|
360 |
+
else:
|
361 |
+
st.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}")
|
362 |
+
else:
|
363 |
+
st.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}")
|
364 |
else:
|
365 |
for d_idx, doc in enumerate(run1_top_n_docs):
|
366 |
+
if use_model_saliency:
|
367 |
+
if st.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{d_idx}ranked", value=False):
|
368 |
+
st.markdown(get_saliency(query_text, doc),unsafe_allow_html=True)
|
369 |
+
else:
|
370 |
+
st.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}")
|
371 |
+
else:
|
372 |
+
st.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}")
|
373 |
st.divider()
|
374 |
|
375 |
# none checked
|
|
|
414 |
combined_text2 = combine(query_text_og, alt_text2, run2_uses_query_expansion)
|
415 |
col_run1.markdown(combined_text1)
|
416 |
col_run2.markdown(combined_text2)
|
417 |
+
query_text1 = combined_text1
|
418 |
+
query_text2 = combined_text2
|
419 |
elif run1_uses_query_expansion != "None":
|
420 |
alt_text = query_expansion1[str(inst_num)]
|
421 |
combined_text1 = combine(query_text_og, alt_text, run1_uses_query_expansion)
|
422 |
col_run1.markdown(combined_text1)
|
423 |
col_run2.markdown(query_text_og)
|
424 |
+
query_text1 = combined_text1
|
425 |
+
query_text2 = query_text_og
|
426 |
elif run2_uses_query_expansion != "None":
|
427 |
alt_text = query_expansion2[str(inst_num)]
|
428 |
combined_text2 = combine(query_text_og, alt_text, run2_uses_query_expansion)
|
429 |
col_run1.markdown(query_text_og)
|
430 |
col_run2.markdown(combined_text2)
|
431 |
+
query_text1 = query_text_og
|
432 |
+
query_text2 = combined_text2
|
433 |
else:
|
434 |
query_text = query_text_og
|
435 |
col_run1.markdown(query_text)
|
436 |
col_run2.markdown(query_text)
|
437 |
+
query_text1 = query_text
|
438 |
+
query_text2 = query_text
|
439 |
|
440 |
st.divider()
|
441 |
|
|
|
458 |
if doc_expansion1 is not None and run1_uses_doc_expansion != "None" and not show_orig_rel1:
|
459 |
alt_text = doc_expansion1[docid]["text"]
|
460 |
text = combine(text, alt_text, run1_uses_doc_expansion)
|
461 |
+
|
462 |
+
if use_model_saliency:
|
463 |
+
if col_run1.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{docid}relevant", value=False):
|
464 |
+
col_run1.markdown(get_saliency(query_text1, text),unsafe_allow_html=True)
|
465 |
+
else:
|
466 |
+
col_run1.text_area(f"{docid}:", text, key=f"{inst_num}doc{docid}1")
|
467 |
+
else:
|
468 |
+
col_run1.text_area(f"{docid}:", text, key=f"{inst_num}doc{docid}1")
|
469 |
|
470 |
for (docid, title, text) in doc_texts:
|
471 |
if doc_expansion2 is not None and run2_uses_doc_expansion != "None" and not show_orig_rel2:
|
472 |
alt_text = doc_expansion2[docid]["text"]
|
473 |
text = combine(text, alt_text, run2_uses_doc_expansion)
|
474 |
+
|
475 |
+
if use_model_saliency:
|
476 |
+
if col_run2.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{docid}relevant2", value=False):
|
477 |
+
col_run2.markdown(get_saliency(query_text2, text),unsafe_allow_html=True)
|
478 |
+
else:
|
479 |
+
col_run2.text_area(f"{docid}:", text, key=f"{inst_num}doc{docid}2")
|
480 |
+
else:
|
481 |
+
col_run2.text_area(f"{docid}:", text, key=f"{inst_num}doc{docid}2")
|
482 |
|
483 |
# top ranked
|
484 |
# NOTE: BEIR calls trec_eval which ranks by score, then doc_id for ties
|
|
|
526 |
for d_idx, doc in enumerate(run1_top_n_docs):
|
527 |
alt_text = run1_top_n_docs_alt[d_idx]["text"]
|
528 |
doc_text = combine(doc["text"], alt_text, run1_uses_doc_expansion)
|
529 |
+
if use_model_saliency:
|
530 |
+
if col_run1.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{d_idx}ranked1", value=False):
|
531 |
+
col_run1.markdown(get_saliency(query_text1, doc_text),unsafe_allow_html=True)
|
532 |
+
else:
|
533 |
+
col_run1.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}1")
|
534 |
+
else:
|
535 |
+
col_run1.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}1")
|
536 |
else:
|
537 |
for d_idx, doc in enumerate(run1_top_n_docs):
|
538 |
+
if use_model_saliency:
|
539 |
+
if col_run1.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{d_idx}ranked1", value=False):
|
540 |
+
col_run1.markdown(get_saliency(query_text1, doc),unsafe_allow_html=True)
|
541 |
+
else:
|
542 |
+
col_run1.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}1")
|
543 |
+
else:
|
544 |
+
col_run1.text_area(f"{run1_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}1")
|
545 |
+
|
546 |
|
547 |
if col_run2.checkbox('Show top ranked documents for Run 2', key=f"{inst_index}top-2run"):
|
548 |
col_run2.subheader("Top N Ranked Documents")
|
|
|
557 |
for d_idx, doc in enumerate(run2_top_n_docs):
|
558 |
alt_text = run2_top_n_docs_alt[d_idx]["text"]
|
559 |
doc_text = combine(doc["text"], alt_text, run2_uses_doc_expansion)
|
560 |
+
if use_model_saliency:
|
561 |
+
if col_run2.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{d_idx}ranked2", value=False):
|
562 |
+
col_run2.markdown(get_saliency(query_text2, doc_text),unsafe_allow_html=True)
|
563 |
+
else:
|
564 |
+
col_run2.text_area(f"{run2_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}2")
|
565 |
+
else:
|
566 |
+
col_run2.text_area(f"{run2_top_n['doc_id'].iloc[d_idx]}: ", doc_text, key=f"{inst_num}doc{d_idx}2")
|
567 |
else:
|
568 |
for d_idx, doc in enumerate(run2_top_n_docs):
|
569 |
+
if use_model_saliency:
|
570 |
+
if col_run2.checkbox("Show Model Saliency", key=f"{inst_index}model_saliency{d_idx}ranked2", value=False):
|
571 |
+
col_run2.markdown(get_saliency(query_text2, doc),unsafe_allow_html=True)
|
572 |
+
else:
|
573 |
+
col_run2.text_area(f"{run2_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}2")
|
574 |
+
else:
|
575 |
+
col_run2.text_area(f"{run2_top_n['doc_id'].iloc[d_idx]}: ", doc["text"], key=f"{inst_num}doc{d_idx}2")
|
576 |
|
577 |
st.divider()
|
578 |
|
dataset_loading.py
CHANGED
@@ -14,6 +14,8 @@ import ir_datasets
|
|
14 |
|
15 |
from constants import BEIR, IR_DATASETS, LOCAL_DATASETS
|
16 |
|
|
|
|
|
17 |
def load_local_corpus(corpus_file, columns_to_combine=["title", "text"]):
|
18 |
if corpus_file is None:
|
19 |
return None
|
@@ -39,6 +41,8 @@ def load_local_corpus(corpus_file, columns_to_combine=["title", "text"]):
|
|
39 |
}
|
40 |
return did2text
|
41 |
|
|
|
|
|
42 |
def load_local_queries(queries_file):
|
43 |
if queries_file is None:
|
44 |
return None
|
@@ -60,6 +64,8 @@ def load_local_queries(queries_file):
|
|
60 |
qid2text[inst[id_key]] = inst["text"]
|
61 |
return qid2text
|
62 |
|
|
|
|
|
63 |
def load_local_qrels(qrels_file):
|
64 |
if qrels_file is None:
|
65 |
return None
|
@@ -84,6 +90,7 @@ def load_local_qrels(qrels_file):
|
|
84 |
return qid2did2label
|
85 |
|
86 |
|
|
|
87 |
def load_run(f_run):
|
88 |
run = pytrec_eval.parse_run(copy.deepcopy(f_run))
|
89 |
# convert bytes to strings for keys
|
@@ -102,7 +109,7 @@ def load_run(f_run):
|
|
102 |
return new_run, run_pandas
|
103 |
|
104 |
|
105 |
-
|
106 |
def load_jsonl(f):
|
107 |
did2text = defaultdict(list)
|
108 |
sub_did2text = {}
|
@@ -126,7 +133,7 @@ def load_jsonl(f):
|
|
126 |
return did2text, sub_did2text
|
127 |
|
128 |
|
129 |
-
|
130 |
def get_beir(dataset: str):
|
131 |
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
|
132 |
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
|
@@ -134,6 +141,7 @@ def get_beir(dataset: str):
|
|
134 |
return GenericDataLoader(data_folder=data_path).load(split="test")
|
135 |
|
136 |
|
|
|
137 |
def get_ir_datasets(dataset_name: str):
|
138 |
dataset = ir_datasets.load(dataset_name)
|
139 |
queries = {}
|
@@ -145,6 +153,7 @@ def get_ir_datasets(dataset_name: str):
|
|
145 |
return dataset.doc_store(), queries, dataset.qrels_dict()
|
146 |
|
147 |
|
|
|
148 |
def get_dataset(dataset_name: str):
|
149 |
if dataset_name == "":
|
150 |
return {}, {}, {}
|
|
|
14 |
|
15 |
from constants import BEIR, IR_DATASETS, LOCAL_DATASETS
|
16 |
|
17 |
+
|
18 |
+
@st.cache_data
|
19 |
def load_local_corpus(corpus_file, columns_to_combine=["title", "text"]):
|
20 |
if corpus_file is None:
|
21 |
return None
|
|
|
41 |
}
|
42 |
return did2text
|
43 |
|
44 |
+
|
45 |
+
@st.cache_data
|
46 |
def load_local_queries(queries_file):
|
47 |
if queries_file is None:
|
48 |
return None
|
|
|
64 |
qid2text[inst[id_key]] = inst["text"]
|
65 |
return qid2text
|
66 |
|
67 |
+
|
68 |
+
@st.cache_data
|
69 |
def load_local_qrels(qrels_file):
|
70 |
if qrels_file is None:
|
71 |
return None
|
|
|
90 |
return qid2did2label
|
91 |
|
92 |
|
93 |
+
@st.cache_data
|
94 |
def load_run(f_run):
|
95 |
run = pytrec_eval.parse_run(copy.deepcopy(f_run))
|
96 |
# convert bytes to strings for keys
|
|
|
109 |
return new_run, run_pandas
|
110 |
|
111 |
|
112 |
+
@st.cache_data
|
113 |
def load_jsonl(f):
|
114 |
did2text = defaultdict(list)
|
115 |
sub_did2text = {}
|
|
|
133 |
return did2text, sub_did2text
|
134 |
|
135 |
|
136 |
+
@st.cache_data
|
137 |
def get_beir(dataset: str):
|
138 |
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
|
139 |
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
|
|
|
141 |
return GenericDataLoader(data_folder=data_path).load(split="test")
|
142 |
|
143 |
|
144 |
+
@st.cache_data
|
145 |
def get_ir_datasets(dataset_name: str):
|
146 |
dataset = ir_datasets.load(dataset_name)
|
147 |
queries = {}
|
|
|
153 |
return dataset.doc_store(), queries, dataset.qrels_dict()
|
154 |
|
155 |
|
156 |
+
@st.cache_data
|
157 |
def get_dataset(dataset_name: str):
|
158 |
if dataset_name == "":
|
159 |
return {}, {}, {}
|
requirements.txt
CHANGED
@@ -5,4 +5,6 @@ streamlit==1.24.1
|
|
5 |
ir_datasets==0.5.5
|
6 |
pyserini==0.21.0
|
7 |
torch==2.0.1
|
8 |
-
plotly==5.15.0
|
|
|
|
|
|
5 |
ir_datasets==0.5.5
|
6 |
pyserini==0.21.0
|
7 |
torch==2.0.1
|
8 |
+
plotly==5.15.0
|
9 |
+
captum==0.6.0
|
10 |
+
protobuf==4.21.11
|