Binder / retrieval /retriever.py
Timothyxxx
Init
f6f97d8
raw
history blame
3.46 kB
"""
Retriever to retrieve relevant examples from annotations.
"""
import copy
from typing import Dict, List, Tuple, Any
import nltk
from nltk.stem import SnowballStemmer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from utils.normalizer import normalize
from retrieval.retrieve_pool import OpenAIQARetrievePool, QAItem
class OpenAIQARetriever(object):
def __init__(self, retrieve_pool: OpenAIQARetrievePool):
self.retrieve_pool = retrieve_pool
@staticmethod
def _string_bleu(q1: str, q2: str, stop_words=None, stemmer=None):
"""
BLEU score.
"""
q1, q2 = normalize(q1), normalize(q2)
reference = [[tk for tk in nltk.word_tokenize(q1)]]
candidate = [tk for tk in nltk.word_tokenize(q2)]
if stemmer is not None:
reference = [[stemmer.stem(tk) for tk in reference[0]]]
candidate = [stemmer.stem(tk) for tk in candidate]
chencherry_smooth = SmoothingFunction() # bleu smooth to avoid hard behaviour when no ngram overlaps
bleu_score = sentence_bleu(
reference,
candidate,
weights=(0.25, 0.3, 0.3, 0.15),
smoothing_function=chencherry_smooth.method1
)
return bleu_score
def _qh2qh_similarity(
self,
item: QAItem,
num_retrieve_samples: int,
score_func: str,
qa_type: str,
weight_h: float = 0.2,
verbose: bool = False
):
"""
Retrieve top K nsqls based on query&header to query&header similarities.
"""
q = item.qa_question
header_wo_row_id = copy.copy(item.table['header'])
header_wo_row_id.remove('row_id')
h = ' '.join(header_wo_row_id)
stemmer = SnowballStemmer('english')
if score_func == 'bleu':
retrieve_q_list = [(d, self._string_bleu(q, d.qa_question.split('@')[1], stemmer=stemmer))
for d in self.retrieve_pool if d.qa_question.split('@')[0] == qa_type]
retrieve_h_list = [(d, self._string_bleu(h, ' '.join(d.table['header']), stemmer=stemmer))
for d in self.retrieve_pool if d.qa_question.split('@')[0] == qa_type]
retrieve_list = [(retrieve_q_list[idx][0], retrieve_q_list[idx][1] + weight_h * retrieve_h_list[idx][1])
for idx in range(len(retrieve_q_list))]
else:
raise ValueError
retrieve_list = sorted(retrieve_list, key=lambda x: x[1], reverse=True)
retrieve_list = list(map(lambda x: x[0], retrieve_list))[:num_retrieve_samples]
if verbose:
print(retrieve_list)
return retrieve_list
def retrieve(
self,
item: QAItem,
num_shots: int,
method: str = 'qh2qh_bleu',
qa_type: str = 'map',
verbose: bool = False
) -> List[QAItem]:
"""
Retrieve a list of relevant QA samples.
"""
if method == 'qh2qh_bleu':
retrieved_items = self._qh2qh_similarity(
item=item,
num_retrieve_samples=num_shots,
score_func='bleu',
qa_type=qa_type,
verbose=verbose
)
return retrieved_items
else:
raise ValueError(f'Retrieve method {method} is not supported.')