Spaces:
Build error
Build error
Upload folder using huggingface_hub
Browse files- .DS_Store +0 -0
- .gitattributes +2 -0
- .github/workflows/update_space.yml +28 -0
- .ipynb_checkpoints/config_gpt35-checkpoint.yaml +3 -0
- .ipynb_checkpoints/config_gpt4-checkpoint.yaml +3 -0
- .ipynb_checkpoints/demo-checkpoint.py +221 -0
- .ipynb_checkpoints/llm_query_api-checkpoint.py +216 -0
- .ipynb_checkpoints/matsa-checkpoint.py +182 -0
- .ipynb_checkpoints/semantic_retrieval-checkpoint.py +146 -0
- README.md +3 -9
- __pycache__/demo.cpython-310.pyc +0 -0
- __pycache__/llm_query_api.cpython-310.pyc +0 -0
- __pycache__/matsa.cpython-310.pyc +0 -0
- __pycache__/prompts.cpython-310.pyc +0 -0
- __pycache__/semantic_retrieval.cpython-310.pyc +0 -0
- config_gpt35.yaml +3 -0
- config_gpt4.yaml +3 -0
- demo.py +221 -0
- llm_query_api.py +216 -0
- matsa.py +182 -0
- prompts.py +69 -0
- requirements.txt +104 -0
- semantic_retrieval.py +146 -0
- tables_folder/MATSA_aitqa.json +0 -0
- tables_folder/MATSA_fetaqa.json +3 -0
- wkhtmltox_0.12.6-1.bionic_amd64.deb +3 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
tables_folder/MATSA_fetaqa.json filter=lfs diff=lfs merge=lfs -text
|
37 |
+
wkhtmltox_0.12.6-1.bionic_amd64.deb filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/update_space.yml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Run Python script
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
|
8 |
+
jobs:
|
9 |
+
build:
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
|
12 |
+
steps:
|
13 |
+
- name: Checkout
|
14 |
+
uses: actions/checkout@v2
|
15 |
+
|
16 |
+
- name: Set up Python
|
17 |
+
uses: actions/setup-python@v2
|
18 |
+
with:
|
19 |
+
python-version: '3.9'
|
20 |
+
|
21 |
+
- name: Install Gradio
|
22 |
+
run: python -m pip install gradio
|
23 |
+
|
24 |
+
- name: Log in to Hugging Face
|
25 |
+
run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
|
26 |
+
|
27 |
+
- name: Deploy to Spaces
|
28 |
+
run: gradio deploy
|
.ipynb_checkpoints/config_gpt35-checkpoint.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
API_BASE: https://dil-research.openai.azure.com/
|
2 |
+
API_KEY: de6e495251174ceab84f290cd3925b07
|
3 |
+
API_VERSION: '2023-05-15'
|
.ipynb_checkpoints/config_gpt4-checkpoint.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
API_BASE: https://dil-research-sweden-central.openai.azure.com/
|
2 |
+
API_KEY: 0df7237e0fc243a8830a23b2dbba2dcb
|
3 |
+
API_VERSION: "2023-12-01-preview"
|
.ipynb_checkpoints/demo-checkpoint.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
from bs4 import BeautifulSoup
|
4 |
+
from matsa import MATSA, InputInstance
|
5 |
+
import imgkit
|
6 |
+
import tempfile
|
7 |
+
import time
|
8 |
+
import threading
|
9 |
+
|
10 |
+
TABLE_FOLDER = "./tables_folder/MATSA_fetaqa.json"
|
11 |
+
# Load data from JSON file
|
12 |
+
def load_data():
|
13 |
+
with open(TABLE_FOLDER, 'r') as json_file:
|
14 |
+
return json.load(json_file)
|
15 |
+
|
16 |
+
# Global variable to store the loaded data
|
17 |
+
TABLE_DATA = load_data()
|
18 |
+
|
19 |
+
def get_table_names():
|
20 |
+
return [f"tab_{i+1}" for i in range(len(TABLE_DATA))]
|
21 |
+
|
22 |
+
def html_to_image(html_content):
|
23 |
+
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_img:
|
24 |
+
imgkit.from_string(html_content, temp_img.name)
|
25 |
+
return temp_img.name
|
26 |
+
|
27 |
+
def highlight_table(html_table, row_ids, col_ids):
|
28 |
+
soup = BeautifulSoup(html_table, 'html.parser')
|
29 |
+
row_sel = []
|
30 |
+
|
31 |
+
# Highlight rows
|
32 |
+
for row_id in row_ids:
|
33 |
+
row = soup.find('tr', id=row_id)
|
34 |
+
if row:
|
35 |
+
row_sel.append(row)
|
36 |
+
|
37 |
+
for col_id in col_ids:
|
38 |
+
col_index = int(col_id.split('-')[1]) - 1 # Convert col-1 to index 0, col-2 to index 1, etc.
|
39 |
+
for row in soup.find_all('tr'):
|
40 |
+
cells = row.find_all(['td', 'th'])
|
41 |
+
if row in row_sel:
|
42 |
+
if col_index < len(cells):
|
43 |
+
cells[col_index]['style'] = 'background-color: rgba(173, 216, 230, 0.7);'
|
44 |
+
else:
|
45 |
+
if col_index < len(cells):
|
46 |
+
cells[col_index]['style'] = 'background-color: rgba(211, 211, 211, 0.6);'
|
47 |
+
|
48 |
+
return str(soup)
|
49 |
+
|
50 |
+
def load_table_data(table_name):
|
51 |
+
if not table_name:
|
52 |
+
return None, "", ""
|
53 |
+
index = int(table_name.split('_')[1]) - 1
|
54 |
+
data = TABLE_DATA[index]
|
55 |
+
|
56 |
+
html_content = data['html_table']
|
57 |
+
question = data.get("question", "") #data['question']
|
58 |
+
if question is None:
|
59 |
+
question = ""
|
60 |
+
answer = data['answer_statement']
|
61 |
+
|
62 |
+
image_path = html_to_image(html_content)
|
63 |
+
|
64 |
+
return image_path, question, answer
|
65 |
+
|
66 |
+
def process_input(table_name, question, answer):
|
67 |
+
if not table_name:
|
68 |
+
return "Please select a table from the dropdown."
|
69 |
+
|
70 |
+
# Get the data for the selected table
|
71 |
+
index = int(table_name.split('_')[1]) - 1
|
72 |
+
data = TABLE_DATA[index]
|
73 |
+
|
74 |
+
html_content = data['html_table']
|
75 |
+
|
76 |
+
print("html_content: ", html_content)
|
77 |
+
print("question: ", question)
|
78 |
+
print("answer: ", answer)
|
79 |
+
|
80 |
+
# Initialize MATSA
|
81 |
+
matsa_agent = MATSA()
|
82 |
+
|
83 |
+
# Create input instance
|
84 |
+
instance = InputInstance(html_table=html_content, question=question, answer=answer)
|
85 |
+
|
86 |
+
# Apply MATSA pipeline
|
87 |
+
# formatted_table = matsa_agent.table_formatting_agent(instance.html_table)
|
88 |
+
augmented_table = matsa_agent.description_augmentation_agent(instance.html_table)
|
89 |
+
print("augmented_table: ", augmented_table)
|
90 |
+
fact_list = matsa_agent.answer_decomposition_agent(instance.answer)
|
91 |
+
print("fact_list: ", fact_list)
|
92 |
+
attributed_table, _, _ = matsa_agent.semantic_retreival_agent(augmented_table, fact_list)
|
93 |
+
print("attributed_table: ", attributed_table)
|
94 |
+
attribution_fxn = matsa_agent.sufficiency_attribution_agent(fact_list, attributed_table)
|
95 |
+
print("attribution_fxn: ", attribution_fxn)
|
96 |
+
|
97 |
+
# Get row and column attributions
|
98 |
+
row_attribution_set = attribution_fxn["Row Citations"]
|
99 |
+
col_attribution_set = attribution_fxn["Column Citations"]
|
100 |
+
explnation = attribution_fxn.get("Explanation", "")
|
101 |
+
print("row_attribution_set: ", row_attribution_set)
|
102 |
+
print("col_attribution_set: ", col_attribution_set)
|
103 |
+
print("Explanation: ", attribution_fxn.get("Explanation", ""))
|
104 |
+
|
105 |
+
# Convert string representations to lists
|
106 |
+
if isinstance(row_attribution_set, str):
|
107 |
+
row_ids = eval(row_attribution_set)
|
108 |
+
else:
|
109 |
+
row_ids = row_attribution_set
|
110 |
+
|
111 |
+
if isinstance(col_attribution_set, str):
|
112 |
+
col_ids = eval(col_attribution_set)
|
113 |
+
else:
|
114 |
+
col_ids = col_attribution_set
|
115 |
+
|
116 |
+
# Highlight the table
|
117 |
+
highlighted_table = highlight_table(instance.html_table, row_ids, col_ids)
|
118 |
+
|
119 |
+
result = {
|
120 |
+
"highlighted_table": highlighted_table,
|
121 |
+
"facts": attribution_fxn.get("List of Facts", []),
|
122 |
+
"row_citations": row_attribution_set,
|
123 |
+
"column_citations": col_attribution_set,
|
124 |
+
"Explanation": explnation
|
125 |
+
}
|
126 |
+
|
127 |
+
return json.dumps(result)
|
128 |
+
|
129 |
+
# Define Gradio interface
|
130 |
+
with gr.Blocks() as iface:
|
131 |
+
gr.Markdown("# MATSA: Table Question Answering with Attribution")
|
132 |
+
gr.Markdown("Select a table from dropdown load table image, question, and answer.")
|
133 |
+
gr.Markdown("Attributions are provided as per answer. You may change the question/answer as per your need.")
|
134 |
+
|
135 |
+
table_dropdown = gr.Dropdown(choices=get_table_names(), label="Select Table")
|
136 |
+
original_table = gr.Image(type="filepath", label="Original Table")
|
137 |
+
question_box = gr.Textbox(label="Question")
|
138 |
+
answer_box = gr.Textbox(label="Answer")
|
139 |
+
|
140 |
+
gr.Markdown("Click 'Process' to see the highlighted relevant parts. Click 'Reset' to start over.")
|
141 |
+
|
142 |
+
process_button = gr.Button("Process")
|
143 |
+
reset_button = gr.Button("Reset")
|
144 |
+
processing_time = gr.Textbox(label="Processing Time", value="0 seconds")
|
145 |
+
highlighted_table = gr.HTML(label="Highlighted Table")
|
146 |
+
explanation_box = gr.Textbox(label="Explanation")
|
147 |
+
|
148 |
+
def update_table_data(table_name):
|
149 |
+
image_path, question, answer = load_table_data(table_name)
|
150 |
+
return image_path, question, answer, gr.update(interactive=True)
|
151 |
+
|
152 |
+
def reset_app():
|
153 |
+
return (
|
154 |
+
gr.update(value="", interactive=True), # table_dropdown
|
155 |
+
None, # original_table
|
156 |
+
"", # question_box
|
157 |
+
"", # answer_box
|
158 |
+
"", # highlighted_table
|
159 |
+
"", # explanation_box
|
160 |
+
gr.update(interactive=True), # process_button
|
161 |
+
"0 seconds", # processing_time
|
162 |
+
)
|
163 |
+
|
164 |
+
def process_and_disable(table_name, question, answer):
|
165 |
+
processing = True
|
166 |
+
counter = 0
|
167 |
+
|
168 |
+
def update_counter():
|
169 |
+
nonlocal counter
|
170 |
+
while processing:
|
171 |
+
counter += 1
|
172 |
+
yield counter
|
173 |
+
time.sleep(1)
|
174 |
+
|
175 |
+
counter_thread = threading.Thread(target=update_counter)
|
176 |
+
counter_thread.start()
|
177 |
+
|
178 |
+
# Disable the dropdown and process button during processing
|
179 |
+
yield (
|
180 |
+
gr.update(interactive=False), # table_dropdown
|
181 |
+
gr.update(interactive=False), # process_button
|
182 |
+
gr.update(value="Processing..."), # processing_time
|
183 |
+
gr.update(), # highlighted_table
|
184 |
+
gr.update(), # explanation_box
|
185 |
+
)
|
186 |
+
|
187 |
+
# Process the input
|
188 |
+
result = process_input(table_name, question, answer)
|
189 |
+
result_dict = json.loads(result)
|
190 |
+
|
191 |
+
# Stop the counter
|
192 |
+
processing = False
|
193 |
+
counter_thread.join()
|
194 |
+
|
195 |
+
# Re-enable the dropdown and process button, update processing time, and return the result
|
196 |
+
yield (
|
197 |
+
gr.update(interactive=True), # table_dropdown
|
198 |
+
gr.update(interactive=True), # process_button
|
199 |
+
f"Processed in {counter} seconds", # processing_time
|
200 |
+
gr.update(value=result_dict['highlighted_table']), # highlighted_table
|
201 |
+
gr.update(value=result_dict.get('Explanation', '')), # explanation_box
|
202 |
+
{
|
203 |
+
"Row Citations": result_dict['row_citations'],
|
204 |
+
"Column Citations": result_dict['column_citations']
|
205 |
+
}
|
206 |
+
)
|
207 |
+
|
208 |
+
table_dropdown.change(update_table_data,
|
209 |
+
inputs=[table_dropdown],
|
210 |
+
outputs=[original_table, question_box, answer_box, process_button])
|
211 |
+
|
212 |
+
process_button.click(process_and_disable,
|
213 |
+
inputs=[table_dropdown, question_box, answer_box],
|
214 |
+
outputs=[table_dropdown, process_button, processing_time, highlighted_table, explanation_box])
|
215 |
+
|
216 |
+
reset_button.click(reset_app,
|
217 |
+
inputs=[],
|
218 |
+
outputs=[table_dropdown, original_table, question_box, answer_box, highlighted_table, explanation_box, process_button, processing_time])
|
219 |
+
|
220 |
+
# Launch the interface
|
221 |
+
iface.launch(share=True)
|
.ipynb_checkpoints/llm_query_api-checkpoint.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import copy
|
3 |
+
import io
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import time
|
7 |
+
import re
|
8 |
+
import json
|
9 |
+
import argparse
|
10 |
+
import yaml
|
11 |
+
import openai
|
12 |
+
from openai import AzureOpenAI
|
13 |
+
from prompts import *
|
14 |
+
import base64
|
15 |
+
from mimetypes import guess_type
|
16 |
+
# from img2table.document import Image
|
17 |
+
# from img2table.document import PDF
|
18 |
+
# from img2table.ocr import TesseractOCR
|
19 |
+
# from img2table.ocr import EasyOCR
|
20 |
+
from PIL import Image as PILImage
|
21 |
+
|
22 |
+
class LLMQueryAPI:
|
23 |
+
|
24 |
+
def __init__(self) -> None:
|
25 |
+
pass
|
26 |
+
|
27 |
+
def gpt4_chat_completion(self, query):
|
28 |
+
|
29 |
+
with open('config_gpt4.yaml', 'r') as f:
|
30 |
+
config = yaml.safe_load(f)
|
31 |
+
|
32 |
+
API_KEY = config.get('API_KEY')
|
33 |
+
API_VERSION = config.get('API_VERSION')
|
34 |
+
API_BASE = config.get('API_BASE')
|
35 |
+
|
36 |
+
client = AzureOpenAI(
|
37 |
+
azure_endpoint= API_BASE,
|
38 |
+
api_version= API_VERSION,
|
39 |
+
api_key = API_KEY
|
40 |
+
)
|
41 |
+
|
42 |
+
deployment_name='gpt-4-2024-04-09'
|
43 |
+
|
44 |
+
response = client.chat.completions.create(
|
45 |
+
model=deployment_name,
|
46 |
+
messages=query)
|
47 |
+
|
48 |
+
return response.choices[0].message.content
|
49 |
+
|
50 |
+
def gpt35_chat_completion(self, query):
|
51 |
+
|
52 |
+
with open('config_gpt35.yaml', 'r') as f:
|
53 |
+
config = yaml.safe_load(f)
|
54 |
+
|
55 |
+
API_KEY = config.get('API_KEY')
|
56 |
+
API_VERSION = config.get('API_VERSION')
|
57 |
+
API_BASE = config.get('API_BASE')
|
58 |
+
|
59 |
+
response = openai.ChatCompletion.create(
|
60 |
+
engine='gpt-35-turbo-0613',
|
61 |
+
messages=query,
|
62 |
+
request_timeout=60,
|
63 |
+
api_key = API_KEY,
|
64 |
+
api_version = API_VERSION,
|
65 |
+
api_type = "azure",
|
66 |
+
api_base = API_BASE,
|
67 |
+
)
|
68 |
+
|
69 |
+
return response['choices'][0]['message']
|
70 |
+
|
71 |
+
def copilot_chat_completion(self, query):
|
72 |
+
|
73 |
+
with open('config_gpt4.yaml', 'r') as f:
|
74 |
+
config = yaml.safe_load(f)
|
75 |
+
|
76 |
+
API_KEY = config.get('API_KEY')
|
77 |
+
API_VERSION = config.get('API_VERSION')
|
78 |
+
API_BASE = config.get('API_BASE')
|
79 |
+
|
80 |
+
response = openai.ChatCompletion.create(
|
81 |
+
engine='gpt-4-0613',
|
82 |
+
messages=query,
|
83 |
+
request_timeout=60,
|
84 |
+
api_key = API_KEY,
|
85 |
+
api_version = API_VERSION,
|
86 |
+
api_type = "azure",
|
87 |
+
api_base = API_BASE,
|
88 |
+
)
|
89 |
+
return response['choices'][0]['message']
|
90 |
+
|
91 |
+
def LLM_chat_query(self, query, llm):
|
92 |
+
|
93 |
+
if llm == 'gpt-3.5-turbo':
|
94 |
+
return self.gpt35_chat_completion(query)
|
95 |
+
elif llm == "gpt-4":
|
96 |
+
return self.gpt4_chat_completion(query)
|
97 |
+
# return self.copilot_chat_completion(query)
|
98 |
+
|
99 |
+
def get_llm_response(self, llm, query):
|
100 |
+
chat_completion = []
|
101 |
+
chat_completion.append({"role": "system", "content": query})
|
102 |
+
res = self.LLM_chat_query(chat_completion, llm)
|
103 |
+
return res
|
104 |
+
|
105 |
+
class LLMProxyQueryAPI:
|
106 |
+
|
107 |
+
def __init__(self) -> None:
|
108 |
+
pass
|
109 |
+
|
110 |
+
def gpt35_chat_completion(self, query):
|
111 |
+
client = openai.Client()
|
112 |
+
response = client.chat.completions.create(
|
113 |
+
model="gpt-3.5-turbo-16k",
|
114 |
+
messages=query,
|
115 |
+
)
|
116 |
+
return response.choices[0].message.content
|
117 |
+
|
118 |
+
def gpt4o_chat_completion(self, query):
|
119 |
+
client = openai.Client()
|
120 |
+
response = client.chat.completions.create(
|
121 |
+
model="gpt-4o",
|
122 |
+
messages=query,
|
123 |
+
)
|
124 |
+
return response.choices[0].message.content
|
125 |
+
|
126 |
+
def gpt4_chat_completion(self, query):
|
127 |
+
client = openai.Client()
|
128 |
+
response = client.chat.completions.create(
|
129 |
+
model="gpt-4-1106-preview",
|
130 |
+
messages=query,
|
131 |
+
)
|
132 |
+
return response.choices[0].message.content
|
133 |
+
|
134 |
+
def gpt4_vision(self, query, image_path):
|
135 |
+
|
136 |
+
print(query)
|
137 |
+
|
138 |
+
client = openai.Client()
|
139 |
+
response = client.chat.completions.create(
|
140 |
+
model="gpt-4-vision-preview",
|
141 |
+
messages=[
|
142 |
+
{
|
143 |
+
"role": "user",
|
144 |
+
"content": [
|
145 |
+
{
|
146 |
+
"type": "text",
|
147 |
+
"text": query
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"type": "image_url",
|
151 |
+
"image_url": {
|
152 |
+
"url": image_path
|
153 |
+
}
|
154 |
+
}
|
155 |
+
]
|
156 |
+
}
|
157 |
+
],
|
158 |
+
max_tokens=4096,
|
159 |
+
stream=False
|
160 |
+
)
|
161 |
+
return response.choices[0].message.content
|
162 |
+
|
163 |
+
def LLM_chat_query(self, llm, query, image_path=None):
|
164 |
+
|
165 |
+
if llm == 'gpt-3.5-turbo':
|
166 |
+
return self.gpt35_chat_completion(query)
|
167 |
+
|
168 |
+
elif llm == "gpt-4":
|
169 |
+
return self.gpt4_chat_completion(query)
|
170 |
+
|
171 |
+
elif llm == "gpt-4o":
|
172 |
+
return self.gpt4o_chat_completion(query)
|
173 |
+
|
174 |
+
elif llm == "gpt-4V":
|
175 |
+
return self.gpt4_vision(query, image_path)
|
176 |
+
|
177 |
+
def get_llm_response(self, llm, query, image_path=None):
|
178 |
+
|
179 |
+
if llm == "gpt-4V" and image_path:
|
180 |
+
res = self.LLM_chat_query(llm, query, image_path)
|
181 |
+
return res
|
182 |
+
|
183 |
+
chat_completion = []
|
184 |
+
chat_completion.append({"role": "system", "content": query})
|
185 |
+
res = self.LLM_chat_query(llm, chat_completion)
|
186 |
+
return res
|
187 |
+
|
188 |
+
# if __name__ == '__main__':
|
189 |
+
|
190 |
+
# llm_query_api = LLMProxyQueryAPI()
|
191 |
+
|
192 |
+
# def local_image_to_data_url(image_path):
|
193 |
+
# mime_type, _ = guess_type(image_path)
|
194 |
+
# if mime_type is None:
|
195 |
+
# mime_type = 'application/octet-stream'
|
196 |
+
|
197 |
+
# with open(image_path, "rb") as image_file:
|
198 |
+
# base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
|
199 |
+
|
200 |
+
# return f"data:{mime_type};base64,{base64_encoded_data}"
|
201 |
+
|
202 |
+
# tesseract = TesseractOCR()
|
203 |
+
# pdf = PDF(src="temp3.pdf", pages=[0, 0])
|
204 |
+
# extracted_tables = pdf.extract_tables(ocr=tesseract,
|
205 |
+
# implicit_rows=True,
|
206 |
+
# borderless_tables=True,)
|
207 |
+
# html_table = extracted_tables[0][0].html_repr()
|
208 |
+
# print(html_table)
|
209 |
+
|
210 |
+
# table_image_path = "./temp3.jpeg"
|
211 |
+
# table_image_data_url = local_image_to_data_url(table_image_path)
|
212 |
+
# print(table_image_data_url)
|
213 |
+
# query = table_image_to_html_prompt.replace("{{html_table}}", html_table)
|
214 |
+
# html_table_refined = llm_query_api.get_llm_response("gpt-4V", query, table_image_data_url)
|
215 |
+
# print(html_table_refined)
|
216 |
+
|
.ipynb_checkpoints/matsa-checkpoint.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import json
|
3 |
+
from prompts import *
|
4 |
+
import ast
|
5 |
+
from bs4 import BeautifulSoup
|
6 |
+
from semantic_retrieval import *
|
7 |
+
from llm_query_api import *
|
8 |
+
import base64
|
9 |
+
from mimetypes import guess_type
|
10 |
+
|
11 |
+
class InputInstance:
|
12 |
+
def __init__(self, id=None, html_table=None, question=None, answer=None):
|
13 |
+
self.id = id
|
14 |
+
self.html_table = html_table
|
15 |
+
self.question = question
|
16 |
+
self.answer = answer
|
17 |
+
|
18 |
+
return
|
19 |
+
|
20 |
+
class MATSA:
|
21 |
+
def __init__(self, llm = "gpt-4"):
|
22 |
+
self.llm = llm
|
23 |
+
self.llm_query_api = LLMQueryAPI() #LLMProxyQueryAPI()
|
24 |
+
pass
|
25 |
+
|
26 |
+
def table_formatting_agent(self, html_table = None, table_image_path = None):
|
27 |
+
|
28 |
+
def local_image_to_data_url(image_path):
|
29 |
+
mime_type, _ = guess_type(image_path)
|
30 |
+
if mime_type is None:
|
31 |
+
mime_type = 'application/octet-stream'
|
32 |
+
|
33 |
+
with open(image_path, "rb") as image_file:
|
34 |
+
base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
|
35 |
+
|
36 |
+
return f"data:{mime_type};base64,{base64_encoded_data}"
|
37 |
+
|
38 |
+
if table_image_path != None:
|
39 |
+
tesseract = TesseractOCR()
|
40 |
+
pdf = PDF(src=table_image_path, pages=[0, 0])
|
41 |
+
extracted_tables = pdf.extract_tables(ocr=tesseract,
|
42 |
+
implicit_rows=True,
|
43 |
+
borderless_tables=True,)
|
44 |
+
html_table = extracted_tables[0][0].html_repr()
|
45 |
+
|
46 |
+
table_image_data_url = local_image_to_data_url(table_image_path)
|
47 |
+
query = table_image_to_html_prompt.replace("{{html_table}}", html_table)
|
48 |
+
html_table = llm_query_api.get_llm_response("gpt-4V", query, table_image_data_url)
|
49 |
+
|
50 |
+
soup = BeautifulSoup(html_table, 'html.parser')
|
51 |
+
tr_tags = soup.find_all('tr')
|
52 |
+
for i, tr_tag in enumerate(tr_tags):
|
53 |
+
tr_tag['id'] = f"row-{i + 1}" # Assign unique ID using 'row-i' format
|
54 |
+
|
55 |
+
if i == 0:
|
56 |
+
th_tags = tr_tag.find_all('th')
|
57 |
+
for i, th_tag in enumerate(th_tags):
|
58 |
+
th_tag['id'] = f"col-{i + 1}" # Assign unique ID using 'col-i' format
|
59 |
+
|
60 |
+
return str(soup)
|
61 |
+
|
62 |
+
def description_augmentation_agent(self, html_table):
|
63 |
+
|
64 |
+
query = col_description_prompt.replace("{{html_table}}", str(html_table))
|
65 |
+
col_augmented_html_table = self.llm_query_api.get_llm_response(self.llm, query)
|
66 |
+
|
67 |
+
query = row_description_prompt.replace("{{html_table}}", str(col_augmented_html_table))
|
68 |
+
row_augmented_html_table = self.llm_query_api.get_llm_response(self.llm, query)
|
69 |
+
|
70 |
+
query = trend_description_prompt.replace("{{html_table}}", str(row_augmented_html_table))
|
71 |
+
trend_augmented_html_table = self.llm_query_api.get_llm_response(self.llm, query)
|
72 |
+
|
73 |
+
return trend_augmented_html_table
|
74 |
+
|
75 |
+
def answer_decomposition_agent(self, answer):
|
76 |
+
|
77 |
+
prompt = answer_decomposition_prompt
|
78 |
+
query = prompt.replace("{{answer}}", answer)
|
79 |
+
res = self.llm_query_api.get_llm_response(self.llm, query)
|
80 |
+
res = ast.literal_eval(res)
|
81 |
+
if isinstance(res, list):
|
82 |
+
return res
|
83 |
+
else:
|
84 |
+
return None
|
85 |
+
|
86 |
+
def semantic_retreival_agent(self, html_table, fact_list, topK=5):
|
87 |
+
|
88 |
+
attributed_html_table, row_attribution_ids, col_attribution_ids = get_embedding_attribution(html_table, fact_list, topK)
|
89 |
+
return attributed_html_table, row_attribution_ids, col_attribution_ids
|
90 |
+
|
91 |
+
def sufficiency_attribution_agent(self, fact_list, attributed_html_table):
|
92 |
+
|
93 |
+
fact_verification_function = {}
|
94 |
+
|
95 |
+
fact_verification_list = []
|
96 |
+
|
97 |
+
for i in range(len(fact_list)):
|
98 |
+
fact=fact_list[i]
|
99 |
+
fxn = {}
|
100 |
+
fxn["Fact " + str(i+1)+":"] = str(fact)
|
101 |
+
# fxn["Verified"] = "..."
|
102 |
+
fact_verification_list.append(fxn)
|
103 |
+
|
104 |
+
fact_verification_function["List of Fact"] = fact_verification_list
|
105 |
+
|
106 |
+
fact_verification_function["Row Citations"] = "[..., ..., ...]"
|
107 |
+
fact_verification_function["Column Citations"] = "[..., ..., ...]"
|
108 |
+
fact_verification_function["Explanation"] = "..."
|
109 |
+
|
110 |
+
fact_verification_function_string = json.dumps(fact_verification_function)
|
111 |
+
|
112 |
+
query = functional_attribution_prompt.replace("{{attributed_html_table}}", str(attributed_html_table)).replace("{{fact_verification_function}}", fact_verification_function_string)
|
113 |
+
attribution_fxn = self.llm_query_api.get_llm_response(self.llm, query)
|
114 |
+
|
115 |
+
attribution_fxn = attribution_fxn.replace("```json", "")
|
116 |
+
attribution_fxn = attribution_fxn.replace("```", "")
|
117 |
+
print(attribution_fxn)
|
118 |
+
attribution_fxn = json.loads(attribution_fxn)
|
119 |
+
|
120 |
+
if isinstance(attribution_fxn, dict):
|
121 |
+
return attribution_fxn
|
122 |
+
else:
|
123 |
+
return None
|
124 |
+
|
125 |
+
if __name__ == '__main__':
|
126 |
+
|
127 |
+
html_table = """<table>
|
128 |
+
<tr>
|
129 |
+
<th rowspan="1">Sr. Number</th>
|
130 |
+
<th colspan="3">Types</th>
|
131 |
+
<th rowspan="1">Remark</th>
|
132 |
+
</tr>
|
133 |
+
<tr>
|
134 |
+
<th> </th>
|
135 |
+
<th>A</th>
|
136 |
+
<th>B</th>
|
137 |
+
<th>C</th>
|
138 |
+
<th> </th>
|
139 |
+
</tr>
|
140 |
+
<tr>
|
141 |
+
<td>1</td>
|
142 |
+
<td>Mitten</td>
|
143 |
+
<td>Kity</td>
|
144 |
+
<td>Teddy</td>
|
145 |
+
<td>Names of cats</td>
|
146 |
+
</tr>
|
147 |
+
<tr>
|
148 |
+
<td>1</td>
|
149 |
+
<td>Tommy</td>
|
150 |
+
<td>Rudolph</td>
|
151 |
+
<td>Jerry</td>
|
152 |
+
<td>Names of dogs</td>
|
153 |
+
</tr>
|
154 |
+
</table>"""
|
155 |
+
|
156 |
+
answer = "Tommy is a dog but Mitten is a cat."
|
157 |
+
|
158 |
+
|
159 |
+
x = InputInstance(html_table=html_table, answer=answer)
|
160 |
+
|
161 |
+
matsa_agent = MATSA()
|
162 |
+
|
163 |
+
x_reformulated = matsa_agent.table_formatting_agent(x.html_table)
|
164 |
+
print(x_reformulated)
|
165 |
+
|
166 |
+
x_descriptions = matsa_agent.description_augmentation_agent(x_reformulated)
|
167 |
+
print(x_descriptions)
|
168 |
+
|
169 |
+
fact_list = matsa_agent.answer_decomposition_agent(x.answer)
|
170 |
+
print(fact_list)
|
171 |
+
|
172 |
+
attributed_html_table, row_attribution_ids, col_attribution_ids = matsa_agent.semantic_retreival_agent(x_descriptions, fact_list)
|
173 |
+
print(attributed_html_table)
|
174 |
+
|
175 |
+
attribution_fxn = matsa_agent.sufficiency_attribution_agent(fact_list, attributed_html_table)
|
176 |
+
print(attribution_fxn)
|
177 |
+
|
178 |
+
row_attribution_set = attribution_fxn["Row Citations"]
|
179 |
+
col_attribution_set = attribution_fxn["Column Citations"]
|
180 |
+
|
181 |
+
print(row_attribution_set)
|
182 |
+
print(col_attribution_set)
|
.ipynb_checkpoints/semantic_retrieval-checkpoint.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
from sklearn.preprocessing import minmax_scale
|
4 |
+
from sentence_transformers import SentenceTransformer, util
|
5 |
+
import numpy as np
|
6 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
7 |
+
sbert = SentenceTransformer("all-MiniLM-L6-v2")
|
8 |
+
from llm_query_api import *
|
9 |
+
|
10 |
+
def get_row_embedding(html_table):
|
11 |
+
|
12 |
+
def get_row_elements(html_table):
|
13 |
+
tr_elements = []
|
14 |
+
soup = BeautifulSoup(html_table, 'html.parser')
|
15 |
+
tr_tags = soup.find_all('tr')
|
16 |
+
for t in tr_tags:
|
17 |
+
temp = " " + str(t.get('description'))
|
18 |
+
try:
|
19 |
+
tr_elements.append({'id':str(t.get('id')), 'text': temp})
|
20 |
+
except:
|
21 |
+
pass
|
22 |
+
return tr_elements
|
23 |
+
|
24 |
+
rows = get_row_elements(html_table)
|
25 |
+
|
26 |
+
all_elements = rows
|
27 |
+
sentences = []
|
28 |
+
element_ids = []
|
29 |
+
for i in range(len(all_elements)):
|
30 |
+
sentences.append(all_elements[i]['text'])
|
31 |
+
element_ids.append(all_elements[i]['id'])
|
32 |
+
|
33 |
+
embeddings = sbert.encode(sentences, convert_to_tensor=True).cpu().numpy()
|
34 |
+
return embeddings, element_ids
|
35 |
+
|
36 |
+
def get_col_embedding(html_table):
|
37 |
+
|
38 |
+
def get_column_elements(html_table):
|
39 |
+
th_elements = []
|
40 |
+
soup = BeautifulSoup(html_table, 'html.parser')
|
41 |
+
th_tags = soup.find_all('th')
|
42 |
+
for t in th_tags:
|
43 |
+
temp = " " + str(t.get('description'))
|
44 |
+
try:
|
45 |
+
th_elements.append({'id':str(t.get('id')), 'text': temp})
|
46 |
+
except:
|
47 |
+
pass
|
48 |
+
|
49 |
+
return th_elements
|
50 |
+
|
51 |
+
cols = get_column_elements(html_table)
|
52 |
+
|
53 |
+
all_elements = cols
|
54 |
+
sentences = []
|
55 |
+
element_ids = []
|
56 |
+
for i in range(len(all_elements)):
|
57 |
+
sentences.append(all_elements[i]['text'])
|
58 |
+
element_ids.append(all_elements[i]['id'])
|
59 |
+
|
60 |
+
embeddings = sbert.encode(sentences, convert_to_tensor=True).cpu().numpy()
|
61 |
+
return embeddings, element_ids
|
62 |
+
|
63 |
+
def normalize_list_numpy(list_numpy):
|
64 |
+
normalized_list = minmax_scale(list_numpy)
|
65 |
+
return normalized_list
|
66 |
+
|
67 |
+
def get_answer_embedding(answer):
|
68 |
+
return sbert.encode([answer], convert_to_tensor=True).cpu().numpy()
|
69 |
+
|
70 |
+
def row_attribution(answer, html_table, topk=5, threshold = 0.7):
|
71 |
+
|
72 |
+
answer_embedding = get_answer_embedding(answer)
|
73 |
+
row_embedding = get_row_embedding(html_table)
|
74 |
+
|
75 |
+
similarities = cosine_similarity(row_embedding[0], answer_embedding.reshape(1, -1))
|
76 |
+
sims = similarities.flatten()
|
77 |
+
sims = normalize_list_numpy(sims)
|
78 |
+
#if no of rows >= 5, take max of (5, 1/3 x rows)
|
79 |
+
#else if no of rows < 5, take least of (5, rows)
|
80 |
+
k = max(topk, int(0.3*len(sims)))
|
81 |
+
k = min(k, len(sims))
|
82 |
+
top_k_indices = np.argpartition(sims, -k)[-k:]
|
83 |
+
sorted_indices = top_k_indices[np.argsort(sims[top_k_indices])][::-1]
|
84 |
+
top_k_results = [row_embedding[1][idx] for idx in sorted_indices]
|
85 |
+
|
86 |
+
return top_k_results
|
87 |
+
|
88 |
+
def col_attribution(answer, html_table, topk=5, threshold = 0.7):
|
89 |
+
|
90 |
+
answer_embedding = get_answer_embedding(answer)
|
91 |
+
col_embedding = get_col_embedding(html_table)
|
92 |
+
|
93 |
+
similarities = cosine_similarity(col_embedding[0], answer_embedding.reshape(1, -1))
|
94 |
+
sims = similarities.flatten()
|
95 |
+
sims = normalize_list_numpy(sims)
|
96 |
+
#if no of cols >= 5, take max of (5, 1/3 x cols)
|
97 |
+
#else if no of cols < 5, take least of (5, cols)
|
98 |
+
k = max(topk, int(0.3*len(sims)))
|
99 |
+
k = min(k, len(sims))
|
100 |
+
top_k_indices = np.argpartition(sims, -k)[-k:]
|
101 |
+
sorted_indices = top_k_indices[np.argsort(sims[top_k_indices])][::-1]
|
102 |
+
top_k_results = [col_embedding[1][idx] for idx in sorted_indices]
|
103 |
+
|
104 |
+
return top_k_results
|
105 |
+
|
106 |
+
def retain_rows_and_columns(augmented_html_table, row_ids, column_ids):
|
107 |
+
soup = BeautifulSoup(augmented_html_table, 'html.parser')
|
108 |
+
|
109 |
+
row_ids = list(set(row_ids))
|
110 |
+
column_ids = list(set(column_ids))
|
111 |
+
|
112 |
+
# Retain specified rows and remove others
|
113 |
+
all_rows = soup.find_all('tr')
|
114 |
+
for row in all_rows:
|
115 |
+
if row.get('id') not in row_ids:
|
116 |
+
row.decompose()
|
117 |
+
|
118 |
+
# Retain specified columns and remove others
|
119 |
+
if all_rows:
|
120 |
+
all_columns = all_rows[0].find_all(['th'])
|
121 |
+
for i, col in enumerate(all_columns):
|
122 |
+
if col.get('id') not in column_ids:
|
123 |
+
for row in soup.find_all('tr'):
|
124 |
+
cells = row.find_all(['td', 'th'])
|
125 |
+
if len(cells) > i:
|
126 |
+
cells[i].decompose()
|
127 |
+
|
128 |
+
return str(soup)
|
129 |
+
|
130 |
+
def get_embedding_attribution(augmented_html_table, decomposed_fact_list, topK=5, threshold = 0.7):
|
131 |
+
|
132 |
+
row_attribution_ids = []
|
133 |
+
col_attribution_ids = []
|
134 |
+
|
135 |
+
for i in range(len(decomposed_fact_list)):
|
136 |
+
answer = decomposed_fact_list[i]
|
137 |
+
|
138 |
+
rorAttr = row_attribution(answer, augmented_html_table, topK)
|
139 |
+
colAttr = col_attribution(answer, augmented_html_table, topK)
|
140 |
+
|
141 |
+
row_attribution_ids.extend(rorAttr)
|
142 |
+
col_attribution_ids.extend(colAttr)
|
143 |
+
|
144 |
+
attributed_html_table = retain_rows_and_columns(augmented_html_table, row_attribution_ids, col_attribution_ids)
|
145 |
+
|
146 |
+
return attributed_html_table, row_attribution_ids, col_attribution_ids
|
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title: Matsa
|
3 |
-
|
4 |
-
colorFrom: pink
|
5 |
-
colorTo: red
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Matsa-demo
|
3 |
+
app_file: demo.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
sdk_version: 4.41.0
|
|
|
|
|
6 |
---
|
|
|
|
__pycache__/demo.cpython-310.pyc
ADDED
Binary file (5.95 kB). View file
|
|
__pycache__/llm_query_api.cpython-310.pyc
ADDED
Binary file (4.44 kB). View file
|
|
__pycache__/matsa.cpython-310.pyc
ADDED
Binary file (5.21 kB). View file
|
|
__pycache__/prompts.cpython-310.pyc
ADDED
Binary file (6.37 kB). View file
|
|
__pycache__/semantic_retrieval.cpython-310.pyc
ADDED
Binary file (3.89 kB). View file
|
|
config_gpt35.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
API_BASE: https://dil-research.openai.azure.com/
|
2 |
+
API_KEY: de6e495251174ceab84f290cd3925b07
|
3 |
+
API_VERSION: '2023-05-15'
|
config_gpt4.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
API_BASE: https://dil-research-sweden-central.openai.azure.com/
|
2 |
+
API_KEY: 0df7237e0fc243a8830a23b2dbba2dcb
|
3 |
+
API_VERSION: "2023-12-01-preview"
|
demo.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
from bs4 import BeautifulSoup
|
4 |
+
from matsa import MATSA, InputInstance
|
5 |
+
import imgkit
|
6 |
+
import tempfile
|
7 |
+
import time
|
8 |
+
import threading
|
9 |
+
|
10 |
+
TABLE_FOLDER = "./tables_folder/MATSA_fetaqa.json"
|
11 |
+
# Load data from JSON file
|
12 |
+
def load_data():
|
13 |
+
with open(TABLE_FOLDER, 'r') as json_file:
|
14 |
+
return json.load(json_file)
|
15 |
+
|
16 |
+
# Global variable to store the loaded data
|
17 |
+
TABLE_DATA = load_data()
|
18 |
+
|
19 |
+
def get_table_names():
|
20 |
+
return [f"tab_{i+1}" for i in range(len(TABLE_DATA))]
|
21 |
+
|
22 |
+
def html_to_image(html_content):
|
23 |
+
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_img:
|
24 |
+
imgkit.from_string(html_content, temp_img.name)
|
25 |
+
return temp_img.name
|
26 |
+
|
27 |
+
def highlight_table(html_table, row_ids, col_ids):
|
28 |
+
soup = BeautifulSoup(html_table, 'html.parser')
|
29 |
+
row_sel = []
|
30 |
+
|
31 |
+
# Highlight rows
|
32 |
+
for row_id in row_ids:
|
33 |
+
row = soup.find('tr', id=row_id)
|
34 |
+
if row:
|
35 |
+
row_sel.append(row)
|
36 |
+
|
37 |
+
for col_id in col_ids:
|
38 |
+
col_index = int(col_id.split('-')[1]) - 1 # Convert col-1 to index 0, col-2 to index 1, etc.
|
39 |
+
for row in soup.find_all('tr'):
|
40 |
+
cells = row.find_all(['td', 'th'])
|
41 |
+
if row in row_sel:
|
42 |
+
if col_index < len(cells):
|
43 |
+
cells[col_index]['style'] = 'background-color: rgba(173, 216, 230, 0.7);'
|
44 |
+
else:
|
45 |
+
if col_index < len(cells):
|
46 |
+
cells[col_index]['style'] = 'background-color: rgba(211, 211, 211, 0.6);'
|
47 |
+
|
48 |
+
return str(soup)
|
49 |
+
|
50 |
+
def load_table_data(table_name):
|
51 |
+
if not table_name:
|
52 |
+
return None, "", ""
|
53 |
+
index = int(table_name.split('_')[1]) - 1
|
54 |
+
data = TABLE_DATA[index]
|
55 |
+
|
56 |
+
html_content = data['html_table']
|
57 |
+
question = data.get("question", "") #data['question']
|
58 |
+
if question is None:
|
59 |
+
question = ""
|
60 |
+
answer = data['answer_statement']
|
61 |
+
|
62 |
+
image_path = html_to_image(html_content)
|
63 |
+
|
64 |
+
return image_path, question, answer
|
65 |
+
|
66 |
+
def process_input(table_name, question, answer):
|
67 |
+
if not table_name:
|
68 |
+
return "Please select a table from the dropdown."
|
69 |
+
|
70 |
+
# Get the data for the selected table
|
71 |
+
index = int(table_name.split('_')[1]) - 1
|
72 |
+
data = TABLE_DATA[index]
|
73 |
+
|
74 |
+
html_content = data['html_table']
|
75 |
+
|
76 |
+
print("html_content: ", html_content)
|
77 |
+
print("question: ", question)
|
78 |
+
print("answer: ", answer)
|
79 |
+
|
80 |
+
# Initialize MATSA
|
81 |
+
matsa_agent = MATSA()
|
82 |
+
|
83 |
+
# Create input instance
|
84 |
+
instance = InputInstance(html_table=html_content, question=question, answer=answer)
|
85 |
+
|
86 |
+
# Apply MATSA pipeline
|
87 |
+
# formatted_table = matsa_agent.table_formatting_agent(instance.html_table)
|
88 |
+
augmented_table = matsa_agent.description_augmentation_agent(instance.html_table)
|
89 |
+
print("augmented_table: ", augmented_table)
|
90 |
+
fact_list = matsa_agent.answer_decomposition_agent(instance.answer)
|
91 |
+
print("fact_list: ", fact_list)
|
92 |
+
attributed_table, _, _ = matsa_agent.semantic_retreival_agent(augmented_table, fact_list)
|
93 |
+
print("attributed_table: ", attributed_table)
|
94 |
+
attribution_fxn = matsa_agent.sufficiency_attribution_agent(fact_list, attributed_table)
|
95 |
+
print("attribution_fxn: ", attribution_fxn)
|
96 |
+
|
97 |
+
# Get row and column attributions
|
98 |
+
row_attribution_set = attribution_fxn["Row Citations"]
|
99 |
+
col_attribution_set = attribution_fxn["Column Citations"]
|
100 |
+
explnation = attribution_fxn.get("Explanation", "")
|
101 |
+
print("row_attribution_set: ", row_attribution_set)
|
102 |
+
print("col_attribution_set: ", col_attribution_set)
|
103 |
+
print("Explanation: ", attribution_fxn.get("Explanation", ""))
|
104 |
+
|
105 |
+
# Convert string representations to lists
|
106 |
+
if isinstance(row_attribution_set, str):
|
107 |
+
row_ids = eval(row_attribution_set)
|
108 |
+
else:
|
109 |
+
row_ids = row_attribution_set
|
110 |
+
|
111 |
+
if isinstance(col_attribution_set, str):
|
112 |
+
col_ids = eval(col_attribution_set)
|
113 |
+
else:
|
114 |
+
col_ids = col_attribution_set
|
115 |
+
|
116 |
+
# Highlight the table
|
117 |
+
highlighted_table = highlight_table(instance.html_table, row_ids, col_ids)
|
118 |
+
|
119 |
+
result = {
|
120 |
+
"highlighted_table": highlighted_table,
|
121 |
+
"facts": attribution_fxn.get("List of Facts", []),
|
122 |
+
"row_citations": row_attribution_set,
|
123 |
+
"column_citations": col_attribution_set,
|
124 |
+
"Explanation": explnation
|
125 |
+
}
|
126 |
+
|
127 |
+
return json.dumps(result)
|
128 |
+
|
129 |
+
# Define Gradio interface
|
130 |
+
with gr.Blocks() as iface:
|
131 |
+
gr.Markdown("# MATSA: Table Question Answering with Attribution")
|
132 |
+
gr.Markdown("Select a table from dropdown load table image, question, and answer.")
|
133 |
+
gr.Markdown("Attributions are provided as per answer. You may change the question/answer as per your need.")
|
134 |
+
|
135 |
+
table_dropdown = gr.Dropdown(choices=get_table_names(), label="Select Table")
|
136 |
+
original_table = gr.Image(type="filepath", label="Original Table")
|
137 |
+
question_box = gr.Textbox(label="Question")
|
138 |
+
answer_box = gr.Textbox(label="Answer")
|
139 |
+
|
140 |
+
gr.Markdown("Click 'Process' to see the highlighted relevant parts. Click 'Reset' to start over.")
|
141 |
+
|
142 |
+
process_button = gr.Button("Process")
|
143 |
+
reset_button = gr.Button("Reset")
|
144 |
+
processing_time = gr.Textbox(label="Processing Time", value="0 seconds")
|
145 |
+
highlighted_table = gr.HTML(label="Highlighted Table")
|
146 |
+
explanation_box = gr.Textbox(label="Explanation")
|
147 |
+
|
148 |
+
def update_table_data(table_name):
|
149 |
+
image_path, question, answer = load_table_data(table_name)
|
150 |
+
return image_path, question, answer, gr.update(interactive=True)
|
151 |
+
|
152 |
+
def reset_app():
|
153 |
+
return (
|
154 |
+
gr.update(value="", interactive=True), # table_dropdown
|
155 |
+
None, # original_table
|
156 |
+
"", # question_box
|
157 |
+
"", # answer_box
|
158 |
+
"", # highlighted_table
|
159 |
+
"", # explanation_box
|
160 |
+
gr.update(interactive=True), # process_button
|
161 |
+
"0 seconds", # processing_time
|
162 |
+
)
|
163 |
+
|
164 |
+
def process_and_disable(table_name, question, answer):
|
165 |
+
processing = True
|
166 |
+
counter = 0
|
167 |
+
|
168 |
+
def update_counter():
|
169 |
+
nonlocal counter
|
170 |
+
while processing:
|
171 |
+
counter += 1
|
172 |
+
yield counter
|
173 |
+
time.sleep(1)
|
174 |
+
|
175 |
+
counter_thread = threading.Thread(target=update_counter)
|
176 |
+
counter_thread.start()
|
177 |
+
|
178 |
+
# Disable the dropdown and process button during processing
|
179 |
+
yield (
|
180 |
+
gr.update(interactive=False), # table_dropdown
|
181 |
+
gr.update(interactive=False), # process_button
|
182 |
+
gr.update(value="Processing..."), # processing_time
|
183 |
+
gr.update(), # highlighted_table
|
184 |
+
gr.update(), # explanation_box
|
185 |
+
)
|
186 |
+
|
187 |
+
# Process the input
|
188 |
+
result = process_input(table_name, question, answer)
|
189 |
+
result_dict = json.loads(result)
|
190 |
+
|
191 |
+
# Stop the counter
|
192 |
+
processing = False
|
193 |
+
counter_thread.join()
|
194 |
+
|
195 |
+
# Re-enable the dropdown and process button, update processing time, and return the result
|
196 |
+
yield (
|
197 |
+
gr.update(interactive=True), # table_dropdown
|
198 |
+
gr.update(interactive=True), # process_button
|
199 |
+
f"Processed in {counter} seconds", # processing_time
|
200 |
+
gr.update(value=result_dict['highlighted_table']), # highlighted_table
|
201 |
+
gr.update(value=result_dict.get('Explanation', '')), # explanation_box
|
202 |
+
{
|
203 |
+
"Row Citations": result_dict['row_citations'],
|
204 |
+
"Column Citations": result_dict['column_citations']
|
205 |
+
}
|
206 |
+
)
|
207 |
+
|
208 |
+
table_dropdown.change(update_table_data,
|
209 |
+
inputs=[table_dropdown],
|
210 |
+
outputs=[original_table, question_box, answer_box, process_button])
|
211 |
+
|
212 |
+
process_button.click(process_and_disable,
|
213 |
+
inputs=[table_dropdown, question_box, answer_box],
|
214 |
+
outputs=[table_dropdown, process_button, processing_time, highlighted_table, explanation_box])
|
215 |
+
|
216 |
+
reset_button.click(reset_app,
|
217 |
+
inputs=[],
|
218 |
+
outputs=[table_dropdown, original_table, question_box, answer_box, highlighted_table, explanation_box, process_button, processing_time])
|
219 |
+
|
220 |
+
# Launch the interface
|
221 |
+
iface.launch(share=True)
|
llm_query_api.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import copy
|
3 |
+
import io
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import time
|
7 |
+
import re
|
8 |
+
import json
|
9 |
+
import argparse
|
10 |
+
import yaml
|
11 |
+
import openai
|
12 |
+
from openai import AzureOpenAI
|
13 |
+
from prompts import *
|
14 |
+
import base64
|
15 |
+
from mimetypes import guess_type
|
16 |
+
# from img2table.document import Image
|
17 |
+
# from img2table.document import PDF
|
18 |
+
# from img2table.ocr import TesseractOCR
|
19 |
+
# from img2table.ocr import EasyOCR
|
20 |
+
from PIL import Image as PILImage
|
21 |
+
|
22 |
+
class LLMQueryAPI:
|
23 |
+
|
24 |
+
def __init__(self) -> None:
|
25 |
+
pass
|
26 |
+
|
27 |
+
def gpt4_chat_completion(self, query):
|
28 |
+
|
29 |
+
with open('config_gpt4.yaml', 'r') as f:
|
30 |
+
config = yaml.safe_load(f)
|
31 |
+
|
32 |
+
API_KEY = config.get('API_KEY')
|
33 |
+
API_VERSION = config.get('API_VERSION')
|
34 |
+
API_BASE = config.get('API_BASE')
|
35 |
+
|
36 |
+
client = AzureOpenAI(
|
37 |
+
azure_endpoint= API_BASE,
|
38 |
+
api_version= API_VERSION,
|
39 |
+
api_key = API_KEY
|
40 |
+
)
|
41 |
+
|
42 |
+
deployment_name='gpt-4-2024-04-09'
|
43 |
+
|
44 |
+
response = client.chat.completions.create(
|
45 |
+
model=deployment_name,
|
46 |
+
messages=query)
|
47 |
+
|
48 |
+
return response.choices[0].message.content
|
49 |
+
|
50 |
+
def gpt35_chat_completion(self, query):
|
51 |
+
|
52 |
+
with open('config_gpt35.yaml', 'r') as f:
|
53 |
+
config = yaml.safe_load(f)
|
54 |
+
|
55 |
+
API_KEY = config.get('API_KEY')
|
56 |
+
API_VERSION = config.get('API_VERSION')
|
57 |
+
API_BASE = config.get('API_BASE')
|
58 |
+
|
59 |
+
response = openai.ChatCompletion.create(
|
60 |
+
engine='gpt-35-turbo-0613',
|
61 |
+
messages=query,
|
62 |
+
request_timeout=60,
|
63 |
+
api_key = API_KEY,
|
64 |
+
api_version = API_VERSION,
|
65 |
+
api_type = "azure",
|
66 |
+
api_base = API_BASE,
|
67 |
+
)
|
68 |
+
|
69 |
+
return response['choices'][0]['message']
|
70 |
+
|
71 |
+
def copilot_chat_completion(self, query):
|
72 |
+
|
73 |
+
with open('config_gpt4.yaml', 'r') as f:
|
74 |
+
config = yaml.safe_load(f)
|
75 |
+
|
76 |
+
API_KEY = config.get('API_KEY')
|
77 |
+
API_VERSION = config.get('API_VERSION')
|
78 |
+
API_BASE = config.get('API_BASE')
|
79 |
+
|
80 |
+
response = openai.ChatCompletion.create(
|
81 |
+
engine='gpt-4-0613',
|
82 |
+
messages=query,
|
83 |
+
request_timeout=60,
|
84 |
+
api_key = API_KEY,
|
85 |
+
api_version = API_VERSION,
|
86 |
+
api_type = "azure",
|
87 |
+
api_base = API_BASE,
|
88 |
+
)
|
89 |
+
return response['choices'][0]['message']
|
90 |
+
|
91 |
+
def LLM_chat_query(self, query, llm):
|
92 |
+
|
93 |
+
if llm == 'gpt-3.5-turbo':
|
94 |
+
return self.gpt35_chat_completion(query)
|
95 |
+
elif llm == "gpt-4":
|
96 |
+
return self.gpt4_chat_completion(query)
|
97 |
+
# return self.copilot_chat_completion(query)
|
98 |
+
|
99 |
+
def get_llm_response(self, llm, query):
|
100 |
+
chat_completion = []
|
101 |
+
chat_completion.append({"role": "system", "content": query})
|
102 |
+
res = self.LLM_chat_query(chat_completion, llm)
|
103 |
+
return res
|
104 |
+
|
105 |
+
class LLMProxyQueryAPI:
|
106 |
+
|
107 |
+
def __init__(self) -> None:
|
108 |
+
pass
|
109 |
+
|
110 |
+
def gpt35_chat_completion(self, query):
|
111 |
+
client = openai.Client()
|
112 |
+
response = client.chat.completions.create(
|
113 |
+
model="gpt-3.5-turbo-16k",
|
114 |
+
messages=query,
|
115 |
+
)
|
116 |
+
return response.choices[0].message.content
|
117 |
+
|
118 |
+
def gpt4o_chat_completion(self, query):
|
119 |
+
client = openai.Client()
|
120 |
+
response = client.chat.completions.create(
|
121 |
+
model="gpt-4o",
|
122 |
+
messages=query,
|
123 |
+
)
|
124 |
+
return response.choices[0].message.content
|
125 |
+
|
126 |
+
def gpt4_chat_completion(self, query):
|
127 |
+
client = openai.Client()
|
128 |
+
response = client.chat.completions.create(
|
129 |
+
model="gpt-4-1106-preview",
|
130 |
+
messages=query,
|
131 |
+
)
|
132 |
+
return response.choices[0].message.content
|
133 |
+
|
134 |
+
def gpt4_vision(self, query, image_path):
|
135 |
+
|
136 |
+
print(query)
|
137 |
+
|
138 |
+
client = openai.Client()
|
139 |
+
response = client.chat.completions.create(
|
140 |
+
model="gpt-4-vision-preview",
|
141 |
+
messages=[
|
142 |
+
{
|
143 |
+
"role": "user",
|
144 |
+
"content": [
|
145 |
+
{
|
146 |
+
"type": "text",
|
147 |
+
"text": query
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"type": "image_url",
|
151 |
+
"image_url": {
|
152 |
+
"url": image_path
|
153 |
+
}
|
154 |
+
}
|
155 |
+
]
|
156 |
+
}
|
157 |
+
],
|
158 |
+
max_tokens=4096,
|
159 |
+
stream=False
|
160 |
+
)
|
161 |
+
return response.choices[0].message.content
|
162 |
+
|
163 |
+
def LLM_chat_query(self, llm, query, image_path=None):
|
164 |
+
|
165 |
+
if llm == 'gpt-3.5-turbo':
|
166 |
+
return self.gpt35_chat_completion(query)
|
167 |
+
|
168 |
+
elif llm == "gpt-4":
|
169 |
+
return self.gpt4_chat_completion(query)
|
170 |
+
|
171 |
+
elif llm == "gpt-4o":
|
172 |
+
return self.gpt4o_chat_completion(query)
|
173 |
+
|
174 |
+
elif llm == "gpt-4V":
|
175 |
+
return self.gpt4_vision(query, image_path)
|
176 |
+
|
177 |
+
def get_llm_response(self, llm, query, image_path=None):
|
178 |
+
|
179 |
+
if llm == "gpt-4V" and image_path:
|
180 |
+
res = self.LLM_chat_query(llm, query, image_path)
|
181 |
+
return res
|
182 |
+
|
183 |
+
chat_completion = []
|
184 |
+
chat_completion.append({"role": "system", "content": query})
|
185 |
+
res = self.LLM_chat_query(llm, chat_completion)
|
186 |
+
return res
|
187 |
+
|
188 |
+
# if __name__ == '__main__':
|
189 |
+
|
190 |
+
# llm_query_api = LLMProxyQueryAPI()
|
191 |
+
|
192 |
+
# def local_image_to_data_url(image_path):
|
193 |
+
# mime_type, _ = guess_type(image_path)
|
194 |
+
# if mime_type is None:
|
195 |
+
# mime_type = 'application/octet-stream'
|
196 |
+
|
197 |
+
# with open(image_path, "rb") as image_file:
|
198 |
+
# base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
|
199 |
+
|
200 |
+
# return f"data:{mime_type};base64,{base64_encoded_data}"
|
201 |
+
|
202 |
+
# tesseract = TesseractOCR()
|
203 |
+
# pdf = PDF(src="temp3.pdf", pages=[0, 0])
|
204 |
+
# extracted_tables = pdf.extract_tables(ocr=tesseract,
|
205 |
+
# implicit_rows=True,
|
206 |
+
# borderless_tables=True,)
|
207 |
+
# html_table = extracted_tables[0][0].html_repr()
|
208 |
+
# print(html_table)
|
209 |
+
|
210 |
+
# table_image_path = "./temp3.jpeg"
|
211 |
+
# table_image_data_url = local_image_to_data_url(table_image_path)
|
212 |
+
# print(table_image_data_url)
|
213 |
+
# query = table_image_to_html_prompt.replace("{{html_table}}", html_table)
|
214 |
+
# html_table_refined = llm_query_api.get_llm_response("gpt-4V", query, table_image_data_url)
|
215 |
+
# print(html_table_refined)
|
216 |
+
|
matsa.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import json
|
3 |
+
from prompts import *
|
4 |
+
import ast
|
5 |
+
from bs4 import BeautifulSoup
|
6 |
+
from semantic_retrieval import *
|
7 |
+
from llm_query_api import *
|
8 |
+
import base64
|
9 |
+
from mimetypes import guess_type
|
10 |
+
|
11 |
+
class InputInstance:
|
12 |
+
def __init__(self, id=None, html_table=None, question=None, answer=None):
|
13 |
+
self.id = id
|
14 |
+
self.html_table = html_table
|
15 |
+
self.question = question
|
16 |
+
self.answer = answer
|
17 |
+
|
18 |
+
return
|
19 |
+
|
20 |
+
class MATSA:
|
21 |
+
def __init__(self, llm = "gpt-4"):
|
22 |
+
self.llm = llm
|
23 |
+
self.llm_query_api = LLMQueryAPI() #LLMProxyQueryAPI()
|
24 |
+
pass
|
25 |
+
|
26 |
+
def table_formatting_agent(self, html_table = None, table_image_path = None):
|
27 |
+
|
28 |
+
def local_image_to_data_url(image_path):
|
29 |
+
mime_type, _ = guess_type(image_path)
|
30 |
+
if mime_type is None:
|
31 |
+
mime_type = 'application/octet-stream'
|
32 |
+
|
33 |
+
with open(image_path, "rb") as image_file:
|
34 |
+
base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
|
35 |
+
|
36 |
+
return f"data:{mime_type};base64,{base64_encoded_data}"
|
37 |
+
|
38 |
+
if table_image_path != None:
|
39 |
+
tesseract = TesseractOCR()
|
40 |
+
pdf = PDF(src=table_image_path, pages=[0, 0])
|
41 |
+
extracted_tables = pdf.extract_tables(ocr=tesseract,
|
42 |
+
implicit_rows=True,
|
43 |
+
borderless_tables=True,)
|
44 |
+
html_table = extracted_tables[0][0].html_repr()
|
45 |
+
|
46 |
+
table_image_data_url = local_image_to_data_url(table_image_path)
|
47 |
+
query = table_image_to_html_prompt.replace("{{html_table}}", html_table)
|
48 |
+
html_table = llm_query_api.get_llm_response("gpt-4V", query, table_image_data_url)
|
49 |
+
|
50 |
+
soup = BeautifulSoup(html_table, 'html.parser')
|
51 |
+
tr_tags = soup.find_all('tr')
|
52 |
+
for i, tr_tag in enumerate(tr_tags):
|
53 |
+
tr_tag['id'] = f"row-{i + 1}" # Assign unique ID using 'row-i' format
|
54 |
+
|
55 |
+
if i == 0:
|
56 |
+
th_tags = tr_tag.find_all('th')
|
57 |
+
for i, th_tag in enumerate(th_tags):
|
58 |
+
th_tag['id'] = f"col-{i + 1}" # Assign unique ID using 'col-i' format
|
59 |
+
|
60 |
+
return str(soup)
|
61 |
+
|
62 |
+
def description_augmentation_agent(self, html_table):
|
63 |
+
|
64 |
+
query = col_description_prompt.replace("{{html_table}}", str(html_table))
|
65 |
+
col_augmented_html_table = self.llm_query_api.get_llm_response(self.llm, query)
|
66 |
+
|
67 |
+
query = row_description_prompt.replace("{{html_table}}", str(col_augmented_html_table))
|
68 |
+
row_augmented_html_table = self.llm_query_api.get_llm_response(self.llm, query)
|
69 |
+
|
70 |
+
query = trend_description_prompt.replace("{{html_table}}", str(row_augmented_html_table))
|
71 |
+
trend_augmented_html_table = self.llm_query_api.get_llm_response(self.llm, query)
|
72 |
+
|
73 |
+
return trend_augmented_html_table
|
74 |
+
|
75 |
+
def answer_decomposition_agent(self, answer):
|
76 |
+
|
77 |
+
prompt = answer_decomposition_prompt
|
78 |
+
query = prompt.replace("{{answer}}", answer)
|
79 |
+
res = self.llm_query_api.get_llm_response(self.llm, query)
|
80 |
+
res = ast.literal_eval(res)
|
81 |
+
if isinstance(res, list):
|
82 |
+
return res
|
83 |
+
else:
|
84 |
+
return None
|
85 |
+
|
86 |
+
def semantic_retreival_agent(self, html_table, fact_list, topK=5):
|
87 |
+
|
88 |
+
attributed_html_table, row_attribution_ids, col_attribution_ids = get_embedding_attribution(html_table, fact_list, topK)
|
89 |
+
return attributed_html_table, row_attribution_ids, col_attribution_ids
|
90 |
+
|
91 |
+
def sufficiency_attribution_agent(self, fact_list, attributed_html_table):
|
92 |
+
|
93 |
+
fact_verification_function = {}
|
94 |
+
|
95 |
+
fact_verification_list = []
|
96 |
+
|
97 |
+
for i in range(len(fact_list)):
|
98 |
+
fact=fact_list[i]
|
99 |
+
fxn = {}
|
100 |
+
fxn["Fact " + str(i+1)+":"] = str(fact)
|
101 |
+
# fxn["Verified"] = "..."
|
102 |
+
fact_verification_list.append(fxn)
|
103 |
+
|
104 |
+
fact_verification_function["List of Fact"] = fact_verification_list
|
105 |
+
|
106 |
+
fact_verification_function["Row Citations"] = "[..., ..., ...]"
|
107 |
+
fact_verification_function["Column Citations"] = "[..., ..., ...]"
|
108 |
+
fact_verification_function["Explanation"] = "..."
|
109 |
+
|
110 |
+
fact_verification_function_string = json.dumps(fact_verification_function)
|
111 |
+
|
112 |
+
query = functional_attribution_prompt.replace("{{attributed_html_table}}", str(attributed_html_table)).replace("{{fact_verification_function}}", fact_verification_function_string)
|
113 |
+
attribution_fxn = self.llm_query_api.get_llm_response(self.llm, query)
|
114 |
+
|
115 |
+
attribution_fxn = attribution_fxn.replace("```json", "")
|
116 |
+
attribution_fxn = attribution_fxn.replace("```", "")
|
117 |
+
print(attribution_fxn)
|
118 |
+
attribution_fxn = json.loads(attribution_fxn)
|
119 |
+
|
120 |
+
if isinstance(attribution_fxn, dict):
|
121 |
+
return attribution_fxn
|
122 |
+
else:
|
123 |
+
return None
|
124 |
+
|
125 |
+
if __name__ == '__main__':
|
126 |
+
|
127 |
+
html_table = """<table>
|
128 |
+
<tr>
|
129 |
+
<th rowspan="1">Sr. Number</th>
|
130 |
+
<th colspan="3">Types</th>
|
131 |
+
<th rowspan="1">Remark</th>
|
132 |
+
</tr>
|
133 |
+
<tr>
|
134 |
+
<th> </th>
|
135 |
+
<th>A</th>
|
136 |
+
<th>B</th>
|
137 |
+
<th>C</th>
|
138 |
+
<th> </th>
|
139 |
+
</tr>
|
140 |
+
<tr>
|
141 |
+
<td>1</td>
|
142 |
+
<td>Mitten</td>
|
143 |
+
<td>Kity</td>
|
144 |
+
<td>Teddy</td>
|
145 |
+
<td>Names of cats</td>
|
146 |
+
</tr>
|
147 |
+
<tr>
|
148 |
+
<td>1</td>
|
149 |
+
<td>Tommy</td>
|
150 |
+
<td>Rudolph</td>
|
151 |
+
<td>Jerry</td>
|
152 |
+
<td>Names of dogs</td>
|
153 |
+
</tr>
|
154 |
+
</table>"""
|
155 |
+
|
156 |
+
answer = "Tommy is a dog but Mitten is a cat."
|
157 |
+
|
158 |
+
|
159 |
+
x = InputInstance(html_table=html_table, answer=answer)
|
160 |
+
|
161 |
+
matsa_agent = MATSA()
|
162 |
+
|
163 |
+
x_reformulated = matsa_agent.table_formatting_agent(x.html_table)
|
164 |
+
print(x_reformulated)
|
165 |
+
|
166 |
+
x_descriptions = matsa_agent.description_augmentation_agent(x_reformulated)
|
167 |
+
print(x_descriptions)
|
168 |
+
|
169 |
+
fact_list = matsa_agent.answer_decomposition_agent(x.answer)
|
170 |
+
print(fact_list)
|
171 |
+
|
172 |
+
attributed_html_table, row_attribution_ids, col_attribution_ids = matsa_agent.semantic_retreival_agent(x_descriptions, fact_list)
|
173 |
+
print(attributed_html_table)
|
174 |
+
|
175 |
+
attribution_fxn = matsa_agent.sufficiency_attribution_agent(fact_list, attributed_html_table)
|
176 |
+
print(attribution_fxn)
|
177 |
+
|
178 |
+
row_attribution_set = attribution_fxn["Row Citations"]
|
179 |
+
col_attribution_set = attribution_fxn["Column Citations"]
|
180 |
+
|
181 |
+
print(row_attribution_set)
|
182 |
+
print(col_attribution_set)
|
prompts.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
row_description_prompt = """You are a brilliant table assistant with the capabilities information retrieval, table parsing, and semantic understanding of the structural information of the table.
|
2 |
+
|
3 |
+
Here is a table in html format: \"{{html_table}}\"
|
4 |
+
|
5 |
+
We need to add a detailed description for each row of the table denoted by the "<tr>" tag element. The description should discuss the overall information inferred from the row. It should also mention all the elements, numbers, and figures present in the row. Also, include any hierarchical row/column header information.
|
6 |
+
|
7 |
+
Add this description as a "description" attribute in the <tr> tag. Repeat this process for ALL <tr> tags in the provided HTML. Do NOT delete any information that already exists in the other tags. Just print the html and do NOT output any other message. The HTML structure should NOT be changed.
|
8 |
+
"""
|
9 |
+
|
10 |
+
col_description_prompt = """You are a brilliant table assistant with the capabilities information retrieval, table parsing, and semantic understanding of the structural information of the table.
|
11 |
+
|
12 |
+
Here is a table in html format: \"{{html_table}}\"
|
13 |
+
|
14 |
+
First, write a caption for the entire table that captures the general information being presented by all rows and columns in this table. Add it using the <caption> tag in the HTML.
|
15 |
+
|
16 |
+
Next, need to add a detailed description for each column of the table denoted by the "<th>" tag element. The description should discuss what elements are present in the column and the overall information that can be inferred by the column, including any hierarchical column header information.
|
17 |
+
|
18 |
+
Add this description as a "description" in the <th> tag. Repeat this process for ALL <th> tags in the provided HTML. Do NOT delete any information that already exists in the other tags. Just print the html and do NOT output any other message. The HTML structure should NOT be changed.
|
19 |
+
"""
|
20 |
+
|
21 |
+
trend_description_prompt = """You are a brilliant table assistant with the capabilities information retrieval, table parsing, semantic understanding, and trend analysis of the structural information of the table.
|
22 |
+
|
23 |
+
Here is a table in html format: \"{{html_table}}\".
|
24 |
+
|
25 |
+
We need to add a trend analysis on the elements in the given row compared to its own constituent cells and other rows in the table. The description should discuss semantic descriptions of numerical data, summarizing key quantitative characteristics and tendencies across the table row and across different columns.
|
26 |
+
|
27 |
+
Add this analysis in the \"description\" of the <tr> tag. Repeat this process for ALL <tr> tags in the provided HTML. Do NOT delete any information that already exists in the other tags. Just print the html and do NOT output any other message. The HTML strcuture should NOT be changed.
|
28 |
+
"""
|
29 |
+
|
30 |
+
functional_attribution_prompt = """You are a brilliant assistant with the capabilities information retrieval, fact checking, and semantic understanding of tabular data.
|
31 |
+
|
32 |
+
Here is the html table - \"{{attributed_html_table}}\"
|
33 |
+
|
34 |
+
We have a list of facts pertaining to this table present in this JSON structure - \"{{fact_verification_function}}"\.
|
35 |
+
The JSON structure contains three empty fields - "Row Citations", "Column Citations", and "Explanation" that need to filled with relevant information.
|
36 |
+
|
37 |
+
We want to identify all the ROWS in the table that are important to support these facts. In other words, which rows are needed to collectively verify the facts. Please copy the "row-id" of all relevant table rows in the "Row Citation" field of the JSON structure. All rows IDs should be added in the form of a LIST "[... , ... , ...]" with no value repeated. Here is a sample: "Row Citations": ["row-2", "row-3"].
|
38 |
+
|
39 |
+
Similar to rows, we want to identify all the COLUMNS in the table that are important to support these facts. In other words, which columns are needed to collectively verify the facts. Please copy the "col-id" of all relevant table rows in the "Column Citation" field of the JSON structure. All column IDs should be added in the form of a LIST "[... , ... , ...]" with no value repeated. Here is a sample: "Column Citations": ["col-1", "col-5", "col-7"]
|
40 |
+
|
41 |
+
"Explanation" field should contain a detailed explanation of how the rows and columns identified in the "Row Citations" and "Column Citations" fields respectively, are important to verify the facts present in the JSON structure. The explanation should be coherent and provide a clear rationale for the selection of rows and columns.
|
42 |
+
|
43 |
+
The final result should be a complete JSON structure ONLY. Do not print any extra information. Make sure to fill the JSON structure accurately as provided in the prompt. 'Row citations', 'Column Citations', and 'Explanation' should not be empty.
|
44 |
+
"""
|
45 |
+
|
46 |
+
|
47 |
+
answer_decomposition_prompt = """
|
48 |
+
|
49 |
+
Here is a passage: \"{{answer}}\"
|
50 |
+
|
51 |
+
Convert the given passage into a list of short facts which specifically answer the given question.
|
52 |
+
Make sure that the facts can be found in the given passage.
|
53 |
+
The facts should be coherent and succinct sentences with clear and simple syntax.
|
54 |
+
Do not use pronouns as the subject or object in the syntax of each fact.
|
55 |
+
The facts should be independent to each other.
|
56 |
+
Do not create facts from the passage which are not answering the given question.
|
57 |
+
ONLY return a python LIST of strings seperated by comma (,). Do NOT output any extra explanation.
|
58 |
+
"""
|
59 |
+
|
60 |
+
table_image_to_html_prompt = """ Here is an image of a table. Please convert this table image into a HTML representation with accurate table cell data.
|
61 |
+
In order to help you in this process, here is a noisy HTML representation of the table extracted from the image: \"{{html_table}}\"
|
62 |
+
You may use this as a noisy reference and further refine the HTML structure to accurately represent the table data.
|
63 |
+
You should also add any information pertaining to row and column spans. In case of nested rows/columns with multiple spans, take be very careful to leave blank cells to ensure the semantic structure is maintained.
|
64 |
+
Be careful to handle hierarchical and nested rows/columns.
|
65 |
+
Each HTML should start with "<table>" opening tag.
|
66 |
+
Each HTML should end with "</table>" closing tag.
|
67 |
+
Do NOT output any other explanation or text apart from the HTML code of the table.
|
68 |
+
"""
|
69 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_f56ag8l7kr/croot/aiofiles_1683773599608/work
|
2 |
+
altair @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a8x4081_4h/croot/altair_1687526044471/work
|
3 |
+
annotated-types @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_1fa2djihwb/croot/annotated-types_1709542925772/work
|
4 |
+
anyio @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a17a7759g2/croot/anyio_1706220182417/work
|
5 |
+
argcomplete==3.4.0
|
6 |
+
attrs @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_224434dqzl/croot/attrs_1695717839274/work
|
7 |
+
beautifulsoup4==4.12.3
|
8 |
+
Bottleneck @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_2bxpizxa3c/croot/bottleneck_1707864819812/work
|
9 |
+
Brotli @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_27zk0eqdh0/croot/brotli-split_1714483157007/work
|
10 |
+
certifi @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_70oty9s9jh/croot/certifi_1720453497032/work/certifi
|
11 |
+
charset-normalizer @ file:///croot/charset-normalizer_1721748349566/work
|
12 |
+
click==8.1.7
|
13 |
+
colorama @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_100_k35lkb/croot/colorama_1672386539781/work
|
14 |
+
contourpy @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_041uwyxdzo/croot/contourpy_1700583585236/work
|
15 |
+
cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work
|
16 |
+
distro==1.9.0
|
17 |
+
exceptiongroup @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_b2258scr33/croot/exceptiongroup_1706031391815/work
|
18 |
+
fastapi @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_ab51gv9t42/croot/fastapi_1693295410365/work
|
19 |
+
ffmpy @ file:///home/conda/feedstock_root/build_artifacts/ffmpy_1659474992694/work
|
20 |
+
filelock @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d3quwmvouf/croot/filelock_1700591194006/work
|
21 |
+
fonttools @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_60c8ux4mkl/croot/fonttools_1713551354374/work
|
22 |
+
fsspec @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_19mkn689lo/croot/fsspec_1714461553219/work
|
23 |
+
gradio @ file:///home/conda/feedstock_root/build_artifacts/gradio_1721770785660/work
|
24 |
+
gradio_client @ file:///home/conda/feedstock_root/build_artifacts/gradio-client_1721697178984/work
|
25 |
+
h11 @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_110bmw2coo/croot/h11_1706652289620/work
|
26 |
+
httpcore @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_fcxiho9nv7/croot/httpcore_1706728465004/work
|
27 |
+
httpx @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_727e6zfsxn/croot/httpx_1706887102687/work
|
28 |
+
huggingface-hub==0.24.2
|
29 |
+
idna @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a12xpo84t2/croot/idna_1714398852854/work
|
30 |
+
img2table==1.2.11
|
31 |
+
imgkit==1.2.3
|
32 |
+
importlib_resources @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_784gdd8gq5/croot/importlib_resources-suite_1720641109833/work
|
33 |
+
Jinja2 @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_44yzu12j7f/croot/jinja2_1716993410427/work
|
34 |
+
joblib==1.4.2
|
35 |
+
jsonschema @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_27o3go8sqa/croot/jsonschema_1699041627313/work
|
36 |
+
jsonschema-specifications @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_d38pclgu95/croot/jsonschema-specifications_1699032390832/work
|
37 |
+
kiwisolver @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_93o8te804v/croot/kiwisolver_1672387163224/work
|
38 |
+
llvmlite==0.43.0
|
39 |
+
markdown-it-py @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_43l_4ajkho/croot/markdown-it-py_1684279912406/work
|
40 |
+
MarkupSafe @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a84ni4pci8/croot/markupsafe_1704206002077/work
|
41 |
+
matplotlib @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a8crvoz7ca/croot/matplotlib-suite_1713336381679/work
|
42 |
+
mdurl @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_0a8xm6w4wv/croots/recipe/mdurl_1659716035810/work
|
43 |
+
mpmath==1.3.0
|
44 |
+
networkx==3.3
|
45 |
+
numba==0.60.0
|
46 |
+
numexpr @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_45yefq0kt6/croot/numexpr_1696515289183/work
|
47 |
+
numpy @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a51i_mbs7m/croot/numpy_and_numpy_base_1708638620867/work/dist/numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl#sha256=c4b11b3c4d4fdb810039503fe01f311ade06cd1d675fcd6d208800a393f19b69
|
48 |
+
openai==1.37.0
|
49 |
+
opencv-contrib-python==4.10.0.84
|
50 |
+
orjson @ file:///var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_910h7wyrw8/croot/orjson_1711143065818/work/target/wheels/orjson-3.9.15-cp310-cp310-macosx_11_0_arm64.whl#sha256=fb0784f650f58e15827dd32f58284d983bb401ec3b85b38321eea14ebd2d01e9
|
51 |
+
packaging==24.1
|
52 |
+
pandas @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_b53hgou29t/croot/pandas_1718308972393/work/dist/pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl#sha256=7e70989ec6c6e08f2fd87d3940d106704e9feb936002297690d7462e35b5da35
|
53 |
+
pdfkit==1.0.0
|
54 |
+
pillow @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_617xe_y58w/croot/pillow_1721059447446/work
|
55 |
+
pipx==1.6.0
|
56 |
+
platformdirs==4.2.2
|
57 |
+
polars==1.2.1
|
58 |
+
pyarrow==17.0.0
|
59 |
+
pydantic @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_0ai8cvgm2c/croot/pydantic_1709577986211/work
|
60 |
+
pydantic_core @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_06smitnu98/croot/pydantic-core_1709573985903/work
|
61 |
+
pydub @ file:///home/conda/feedstock_root/build_artifacts/pydub_1615612442567/work
|
62 |
+
Pygments @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_29bs9f_dh9/croot/pygments_1684279974747/work
|
63 |
+
PyMuPDF==1.24.9
|
64 |
+
PyMuPDFb==1.24.9
|
65 |
+
pyparsing @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_3b_3vxnd07/croots/recipe/pyparsing_1661452540919/work
|
66 |
+
PySocks @ file:///Users/ktietz/ci_310/pysocks_1643961536721/work
|
67 |
+
python-dateutil @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_66ud1l42_h/croot/python-dateutil_1716495741162/work
|
68 |
+
python-multipart @ file:///home/conda/feedstock_root/build_artifacts/python-multipart_1707760088566/work
|
69 |
+
pytz @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a4b76c83ik/croot/pytz_1713974318928/work
|
70 |
+
PyYAML @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a8_sdgulmz/croot/pyyaml_1698096054705/work
|
71 |
+
referencing @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_5cz64gsx70/croot/referencing_1699012046031/work
|
72 |
+
regex==2024.7.24
|
73 |
+
requests @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_70sm12ba9w/croot/requests_1721414707360/work
|
74 |
+
rich @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_bcb8hqb_r4/croot/rich_1720637498249/work
|
75 |
+
rpds-py @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_f8jkozoefm/croot/rpds-py_1698945944860/work
|
76 |
+
ruff @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e70v1wc31c/croot/ruff_1713372744076/work
|
77 |
+
safetensors==0.4.3
|
78 |
+
scikit-learn==1.5.1
|
79 |
+
scipy==1.14.0
|
80 |
+
semantic-version @ file:///tmp/build/80754af9/semantic_version_1613321057691/work
|
81 |
+
sentence-transformers==3.0.1
|
82 |
+
shellingham @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_9bf5wowles/croot/shellingham_1669142181600/work
|
83 |
+
six @ file:///tmp/build/80754af9/six_1644875935023/work
|
84 |
+
sniffio @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_1573pknjrg/croot/sniffio_1705431298885/work
|
85 |
+
soupsieve==2.5
|
86 |
+
starlette @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_7feeoam7sd/croot/starlette-recipe_1692980426894/work
|
87 |
+
sympy==1.13.1
|
88 |
+
threadpoolctl==3.5.0
|
89 |
+
tokenizers==0.19.1
|
90 |
+
tomli==2.0.1
|
91 |
+
tomlkit @ file:///home/conda/feedstock_root/build_artifacts/tomlkit_1690458286251/work
|
92 |
+
toolz @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_362wyqvvgy/croot/toolz_1667464079070/work
|
93 |
+
torch==2.4.0
|
94 |
+
tqdm @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_f76_dxtcsh/croot/tqdm_1716395948224/work
|
95 |
+
transformers==4.43.2
|
96 |
+
typer @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_e62p5mo8z5/croot/typer_1684251930377/work
|
97 |
+
typing_extensions @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_93dg13ilv4/croot/typing_extensions_1715268840722/work
|
98 |
+
tzdata @ file:///croot/python-tzdata_1690578112552/work
|
99 |
+
unicodedata2 @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a3epjto7gs/croot/unicodedata2_1713212955584/work
|
100 |
+
urllib3 @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_cao7_u9937/croot/urllib3_1718912649114/work
|
101 |
+
userpath==1.9.2
|
102 |
+
uvicorn @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_75zglvwlmy/croot/uvicorn-split_1678090090396/work
|
103 |
+
websockets @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d8ij4gljy1/croot/websockets_1678966799107/work
|
104 |
+
XlsxWriter==3.2.0
|
semantic_retrieval.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
from sklearn.preprocessing import minmax_scale
|
4 |
+
from sentence_transformers import SentenceTransformer, util
|
5 |
+
import numpy as np
|
6 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
7 |
+
sbert = SentenceTransformer("all-MiniLM-L6-v2")
|
8 |
+
from llm_query_api import *
|
9 |
+
|
10 |
+
def get_row_embedding(html_table):
|
11 |
+
|
12 |
+
def get_row_elements(html_table):
|
13 |
+
tr_elements = []
|
14 |
+
soup = BeautifulSoup(html_table, 'html.parser')
|
15 |
+
tr_tags = soup.find_all('tr')
|
16 |
+
for t in tr_tags:
|
17 |
+
temp = " " + str(t.get('description'))
|
18 |
+
try:
|
19 |
+
tr_elements.append({'id':str(t.get('id')), 'text': temp})
|
20 |
+
except:
|
21 |
+
pass
|
22 |
+
return tr_elements
|
23 |
+
|
24 |
+
rows = get_row_elements(html_table)
|
25 |
+
|
26 |
+
all_elements = rows
|
27 |
+
sentences = []
|
28 |
+
element_ids = []
|
29 |
+
for i in range(len(all_elements)):
|
30 |
+
sentences.append(all_elements[i]['text'])
|
31 |
+
element_ids.append(all_elements[i]['id'])
|
32 |
+
|
33 |
+
embeddings = sbert.encode(sentences, convert_to_tensor=True).cpu().numpy()
|
34 |
+
return embeddings, element_ids
|
35 |
+
|
36 |
+
def get_col_embedding(html_table):
|
37 |
+
|
38 |
+
def get_column_elements(html_table):
|
39 |
+
th_elements = []
|
40 |
+
soup = BeautifulSoup(html_table, 'html.parser')
|
41 |
+
th_tags = soup.find_all('th')
|
42 |
+
for t in th_tags:
|
43 |
+
temp = " " + str(t.get('description'))
|
44 |
+
try:
|
45 |
+
th_elements.append({'id':str(t.get('id')), 'text': temp})
|
46 |
+
except:
|
47 |
+
pass
|
48 |
+
|
49 |
+
return th_elements
|
50 |
+
|
51 |
+
cols = get_column_elements(html_table)
|
52 |
+
|
53 |
+
all_elements = cols
|
54 |
+
sentences = []
|
55 |
+
element_ids = []
|
56 |
+
for i in range(len(all_elements)):
|
57 |
+
sentences.append(all_elements[i]['text'])
|
58 |
+
element_ids.append(all_elements[i]['id'])
|
59 |
+
|
60 |
+
embeddings = sbert.encode(sentences, convert_to_tensor=True).cpu().numpy()
|
61 |
+
return embeddings, element_ids
|
62 |
+
|
63 |
+
def normalize_list_numpy(list_numpy):
|
64 |
+
normalized_list = minmax_scale(list_numpy)
|
65 |
+
return normalized_list
|
66 |
+
|
67 |
+
def get_answer_embedding(answer):
|
68 |
+
return sbert.encode([answer], convert_to_tensor=True).cpu().numpy()
|
69 |
+
|
70 |
+
def row_attribution(answer, html_table, topk=5, threshold = 0.7):
|
71 |
+
|
72 |
+
answer_embedding = get_answer_embedding(answer)
|
73 |
+
row_embedding = get_row_embedding(html_table)
|
74 |
+
|
75 |
+
similarities = cosine_similarity(row_embedding[0], answer_embedding.reshape(1, -1))
|
76 |
+
sims = similarities.flatten()
|
77 |
+
sims = normalize_list_numpy(sims)
|
78 |
+
#if no of rows >= 5, take max of (5, 1/3 x rows)
|
79 |
+
#else if no of rows < 5, take least of (5, rows)
|
80 |
+
k = max(topk, int(0.3*len(sims)))
|
81 |
+
k = min(k, len(sims))
|
82 |
+
top_k_indices = np.argpartition(sims, -k)[-k:]
|
83 |
+
sorted_indices = top_k_indices[np.argsort(sims[top_k_indices])][::-1]
|
84 |
+
top_k_results = [row_embedding[1][idx] for idx in sorted_indices]
|
85 |
+
|
86 |
+
return top_k_results
|
87 |
+
|
88 |
+
def col_attribution(answer, html_table, topk=5, threshold = 0.7):
|
89 |
+
|
90 |
+
answer_embedding = get_answer_embedding(answer)
|
91 |
+
col_embedding = get_col_embedding(html_table)
|
92 |
+
|
93 |
+
similarities = cosine_similarity(col_embedding[0], answer_embedding.reshape(1, -1))
|
94 |
+
sims = similarities.flatten()
|
95 |
+
sims = normalize_list_numpy(sims)
|
96 |
+
#if no of cols >= 5, take max of (5, 1/3 x cols)
|
97 |
+
#else if no of cols < 5, take least of (5, cols)
|
98 |
+
k = max(topk, int(0.3*len(sims)))
|
99 |
+
k = min(k, len(sims))
|
100 |
+
top_k_indices = np.argpartition(sims, -k)[-k:]
|
101 |
+
sorted_indices = top_k_indices[np.argsort(sims[top_k_indices])][::-1]
|
102 |
+
top_k_results = [col_embedding[1][idx] for idx in sorted_indices]
|
103 |
+
|
104 |
+
return top_k_results
|
105 |
+
|
106 |
+
def retain_rows_and_columns(augmented_html_table, row_ids, column_ids):
|
107 |
+
soup = BeautifulSoup(augmented_html_table, 'html.parser')
|
108 |
+
|
109 |
+
row_ids = list(set(row_ids))
|
110 |
+
column_ids = list(set(column_ids))
|
111 |
+
|
112 |
+
# Retain specified rows and remove others
|
113 |
+
all_rows = soup.find_all('tr')
|
114 |
+
for row in all_rows:
|
115 |
+
if row.get('id') not in row_ids:
|
116 |
+
row.decompose()
|
117 |
+
|
118 |
+
# Retain specified columns and remove others
|
119 |
+
if all_rows:
|
120 |
+
all_columns = all_rows[0].find_all(['th'])
|
121 |
+
for i, col in enumerate(all_columns):
|
122 |
+
if col.get('id') not in column_ids:
|
123 |
+
for row in soup.find_all('tr'):
|
124 |
+
cells = row.find_all(['td', 'th'])
|
125 |
+
if len(cells) > i:
|
126 |
+
cells[i].decompose()
|
127 |
+
|
128 |
+
return str(soup)
|
129 |
+
|
130 |
+
def get_embedding_attribution(augmented_html_table, decomposed_fact_list, topK=5, threshold = 0.7):
|
131 |
+
|
132 |
+
row_attribution_ids = []
|
133 |
+
col_attribution_ids = []
|
134 |
+
|
135 |
+
for i in range(len(decomposed_fact_list)):
|
136 |
+
answer = decomposed_fact_list[i]
|
137 |
+
|
138 |
+
rorAttr = row_attribution(answer, augmented_html_table, topK)
|
139 |
+
colAttr = col_attribution(answer, augmented_html_table, topK)
|
140 |
+
|
141 |
+
row_attribution_ids.extend(rorAttr)
|
142 |
+
col_attribution_ids.extend(colAttr)
|
143 |
+
|
144 |
+
attributed_html_table = retain_rows_and_columns(augmented_html_table, row_attribution_ids, col_attribution_ids)
|
145 |
+
|
146 |
+
return attributed_html_table, row_attribution_ids, col_attribution_ids
|
tables_folder/MATSA_aitqa.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tables_folder/MATSA_fetaqa.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:365905f7696b0c00498f7b88cf602769677998c5b32614823cb86127ccfd13f9
|
3 |
+
size 18562195
|
wkhtmltox_0.12.6-1.bionic_amd64.deb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:503a8a97fcf8fd397ed52c1789471e0f2513f5752f3e214d3a5eda30caa0354b
|
3 |
+
size 15729530
|