annotate-relevance / dataset_loading.py
orionweller's picture
add options to turn features off
bfcacbc
raw
history blame
3.98 kB
import streamlit as st
import os
import pathlib
import pandas as pd
from collections import defaultdict
import json
import copy
import plotly.express as px
@st.cache_data
def load_local_corpus(corpus_file, columns_to_combine=["title", "text"]):
if corpus_file is None:
return None
did2text = {}
id_key = "_id"
with corpus_file as f:
for idx, line in enumerate(f):
uses_bytes = not (type(line) == str)
if uses_bytes:
if idx == 0 and "doc_id" in line.decode("utf-8"):
continue
inst = json.loads(line.decode("utf-8"))
else:
if idx == 0 and "doc_id" in line:
continue
inst = json.loads(line)
all_text = " ".join([inst[col] for col in columns_to_combine if col in inst and inst[col] is not None])
if id_key not in inst:
id_key = "doc_id"
did2text[inst[id_key]] = {
"text": all_text,
"title": inst["title"] if "title" in inst else "",
}
return did2text
@st.cache_data
def load_local_queries(queries_file):
if queries_file is None:
return None
qid2text = {}
id_key = "_id"
with queries_file as f:
for idx, line in enumerate(f):
uses_bytes = not (type(line) == str)
if uses_bytes:
if idx == 0 and "query_id" in line.decode("utf-8"):
continue
inst = json.loads(line.decode("utf-8"))
else:
if idx == 0 and "query_id" in line:
continue
inst = json.loads(line)
if id_key not in inst:
id_key = "query_id"
qid2text[inst[id_key]] = inst["text"]
return qid2text
@st.cache_data
def load_local_qrels(qrels_file):
if qrels_file is None:
return None
qid2did2label = defaultdict(dict)
with qrels_file as f:
for idx, line in enumerate(f):
uses_bytes = not (type(line) == str)
if uses_bytes:
if idx == 0 and "qid" in line.decode("utf-8") or "query-id" in line.decode("utf-8"):
continue
cur_line = line.decode("utf-8")
else:
if idx == 0 and "qid" in line or "query-id" in line:
continue
cur_line = line
try:
qid, _, doc_id, label = cur_line.split()
except:
qid, doc_id, label = cur_line.split()
qid2did2label[str(qid)][str(doc_id)] = int(label)
return qid2did2label
@st.cache_data
def load_jsonl(f):
did2text = defaultdict(list)
sub_did2text = {}
for idx, line in enumerate(f):
inst = json.loads(line)
if "question" in inst:
docid = inst["metadata"][0]["passage_id"] if "doc_id" not in inst else inst["doc_id"]
did2text[docid].append(inst["question"])
elif "text" in inst:
docid = inst["doc_id"] if "doc_id" in inst else inst["did"]
did2text[docid].append(inst["text"])
sub_did2text[inst["did"]] = inst["text"]
elif "query" in inst:
docid = inst["doc_id"] if "doc_id" in inst else inst["did"]
did2text[docid].append(inst["query"])
else:
breakpoint()
raise NotImplementedError("Need to handle this case")
return did2text, sub_did2text
@st.cache_data(persist="disk")
def get_dataset(dataset_name: str, input_fields_doc, input_fields_query):
if type(input_fields_doc) == str:
input_fields_doc = input_fields_doc.strip().split(",")
if type(input_fields_query) == str:
input_fields_query = input_fields_query.strip().split(",")
if dataset_name == "":
return {}, {}, {}
else:
raise NotImplementedError("Dataset not implemented")