Binder / retrieval /retrieve_pool.py
Timothyxxx
Init
f6f97d8
raw
history blame
1.32 kB
"""
Retrieval pool of candidates
"""
from dataclasses import dataclass
from typing import List, Dict
import json
class OpenAIQARetrievePool(object):
def __init__(self, data_path):
with open(data_path, 'r') as f:
data = json.load(f)
self.data = []
for d in data:
if isinstance(d['qa_column'], List):
d['qa_column'] = '|'.join(d['qa_column'])
qa_item = QAItem(
id=d['id'],
qa_question=d['qa_question'],
qa_column=d['qa_column'],
qa_answer=d['qa_answer'],
table=d['table'],
title=d['title']
)
self.data.append(qa_item)
self.pointer = 0
def __iter__(self):
return self
def __next__(self):
pointer = self.pointer
if pointer < len(self):
self.pointer += 1
return self.data[pointer]
else:
self.pointer = 0
raise StopIteration
def __getitem__(self, item):
return self.data[item]
def __len__(self):
return len(self.data)
@dataclass
class QAItem(object):
id: int = None
qa_question: str = None
qa_column: str = None
qa_answer: str = None
table: Dict = None
title: str = None