Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import pathlib | |
import beir | |
from beir import util | |
from beir.datasets.data_loader import GenericDataLoader | |
import pytrec_eval | |
import pandas as pd | |
from collections import defaultdict | |
import json | |
import copy | |
import ir_datasets | |
from constants import BEIR, IR_DATASETS, LOCAL_DATASETS | |
def load_local_corpus(corpus_file, columns_to_combine=["title", "text"]): | |
if corpus_file is None: | |
return None | |
did2text = {} | |
id_key = "_id" | |
with corpus_file as f: | |
for idx, line in enumerate(f): | |
uses_bytes = not (type(line) == str) | |
if uses_bytes: | |
if idx == 0 and "doc_id" in line.decode("utf-8"): | |
continue | |
inst = json.loads(line.decode("utf-8")) | |
else: | |
if idx == 0 and "doc_id" in line: | |
continue | |
inst = json.loads(line) | |
all_text = " ".join([inst[col] for col in columns_to_combine if col in inst]) | |
if id_key not in inst: | |
id_key = "doc_id" | |
did2text[inst[id_key]] = { | |
"text": all_text, | |
"title": inst["title"] if "title" in inst else "", | |
} | |
return did2text | |
def load_local_queries(queries_file): | |
if queries_file is None: | |
return None | |
qid2text = {} | |
id_key = "_id" | |
with queries_file as f: | |
for idx, line in enumerate(f): | |
uses_bytes = not (type(line) == str) | |
if uses_bytes: | |
if idx == 0 and "query_id" in line.decode("utf-8"): | |
continue | |
inst = json.loads(line.decode("utf-8")) | |
else: | |
if idx == 0 and "query_id" in line: | |
continue | |
inst = json.loads(line) | |
if id_key not in inst: | |
id_key = "query_id" | |
qid2text[inst[id_key]] = inst["text"] | |
return qid2text | |
def load_local_qrels(qrels_file): | |
if qrels_file is None: | |
return None | |
qid2did2label = defaultdict(dict) | |
with qrels_file as f: | |
for idx, line in enumerate(f): | |
uses_bytes = not (type(line) == str) | |
if uses_bytes: | |
if idx == 0 and "qid" in line.decode("utf-8") or "query-id" in line.decode("utf-8"): | |
continue | |
cur_line = line.decode("utf-8") | |
else: | |
if idx == 0 and "qid" in line or "query-id" in line: | |
continue | |
cur_line = line | |
try: | |
qid, _, doc_id, label = cur_line.split() | |
except: | |
qid, doc_id, label = cur_line.split() | |
qid2did2label[str(qid)][str(doc_id)] = int(label) | |
return qid2did2label | |
def load_run(f_run): | |
run = pytrec_eval.parse_run(copy.deepcopy(f_run)) | |
# convert bytes to strings for keys | |
new_run = defaultdict(dict) | |
for key, sub_dict in run.items(): | |
new_run[key.decode("utf-8")] = {k.decode("utf-8"): v for k, v in sub_dict.items()} | |
run_pandas = pd.read_csv(f_run, header=None, index_col=None, sep="\t") | |
run_pandas.columns = ["qid", "generic", "doc_id", "rank", "score", "model"] | |
run_pandas.doc_id = run_pandas.doc_id.astype(str) | |
run_pandas.qid = run_pandas.qid.astype(str) | |
run_pandas["rank"] = run_pandas["rank"].astype(int) | |
run_pandas.score = run_pandas.score.astype(float) | |
all_groups = [] | |
for qid, sub_df in run_pandas.groupby("qid"): | |
sub_df.sort_values(["score", "doc_id"], ascending=[False, False]) | |
sub_df["rank"] = list(range(1, len(sub_df) + 1)) | |
all_groups.append(sub_df) | |
run_pandas = pd.concat(all_groups) | |
return new_run, run_pandas | |
def load_jsonl(f): | |
did2text = defaultdict(list) | |
sub_did2text = {} | |
for idx, line in enumerate(f): | |
inst = json.loads(line) | |
if "question" in inst: | |
docid = inst["metadata"][0]["passage_id"] if "doc_id" not in inst else inst["doc_id"] | |
did2text[docid].append(inst["question"]) | |
elif "text" in inst: | |
docid = inst["doc_id"] if "doc_id" in inst else inst["did"] | |
did2text[docid].append(inst["text"]) | |
sub_did2text[inst["did"]] = inst["text"] | |
elif "query" in inst: | |
docid = inst["doc_id"] if "doc_id" in inst else inst["did"] | |
did2text[docid].append(inst["query"]) | |
else: | |
breakpoint() | |
raise NotImplementedError("Need to handle this case") | |
return did2text, sub_did2text | |
def get_beir(dataset: str): | |
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset) | |
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets") | |
data_path = util.download_and_unzip(url, out_dir) | |
return GenericDataLoader(data_folder=data_path).load(split="test") | |
def get_ir_datasets(dataset_name: str, input_fields_doc: str = None, input_fields_query: str = None): | |
dataset = ir_datasets.load(dataset_name) | |
queries = {} | |
for qid, query in dataset.queries_iter(): | |
if input_fields_query is None: | |
if type(query) == str: | |
queries[qid] = query | |
else: | |
# get all fields that exist in query | |
all_fields = {field: getattr(query, field) for field in query._fields} | |
# put all fields into a single string | |
queries[qid] = " ".join([str(v) for v in all_fields.values()]) | |
else: | |
all_fields = {field: getattr(query, field) for field in input_fields_query} | |
queries[qid] = " ".join([str(v) for v in all_fields.values()]) | |
corpus = {} | |
for doc in dataset.docs_iter(): | |
if input_fields_doc is None: | |
if type(doc) == str: | |
corpus[doc.doc_id] = {"text": doc} | |
else: # get all fields that exist in query | |
all_fields = {field: getattr(doc, field) for field in doc._fields} | |
corpus[doc.doc_id] = {"text": " ".join([str(v) for v in all_fields.values()])} | |
else: | |
all_fields = {field: getattr(doc, field) for field in input_fields_doc} | |
corpus[doc.doc_id] = {"text": " ".join([str(v) for v in all_fields.values()])} | |
# return corpus, queries, qrels | |
return corpus, queries, dataset.qrels_dict() | |
def get_dataset(dataset_name: str, input_fields_doc, input_fields_query): | |
if type(input_fields_doc) == str: | |
input_fields_doc = input_fields_doc.strip().split(",") | |
if type(input_fields_query) == str: | |
input_fields_query = input_fields_query.strip().split(",") | |
if dataset_name == "": | |
return {}, {}, {} | |
if dataset_name in BEIR: | |
return get_beir(dataset_name) | |
elif dataset_name in IR_DATASETS: | |
return get_ir_datasets(dataset_name, input_fields_doc, input_fields_query) | |
elif dataset_name in LOCAL_DATASETS: | |
base_path = f"local_datasets/{dataset_name}" | |
corpus_file = open(f"{base_path}/corpus.jsonl", "r") | |
queries_file = open(f"{base_path}/queries.jsonl", "r") | |
qrels_file = open(f"{base_path}/qrels/test.tsv", "r") | |
return load_local_corpus(corpus_file), load_local_queries(queries_file), load_local_qrels(qrels_file) | |
else: | |
raise NotImplementedError("Dataset not implemented") |