import time |
import json |
import numpy as np |
import streamlit as st |
from pathlib import Path |
from collections import defaultdict |
import sys |
path_root = Path("./") |
sys.path.append(str(path_root)) |
st.set_page_config(page_title="PSC Runtime", |
page_icon='🌸', layout="centered") |
name = st.selectbox( |
"Choose a dataset", |
["dl19", "dl20"], |
index=None, |
placeholder="Choose a dataset..." |
) |
model_name = st.selectbox( |
"Choose a model", |
["gpt-3.5", "gpt-4"], |
index=None, |
placeholder="Choose a model..." |
) |
if name and model_name: |
import torch |
fn = f"{name}-{model_name}.pt" |
object = torch.load(fn) |
outputs = object[2] |
query2outputs = {} |
for output in outputs: |
all_queries = {x['query'] for x in output} |
assert len(all_queries) == 1 |
query = list(all_queries)[0] |
query2outputs[query] = [x['hits'] for x in output] |
search_query = st.selectbox( |
"Choose a query from the list", |
sorted(query2outputs), |
) |
def preferences_from_hits(list_of_hits): |
docid2id = {} |
id2doc = {} |
preferences = [] |
for result in list_of_hits: |
for doc in result: |
if doc["docid"] not in docid2id: |
id = len(docid2id) |
docid2id[doc["docid"]] = id |
id2doc[id] = doc |
print([doc["docid"] for doc in result]) |
print([docid2id[doc["docid"]] for doc in result]) |
preferences.append([docid2id[doc["docid"]] for doc in result]) |
return np.array(preferences), id2doc |
def load_qrels(name): |
import ir_datasets |
if name == "dl19": |
ds_name = "msmarco-passage/trec-dl-2019/judged" |
elif name == "dl20": |
ds_name = "msmarco-passage/trec-dl-2020/judged" |
else: |
raise ValueError(name) |
dataset = ir_datasets.load(ds_name) |
qrels = defaultdict(dict) |
for qrel in dataset.qrels_iter(): |
qrels[qrel.query_id][qrel.doc_id] = qrel.relevance |
return qrels |
def aggregate(list_of_hits): |
import numpy as np |
from permsc import KemenyOptimalAggregator, sum_kendall_tau, ranks_from_preferences |
from permsc import BordaRankAggregator |
preferences, id2doc = preferences_from_hits(list_of_hits) |
y_optimal = KemenyOptimalAggregator().aggregate(preferences) |
return [id2doc[id] for id in y_optimal] |
def write_ranking(search_results, text): |
st.write(f'<p align=\"right\" style=\"color:grey;\"> {text} ms</p>', unsafe_allow_html=True) |
qid = {result["qid"] for result in search_results} |
assert len(qid) == 1 |
qid = list(qid)[0] |
for i, result in enumerate(search_results): |
result_id = result["docid"] |
contents = result["content"] |
label = qrels[str(qid)].get(str(result_id), -1) |
label_text = "Unlabeled" |
if label == 3: |
style = "style=\"color:rgb(237, 125, 12);\"" |
label_text = "High" |
elif label == 2: |
style = "style=\"color:rgb(244, 185, 66);\"" |
label_text = "Medium" |
elif label == 1: |
style = "style=\"color:rgb(241, 177, 118);\"" |
label_text = "Low" |
elif label == 0: |
style = "style=\"color:black;\"" |
label_text = "Not Relevance" |
else: |
style = "style=\"color:grey;\"" |
print(qid, result_id, label, style) |
output_1 = f'<div class="row" {style}> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id}</div>' |
output_2 = f'<div class="row" {style}> <b>True Relevance</b>: {label_text}</div>' |
try: |
st.write(output_1, unsafe_allow_html=True) |
st.write(output_2, unsafe_allow_html=True) |
st.write( |
f'<div class="row" {style}>{contents}</div>', unsafe_allow_html=True) |
except: |
pass |
st.write('---') |
aggregated_ranking = aggregate(query2outputs[search_query]) |
qrels = load_qrels(name) |
col1, col2 = st.columns([5, 5]) |
if search_query: |
with col1: |
if search_query or button_clicked: |
write_ranking(search_results=query2outputs[search_query][0], text="w/o PSC") |
with col2: |
if search_query or button_clicked: |
write_ranking(search_results=aggregated_ranking, text="w/ PSC") |