|
from transformers import DPRReader, DPRReaderTokenizer |
|
from typing import List, Dict, Tuple |
|
from dotenv import load_dotenv |
|
|
|
from src.readers.base_reader import Reader |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
class DprReader(Reader): |
|
def __init__(self) -> None: |
|
self._tokenizer = DPRReaderTokenizer.from_pretrained( |
|
"facebook/dpr-reader-single-nq-base") |
|
self._model = DPRReader.from_pretrained( |
|
"facebook/dpr-reader-single-nq-base") |
|
|
|
def read(self, |
|
query: str, |
|
context: Dict[str, List[str]], |
|
num_answers=5) -> List[Tuple]: |
|
encoded_inputs = self._tokenizer( |
|
questions=query, |
|
titles=context['titles'], |
|
texts=context['texts'], |
|
return_tensors='pt', |
|
truncation=True, |
|
padding=True |
|
) |
|
outputs = self._model(**encoded_inputs) |
|
|
|
predicted_spans = self._tokenizer.decode_best_spans( |
|
encoded_inputs, |
|
outputs, |
|
num_spans=num_answers, |
|
num_spans_per_passage=2, |
|
max_answer_length=512) |
|
|
|
return predicted_spans |
|
|