Matsa-demo / matsa.py
puneetm's picture
Upload folder using huggingface_hub
35d31f5 verified
import base64
import json
from prompts import *
import ast
from bs4 import BeautifulSoup
from semantic_retrieval import *
from llm_query_api import *
import base64
from mimetypes import guess_type
class InputInstance:
def __init__(self, id=None, html_table=None, question=None, answer=None):
self.id = id
self.html_table = html_table
self.question = question
self.answer = answer
return
class MATSA:
def __init__(self, llm = "gpt-4"):
self.llm = llm
self.llm_query_api = LLMQueryAPI() #LLMProxyQueryAPI()
pass
def table_formatting_agent(self, html_table = None, table_image_path = None):
def local_image_to_data_url(image_path):
mime_type, _ = guess_type(image_path)
if mime_type is None:
mime_type = 'application/octet-stream'
with open(image_path, "rb") as image_file:
base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
return f"data:{mime_type};base64,{base64_encoded_data}"
if table_image_path != None:
tesseract = TesseractOCR()
pdf = PDF(src=table_image_path, pages=[0, 0])
extracted_tables = pdf.extract_tables(ocr=tesseract,
implicit_rows=True,
borderless_tables=True,)
html_table = extracted_tables[0][0].html_repr()
table_image_data_url = local_image_to_data_url(table_image_path)
query = table_image_to_html_prompt.replace("{{html_table}}", html_table)
html_table = llm_query_api.get_llm_response("gpt-4V", query, table_image_data_url)
soup = BeautifulSoup(html_table, 'html.parser')
tr_tags = soup.find_all('tr')
for i, tr_tag in enumerate(tr_tags):
tr_tag['id'] = f"row-{i + 1}" # Assign unique ID using 'row-i' format
if i == 0:
th_tags = tr_tag.find_all('th')
for i, th_tag in enumerate(th_tags):
th_tag['id'] = f"col-{i + 1}" # Assign unique ID using 'col-i' format
return str(soup)
def description_augmentation_agent(self, html_table):
query = col_description_prompt.replace("{{html_table}}", str(html_table))
col_augmented_html_table = self.llm_query_api.get_llm_response(self.llm, query)
query = row_description_prompt.replace("{{html_table}}", str(col_augmented_html_table))
row_augmented_html_table = self.llm_query_api.get_llm_response(self.llm, query)
query = trend_description_prompt.replace("{{html_table}}", str(row_augmented_html_table))
trend_augmented_html_table = self.llm_query_api.get_llm_response(self.llm, query)
return trend_augmented_html_table
def answer_decomposition_agent(self, answer):
prompt = answer_decomposition_prompt
query = prompt.replace("{{answer}}", answer)
res = self.llm_query_api.get_llm_response(self.llm, query)
res = ast.literal_eval(res)
if isinstance(res, list):
return res
else:
return None
def semantic_retreival_agent(self, html_table, fact_list, topK=5):
attributed_html_table, row_attribution_ids, col_attribution_ids = get_embedding_attribution(html_table, fact_list, topK)
return attributed_html_table, row_attribution_ids, col_attribution_ids
def sufficiency_attribution_agent(self, fact_list, attributed_html_table):
fact_verification_function = {}
fact_verification_list = []
for i in range(len(fact_list)):
fact=fact_list[i]
fxn = {}
fxn["Fact " + str(i+1)+":"] = str(fact)
# fxn["Verified"] = "..."
fact_verification_list.append(fxn)
fact_verification_function["List of Fact"] = fact_verification_list
fact_verification_function["Row Citations"] = "[..., ..., ...]"
fact_verification_function["Column Citations"] = "[..., ..., ...]"
fact_verification_function["Explanation"] = "..."
fact_verification_function_string = json.dumps(fact_verification_function)
query = functional_attribution_prompt.replace("{{attributed_html_table}}", str(attributed_html_table)).replace("{{fact_verification_function}}", fact_verification_function_string)
attribution_fxn = self.llm_query_api.get_llm_response(self.llm, query)
attribution_fxn = attribution_fxn.replace("```json", "")
attribution_fxn = attribution_fxn.replace("```", "")
print(attribution_fxn)
attribution_fxn = json.loads(attribution_fxn)
if isinstance(attribution_fxn, dict):
return attribution_fxn
else:
return None
if __name__ == '__main__':
html_table = """<table>
<tr>
<th rowspan="1">Sr. Number</th>
<th colspan="3">Types</th>
<th rowspan="1">Remark</th>
</tr>
<tr>
<th> </th>
<th>A</th>
<th>B</th>
<th>C</th>
<th> </th>
</tr>
<tr>
<td>1</td>
<td>Mitten</td>
<td>Kity</td>
<td>Teddy</td>
<td>Names of cats</td>
</tr>
<tr>
<td>1</td>
<td>Tommy</td>
<td>Rudolph</td>
<td>Jerry</td>
<td>Names of dogs</td>
</tr>
</table>"""
answer = "Tommy is a dog but Mitten is a cat."
x = InputInstance(html_table=html_table, answer=answer)
matsa_agent = MATSA()
x_reformulated = matsa_agent.table_formatting_agent(x.html_table)
print(x_reformulated)
x_descriptions = matsa_agent.description_augmentation_agent(x_reformulated)
print(x_descriptions)
fact_list = matsa_agent.answer_decomposition_agent(x.answer)
print(fact_list)
attributed_html_table, row_attribution_ids, col_attribution_ids = matsa_agent.semantic_retreival_agent(x_descriptions, fact_list)
print(attributed_html_table)
attribution_fxn = matsa_agent.sufficiency_attribution_agent(fact_list, attributed_html_table)
print(attribution_fxn)
row_attribution_set = attribution_fxn["Row Citations"]
col_attribution_set = attribution_fxn["Column Citations"]
print(row_attribution_set)
print(col_attribution_set)