fadliaulawi's picture
Tidy up prompts
cda22ff
from dotenv import load_dotenv
from img2table.document import Image
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from pdf2image import convert_from_path
from prompt import *
from prompt_old import *
from table_detector import detection_transform, device, model, ocr, outputs_to_objects
import io
import json
import os
import pandas as pd
import re
import torch
load_dotenv()
prompts = {
'alls': [prompt_entity_chunk, prompt_entity_combine],
'gsd': [prompt_entity_gsd_chunk, prompt_entity_gsd_combine],
'summ': [prompt_entity_summ_chunk, prompt_entity_summ_combine],
'all': [prompt_entities_chunk, prompt_entities_combine]
}
class Process():
def __init__(self, llm):
if llm.startswith('gpt'):
self.llm = ChatOpenAI(temperature=0, model_name=llm)
elif llm.startswith('gemini'):
self.llm = ChatGoogleGenerativeAI(temperature=0, model=llm)
else:
self.llm = ChatOpenAI(temperature=0, model_name=llm, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai")
def get_entity(self, data):
chunks, types = data
map_template = prompts[types][0]
map_prompt = PromptTemplate.from_template(map_template)
map_chain = LLMChain(llm=self.llm, prompt=map_prompt)
reduce_template = prompts[types][1]
reduce_prompt = PromptTemplate.from_template(reduce_template)
reduce_chain = LLMChain(llm=self.llm, prompt=reduce_prompt)
combine_chain = StuffDocumentsChain(
llm_chain=reduce_chain, document_variable_name="doc_summaries"
)
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_chain,
collapse_documents_chain=combine_chain,
token_max=100000,
)
map_reduce_chain = MapReduceDocumentsChain(
llm_chain=map_chain,
reduce_documents_chain=reduce_documents_chain,
document_variable_name="docs",
return_intermediate_steps=False,
)
result = map_reduce_chain.invoke(chunks)['output_text']
print(types)
print(result)
if types != 'summ':
result = eval(re.findall('(\{[^}]+\})', result)[0])
max_len = max([len(result[k]) for k in result])
for k in result:
while len(result[k]) < max_len:
result[k].append('')
return pd.DataFrame(result)
return result
def get_entity_one(self, chunks):
result = self.llm.invoke(prompt_entity_one_chunk.format(chunks)).content
print('One')
print(result)
result = re.findall('(\{[^}]+\})', result)[0]
return eval(result)
def get_table(self, path):
images = convert_from_path(path)
tables = []
# Loop pages
for image in images:
pixel_values = detection_transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(pixel_values)
id2label = model.config.id2label
id2label[len(model.config.id2label)] = "no object"
detected_tables = outputs_to_objects(outputs, image.size, id2label)
# Loop table in page (if any)
for idx in range(len(detected_tables)):
cropped_table = image.crop(detected_tables[idx]["bbox"])
if detected_tables[idx]["label"] == 'table rotated':
cropped_table = cropped_table.rotate(270, expand=True)
# TODO: what is the perfect threshold?
if detected_tables[idx]['score'] > 0.9:
print(detected_tables[idx])
tables.append(cropped_table)
df_result = pd.DataFrame()
# Loop tables
for table in tables:
buffer = io.BytesIO()
table.save(buffer, format='PNG')
image = Image(buffer)
# Extract to dataframe
extracted_tables = image.extract_tables(ocr=ocr, implicit_rows=True, borderless_tables=True, min_confidence=0)
if len(extracted_tables) == 0:
continue
# Combine multiple dataframe
df_table = extracted_tables[0].df
for extracted_table in extracted_tables[1:]:
df_table = pd.concat([df_table, extracted_table.df]).reset_index(drop=True)
df_table = df_table.fillna('')
# Ask LLM with JSON data
json_table = df_table.to_json(orient='records')
str_json_table = json.dumps(json.loads(json_table), indent=2)
result = self.llm.invoke(prompt_table.format(str_json_table)).content
print('table')
print(result)
result = result[result.find('['):result.rfind(']')+1]
try:
result = eval(result)
except SyntaxError:
result = []
df_result = pd.concat([df_result, pd.DataFrame(result)], ignore_index=True)
return df_result
def get_rsid(self, text):
rsids = re.findall('(rs[\d]{3,})', text)
rsids = list(set(rsids))
df_rsid = pd.DataFrame(rsids, columns=['rsID'])
return df_rsid