Ramon Meffert
commited on
Commit
·
1f08ed2
1
Parent(s):
a1746cf
Add query cli w/ argparse
Browse files
query.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import transformers
|
4 |
+
|
5 |
+
from typing import List
|
6 |
+
from datasets import load_dataset, DatasetDict
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
|
9 |
+
from src.readers.dpr_reader import DprReader
|
10 |
+
from src.retrievers.base_retriever import Retriever
|
11 |
+
from src.retrievers.es_retriever import ESRetriever
|
12 |
+
from src.retrievers.faiss_retriever import FaissRetriever
|
13 |
+
from src.utils.preprocessing import result_to_reader_input
|
14 |
+
from src.utils.log import get_logger
|
15 |
+
|
16 |
+
|
17 |
+
def get_retriever(r: str, ds: DatasetDict) -> Retriever:
|
18 |
+
retriever = ESRetriever if r == "es" else FaissRetriever
|
19 |
+
return retriever(ds)
|
20 |
+
|
21 |
+
|
22 |
+
def print_name(contexts: dict, section: str, id: int):
|
23 |
+
name = contexts[section][id]
|
24 |
+
if name != 'nan':
|
25 |
+
print(f" {section}: {name}")
|
26 |
+
|
27 |
+
|
28 |
+
def print_answers(answers: List[tuple], scores: List[float], contexts: dict):
|
29 |
+
# calculate answer scores
|
30 |
+
sm = torch.nn.Softmax(dim=0)
|
31 |
+
d_scores = sm(torch.Tensor(
|
32 |
+
[pred.relevance_score for pred in answers]))
|
33 |
+
s_scores = sm(torch.Tensor(
|
34 |
+
[pred.span_score for pred in answers]))
|
35 |
+
|
36 |
+
for pos, answer in enumerate(answers):
|
37 |
+
print(f"{pos + 1:>4}. {answer.text}")
|
38 |
+
print(f" {'-' * len(answer.text)}")
|
39 |
+
print_name(contexts, 'chapter', answer.doc_id)
|
40 |
+
print_name(contexts, 'section', answer.doc_id)
|
41 |
+
print_name(contexts, 'subsection', answer.doc_id)
|
42 |
+
print(f" retrieval score: {scores[answer.doc_id]:6.02f}%")
|
43 |
+
print(f" document score: {d_scores[pos] * 100:6.02f}%")
|
44 |
+
print(f" span score: {s_scores[pos] * 100:6.02f}%")
|
45 |
+
print()
|
46 |
+
|
47 |
+
|
48 |
+
def main(args: argparse.Namespace):
|
49 |
+
# Initialize dataset
|
50 |
+
dataset = load_dataset("GroNLP/ik-nlp-22_slp")
|
51 |
+
|
52 |
+
# Retrieve
|
53 |
+
retriever = get_retriever(args.retriever, dataset)
|
54 |
+
scores, contexts = retriever.retrieve(args.query)
|
55 |
+
|
56 |
+
# Read
|
57 |
+
reader = DprReader()
|
58 |
+
reader_input = result_to_reader_input(contexts)
|
59 |
+
answers = reader.read(args.query, reader_input, num_answers=args.top)
|
60 |
+
|
61 |
+
# Print output
|
62 |
+
print_answers(answers, scores, contexts)
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
# Setup environment
|
67 |
+
load_dotenv()
|
68 |
+
logger = get_logger()
|
69 |
+
transformers.logging.set_verbosity_error()
|
70 |
+
|
71 |
+
# Set up CLI arguments
|
72 |
+
parser = argparse.ArgumentParser(
|
73 |
+
formatter_class=argparse.MetavarTypeHelpFormatter
|
74 |
+
)
|
75 |
+
parser.add_argument("query", type=str,
|
76 |
+
help="The question to feed to the QA system")
|
77 |
+
parser.add_argument("--top", "-t", type=int, default=1,
|
78 |
+
help="The number of answers to retrieve")
|
79 |
+
parser.add_argument("--retriever", "-r", type=str.lower,
|
80 |
+
choices=["faiss", "es"], default="faiss",
|
81 |
+
help="The retrieval method to use")
|
82 |
+
|
83 |
+
args = parser.parse_args()
|
84 |
+
main(args)
|