Spaces:
Sleeping
Sleeping
File size: 7,523 Bytes
bf8e6b0 a09b56d bf8e6b0 a09b56d bf8e6b0 a09b56d bf8e6b0 a09b56d bf8e6b0 68ecf38 bf8e6b0 a09b56d bf8e6b0 68ecf38 bf8e6b0 68ecf38 bf8e6b0 68ecf38 bf8e6b0 68ecf38 bf8e6b0 68ecf38 bf8e6b0 68ecf38 bf8e6b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import streamlit as st
import os
import pathlib
import beir
from beir import util
from beir.datasets.data_loader import GenericDataLoader
import pytrec_eval
import pandas as pd
from collections import defaultdict
import json
import copy
import ir_datasets
from constants import BEIR, IR_DATASETS, LOCAL_DATASETS
@st.cache_data
def load_local_corpus(corpus_file, columns_to_combine=["title", "text"]):
if corpus_file is None:
return None
did2text = {}
id_key = "_id"
with corpus_file as f:
for idx, line in enumerate(f):
uses_bytes = not (type(line) == str)
if uses_bytes:
if idx == 0 and "doc_id" in line.decode("utf-8"):
continue
inst = json.loads(line.decode("utf-8"))
else:
if idx == 0 and "doc_id" in line:
continue
inst = json.loads(line)
all_text = " ".join([inst[col] for col in columns_to_combine if col in inst])
if id_key not in inst:
id_key = "doc_id"
did2text[inst[id_key]] = {
"text": all_text,
"title": inst["title"] if "title" in inst else "",
}
return did2text
@st.cache_data
def load_local_queries(queries_file):
if queries_file is None:
return None
qid2text = {}
id_key = "_id"
with queries_file as f:
for idx, line in enumerate(f):
uses_bytes = not (type(line) == str)
if uses_bytes:
if idx == 0 and "query_id" in line.decode("utf-8"):
continue
inst = json.loads(line.decode("utf-8"))
else:
if idx == 0 and "query_id" in line:
continue
inst = json.loads(line)
if id_key not in inst:
id_key = "query_id"
qid2text[inst[id_key]] = inst["text"]
return qid2text
@st.cache_data
def load_local_qrels(qrels_file):
if qrels_file is None:
return None
qid2did2label = defaultdict(dict)
with qrels_file as f:
for idx, line in enumerate(f):
uses_bytes = not (type(line) == str)
if uses_bytes:
if idx == 0 and "qid" in line.decode("utf-8") or "query-id" in line.decode("utf-8"):
continue
cur_line = line.decode("utf-8")
else:
if idx == 0 and "qid" in line or "query-id" in line:
continue
cur_line = line
try:
qid, _, doc_id, label = cur_line.split()
except:
qid, doc_id, label = cur_line.split()
qid2did2label[str(qid)][str(doc_id)] = int(label)
return qid2did2label
@st.cache_data
def load_run(f_run):
run = pytrec_eval.parse_run(copy.deepcopy(f_run))
# convert bytes to strings for keys
new_run = defaultdict(dict)
for key, sub_dict in run.items():
new_run[key.decode("utf-8")] = {k.decode("utf-8"): v for k, v in sub_dict.items()}
run_pandas = pd.read_csv(f_run, header=None, index_col=None, sep="\t")
run_pandas.columns = ["qid", "generic", "doc_id", "rank", "score", "model"]
run_pandas.doc_id = run_pandas.doc_id.astype(str)
run_pandas.qid = run_pandas.qid.astype(str)
run_pandas["rank"] = run_pandas["rank"].astype(int)
run_pandas.score = run_pandas.score.astype(float)
all_groups = []
for qid, sub_df in run_pandas.groupby("qid"):
sub_df.sort_values(["score", "doc_id"], ascending=[False, False])
sub_df["rank"] = list(range(1, len(sub_df) + 1))
all_groups.append(sub_df)
run_pandas = pd.concat(all_groups)
return new_run, run_pandas
@st.cache_data
def load_jsonl(f):
did2text = defaultdict(list)
sub_did2text = {}
for idx, line in enumerate(f):
inst = json.loads(line)
if "question" in inst:
docid = inst["metadata"][0]["passage_id"] if "doc_id" not in inst else inst["doc_id"]
did2text[docid].append(inst["question"])
elif "text" in inst:
docid = inst["doc_id"] if "doc_id" in inst else inst["did"]
did2text[docid].append(inst["text"])
sub_did2text[inst["did"]] = inst["text"]
elif "query" in inst:
docid = inst["doc_id"] if "doc_id" in inst else inst["did"]
did2text[docid].append(inst["query"])
else:
breakpoint()
raise NotImplementedError("Need to handle this case")
return did2text, sub_did2text
@st.cache_data(persist="disk")
def get_beir(dataset: str):
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
data_path = util.download_and_unzip(url, out_dir)
return GenericDataLoader(data_folder=data_path).load(split="test")
@st.cache_data(persist="disk")
def get_ir_datasets(dataset_name: str, input_fields_doc: str = None, input_fields_query: str = None):
dataset = ir_datasets.load(dataset_name)
queries = {}
for qid, query in dataset.queries_iter():
if input_fields_query is None:
if type(query) == str:
queries[qid] = query
else:
# get all fields that exist in query
all_fields = {field: getattr(query, field) for field in query._fields}
# put all fields into a single string
queries[qid] = " ".join([str(v) for v in all_fields.values()])
else:
all_fields = {field: getattr(query, field) for field in input_fields_query}
queries[qid] = " ".join([str(v) for v in all_fields.values()])
corpus = {}
for doc in dataset.docs_iter():
if input_fields_doc is None:
if type(doc) == str:
corpus[doc.doc_id] = {"text": doc}
else: # get all fields that exist in query
all_fields = {field: getattr(doc, field) for field in doc._fields}
corpus[doc.doc_id] = {"text": " ".join([str(v) for v in all_fields.values()])}
else:
all_fields = {field: getattr(doc, field) for field in input_fields_doc}
corpus[doc.doc_id] = {"text": " ".join([str(v) for v in all_fields.values()])}
# return corpus, queries, qrels
return corpus, queries, dataset.qrels_dict()
@st.cache_data(persist="disk")
def get_dataset(dataset_name: str, input_fields_doc, input_fields_query):
if type(input_fields_doc) == str:
input_fields_doc = input_fields_doc.strip().split(",")
if type(input_fields_query) == str:
input_fields_query = input_fields_query.strip().split(",")
if dataset_name == "":
return {}, {}, {}
if dataset_name in BEIR:
return get_beir(dataset_name)
elif dataset_name in IR_DATASETS:
return get_ir_datasets(dataset_name, input_fields_doc, input_fields_query)
elif dataset_name in LOCAL_DATASETS:
base_path = f"local_datasets/{dataset_name}"
corpus_file = open(f"{base_path}/corpus.jsonl", "r")
queries_file = open(f"{base_path}/queries.jsonl", "r")
qrels_file = open(f"{base_path}/qrels/test.tsv", "r")
return load_local_corpus(corpus_file), load_local_queries(queries_file), load_local_qrels(qrels_file)
else:
raise NotImplementedError("Dataset not implemented") |