Spaces:
Build error
Build error
import base64 | |
import json | |
from prompts import * | |
import ast | |
from bs4 import BeautifulSoup | |
from semantic_retrieval import * | |
from llm_query_api import * | |
import base64 | |
from mimetypes import guess_type | |
class InputInstance: | |
def __init__(self, id=None, html_table=None, question=None, answer=None): | |
self.id = id | |
self.html_table = html_table | |
self.question = question | |
self.answer = answer | |
return | |
class MATSA: | |
def __init__(self, llm = "gpt-4"): | |
self.llm = llm | |
self.llm_query_api = LLMQueryAPI() #LLMProxyQueryAPI() | |
pass | |
def table_formatting_agent(self, html_table = None, table_image_path = None): | |
def local_image_to_data_url(image_path): | |
mime_type, _ = guess_type(image_path) | |
if mime_type is None: | |
mime_type = 'application/octet-stream' | |
with open(image_path, "rb") as image_file: | |
base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8') | |
return f"data:{mime_type};base64,{base64_encoded_data}" | |
if table_image_path != None: | |
tesseract = TesseractOCR() | |
pdf = PDF(src=table_image_path, pages=[0, 0]) | |
extracted_tables = pdf.extract_tables(ocr=tesseract, | |
implicit_rows=True, | |
borderless_tables=True,) | |
html_table = extracted_tables[0][0].html_repr() | |
table_image_data_url = local_image_to_data_url(table_image_path) | |
query = table_image_to_html_prompt.replace("{{html_table}}", html_table) | |
html_table = llm_query_api.get_llm_response("gpt-4V", query, table_image_data_url) | |
soup = BeautifulSoup(html_table, 'html.parser') | |
tr_tags = soup.find_all('tr') | |
for i, tr_tag in enumerate(tr_tags): | |
tr_tag['id'] = f"row-{i + 1}" # Assign unique ID using 'row-i' format | |
if i == 0: | |
th_tags = tr_tag.find_all('th') | |
for i, th_tag in enumerate(th_tags): | |
th_tag['id'] = f"col-{i + 1}" # Assign unique ID using 'col-i' format | |
return str(soup) | |
def description_augmentation_agent(self, html_table): | |
query = col_description_prompt.replace("{{html_table}}", str(html_table)) | |
col_augmented_html_table = self.llm_query_api.get_llm_response(self.llm, query) | |
query = row_description_prompt.replace("{{html_table}}", str(col_augmented_html_table)) | |
row_augmented_html_table = self.llm_query_api.get_llm_response(self.llm, query) | |
query = trend_description_prompt.replace("{{html_table}}", str(row_augmented_html_table)) | |
trend_augmented_html_table = self.llm_query_api.get_llm_response(self.llm, query) | |
return trend_augmented_html_table | |
def answer_decomposition_agent(self, answer): | |
prompt = answer_decomposition_prompt | |
query = prompt.replace("{{answer}}", answer) | |
res = self.llm_query_api.get_llm_response(self.llm, query) | |
res = ast.literal_eval(res) | |
if isinstance(res, list): | |
return res | |
else: | |
return None | |
def semantic_retreival_agent(self, html_table, fact_list, topK=5): | |
attributed_html_table, row_attribution_ids, col_attribution_ids = get_embedding_attribution(html_table, fact_list, topK) | |
return attributed_html_table, row_attribution_ids, col_attribution_ids | |
def sufficiency_attribution_agent(self, fact_list, attributed_html_table): | |
fact_verification_function = {} | |
fact_verification_list = [] | |
for i in range(len(fact_list)): | |
fact=fact_list[i] | |
fxn = {} | |
fxn["Fact " + str(i+1)+":"] = str(fact) | |
# fxn["Verified"] = "..." | |
fact_verification_list.append(fxn) | |
fact_verification_function["List of Fact"] = fact_verification_list | |
fact_verification_function["Row Citations"] = "[..., ..., ...]" | |
fact_verification_function["Column Citations"] = "[..., ..., ...]" | |
fact_verification_function["Explanation"] = "..." | |
fact_verification_function_string = json.dumps(fact_verification_function) | |
query = functional_attribution_prompt.replace("{{attributed_html_table}}", str(attributed_html_table)).replace("{{fact_verification_function}}", fact_verification_function_string) | |
attribution_fxn = self.llm_query_api.get_llm_response(self.llm, query) | |
attribution_fxn = attribution_fxn.replace("```json", "") | |
attribution_fxn = attribution_fxn.replace("```", "") | |
print(attribution_fxn) | |
attribution_fxn = json.loads(attribution_fxn) | |
if isinstance(attribution_fxn, dict): | |
return attribution_fxn | |
else: | |
return None | |
if __name__ == '__main__': | |
html_table = """<table> | |
<tr> | |
<th rowspan="1">Sr. Number</th> | |
<th colspan="3">Types</th> | |
<th rowspan="1">Remark</th> | |
</tr> | |
<tr> | |
<th> </th> | |
<th>A</th> | |
<th>B</th> | |
<th>C</th> | |
<th> </th> | |
</tr> | |
<tr> | |
<td>1</td> | |
<td>Mitten</td> | |
<td>Kity</td> | |
<td>Teddy</td> | |
<td>Names of cats</td> | |
</tr> | |
<tr> | |
<td>1</td> | |
<td>Tommy</td> | |
<td>Rudolph</td> | |
<td>Jerry</td> | |
<td>Names of dogs</td> | |
</tr> | |
</table>""" | |
answer = "Tommy is a dog but Mitten is a cat." | |
x = InputInstance(html_table=html_table, answer=answer) | |
matsa_agent = MATSA() | |
x_reformulated = matsa_agent.table_formatting_agent(x.html_table) | |
print(x_reformulated) | |
x_descriptions = matsa_agent.description_augmentation_agent(x_reformulated) | |
print(x_descriptions) | |
fact_list = matsa_agent.answer_decomposition_agent(x.answer) | |
print(fact_list) | |
attributed_html_table, row_attribution_ids, col_attribution_ids = matsa_agent.semantic_retreival_agent(x_descriptions, fact_list) | |
print(attributed_html_table) | |
attribution_fxn = matsa_agent.sufficiency_attribution_agent(fact_list, attributed_html_table) | |
print(attribution_fxn) | |
row_attribution_set = attribution_fxn["Row Citations"] | |
col_attribution_set = attribution_fxn["Column Citations"] | |
print(row_attribution_set) | |
print(col_attribution_set) | |