|
import time |
|
import json |
|
import numpy as np |
|
|
|
import streamlit as st |
|
from pathlib import Path |
|
from collections import defaultdict |
|
|
|
import sys |
|
path_root = Path("./") |
|
sys.path.append(str(path_root)) |
|
|
|
|
|
st.set_page_config(page_title="PSC Runtime", |
|
page_icon='🌸', layout="centered") |
|
|
|
|
|
|
|
name = st.selectbox( |
|
"Choose a dataset", |
|
["dl19", "dl20"], |
|
index=None, |
|
placeholder="Choose a dataset..." |
|
) |
|
|
|
model_name = st.selectbox( |
|
"Choose a model", |
|
["gpt-3.5", "gpt-4"], |
|
index=None, |
|
placeholder="Choose a model..." |
|
) |
|
|
|
|
|
if name and model_name: |
|
import torch |
|
|
|
fn = f"{name}-{model_name}.pt" |
|
object = torch.load(fn) |
|
|
|
outputs = object[2] |
|
query2outputs = {} |
|
for output in outputs: |
|
all_queries = {x['query'] for x in output} |
|
assert len(all_queries) == 1 |
|
query = list(all_queries)[0] |
|
query2outputs[query] = [x['hits'] for x in output] |
|
|
|
search_query = st.selectbox( |
|
"Choose a query from the list", |
|
sorted(query2outputs), |
|
|
|
|
|
) |
|
|
|
def preferences_from_hits(list_of_hits): |
|
docid2id = {} |
|
id2doc = {} |
|
preferences = [] |
|
|
|
for result in list_of_hits: |
|
for doc in result: |
|
if doc["docid"] not in docid2id: |
|
id = len(docid2id) |
|
docid2id[doc["docid"]] = id |
|
id2doc[id] = doc |
|
print([doc["docid"] for doc in result]) |
|
print([docid2id[doc["docid"]] for doc in result]) |
|
preferences.append([docid2id[doc["docid"]] for doc in result]) |
|
|
|
|
|
return np.array(preferences), id2doc |
|
|
|
|
|
def load_qrels(name): |
|
import ir_datasets |
|
if name == "dl19": |
|
ds_name = "msmarco-passage/trec-dl-2019/judged" |
|
elif name == "dl20": |
|
ds_name = "msmarco-passage/trec-dl-2020/judged" |
|
else: |
|
raise ValueError(name) |
|
|
|
dataset = ir_datasets.load(ds_name) |
|
qrels = defaultdict(dict) |
|
for qrel in dataset.qrels_iter(): |
|
qrels[qrel.query_id][qrel.doc_id] = qrel.relevance |
|
return qrels |
|
|
|
|
|
def aggregate(list_of_hits): |
|
import numpy as np |
|
from permsc import KemenyOptimalAggregator, sum_kendall_tau, ranks_from_preferences |
|
from permsc import BordaRankAggregator |
|
|
|
preferences, id2doc = preferences_from_hits(list_of_hits) |
|
y_optimal = KemenyOptimalAggregator().aggregate(preferences) |
|
|
|
|
|
return [id2doc[id] for id in y_optimal] |
|
|
|
|
|
def write_ranking(search_results, text): |
|
st.write(f'<p align=\"right\" style=\"color:grey;\"> {text} ms</p>', unsafe_allow_html=True) |
|
|
|
qid = {result["qid"] for result in search_results} |
|
assert len(qid) == 1 |
|
qid = list(qid)[0] |
|
|
|
for i, result in enumerate(search_results): |
|
result_id = result["docid"] |
|
contents = result["content"] |
|
|
|
label = qrels[str(qid)].get(str(result_id), -1) |
|
label_text = "Unlabeled" |
|
if label == 3: |
|
style = "style=\"color:rgb(237, 125, 12);\"" |
|
label_text = "High" |
|
elif label == 2: |
|
style = "style=\"color:rgb(244, 185, 66);\"" |
|
label_text = "Medium" |
|
elif label == 1: |
|
style = "style=\"color:rgb(241, 177, 118);\"" |
|
label_text = "Low" |
|
elif label == 0: |
|
style = "style=\"color:black;\"" |
|
label_text = "Not Relevance" |
|
else: |
|
style = "style=\"color:grey;\"" |
|
|
|
print(qid, result_id, label, style) |
|
|
|
output_1 = f'<div class="row" {style}> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id}</div>' |
|
output_2 = f'<div class="row" {style}> <b>True Relevance</b>: {label_text}</div>' |
|
|
|
try: |
|
st.write(output_1, unsafe_allow_html=True) |
|
st.write(output_2, unsafe_allow_html=True) |
|
st.write( |
|
f'<div class="row" {style}>{contents}</div>', unsafe_allow_html=True) |
|
|
|
except: |
|
pass |
|
st.write('---') |
|
|
|
|
|
aggregated_ranking = aggregate(query2outputs[search_query]) |
|
qrels = load_qrels(name) |
|
col1, col2 = st.columns([5, 5]) |
|
|
|
if search_query: |
|
with col1: |
|
if search_query or button_clicked: |
|
write_ranking(search_results=query2outputs[search_query][0], text="w/o PSC") |
|
|
|
with col2: |
|
if search_query or button_clicked: |
|
write_ranking(search_results=aggregated_ranking, text="w/ PSC") |
|
|