mColBERT / fix_colbert_docid.py
vjeronymo2's picture
Adding model and checkpoint
828992f
raw
history blame
1.76 kB
import jsonlines
import argparse
import pandas as pd
from tqdm import tqdm
parser = argparse.ArgumentParser(description=__doc__,
formatter_class=lambda prog: argparse.HelpFormatter(prog, width=100))
parser.add_argument('--corpus', metavar='FILE', type=str, required=True, help='Corpus file in jsonl')
parser.add_argument('--input_ranking', metavar='FILE', type=str, required=True, help='Ranking file from ColBERT in tsv')
parser.add_argument('--output_ranking', metavar='FILE', type=str, required=True, help='Ranking file with robust doc ids in tsv')
args = parser.parse_args()
with jsonlines.open(args.corpus,'r') as reader:
doc_ids = [obj['id'] for obj in reader]
df = pd.read_csv(args.input_ranking, sep='\t', header=None, names=['query_id', 'doc_id', 'rank'])
df['doc_id'] = df['doc_id'].apply(lambda x: doc_ids[int(x)])
df['score'] = 1 / df['rank']
df = df.sort_values(by='score', ascending=False)
df = df.drop_duplicates(subset=['query_id', 'doc_id'])
df = df.groupby('query_id').head(1000)
df['rank'] = df.groupby('query_id').cumcount()
df = df.sort_values(['query_id','rank'])
with open(args.output_ranking,'w') as writer:
for _, obj in df.iterrows():
query_id, doc_id, rank, score = obj['query_id'], obj['doc_id'], obj['rank'], obj['score']
writer.write(f'{query_id}\tQ0\t{doc_id}\t{rank}\t{score}\tColBERT\n')
# with open(args.input_ranking, 'r') as reader_ranking:
# with open(args.output_ranking,'w') as writer:
# for obj in tqdm(reader_ranking):
# query_id, doc_idx, rank = obj.replace('\n', '').split('\t')
# doc_id = doc_ids[int(doc_idx)]
# writer.write(f'{query_id}\tQ0\t{doc_id}\t{rank}\n')