puneetm commited on
Commit
35d31f5
·
verified ·
1 Parent(s): 7309f50

Upload folder using huggingface_hub

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tables_folder/MATSA_fetaqa.json filter=lfs diff=lfs merge=lfs -text
37
+ wkhtmltox_0.12.6-1.bionic_amd64.deb filter=lfs diff=lfs merge=lfs -text
.github/workflows/update_space.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run Python script
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout
14
+ uses: actions/checkout@v2
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v2
18
+ with:
19
+ python-version: '3.9'
20
+
21
+ - name: Install Gradio
22
+ run: python -m pip install gradio
23
+
24
+ - name: Log in to Hugging Face
25
+ run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
26
+
27
+ - name: Deploy to Spaces
28
+ run: gradio deploy
.ipynb_checkpoints/config_gpt35-checkpoint.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ API_BASE: https://dil-research.openai.azure.com/
2
+ API_KEY: de6e495251174ceab84f290cd3925b07
3
+ API_VERSION: '2023-05-15'
.ipynb_checkpoints/config_gpt4-checkpoint.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ API_BASE: https://dil-research-sweden-central.openai.azure.com/
2
+ API_KEY: 0df7237e0fc243a8830a23b2dbba2dcb
3
+ API_VERSION: "2023-12-01-preview"
.ipynb_checkpoints/demo-checkpoint.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ from bs4 import BeautifulSoup
4
+ from matsa import MATSA, InputInstance
5
+ import imgkit
6
+ import tempfile
7
+ import time
8
+ import threading
9
+
10
+ TABLE_FOLDER = "./tables_folder/MATSA_fetaqa.json"
11
+ # Load data from JSON file
12
+ def load_data():
13
+ with open(TABLE_FOLDER, 'r') as json_file:
14
+ return json.load(json_file)
15
+
16
+ # Global variable to store the loaded data
17
+ TABLE_DATA = load_data()
18
+
19
+ def get_table_names():
20
+ return [f"tab_{i+1}" for i in range(len(TABLE_DATA))]
21
+
22
+ def html_to_image(html_content):
23
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_img:
24
+ imgkit.from_string(html_content, temp_img.name)
25
+ return temp_img.name
26
+
27
+ def highlight_table(html_table, row_ids, col_ids):
28
+ soup = BeautifulSoup(html_table, 'html.parser')
29
+ row_sel = []
30
+
31
+ # Highlight rows
32
+ for row_id in row_ids:
33
+ row = soup.find('tr', id=row_id)
34
+ if row:
35
+ row_sel.append(row)
36
+
37
+ for col_id in col_ids:
38
+ col_index = int(col_id.split('-')[1]) - 1 # Convert col-1 to index 0, col-2 to index 1, etc.
39
+ for row in soup.find_all('tr'):
40
+ cells = row.find_all(['td', 'th'])
41
+ if row in row_sel:
42
+ if col_index < len(cells):
43
+ cells[col_index]['style'] = 'background-color: rgba(173, 216, 230, 0.7);'
44
+ else:
45
+ if col_index < len(cells):
46
+ cells[col_index]['style'] = 'background-color: rgba(211, 211, 211, 0.6);'
47
+
48
+ return str(soup)
49
+
50
+ def load_table_data(table_name):
51
+ if not table_name:
52
+ return None, "", ""
53
+ index = int(table_name.split('_')[1]) - 1
54
+ data = TABLE_DATA[index]
55
+
56
+ html_content = data['html_table']
57
+ question = data.get("question", "") #data['question']
58
+ if question is None:
59
+ question = ""
60
+ answer = data['answer_statement']
61
+
62
+ image_path = html_to_image(html_content)
63
+
64
+ return image_path, question, answer
65
+
66
+ def process_input(table_name, question, answer):
67
+ if not table_name:
68
+ return "Please select a table from the dropdown."
69
+
70
+ # Get the data for the selected table
71
+ index = int(table_name.split('_')[1]) - 1
72
+ data = TABLE_DATA[index]
73
+
74
+ html_content = data['html_table']
75
+
76
+ print("html_content: ", html_content)
77
+ print("question: ", question)
78
+ print("answer: ", answer)
79
+
80
+ # Initialize MATSA
81
+ matsa_agent = MATSA()
82
+
83
+ # Create input instance
84
+ instance = InputInstance(html_table=html_content, question=question, answer=answer)
85
+
86
+ # Apply MATSA pipeline
87
+ # formatted_table = matsa_agent.table_formatting_agent(instance.html_table)
88
+ augmented_table = matsa_agent.description_augmentation_agent(instance.html_table)
89
+ print("augmented_table: ", augmented_table)
90
+ fact_list = matsa_agent.answer_decomposition_agent(instance.answer)
91
+ print("fact_list: ", fact_list)
92
+ attributed_table, _, _ = matsa_agent.semantic_retreival_agent(augmented_table, fact_list)
93
+ print("attributed_table: ", attributed_table)
94
+ attribution_fxn = matsa_agent.sufficiency_attribution_agent(fact_list, attributed_table)
95
+ print("attribution_fxn: ", attribution_fxn)
96
+
97
+ # Get row and column attributions
98
+ row_attribution_set = attribution_fxn["Row Citations"]
99
+ col_attribution_set = attribution_fxn["Column Citations"]
100
+ explnation = attribution_fxn.get("Explanation", "")
101
+ print("row_attribution_set: ", row_attribution_set)
102
+ print("col_attribution_set: ", col_attribution_set)
103
+ print("Explanation: ", attribution_fxn.get("Explanation", ""))
104
+
105
+ # Convert string representations to lists
106
+ if isinstance(row_attribution_set, str):
107
+ row_ids = eval(row_attribution_set)
108
+ else:
109
+ row_ids = row_attribution_set
110
+
111
+ if isinstance(col_attribution_set, str):
112
+ col_ids = eval(col_attribution_set)
113
+ else:
114
+ col_ids = col_attribution_set
115
+
116
+ # Highlight the table
117
+ highlighted_table = highlight_table(instance.html_table, row_ids, col_ids)
118
+
119
+ result = {
120
+ "highlighted_table": highlighted_table,
121
+ "facts": attribution_fxn.get("List of Facts", []),
122
+ "row_citations": row_attribution_set,
123
+ "column_citations": col_attribution_set,
124
+ "Explanation": explnation
125
+ }
126
+
127
+ return json.dumps(result)
128
+
129
+ # Define Gradio interface
130
+ with gr.Blocks() as iface:
131
+ gr.Markdown("# MATSA: Table Question Answering with Attribution")
132
+ gr.Markdown("Select a table from dropdown load table image, question, and answer.")
133
+ gr.Markdown("Attributions are provided as per answer. You may change the question/answer as per your need.")
134
+
135
+ table_dropdown = gr.Dropdown(choices=get_table_names(), label="Select Table")
136
+ original_table = gr.Image(type="filepath", label="Original Table")
137
+ question_box = gr.Textbox(label="Question")
138
+ answer_box = gr.Textbox(label="Answer")
139
+
140
+ gr.Markdown("Click 'Process' to see the highlighted relevant parts. Click 'Reset' to start over.")
141
+
142
+ process_button = gr.Button("Process")
143
+ reset_button = gr.Button("Reset")
144
+ processing_time = gr.Textbox(label="Processing Time", value="0 seconds")
145
+ highlighted_table = gr.HTML(label="Highlighted Table")
146
+ explanation_box = gr.Textbox(label="Explanation")
147
+
148
+ def update_table_data(table_name):
149
+ image_path, question, answer = load_table_data(table_name)
150
+ return image_path, question, answer, gr.update(interactive=True)
151
+
152
+ def reset_app():
153
+ return (
154
+ gr.update(value="", interactive=True), # table_dropdown
155
+ None, # original_table
156
+ "", # question_box
157
+ "", # answer_box
158
+ "", # highlighted_table
159
+ "", # explanation_box
160
+ gr.update(interactive=True), # process_button
161
+ "0 seconds", # processing_time
162
+ )
163
+
164
+ def process_and_disable(table_name, question, answer):
165
+ processing = True
166
+ counter = 0
167
+
168
+ def update_counter():
169
+ nonlocal counter
170
+ while processing:
171
+ counter += 1
172
+ yield counter
173
+ time.sleep(1)
174
+
175
+ counter_thread = threading.Thread(target=update_counter)
176
+ counter_thread.start()
177
+
178
+ # Disable the dropdown and process button during processing
179
+ yield (
180
+ gr.update(interactive=False), # table_dropdown
181
+ gr.update(interactive=False), # process_button
182
+ gr.update(value="Processing..."), # processing_time
183
+ gr.update(), # highlighted_table
184
+ gr.update(), # explanation_box
185
+ )
186
+
187
+ # Process the input
188
+ result = process_input(table_name, question, answer)
189
+ result_dict = json.loads(result)
190
+
191
+ # Stop the counter
192
+ processing = False
193
+ counter_thread.join()
194
+
195
+ # Re-enable the dropdown and process button, update processing time, and return the result
196
+ yield (
197
+ gr.update(interactive=True), # table_dropdown
198
+ gr.update(interactive=True), # process_button
199
+ f"Processed in {counter} seconds", # processing_time
200
+ gr.update(value=result_dict['highlighted_table']), # highlighted_table
201
+ gr.update(value=result_dict.get('Explanation', '')), # explanation_box
202
+ {
203
+ "Row Citations": result_dict['row_citations'],
204
+ "Column Citations": result_dict['column_citations']
205
+ }
206
+ )
207
+
208
+ table_dropdown.change(update_table_data,
209
+ inputs=[table_dropdown],
210
+ outputs=[original_table, question_box, answer_box, process_button])
211
+
212
+ process_button.click(process_and_disable,
213
+ inputs=[table_dropdown, question_box, answer_box],
214
+ outputs=[table_dropdown, process_button, processing_time, highlighted_table, explanation_box])
215
+
216
+ reset_button.click(reset_app,
217
+ inputs=[],
218
+ outputs=[table_dropdown, original_table, question_box, answer_box, highlighted_table, explanation_box, process_button, processing_time])
219
+
220
+ # Launch the interface
221
+ iface.launch(share=True)
.ipynb_checkpoints/llm_query_api-checkpoint.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import copy
3
+ import io
4
+ import os
5
+ import random
6
+ import time
7
+ import re
8
+ import json
9
+ import argparse
10
+ import yaml
11
+ import openai
12
+ from openai import AzureOpenAI
13
+ from prompts import *
14
+ import base64
15
+ from mimetypes import guess_type
16
+ # from img2table.document import Image
17
+ # from img2table.document import PDF
18
+ # from img2table.ocr import TesseractOCR
19
+ # from img2table.ocr import EasyOCR
20
+ from PIL import Image as PILImage
21
+
22
+ class LLMQueryAPI:
23
+
24
+ def __init__(self) -> None:
25
+ pass
26
+
27
+ def gpt4_chat_completion(self, query):
28
+
29
+ with open('config_gpt4.yaml', 'r') as f:
30
+ config = yaml.safe_load(f)
31
+
32
+ API_KEY = config.get('API_KEY')
33
+ API_VERSION = config.get('API_VERSION')
34
+ API_BASE = config.get('API_BASE')
35
+
36
+ client = AzureOpenAI(
37
+ azure_endpoint= API_BASE,
38
+ api_version= API_VERSION,
39
+ api_key = API_KEY
40
+ )
41
+
42
+ deployment_name='gpt-4-2024-04-09'
43
+
44
+ response = client.chat.completions.create(
45
+ model=deployment_name,
46
+ messages=query)
47
+
48
+ return response.choices[0].message.content
49
+
50
+ def gpt35_chat_completion(self, query):
51
+
52
+ with open('config_gpt35.yaml', 'r') as f:
53
+ config = yaml.safe_load(f)
54
+
55
+ API_KEY = config.get('API_KEY')
56
+ API_VERSION = config.get('API_VERSION')
57
+ API_BASE = config.get('API_BASE')
58
+
59
+ response = openai.ChatCompletion.create(
60
+ engine='gpt-35-turbo-0613',
61
+ messages=query,
62
+ request_timeout=60,
63
+ api_key = API_KEY,
64
+ api_version = API_VERSION,
65
+ api_type = "azure",
66
+ api_base = API_BASE,
67
+ )
68
+
69
+ return response['choices'][0]['message']
70
+
71
+ def copilot_chat_completion(self, query):
72
+
73
+ with open('config_gpt4.yaml', 'r') as f:
74
+ config = yaml.safe_load(f)
75
+
76
+ API_KEY = config.get('API_KEY')
77
+ API_VERSION = config.get('API_VERSION')
78
+ API_BASE = config.get('API_BASE')
79
+
80
+ response = openai.ChatCompletion.create(
81
+ engine='gpt-4-0613',
82
+ messages=query,
83
+ request_timeout=60,
84
+ api_key = API_KEY,
85
+ api_version = API_VERSION,
86
+ api_type = "azure",
87
+ api_base = API_BASE,
88
+ )
89
+ return response['choices'][0]['message']
90
+
91
+ def LLM_chat_query(self, query, llm):
92
+
93
+ if llm == 'gpt-3.5-turbo':
94
+ return self.gpt35_chat_completion(query)
95
+ elif llm == "gpt-4":
96
+ return self.gpt4_chat_completion(query)
97
+ # return self.copilot_chat_completion(query)
98
+
99
+ def get_llm_response(self, llm, query):
100
+ chat_completion = []
101
+ chat_completion.append({"role": "system", "content": query})
102
+ res = self.LLM_chat_query(chat_completion, llm)
103
+ return res
104
+
105
+ class LLMProxyQueryAPI:
106
+
107
+ def __init__(self) -> None:
108
+ pass
109
+
110
+ def gpt35_chat_completion(self, query):
111
+ client = openai.Client()
112
+ response = client.chat.completions.create(
113
+ model="gpt-3.5-turbo-16k",
114
+ messages=query,
115
+ )
116
+ return response.choices[0].message.content
117
+
118
+ def gpt4o_chat_completion(self, query):
119
+ client = openai.Client()
120
+ response = client.chat.completions.create(
121
+ model="gpt-4o",
122
+ messages=query,
123
+ )
124
+ return response.choices[0].message.content
125
+
126
+ def gpt4_chat_completion(self, query):
127
+ client = openai.Client()
128
+ response = client.chat.completions.create(
129
+ model="gpt-4-1106-preview",
130
+ messages=query,
131
+ )
132
+ return response.choices[0].message.content
133
+
134
+ def gpt4_vision(self, query, image_path):
135
+
136
+ print(query)
137
+
138
+ client = openai.Client()
139
+ response = client.chat.completions.create(
140
+ model="gpt-4-vision-preview",
141
+ messages=[
142
+ {
143
+ "role": "user",
144
+ "content": [
145
+ {
146
+ "type": "text",
147
+ "text": query
148
+ },
149
+ {
150
+ "type": "image_url",
151
+ "image_url": {
152
+ "url": image_path
153
+ }
154
+ }
155
+ ]
156
+ }
157
+ ],
158
+ max_tokens=4096,
159
+ stream=False
160
+ )
161
+ return response.choices[0].message.content
162
+
163
+ def LLM_chat_query(self, llm, query, image_path=None):
164
+
165
+ if llm == 'gpt-3.5-turbo':
166
+ return self.gpt35_chat_completion(query)
167
+
168
+ elif llm == "gpt-4":
169
+ return self.gpt4_chat_completion(query)
170
+
171
+ elif llm == "gpt-4o":
172
+ return self.gpt4o_chat_completion(query)
173
+
174
+ elif llm == "gpt-4V":
175
+ return self.gpt4_vision(query, image_path)
176
+
177
+ def get_llm_response(self, llm, query, image_path=None):
178
+
179
+ if llm == "gpt-4V" and image_path:
180
+ res = self.LLM_chat_query(llm, query, image_path)
181
+ return res
182
+
183
+ chat_completion = []
184
+ chat_completion.append({"role": "system", "content": query})
185
+ res = self.LLM_chat_query(llm, chat_completion)
186
+ return res
187
+
188
+ # if __name__ == '__main__':
189
+
190
+ # llm_query_api = LLMProxyQueryAPI()
191
+
192
+ # def local_image_to_data_url(image_path):
193
+ # mime_type, _ = guess_type(image_path)
194
+ # if mime_type is None:
195
+ # mime_type = 'application/octet-stream'
196
+
197
+ # with open(image_path, "rb") as image_file:
198
+ # base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
199
+
200
+ # return f"data:{mime_type};base64,{base64_encoded_data}"
201
+
202
+ # tesseract = TesseractOCR()
203
+ # pdf = PDF(src="temp3.pdf", pages=[0, 0])
204
+ # extracted_tables = pdf.extract_tables(ocr=tesseract,
205
+ # implicit_rows=True,
206
+ # borderless_tables=True,)
207
+ # html_table = extracted_tables[0][0].html_repr()
208
+ # print(html_table)
209
+
210
+ # table_image_path = "./temp3.jpeg"
211
+ # table_image_data_url = local_image_to_data_url(table_image_path)
212
+ # print(table_image_data_url)
213
+ # query = table_image_to_html_prompt.replace("{{html_table}}", html_table)
214
+ # html_table_refined = llm_query_api.get_llm_response("gpt-4V", query, table_image_data_url)
215
+ # print(html_table_refined)
216
+
.ipynb_checkpoints/matsa-checkpoint.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ from prompts import *
4
+ import ast
5
+ from bs4 import BeautifulSoup
6
+ from semantic_retrieval import *
7
+ from llm_query_api import *
8
+ import base64
9
+ from mimetypes import guess_type
10
+
11
+ class InputInstance:
12
+ def __init__(self, id=None, html_table=None, question=None, answer=None):
13
+ self.id = id
14
+ self.html_table = html_table
15
+ self.question = question
16
+ self.answer = answer
17
+
18
+ return
19
+
20
+ class MATSA:
21
+ def __init__(self, llm = "gpt-4"):
22
+ self.llm = llm
23
+ self.llm_query_api = LLMQueryAPI() #LLMProxyQueryAPI()
24
+ pass
25
+
26
+ def table_formatting_agent(self, html_table = None, table_image_path = None):
27
+
28
+ def local_image_to_data_url(image_path):
29
+ mime_type, _ = guess_type(image_path)
30
+ if mime_type is None:
31
+ mime_type = 'application/octet-stream'
32
+
33
+ with open(image_path, "rb") as image_file:
34
+ base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
35
+
36
+ return f"data:{mime_type};base64,{base64_encoded_data}"
37
+
38
+ if table_image_path != None:
39
+ tesseract = TesseractOCR()
40
+ pdf = PDF(src=table_image_path, pages=[0, 0])
41
+ extracted_tables = pdf.extract_tables(ocr=tesseract,
42
+ implicit_rows=True,
43
+ borderless_tables=True,)
44
+ html_table = extracted_tables[0][0].html_repr()
45
+
46
+ table_image_data_url = local_image_to_data_url(table_image_path)
47
+ query = table_image_to_html_prompt.replace("{{html_table}}", html_table)
48
+ html_table = llm_query_api.get_llm_response("gpt-4V", query, table_image_data_url)
49
+
50
+ soup = BeautifulSoup(html_table, 'html.parser')
51
+ tr_tags = soup.find_all('tr')
52
+ for i, tr_tag in enumerate(tr_tags):
53
+ tr_tag['id'] = f"row-{i + 1}" # Assign unique ID using 'row-i' format
54
+
55
+ if i == 0:
56
+ th_tags = tr_tag.find_all('th')
57
+ for i, th_tag in enumerate(th_tags):
58
+ th_tag['id'] = f"col-{i + 1}" # Assign unique ID using 'col-i' format
59
+
60
+ return str(soup)
61
+
62
+ def description_augmentation_agent(self, html_table):
63
+
64
+ query = col_description_prompt.replace("{{html_table}}", str(html_table))
65
+ col_augmented_html_table = self.llm_query_api.get_llm_response(self.llm, query)
66
+
67
+ query = row_description_prompt.replace("{{html_table}}", str(col_augmented_html_table))
68
+ row_augmented_html_table = self.llm_query_api.get_llm_response(self.llm, query)
69
+
70
+ query = trend_description_prompt.replace("{{html_table}}", str(row_augmented_html_table))
71
+ trend_augmented_html_table = self.llm_query_api.get_llm_response(self.llm, query)
72
+
73
+ return trend_augmented_html_table
74
+
75
+ def answer_decomposition_agent(self, answer):
76
+
77
+ prompt = answer_decomposition_prompt
78
+ query = prompt.replace("{{answer}}", answer)
79
+ res = self.llm_query_api.get_llm_response(self.llm, query)
80
+ res = ast.literal_eval(res)
81
+ if isinstance(res, list):
82
+ return res
83
+ else:
84
+ return None
85
+
86
+ def semantic_retreival_agent(self, html_table, fact_list, topK=5):
87
+
88
+ attributed_html_table, row_attribution_ids, col_attribution_ids = get_embedding_attribution(html_table, fact_list, topK)
89
+ return attributed_html_table, row_attribution_ids, col_attribution_ids
90
+
91
+ def sufficiency_attribution_agent(self, fact_list, attributed_html_table):
92
+
93
+ fact_verification_function = {}
94
+
95
+ fact_verification_list = []
96
+
97
+ for i in range(len(fact_list)):
98
+ fact=fact_list[i]
99
+ fxn = {}
100
+ fxn["Fact " + str(i+1)+":"] = str(fact)
101
+ # fxn["Verified"] = "..."
102
+ fact_verification_list.append(fxn)
103
+
104
+ fact_verification_function["List of Fact"] = fact_verification_list
105
+
106
+ fact_verification_function["Row Citations"] = "[..., ..., ...]"
107
+ fact_verification_function["Column Citations"] = "[..., ..., ...]"
108
+ fact_verification_function["Explanation"] = "..."
109
+
110
+ fact_verification_function_string = json.dumps(fact_verification_function)
111
+
112
+ query = functional_attribution_prompt.replace("{{attributed_html_table}}", str(attributed_html_table)).replace("{{fact_verification_function}}", fact_verification_function_string)
113
+ attribution_fxn = self.llm_query_api.get_llm_response(self.llm, query)
114
+
115
+ attribution_fxn = attribution_fxn.replace("```json", "")
116
+ attribution_fxn = attribution_fxn.replace("```", "")
117
+ print(attribution_fxn)
118
+ attribution_fxn = json.loads(attribution_fxn)
119
+
120
+ if isinstance(attribution_fxn, dict):
121
+ return attribution_fxn
122
+ else:
123
+ return None
124
+
125
+ if __name__ == '__main__':
126
+
127
+ html_table = """<table>
128
+ <tr>
129
+ <th rowspan="1">Sr. Number</th>
130
+ <th colspan="3">Types</th>
131
+ <th rowspan="1">Remark</th>
132
+ </tr>
133
+ <tr>
134
+ <th> </th>
135
+ <th>A</th>
136
+ <th>B</th>
137
+ <th>C</th>
138
+ <th> </th>
139
+ </tr>
140
+ <tr>
141
+ <td>1</td>
142
+ <td>Mitten</td>
143
+ <td>Kity</td>
144
+ <td>Teddy</td>
145
+ <td>Names of cats</td>
146
+ </tr>
147
+ <tr>
148
+ <td>1</td>
149
+ <td>Tommy</td>
150
+ <td>Rudolph</td>
151
+ <td>Jerry</td>
152
+ <td>Names of dogs</td>
153
+ </tr>
154
+ </table>"""
155
+
156
+ answer = "Tommy is a dog but Mitten is a cat."
157
+
158
+
159
+ x = InputInstance(html_table=html_table, answer=answer)
160
+
161
+ matsa_agent = MATSA()
162
+
163
+ x_reformulated = matsa_agent.table_formatting_agent(x.html_table)
164
+ print(x_reformulated)
165
+
166
+ x_descriptions = matsa_agent.description_augmentation_agent(x_reformulated)
167
+ print(x_descriptions)
168
+
169
+ fact_list = matsa_agent.answer_decomposition_agent(x.answer)
170
+ print(fact_list)
171
+
172
+ attributed_html_table, row_attribution_ids, col_attribution_ids = matsa_agent.semantic_retreival_agent(x_descriptions, fact_list)
173
+ print(attributed_html_table)
174
+
175
+ attribution_fxn = matsa_agent.sufficiency_attribution_agent(fact_list, attributed_html_table)
176
+ print(attribution_fxn)
177
+
178
+ row_attribution_set = attribution_fxn["Row Citations"]
179
+ col_attribution_set = attribution_fxn["Column Citations"]
180
+
181
+ print(row_attribution_set)
182
+ print(col_attribution_set)
.ipynb_checkpoints/semantic_retrieval-checkpoint.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from bs4 import BeautifulSoup
3
+ from sklearn.preprocessing import minmax_scale
4
+ from sentence_transformers import SentenceTransformer, util
5
+ import numpy as np
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ sbert = SentenceTransformer("all-MiniLM-L6-v2")
8
+ from llm_query_api import *
9
+
10
+ def get_row_embedding(html_table):
11
+
12
+ def get_row_elements(html_table):
13
+ tr_elements = []
14
+ soup = BeautifulSoup(html_table, 'html.parser')
15
+ tr_tags = soup.find_all('tr')
16
+ for t in tr_tags:
17
+ temp = " " + str(t.get('description'))
18
+ try:
19
+ tr_elements.append({'id':str(t.get('id')), 'text': temp})
20
+ except:
21
+ pass
22
+ return tr_elements
23
+
24
+ rows = get_row_elements(html_table)
25
+
26
+ all_elements = rows
27
+ sentences = []
28
+ element_ids = []
29
+ for i in range(len(all_elements)):
30
+ sentences.append(all_elements[i]['text'])
31
+ element_ids.append(all_elements[i]['id'])
32
+
33
+ embeddings = sbert.encode(sentences, convert_to_tensor=True).cpu().numpy()
34
+ return embeddings, element_ids
35
+
36
+ def get_col_embedding(html_table):
37
+
38
+ def get_column_elements(html_table):
39
+ th_elements = []
40
+ soup = BeautifulSoup(html_table, 'html.parser')
41
+ th_tags = soup.find_all('th')
42
+ for t in th_tags:
43
+ temp = " " + str(t.get('description'))
44
+ try:
45
+ th_elements.append({'id':str(t.get('id')), 'text': temp})
46
+ except:
47
+ pass
48
+
49
+ return th_elements
50
+
51
+ cols = get_column_elements(html_table)
52
+
53
+ all_elements = cols
54
+ sentences = []
55
+ element_ids = []
56
+ for i in range(len(all_elements)):
57
+ sentences.append(all_elements[i]['text'])
58
+ element_ids.append(all_elements[i]['id'])
59
+
60
+ embeddings = sbert.encode(sentences, convert_to_tensor=True).cpu().numpy()
61
+ return embeddings, element_ids
62
+
63
+ def normalize_list_numpy(list_numpy):
64
+ normalized_list = minmax_scale(list_numpy)
65
+ return normalized_list
66
+
67
+ def get_answer_embedding(answer):
68
+ return sbert.encode([answer], convert_to_tensor=True).cpu().numpy()
69
+
70
+ def row_attribution(answer, html_table, topk=5, threshold = 0.7):
71
+
72
+ answer_embedding = get_answer_embedding(answer)
73
+ row_embedding = get_row_embedding(html_table)
74
+
75
+ similarities = cosine_similarity(row_embedding[0], answer_embedding.reshape(1, -1))
76
+ sims = similarities.flatten()
77
+ sims = normalize_list_numpy(sims)
78
+ #if no of rows >= 5, take max of (5, 1/3 x rows)
79
+ #else if no of rows < 5, take least of (5, rows)
80
+ k = max(topk, int(0.3*len(sims)))
81
+ k = min(k, len(sims))
82
+ top_k_indices = np.argpartition(sims, -k)[-k:]
83
+ sorted_indices = top_k_indices[np.argsort(sims[top_k_indices])][::-1]
84
+ top_k_results = [row_embedding[1][idx] for idx in sorted_indices]
85
+
86
+ return top_k_results
87
+
88
+ def col_attribution(answer, html_table, topk=5, threshold = 0.7):
89
+
90
+ answer_embedding = get_answer_embedding(answer)
91
+ col_embedding = get_col_embedding(html_table)
92
+
93
+ similarities = cosine_similarity(col_embedding[0], answer_embedding.reshape(1, -1))
94
+ sims = similarities.flatten()
95
+ sims = normalize_list_numpy(sims)
96
+ #if no of cols >= 5, take max of (5, 1/3 x cols)
97
+ #else if no of cols < 5, take least of (5, cols)
98
+ k = max(topk, int(0.3*len(sims)))
99
+ k = min(k, len(sims))
100
+ top_k_indices = np.argpartition(sims, -k)[-k:]
101
+ sorted_indices = top_k_indices[np.argsort(sims[top_k_indices])][::-1]
102
+ top_k_results = [col_embedding[1][idx] for idx in sorted_indices]
103
+
104
+ return top_k_results
105
+
106
+ def retain_rows_and_columns(augmented_html_table, row_ids, column_ids):
107
+ soup = BeautifulSoup(augmented_html_table, 'html.parser')
108
+
109
+ row_ids = list(set(row_ids))
110
+ column_ids = list(set(column_ids))
111
+
112
+ # Retain specified rows and remove others
113
+ all_rows = soup.find_all('tr')
114
+ for row in all_rows:
115
+ if row.get('id') not in row_ids:
116
+ row.decompose()
117
+
118
+ # Retain specified columns and remove others
119
+ if all_rows:
120
+ all_columns = all_rows[0].find_all(['th'])
121
+ for i, col in enumerate(all_columns):
122
+ if col.get('id') not in column_ids:
123
+ for row in soup.find_all('tr'):
124
+ cells = row.find_all(['td', 'th'])
125
+ if len(cells) > i:
126
+ cells[i].decompose()
127
+
128
+ return str(soup)
129
+
130
+ def get_embedding_attribution(augmented_html_table, decomposed_fact_list, topK=5, threshold = 0.7):
131
+
132
+ row_attribution_ids = []
133
+ col_attribution_ids = []
134
+
135
+ for i in range(len(decomposed_fact_list)):
136
+ answer = decomposed_fact_list[i]
137
+
138
+ rorAttr = row_attribution(answer, augmented_html_table, topK)
139
+ colAttr = col_attribution(answer, augmented_html_table, topK)
140
+
141
+ row_attribution_ids.extend(rorAttr)
142
+ col_attribution_ids.extend(colAttr)
143
+
144
+ attributed_html_table = retain_rows_and_columns(augmented_html_table, row_attribution_ids, col_attribution_ids)
145
+
146
+ return attributed_html_table, row_attribution_ids, col_attribution_ids
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Matsa Demo
3
- emoji: 👀
4
- colorFrom: pink
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.42.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Matsa-demo
3
+ app_file: demo.py
 
 
4
  sdk: gradio
5
+ sdk_version: 4.41.0
 
 
6
  ---
 
 
__pycache__/demo.cpython-310.pyc ADDED
Binary file (5.95 kB). View file
 
__pycache__/llm_query_api.cpython-310.pyc ADDED
Binary file (4.44 kB). View file
 
__pycache__/matsa.cpython-310.pyc ADDED
Binary file (5.21 kB). View file
 
__pycache__/prompts.cpython-310.pyc ADDED
Binary file (6.37 kB). View file
 
__pycache__/semantic_retrieval.cpython-310.pyc ADDED
Binary file (3.89 kB). View file
 
config_gpt35.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ API_BASE: https://dil-research.openai.azure.com/
2
+ API_KEY: de6e495251174ceab84f290cd3925b07
3
+ API_VERSION: '2023-05-15'
config_gpt4.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ API_BASE: https://dil-research-sweden-central.openai.azure.com/
2
+ API_KEY: 0df7237e0fc243a8830a23b2dbba2dcb
3
+ API_VERSION: "2023-12-01-preview"
demo.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ from bs4 import BeautifulSoup
4
+ from matsa import MATSA, InputInstance
5
+ import imgkit
6
+ import tempfile
7
+ import time
8
+ import threading
9
+
10
+ TABLE_FOLDER = "./tables_folder/MATSA_fetaqa.json"
11
+ # Load data from JSON file
12
+ def load_data():
13
+ with open(TABLE_FOLDER, 'r') as json_file:
14
+ return json.load(json_file)
15
+
16
+ # Global variable to store the loaded data
17
+ TABLE_DATA = load_data()
18
+
19
+ def get_table_names():
20
+ return [f"tab_{i+1}" for i in range(len(TABLE_DATA))]
21
+
22
+ def html_to_image(html_content):
23
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_img:
24
+ imgkit.from_string(html_content, temp_img.name)
25
+ return temp_img.name
26
+
27
+ def highlight_table(html_table, row_ids, col_ids):
28
+ soup = BeautifulSoup(html_table, 'html.parser')
29
+ row_sel = []
30
+
31
+ # Highlight rows
32
+ for row_id in row_ids:
33
+ row = soup.find('tr', id=row_id)
34
+ if row:
35
+ row_sel.append(row)
36
+
37
+ for col_id in col_ids:
38
+ col_index = int(col_id.split('-')[1]) - 1 # Convert col-1 to index 0, col-2 to index 1, etc.
39
+ for row in soup.find_all('tr'):
40
+ cells = row.find_all(['td', 'th'])
41
+ if row in row_sel:
42
+ if col_index < len(cells):
43
+ cells[col_index]['style'] = 'background-color: rgba(173, 216, 230, 0.7);'
44
+ else:
45
+ if col_index < len(cells):
46
+ cells[col_index]['style'] = 'background-color: rgba(211, 211, 211, 0.6);'
47
+
48
+ return str(soup)
49
+
50
+ def load_table_data(table_name):
51
+ if not table_name:
52
+ return None, "", ""
53
+ index = int(table_name.split('_')[1]) - 1
54
+ data = TABLE_DATA[index]
55
+
56
+ html_content = data['html_table']
57
+ question = data.get("question", "") #data['question']
58
+ if question is None:
59
+ question = ""
60
+ answer = data['answer_statement']
61
+
62
+ image_path = html_to_image(html_content)
63
+
64
+ return image_path, question, answer
65
+
66
+ def process_input(table_name, question, answer):
67
+ if not table_name:
68
+ return "Please select a table from the dropdown."
69
+
70
+ # Get the data for the selected table
71
+ index = int(table_name.split('_')[1]) - 1
72
+ data = TABLE_DATA[index]
73
+
74
+ html_content = data['html_table']
75
+
76
+ print("html_content: ", html_content)
77
+ print("question: ", question)
78
+ print("answer: ", answer)
79
+
80
+ # Initialize MATSA
81
+ matsa_agent = MATSA()
82
+
83
+ # Create input instance
84
+ instance = InputInstance(html_table=html_content, question=question, answer=answer)
85
+
86
+ # Apply MATSA pipeline
87
+ # formatted_table = matsa_agent.table_formatting_agent(instance.html_table)
88
+ augmented_table = matsa_agent.description_augmentation_agent(instance.html_table)
89
+ print("augmented_table: ", augmented_table)
90
+ fact_list = matsa_agent.answer_decomposition_agent(instance.answer)
91
+ print("fact_list: ", fact_list)
92
+ attributed_table, _, _ = matsa_agent.semantic_retreival_agent(augmented_table, fact_list)
93
+ print("attributed_table: ", attributed_table)
94
+ attribution_fxn = matsa_agent.sufficiency_attribution_agent(fact_list, attributed_table)
95
+ print("attribution_fxn: ", attribution_fxn)
96
+
97
+ # Get row and column attributions
98
+ row_attribution_set = attribution_fxn["Row Citations"]
99
+ col_attribution_set = attribution_fxn["Column Citations"]
100
+ explnation = attribution_fxn.get("Explanation", "")
101
+ print("row_attribution_set: ", row_attribution_set)
102
+ print("col_attribution_set: ", col_attribution_set)
103
+ print("Explanation: ", attribution_fxn.get("Explanation", ""))
104
+
105
+ # Convert string representations to lists
106
+ if isinstance(row_attribution_set, str):
107
+ row_ids = eval(row_attribution_set)
108
+ else:
109
+ row_ids = row_attribution_set
110
+
111
+ if isinstance(col_attribution_set, str):
112
+ col_ids = eval(col_attribution_set)
113
+ else:
114
+ col_ids = col_attribution_set
115
+
116
+ # Highlight the table
117
+ highlighted_table = highlight_table(instance.html_table, row_ids, col_ids)
118
+
119
+ result = {
120
+ "highlighted_table": highlighted_table,
121
+ "facts": attribution_fxn.get("List of Facts", []),
122
+ "row_citations": row_attribution_set,
123
+ "column_citations": col_attribution_set,
124
+ "Explanation": explnation
125
+ }
126
+
127
+ return json.dumps(result)
128
+
129
+ # Define Gradio interface
130
+ with gr.Blocks() as iface:
131
+ gr.Markdown("# MATSA: Table Question Answering with Attribution")
132
+ gr.Markdown("Select a table from dropdown load table image, question, and answer.")
133
+ gr.Markdown("Attributions are provided as per answer. You may change the question/answer as per your need.")
134
+
135
+ table_dropdown = gr.Dropdown(choices=get_table_names(), label="Select Table")
136
+ original_table = gr.Image(type="filepath", label="Original Table")
137
+ question_box = gr.Textbox(label="Question")
138
+ answer_box = gr.Textbox(label="Answer")
139
+
140
+ gr.Markdown("Click 'Process' to see the highlighted relevant parts. Click 'Reset' to start over.")
141
+
142
+ process_button = gr.Button("Process")
143
+ reset_button = gr.Button("Reset")
144
+ processing_time = gr.Textbox(label="Processing Time", value="0 seconds")
145
+ highlighted_table = gr.HTML(label="Highlighted Table")
146
+ explanation_box = gr.Textbox(label="Explanation")
147
+
148
+ def update_table_data(table_name):
149
+ image_path, question, answer = load_table_data(table_name)
150
+ return image_path, question, answer, gr.update(interactive=True)
151
+
152
+ def reset_app():
153
+ return (
154
+ gr.update(value="", interactive=True), # table_dropdown
155
+ None, # original_table
156
+ "", # question_box
157
+ "", # answer_box
158
+ "", # highlighted_table
159
+ "", # explanation_box
160
+ gr.update(interactive=True), # process_button
161
+ "0 seconds", # processing_time
162
+ )
163
+
164
+ def process_and_disable(table_name, question, answer):
165
+ processing = True
166
+ counter = 0
167
+
168
+ def update_counter():
169
+ nonlocal counter
170
+ while processing:
171
+ counter += 1
172
+ yield counter
173
+ time.sleep(1)
174
+
175
+ counter_thread = threading.Thread(target=update_counter)
176
+ counter_thread.start()
177
+
178
+ # Disable the dropdown and process button during processing
179
+ yield (
180
+ gr.update(interactive=False), # table_dropdown
181
+ gr.update(interactive=False), # process_button
182
+ gr.update(value="Processing..."), # processing_time
183
+ gr.update(), # highlighted_table
184
+ gr.update(), # explanation_box
185
+ )
186
+
187
+ # Process the input
188
+ result = process_input(table_name, question, answer)
189
+ result_dict = json.loads(result)
190
+
191
+ # Stop the counter
192
+ processing = False
193
+ counter_thread.join()
194
+
195
+ # Re-enable the dropdown and process button, update processing time, and return the result
196
+ yield (
197
+ gr.update(interactive=True), # table_dropdown
198
+ gr.update(interactive=True), # process_button
199
+ f"Processed in {counter} seconds", # processing_time
200
+ gr.update(value=result_dict['highlighted_table']), # highlighted_table
201
+ gr.update(value=result_dict.get('Explanation', '')), # explanation_box
202
+ {
203
+ "Row Citations": result_dict['row_citations'],
204
+ "Column Citations": result_dict['column_citations']
205
+ }
206
+ )
207
+
208
+ table_dropdown.change(update_table_data,
209
+ inputs=[table_dropdown],
210
+ outputs=[original_table, question_box, answer_box, process_button])
211
+
212
+ process_button.click(process_and_disable,
213
+ inputs=[table_dropdown, question_box, answer_box],
214
+ outputs=[table_dropdown, process_button, processing_time, highlighted_table, explanation_box])
215
+
216
+ reset_button.click(reset_app,
217
+ inputs=[],
218
+ outputs=[table_dropdown, original_table, question_box, answer_box, highlighted_table, explanation_box, process_button, processing_time])
219
+
220
+ # Launch the interface
221
+ iface.launch(share=True)
llm_query_api.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import copy
3
+ import io
4
+ import os
5
+ import random
6
+ import time
7
+ import re
8
+ import json
9
+ import argparse
10
+ import yaml
11
+ import openai
12
+ from openai import AzureOpenAI
13
+ from prompts import *
14
+ import base64
15
+ from mimetypes import guess_type
16
+ # from img2table.document import Image
17
+ # from img2table.document import PDF
18
+ # from img2table.ocr import TesseractOCR
19
+ # from img2table.ocr import EasyOCR
20
+ from PIL import Image as PILImage
21
+
22
+ class LLMQueryAPI:
23
+
24
+ def __init__(self) -> None:
25
+ pass
26
+
27
+ def gpt4_chat_completion(self, query):
28
+
29
+ with open('config_gpt4.yaml', 'r') as f:
30
+ config = yaml.safe_load(f)
31
+
32
+ API_KEY = config.get('API_KEY')
33
+ API_VERSION = config.get('API_VERSION')
34
+ API_BASE = config.get('API_BASE')
35
+
36
+ client = AzureOpenAI(
37
+ azure_endpoint= API_BASE,
38
+ api_version= API_VERSION,
39
+ api_key = API_KEY
40
+ )
41
+
42
+ deployment_name='gpt-4-2024-04-09'
43
+
44
+ response = client.chat.completions.create(
45
+ model=deployment_name,
46
+ messages=query)
47
+
48
+ return response.choices[0].message.content
49
+
50
+ def gpt35_chat_completion(self, query):
51
+
52
+ with open('config_gpt35.yaml', 'r') as f:
53
+ config = yaml.safe_load(f)
54
+
55
+ API_KEY = config.get('API_KEY')
56
+ API_VERSION = config.get('API_VERSION')
57
+ API_BASE = config.get('API_BASE')
58
+
59
+ response = openai.ChatCompletion.create(
60
+ engine='gpt-35-turbo-0613',
61
+ messages=query,
62
+ request_timeout=60,
63
+ api_key = API_KEY,
64
+ api_version = API_VERSION,
65
+ api_type = "azure",
66
+ api_base = API_BASE,
67
+ )
68
+
69
+ return response['choices'][0]['message']
70
+
71
+ def copilot_chat_completion(self, query):
72
+
73
+ with open('config_gpt4.yaml', 'r') as f:
74
+ config = yaml.safe_load(f)
75
+
76
+ API_KEY = config.get('API_KEY')
77
+ API_VERSION = config.get('API_VERSION')
78
+ API_BASE = config.get('API_BASE')
79
+
80
+ response = openai.ChatCompletion.create(
81
+ engine='gpt-4-0613',
82
+ messages=query,
83
+ request_timeout=60,
84
+ api_key = API_KEY,
85
+ api_version = API_VERSION,
86
+ api_type = "azure",
87
+ api_base = API_BASE,
88
+ )
89
+ return response['choices'][0]['message']
90
+
91
+ def LLM_chat_query(self, query, llm):
92
+
93
+ if llm == 'gpt-3.5-turbo':
94
+ return self.gpt35_chat_completion(query)
95
+ elif llm == "gpt-4":
96
+ return self.gpt4_chat_completion(query)
97
+ # return self.copilot_chat_completion(query)
98
+
99
+ def get_llm_response(self, llm, query):
100
+ chat_completion = []
101
+ chat_completion.append({"role": "system", "content": query})
102
+ res = self.LLM_chat_query(chat_completion, llm)
103
+ return res
104
+
105
+ class LLMProxyQueryAPI:
106
+
107
+ def __init__(self) -> None:
108
+ pass
109
+
110
+ def gpt35_chat_completion(self, query):
111
+ client = openai.Client()
112
+ response = client.chat.completions.create(
113
+ model="gpt-3.5-turbo-16k",
114
+ messages=query,
115
+ )
116
+ return response.choices[0].message.content
117
+
118
+ def gpt4o_chat_completion(self, query):
119
+ client = openai.Client()
120
+ response = client.chat.completions.create(
121
+ model="gpt-4o",
122
+ messages=query,
123
+ )
124
+ return response.choices[0].message.content
125
+
126
+ def gpt4_chat_completion(self, query):
127
+ client = openai.Client()
128
+ response = client.chat.completions.create(
129
+ model="gpt-4-1106-preview",
130
+ messages=query,
131
+ )
132
+ return response.choices[0].message.content
133
+
134
+ def gpt4_vision(self, query, image_path):
135
+
136
+ print(query)
137
+
138
+ client = openai.Client()
139
+ response = client.chat.completions.create(
140
+ model="gpt-4-vision-preview",
141
+ messages=[
142
+ {
143
+ "role": "user",
144
+ "content": [
145
+ {
146
+ "type": "text",
147
+ "text": query
148
+ },
149
+ {
150
+ "type": "image_url",
151
+ "image_url": {
152
+ "url": image_path
153
+ }
154
+ }
155
+ ]
156
+ }
157
+ ],
158
+ max_tokens=4096,
159
+ stream=False
160
+ )
161
+ return response.choices[0].message.content
162
+
163
+ def LLM_chat_query(self, llm, query, image_path=None):
164
+
165
+ if llm == 'gpt-3.5-turbo':
166
+ return self.gpt35_chat_completion(query)
167
+
168
+ elif llm == "gpt-4":
169
+ return self.gpt4_chat_completion(query)
170
+
171
+ elif llm == "gpt-4o":
172
+ return self.gpt4o_chat_completion(query)
173
+
174
+ elif llm == "gpt-4V":
175
+ return self.gpt4_vision(query, image_path)
176
+
177
+ def get_llm_response(self, llm, query, image_path=None):
178
+
179
+ if llm == "gpt-4V" and image_path:
180
+ res = self.LLM_chat_query(llm, query, image_path)
181
+ return res
182
+
183
+ chat_completion = []
184
+ chat_completion.append({"role": "system", "content": query})
185
+ res = self.LLM_chat_query(llm, chat_completion)
186
+ return res
187
+
188
+ # if __name__ == '__main__':
189
+
190
+ # llm_query_api = LLMProxyQueryAPI()
191
+
192
+ # def local_image_to_data_url(image_path):
193
+ # mime_type, _ = guess_type(image_path)
194
+ # if mime_type is None:
195
+ # mime_type = 'application/octet-stream'
196
+
197
+ # with open(image_path, "rb") as image_file:
198
+ # base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
199
+
200
+ # return f"data:{mime_type};base64,{base64_encoded_data}"
201
+
202
+ # tesseract = TesseractOCR()
203
+ # pdf = PDF(src="temp3.pdf", pages=[0, 0])
204
+ # extracted_tables = pdf.extract_tables(ocr=tesseract,
205
+ # implicit_rows=True,
206
+ # borderless_tables=True,)
207
+ # html_table = extracted_tables[0][0].html_repr()
208
+ # print(html_table)
209
+
210
+ # table_image_path = "./temp3.jpeg"
211
+ # table_image_data_url = local_image_to_data_url(table_image_path)
212
+ # print(table_image_data_url)
213
+ # query = table_image_to_html_prompt.replace("{{html_table}}", html_table)
214
+ # html_table_refined = llm_query_api.get_llm_response("gpt-4V", query, table_image_data_url)
215
+ # print(html_table_refined)
216
+
matsa.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ from prompts import *
4
+ import ast
5
+ from bs4 import BeautifulSoup
6
+ from semantic_retrieval import *
7
+ from llm_query_api import *
8
+ import base64
9
+ from mimetypes import guess_type
10
+
11
+ class InputInstance:
12
+ def __init__(self, id=None, html_table=None, question=None, answer=None):
13
+ self.id = id
14
+ self.html_table = html_table
15
+ self.question = question
16
+ self.answer = answer
17
+
18
+ return
19
+
20
+ class MATSA:
21
+ def __init__(self, llm = "gpt-4"):
22
+ self.llm = llm
23
+ self.llm_query_api = LLMQueryAPI() #LLMProxyQueryAPI()
24
+ pass
25
+
26
+ def table_formatting_agent(self, html_table = None, table_image_path = None):
27
+
28
+ def local_image_to_data_url(image_path):
29
+ mime_type, _ = guess_type(image_path)
30
+ if mime_type is None:
31
+ mime_type = 'application/octet-stream'
32
+
33
+ with open(image_path, "rb") as image_file:
34
+ base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
35
+
36
+ return f"data:{mime_type};base64,{base64_encoded_data}"
37
+
38
+ if table_image_path != None:
39
+ tesseract = TesseractOCR()
40
+ pdf = PDF(src=table_image_path, pages=[0, 0])
41
+ extracted_tables = pdf.extract_tables(ocr=tesseract,
42
+ implicit_rows=True,
43
+ borderless_tables=True,)
44
+ html_table = extracted_tables[0][0].html_repr()
45
+
46
+ table_image_data_url = local_image_to_data_url(table_image_path)
47
+ query = table_image_to_html_prompt.replace("{{html_table}}", html_table)
48
+ html_table = llm_query_api.get_llm_response("gpt-4V", query, table_image_data_url)
49
+
50
+ soup = BeautifulSoup(html_table, 'html.parser')
51
+ tr_tags = soup.find_all('tr')
52
+ for i, tr_tag in enumerate(tr_tags):
53
+ tr_tag['id'] = f"row-{i + 1}" # Assign unique ID using 'row-i' format
54
+
55
+ if i == 0:
56
+ th_tags = tr_tag.find_all('th')
57
+ for i, th_tag in enumerate(th_tags):
58
+ th_tag['id'] = f"col-{i + 1}" # Assign unique ID using 'col-i' format
59
+
60
+ return str(soup)
61
+
62
+ def description_augmentation_agent(self, html_table):
63
+
64
+ query = col_description_prompt.replace("{{html_table}}", str(html_table))
65
+ col_augmented_html_table = self.llm_query_api.get_llm_response(self.llm, query)
66
+
67
+ query = row_description_prompt.replace("{{html_table}}", str(col_augmented_html_table))
68
+ row_augmented_html_table = self.llm_query_api.get_llm_response(self.llm, query)
69
+
70
+ query = trend_description_prompt.replace("{{html_table}}", str(row_augmented_html_table))
71
+ trend_augmented_html_table = self.llm_query_api.get_llm_response(self.llm, query)
72
+
73
+ return trend_augmented_html_table
74
+
75
+ def answer_decomposition_agent(self, answer):
76
+
77
+ prompt = answer_decomposition_prompt
78
+ query = prompt.replace("{{answer}}", answer)
79
+ res = self.llm_query_api.get_llm_response(self.llm, query)
80
+ res = ast.literal_eval(res)
81
+ if isinstance(res, list):
82
+ return res
83
+ else:
84
+ return None
85
+
86
+ def semantic_retreival_agent(self, html_table, fact_list, topK=5):
87
+
88
+ attributed_html_table, row_attribution_ids, col_attribution_ids = get_embedding_attribution(html_table, fact_list, topK)
89
+ return attributed_html_table, row_attribution_ids, col_attribution_ids
90
+
91
+ def sufficiency_attribution_agent(self, fact_list, attributed_html_table):
92
+
93
+ fact_verification_function = {}
94
+
95
+ fact_verification_list = []
96
+
97
+ for i in range(len(fact_list)):
98
+ fact=fact_list[i]
99
+ fxn = {}
100
+ fxn["Fact " + str(i+1)+":"] = str(fact)
101
+ # fxn["Verified"] = "..."
102
+ fact_verification_list.append(fxn)
103
+
104
+ fact_verification_function["List of Fact"] = fact_verification_list
105
+
106
+ fact_verification_function["Row Citations"] = "[..., ..., ...]"
107
+ fact_verification_function["Column Citations"] = "[..., ..., ...]"
108
+ fact_verification_function["Explanation"] = "..."
109
+
110
+ fact_verification_function_string = json.dumps(fact_verification_function)
111
+
112
+ query = functional_attribution_prompt.replace("{{attributed_html_table}}", str(attributed_html_table)).replace("{{fact_verification_function}}", fact_verification_function_string)
113
+ attribution_fxn = self.llm_query_api.get_llm_response(self.llm, query)
114
+
115
+ attribution_fxn = attribution_fxn.replace("```json", "")
116
+ attribution_fxn = attribution_fxn.replace("```", "")
117
+ print(attribution_fxn)
118
+ attribution_fxn = json.loads(attribution_fxn)
119
+
120
+ if isinstance(attribution_fxn, dict):
121
+ return attribution_fxn
122
+ else:
123
+ return None
124
+
125
+ if __name__ == '__main__':
126
+
127
+ html_table = """<table>
128
+ <tr>
129
+ <th rowspan="1">Sr. Number</th>
130
+ <th colspan="3">Types</th>
131
+ <th rowspan="1">Remark</th>
132
+ </tr>
133
+ <tr>
134
+ <th> </th>
135
+ <th>A</th>
136
+ <th>B</th>
137
+ <th>C</th>
138
+ <th> </th>
139
+ </tr>
140
+ <tr>
141
+ <td>1</td>
142
+ <td>Mitten</td>
143
+ <td>Kity</td>
144
+ <td>Teddy</td>
145
+ <td>Names of cats</td>
146
+ </tr>
147
+ <tr>
148
+ <td>1</td>
149
+ <td>Tommy</td>
150
+ <td>Rudolph</td>
151
+ <td>Jerry</td>
152
+ <td>Names of dogs</td>
153
+ </tr>
154
+ </table>"""
155
+
156
+ answer = "Tommy is a dog but Mitten is a cat."
157
+
158
+
159
+ x = InputInstance(html_table=html_table, answer=answer)
160
+
161
+ matsa_agent = MATSA()
162
+
163
+ x_reformulated = matsa_agent.table_formatting_agent(x.html_table)
164
+ print(x_reformulated)
165
+
166
+ x_descriptions = matsa_agent.description_augmentation_agent(x_reformulated)
167
+ print(x_descriptions)
168
+
169
+ fact_list = matsa_agent.answer_decomposition_agent(x.answer)
170
+ print(fact_list)
171
+
172
+ attributed_html_table, row_attribution_ids, col_attribution_ids = matsa_agent.semantic_retreival_agent(x_descriptions, fact_list)
173
+ print(attributed_html_table)
174
+
175
+ attribution_fxn = matsa_agent.sufficiency_attribution_agent(fact_list, attributed_html_table)
176
+ print(attribution_fxn)
177
+
178
+ row_attribution_set = attribution_fxn["Row Citations"]
179
+ col_attribution_set = attribution_fxn["Column Citations"]
180
+
181
+ print(row_attribution_set)
182
+ print(col_attribution_set)
prompts.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ row_description_prompt = """You are a brilliant table assistant with the capabilities information retrieval, table parsing, and semantic understanding of the structural information of the table.
2
+
3
+ Here is a table in html format: \"{{html_table}}\"
4
+
5
+ We need to add a detailed description for each row of the table denoted by the "<tr>" tag element. The description should discuss the overall information inferred from the row. It should also mention all the elements, numbers, and figures present in the row. Also, include any hierarchical row/column header information.
6
+
7
+ Add this description as a "description" attribute in the <tr> tag. Repeat this process for ALL <tr> tags in the provided HTML. Do NOT delete any information that already exists in the other tags. Just print the html and do NOT output any other message. The HTML structure should NOT be changed.
8
+ """
9
+
10
+ col_description_prompt = """You are a brilliant table assistant with the capabilities information retrieval, table parsing, and semantic understanding of the structural information of the table.
11
+
12
+ Here is a table in html format: \"{{html_table}}\"
13
+
14
+ First, write a caption for the entire table that captures the general information being presented by all rows and columns in this table. Add it using the <caption> tag in the HTML.
15
+
16
+ Next, need to add a detailed description for each column of the table denoted by the "<th>" tag element. The description should discuss what elements are present in the column and the overall information that can be inferred by the column, including any hierarchical column header information.
17
+
18
+ Add this description as a "description" in the <th> tag. Repeat this process for ALL <th> tags in the provided HTML. Do NOT delete any information that already exists in the other tags. Just print the html and do NOT output any other message. The HTML structure should NOT be changed.
19
+ """
20
+
21
+ trend_description_prompt = """You are a brilliant table assistant with the capabilities information retrieval, table parsing, semantic understanding, and trend analysis of the structural information of the table.
22
+
23
+ Here is a table in html format: \"{{html_table}}\".
24
+
25
+ We need to add a trend analysis on the elements in the given row compared to its own constituent cells and other rows in the table. The description should discuss semantic descriptions of numerical data, summarizing key quantitative characteristics and tendencies across the table row and across different columns.
26
+
27
+ Add this analysis in the \"description\" of the <tr> tag. Repeat this process for ALL <tr> tags in the provided HTML. Do NOT delete any information that already exists in the other tags. Just print the html and do NOT output any other message. The HTML strcuture should NOT be changed.
28
+ """
29
+
30
+ functional_attribution_prompt = """You are a brilliant assistant with the capabilities information retrieval, fact checking, and semantic understanding of tabular data.
31
+
32
+ Here is the html table - \"{{attributed_html_table}}\"
33
+
34
+ We have a list of facts pertaining to this table present in this JSON structure - \"{{fact_verification_function}}"\.
35
+ The JSON structure contains three empty fields - "Row Citations", "Column Citations", and "Explanation" that need to filled with relevant information.
36
+
37
+ We want to identify all the ROWS in the table that are important to support these facts. In other words, which rows are needed to collectively verify the facts. Please copy the "row-id" of all relevant table rows in the "Row Citation" field of the JSON structure. All rows IDs should be added in the form of a LIST "[... , ... , ...]" with no value repeated. Here is a sample: "Row Citations": ["row-2", "row-3"].
38
+
39
+ Similar to rows, we want to identify all the COLUMNS in the table that are important to support these facts. In other words, which columns are needed to collectively verify the facts. Please copy the "col-id" of all relevant table rows in the "Column Citation" field of the JSON structure. All column IDs should be added in the form of a LIST "[... , ... , ...]" with no value repeated. Here is a sample: "Column Citations": ["col-1", "col-5", "col-7"]
40
+
41
+ "Explanation" field should contain a detailed explanation of how the rows and columns identified in the "Row Citations" and "Column Citations" fields respectively, are important to verify the facts present in the JSON structure. The explanation should be coherent and provide a clear rationale for the selection of rows and columns.
42
+
43
+ The final result should be a complete JSON structure ONLY. Do not print any extra information. Make sure to fill the JSON structure accurately as provided in the prompt. 'Row citations', 'Column Citations', and 'Explanation' should not be empty.
44
+ """
45
+
46
+
47
+ answer_decomposition_prompt = """
48
+
49
+ Here is a passage: \"{{answer}}\"
50
+
51
+ Convert the given passage into a list of short facts which specifically answer the given question.
52
+ Make sure that the facts can be found in the given passage.
53
+ The facts should be coherent and succinct sentences with clear and simple syntax.
54
+ Do not use pronouns as the subject or object in the syntax of each fact.
55
+ The facts should be independent to each other.
56
+ Do not create facts from the passage which are not answering the given question.
57
+ ONLY return a python LIST of strings seperated by comma (,). Do NOT output any extra explanation.
58
+ """
59
+
60
+ table_image_to_html_prompt = """ Here is an image of a table. Please convert this table image into a HTML representation with accurate table cell data.
61
+ In order to help you in this process, here is a noisy HTML representation of the table extracted from the image: \"{{html_table}}\"
62
+ You may use this as a noisy reference and further refine the HTML structure to accurately represent the table data.
63
+ You should also add any information pertaining to row and column spans. In case of nested rows/columns with multiple spans, take be very careful to leave blank cells to ensure the semantic structure is maintained.
64
+ Be careful to handle hierarchical and nested rows/columns.
65
+ Each HTML should start with "<table>" opening tag.
66
+ Each HTML should end with "</table>" closing tag.
67
+ Do NOT output any other explanation or text apart from the HTML code of the table.
68
+ """
69
+
requirements.txt ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_f56ag8l7kr/croot/aiofiles_1683773599608/work
2
+ altair @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a8x4081_4h/croot/altair_1687526044471/work
3
+ annotated-types @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_1fa2djihwb/croot/annotated-types_1709542925772/work
4
+ anyio @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a17a7759g2/croot/anyio_1706220182417/work
5
+ argcomplete==3.4.0
6
+ attrs @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_224434dqzl/croot/attrs_1695717839274/work
7
+ beautifulsoup4==4.12.3
8
+ Bottleneck @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_2bxpizxa3c/croot/bottleneck_1707864819812/work
9
+ Brotli @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_27zk0eqdh0/croot/brotli-split_1714483157007/work
10
+ certifi @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_70oty9s9jh/croot/certifi_1720453497032/work/certifi
11
+ charset-normalizer @ file:///croot/charset-normalizer_1721748349566/work
12
+ click==8.1.7
13
+ colorama @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_100_k35lkb/croot/colorama_1672386539781/work
14
+ contourpy @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_041uwyxdzo/croot/contourpy_1700583585236/work
15
+ cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work
16
+ distro==1.9.0
17
+ exceptiongroup @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_b2258scr33/croot/exceptiongroup_1706031391815/work
18
+ fastapi @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_ab51gv9t42/croot/fastapi_1693295410365/work
19
+ ffmpy @ file:///home/conda/feedstock_root/build_artifacts/ffmpy_1659474992694/work
20
+ filelock @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d3quwmvouf/croot/filelock_1700591194006/work
21
+ fonttools @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_60c8ux4mkl/croot/fonttools_1713551354374/work
22
+ fsspec @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_19mkn689lo/croot/fsspec_1714461553219/work
23
+ gradio @ file:///home/conda/feedstock_root/build_artifacts/gradio_1721770785660/work
24
+ gradio_client @ file:///home/conda/feedstock_root/build_artifacts/gradio-client_1721697178984/work
25
+ h11 @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_110bmw2coo/croot/h11_1706652289620/work
26
+ httpcore @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_fcxiho9nv7/croot/httpcore_1706728465004/work
27
+ httpx @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_727e6zfsxn/croot/httpx_1706887102687/work
28
+ huggingface-hub==0.24.2
29
+ idna @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a12xpo84t2/croot/idna_1714398852854/work
30
+ img2table==1.2.11
31
+ imgkit==1.2.3
32
+ importlib_resources @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_784gdd8gq5/croot/importlib_resources-suite_1720641109833/work
33
+ Jinja2 @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_44yzu12j7f/croot/jinja2_1716993410427/work
34
+ joblib==1.4.2
35
+ jsonschema @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_27o3go8sqa/croot/jsonschema_1699041627313/work
36
+ jsonschema-specifications @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_d38pclgu95/croot/jsonschema-specifications_1699032390832/work
37
+ kiwisolver @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_93o8te804v/croot/kiwisolver_1672387163224/work
38
+ llvmlite==0.43.0
39
+ markdown-it-py @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_43l_4ajkho/croot/markdown-it-py_1684279912406/work
40
+ MarkupSafe @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a84ni4pci8/croot/markupsafe_1704206002077/work
41
+ matplotlib @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a8crvoz7ca/croot/matplotlib-suite_1713336381679/work
42
+ mdurl @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_0a8xm6w4wv/croots/recipe/mdurl_1659716035810/work
43
+ mpmath==1.3.0
44
+ networkx==3.3
45
+ numba==0.60.0
46
+ numexpr @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_45yefq0kt6/croot/numexpr_1696515289183/work
47
+ numpy @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a51i_mbs7m/croot/numpy_and_numpy_base_1708638620867/work/dist/numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl#sha256=c4b11b3c4d4fdb810039503fe01f311ade06cd1d675fcd6d208800a393f19b69
48
+ openai==1.37.0
49
+ opencv-contrib-python==4.10.0.84
50
+ orjson @ file:///var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_910h7wyrw8/croot/orjson_1711143065818/work/target/wheels/orjson-3.9.15-cp310-cp310-macosx_11_0_arm64.whl#sha256=fb0784f650f58e15827dd32f58284d983bb401ec3b85b38321eea14ebd2d01e9
51
+ packaging==24.1
52
+ pandas @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_b53hgou29t/croot/pandas_1718308972393/work/dist/pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl#sha256=7e70989ec6c6e08f2fd87d3940d106704e9feb936002297690d7462e35b5da35
53
+ pdfkit==1.0.0
54
+ pillow @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_617xe_y58w/croot/pillow_1721059447446/work
55
+ pipx==1.6.0
56
+ platformdirs==4.2.2
57
+ polars==1.2.1
58
+ pyarrow==17.0.0
59
+ pydantic @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_0ai8cvgm2c/croot/pydantic_1709577986211/work
60
+ pydantic_core @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_06smitnu98/croot/pydantic-core_1709573985903/work
61
+ pydub @ file:///home/conda/feedstock_root/build_artifacts/pydub_1615612442567/work
62
+ Pygments @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_29bs9f_dh9/croot/pygments_1684279974747/work
63
+ PyMuPDF==1.24.9
64
+ PyMuPDFb==1.24.9
65
+ pyparsing @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_3b_3vxnd07/croots/recipe/pyparsing_1661452540919/work
66
+ PySocks @ file:///Users/ktietz/ci_310/pysocks_1643961536721/work
67
+ python-dateutil @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_66ud1l42_h/croot/python-dateutil_1716495741162/work
68
+ python-multipart @ file:///home/conda/feedstock_root/build_artifacts/python-multipart_1707760088566/work
69
+ pytz @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a4b76c83ik/croot/pytz_1713974318928/work
70
+ PyYAML @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a8_sdgulmz/croot/pyyaml_1698096054705/work
71
+ referencing @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_5cz64gsx70/croot/referencing_1699012046031/work
72
+ regex==2024.7.24
73
+ requests @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_70sm12ba9w/croot/requests_1721414707360/work
74
+ rich @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_bcb8hqb_r4/croot/rich_1720637498249/work
75
+ rpds-py @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_f8jkozoefm/croot/rpds-py_1698945944860/work
76
+ ruff @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e70v1wc31c/croot/ruff_1713372744076/work
77
+ safetensors==0.4.3
78
+ scikit-learn==1.5.1
79
+ scipy==1.14.0
80
+ semantic-version @ file:///tmp/build/80754af9/semantic_version_1613321057691/work
81
+ sentence-transformers==3.0.1
82
+ shellingham @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_9bf5wowles/croot/shellingham_1669142181600/work
83
+ six @ file:///tmp/build/80754af9/six_1644875935023/work
84
+ sniffio @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_1573pknjrg/croot/sniffio_1705431298885/work
85
+ soupsieve==2.5
86
+ starlette @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_7feeoam7sd/croot/starlette-recipe_1692980426894/work
87
+ sympy==1.13.1
88
+ threadpoolctl==3.5.0
89
+ tokenizers==0.19.1
90
+ tomli==2.0.1
91
+ tomlkit @ file:///home/conda/feedstock_root/build_artifacts/tomlkit_1690458286251/work
92
+ toolz @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_362wyqvvgy/croot/toolz_1667464079070/work
93
+ torch==2.4.0
94
+ tqdm @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_f76_dxtcsh/croot/tqdm_1716395948224/work
95
+ transformers==4.43.2
96
+ typer @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_e62p5mo8z5/croot/typer_1684251930377/work
97
+ typing_extensions @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_93dg13ilv4/croot/typing_extensions_1715268840722/work
98
+ tzdata @ file:///croot/python-tzdata_1690578112552/work
99
+ unicodedata2 @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a3epjto7gs/croot/unicodedata2_1713212955584/work
100
+ urllib3 @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_cao7_u9937/croot/urllib3_1718912649114/work
101
+ userpath==1.9.2
102
+ uvicorn @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_75zglvwlmy/croot/uvicorn-split_1678090090396/work
103
+ websockets @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_d8ij4gljy1/croot/websockets_1678966799107/work
104
+ XlsxWriter==3.2.0
semantic_retrieval.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from bs4 import BeautifulSoup
3
+ from sklearn.preprocessing import minmax_scale
4
+ from sentence_transformers import SentenceTransformer, util
5
+ import numpy as np
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ sbert = SentenceTransformer("all-MiniLM-L6-v2")
8
+ from llm_query_api import *
9
+
10
+ def get_row_embedding(html_table):
11
+
12
+ def get_row_elements(html_table):
13
+ tr_elements = []
14
+ soup = BeautifulSoup(html_table, 'html.parser')
15
+ tr_tags = soup.find_all('tr')
16
+ for t in tr_tags:
17
+ temp = " " + str(t.get('description'))
18
+ try:
19
+ tr_elements.append({'id':str(t.get('id')), 'text': temp})
20
+ except:
21
+ pass
22
+ return tr_elements
23
+
24
+ rows = get_row_elements(html_table)
25
+
26
+ all_elements = rows
27
+ sentences = []
28
+ element_ids = []
29
+ for i in range(len(all_elements)):
30
+ sentences.append(all_elements[i]['text'])
31
+ element_ids.append(all_elements[i]['id'])
32
+
33
+ embeddings = sbert.encode(sentences, convert_to_tensor=True).cpu().numpy()
34
+ return embeddings, element_ids
35
+
36
+ def get_col_embedding(html_table):
37
+
38
+ def get_column_elements(html_table):
39
+ th_elements = []
40
+ soup = BeautifulSoup(html_table, 'html.parser')
41
+ th_tags = soup.find_all('th')
42
+ for t in th_tags:
43
+ temp = " " + str(t.get('description'))
44
+ try:
45
+ th_elements.append({'id':str(t.get('id')), 'text': temp})
46
+ except:
47
+ pass
48
+
49
+ return th_elements
50
+
51
+ cols = get_column_elements(html_table)
52
+
53
+ all_elements = cols
54
+ sentences = []
55
+ element_ids = []
56
+ for i in range(len(all_elements)):
57
+ sentences.append(all_elements[i]['text'])
58
+ element_ids.append(all_elements[i]['id'])
59
+
60
+ embeddings = sbert.encode(sentences, convert_to_tensor=True).cpu().numpy()
61
+ return embeddings, element_ids
62
+
63
+ def normalize_list_numpy(list_numpy):
64
+ normalized_list = minmax_scale(list_numpy)
65
+ return normalized_list
66
+
67
+ def get_answer_embedding(answer):
68
+ return sbert.encode([answer], convert_to_tensor=True).cpu().numpy()
69
+
70
+ def row_attribution(answer, html_table, topk=5, threshold = 0.7):
71
+
72
+ answer_embedding = get_answer_embedding(answer)
73
+ row_embedding = get_row_embedding(html_table)
74
+
75
+ similarities = cosine_similarity(row_embedding[0], answer_embedding.reshape(1, -1))
76
+ sims = similarities.flatten()
77
+ sims = normalize_list_numpy(sims)
78
+ #if no of rows >= 5, take max of (5, 1/3 x rows)
79
+ #else if no of rows < 5, take least of (5, rows)
80
+ k = max(topk, int(0.3*len(sims)))
81
+ k = min(k, len(sims))
82
+ top_k_indices = np.argpartition(sims, -k)[-k:]
83
+ sorted_indices = top_k_indices[np.argsort(sims[top_k_indices])][::-1]
84
+ top_k_results = [row_embedding[1][idx] for idx in sorted_indices]
85
+
86
+ return top_k_results
87
+
88
+ def col_attribution(answer, html_table, topk=5, threshold = 0.7):
89
+
90
+ answer_embedding = get_answer_embedding(answer)
91
+ col_embedding = get_col_embedding(html_table)
92
+
93
+ similarities = cosine_similarity(col_embedding[0], answer_embedding.reshape(1, -1))
94
+ sims = similarities.flatten()
95
+ sims = normalize_list_numpy(sims)
96
+ #if no of cols >= 5, take max of (5, 1/3 x cols)
97
+ #else if no of cols < 5, take least of (5, cols)
98
+ k = max(topk, int(0.3*len(sims)))
99
+ k = min(k, len(sims))
100
+ top_k_indices = np.argpartition(sims, -k)[-k:]
101
+ sorted_indices = top_k_indices[np.argsort(sims[top_k_indices])][::-1]
102
+ top_k_results = [col_embedding[1][idx] for idx in sorted_indices]
103
+
104
+ return top_k_results
105
+
106
+ def retain_rows_and_columns(augmented_html_table, row_ids, column_ids):
107
+ soup = BeautifulSoup(augmented_html_table, 'html.parser')
108
+
109
+ row_ids = list(set(row_ids))
110
+ column_ids = list(set(column_ids))
111
+
112
+ # Retain specified rows and remove others
113
+ all_rows = soup.find_all('tr')
114
+ for row in all_rows:
115
+ if row.get('id') not in row_ids:
116
+ row.decompose()
117
+
118
+ # Retain specified columns and remove others
119
+ if all_rows:
120
+ all_columns = all_rows[0].find_all(['th'])
121
+ for i, col in enumerate(all_columns):
122
+ if col.get('id') not in column_ids:
123
+ for row in soup.find_all('tr'):
124
+ cells = row.find_all(['td', 'th'])
125
+ if len(cells) > i:
126
+ cells[i].decompose()
127
+
128
+ return str(soup)
129
+
130
+ def get_embedding_attribution(augmented_html_table, decomposed_fact_list, topK=5, threshold = 0.7):
131
+
132
+ row_attribution_ids = []
133
+ col_attribution_ids = []
134
+
135
+ for i in range(len(decomposed_fact_list)):
136
+ answer = decomposed_fact_list[i]
137
+
138
+ rorAttr = row_attribution(answer, augmented_html_table, topK)
139
+ colAttr = col_attribution(answer, augmented_html_table, topK)
140
+
141
+ row_attribution_ids.extend(rorAttr)
142
+ col_attribution_ids.extend(colAttr)
143
+
144
+ attributed_html_table = retain_rows_and_columns(augmented_html_table, row_attribution_ids, col_attribution_ids)
145
+
146
+ return attributed_html_table, row_attribution_ids, col_attribution_ids
tables_folder/MATSA_aitqa.json ADDED
The diff for this file is too large to render. See raw diff
 
tables_folder/MATSA_fetaqa.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:365905f7696b0c00498f7b88cf602769677998c5b32614823cb86127ccfd13f9
3
+ size 18562195
wkhtmltox_0.12.6-1.bionic_amd64.deb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:503a8a97fcf8fd397ed52c1789471e0f2513f5752f3e214d3a5eda30caa0354b
3
+ size 15729530