Matsa-demo / demo.py
puneetm's picture
Upload folder using huggingface_hub
35d31f5 verified
import gradio as gr
import json
from bs4 import BeautifulSoup
from matsa import MATSA, InputInstance
import imgkit
import tempfile
import time
import threading
TABLE_FOLDER = "./tables_folder/MATSA_fetaqa.json"
# Load data from JSON file
def load_data():
with open(TABLE_FOLDER, 'r') as json_file:
return json.load(json_file)
# Global variable to store the loaded data
TABLE_DATA = load_data()
def get_table_names():
return [f"tab_{i+1}" for i in range(len(TABLE_DATA))]
def html_to_image(html_content):
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_img:
imgkit.from_string(html_content, temp_img.name)
return temp_img.name
def highlight_table(html_table, row_ids, col_ids):
soup = BeautifulSoup(html_table, 'html.parser')
row_sel = []
# Highlight rows
for row_id in row_ids:
row = soup.find('tr', id=row_id)
if row:
row_sel.append(row)
for col_id in col_ids:
col_index = int(col_id.split('-')[1]) - 1 # Convert col-1 to index 0, col-2 to index 1, etc.
for row in soup.find_all('tr'):
cells = row.find_all(['td', 'th'])
if row in row_sel:
if col_index < len(cells):
cells[col_index]['style'] = 'background-color: rgba(173, 216, 230, 0.7);'
else:
if col_index < len(cells):
cells[col_index]['style'] = 'background-color: rgba(211, 211, 211, 0.6);'
return str(soup)
def load_table_data(table_name):
if not table_name:
return None, "", ""
index = int(table_name.split('_')[1]) - 1
data = TABLE_DATA[index]
html_content = data['html_table']
question = data.get("question", "") #data['question']
if question is None:
question = ""
answer = data['answer_statement']
image_path = html_to_image(html_content)
return image_path, question, answer
def process_input(table_name, question, answer):
if not table_name:
return "Please select a table from the dropdown."
# Get the data for the selected table
index = int(table_name.split('_')[1]) - 1
data = TABLE_DATA[index]
html_content = data['html_table']
print("html_content: ", html_content)
print("question: ", question)
print("answer: ", answer)
# Initialize MATSA
matsa_agent = MATSA()
# Create input instance
instance = InputInstance(html_table=html_content, question=question, answer=answer)
# Apply MATSA pipeline
# formatted_table = matsa_agent.table_formatting_agent(instance.html_table)
augmented_table = matsa_agent.description_augmentation_agent(instance.html_table)
print("augmented_table: ", augmented_table)
fact_list = matsa_agent.answer_decomposition_agent(instance.answer)
print("fact_list: ", fact_list)
attributed_table, _, _ = matsa_agent.semantic_retreival_agent(augmented_table, fact_list)
print("attributed_table: ", attributed_table)
attribution_fxn = matsa_agent.sufficiency_attribution_agent(fact_list, attributed_table)
print("attribution_fxn: ", attribution_fxn)
# Get row and column attributions
row_attribution_set = attribution_fxn["Row Citations"]
col_attribution_set = attribution_fxn["Column Citations"]
explnation = attribution_fxn.get("Explanation", "")
print("row_attribution_set: ", row_attribution_set)
print("col_attribution_set: ", col_attribution_set)
print("Explanation: ", attribution_fxn.get("Explanation", ""))
# Convert string representations to lists
if isinstance(row_attribution_set, str):
row_ids = eval(row_attribution_set)
else:
row_ids = row_attribution_set
if isinstance(col_attribution_set, str):
col_ids = eval(col_attribution_set)
else:
col_ids = col_attribution_set
# Highlight the table
highlighted_table = highlight_table(instance.html_table, row_ids, col_ids)
result = {
"highlighted_table": highlighted_table,
"facts": attribution_fxn.get("List of Facts", []),
"row_citations": row_attribution_set,
"column_citations": col_attribution_set,
"Explanation": explnation
}
return json.dumps(result)
# Define Gradio interface
with gr.Blocks() as iface:
gr.Markdown("# MATSA: Table Question Answering with Attribution")
gr.Markdown("Select a table from dropdown load table image, question, and answer.")
gr.Markdown("Attributions are provided as per answer. You may change the question/answer as per your need.")
table_dropdown = gr.Dropdown(choices=get_table_names(), label="Select Table")
original_table = gr.Image(type="filepath", label="Original Table")
question_box = gr.Textbox(label="Question")
answer_box = gr.Textbox(label="Answer")
gr.Markdown("Click 'Process' to see the highlighted relevant parts. Click 'Reset' to start over.")
process_button = gr.Button("Process")
reset_button = gr.Button("Reset")
processing_time = gr.Textbox(label="Processing Time", value="0 seconds")
highlighted_table = gr.HTML(label="Highlighted Table")
explanation_box = gr.Textbox(label="Explanation")
def update_table_data(table_name):
image_path, question, answer = load_table_data(table_name)
return image_path, question, answer, gr.update(interactive=True)
def reset_app():
return (
gr.update(value="", interactive=True), # table_dropdown
None, # original_table
"", # question_box
"", # answer_box
"", # highlighted_table
"", # explanation_box
gr.update(interactive=True), # process_button
"0 seconds", # processing_time
)
def process_and_disable(table_name, question, answer):
processing = True
counter = 0
def update_counter():
nonlocal counter
while processing:
counter += 1
yield counter
time.sleep(1)
counter_thread = threading.Thread(target=update_counter)
counter_thread.start()
# Disable the dropdown and process button during processing
yield (
gr.update(interactive=False), # table_dropdown
gr.update(interactive=False), # process_button
gr.update(value="Processing..."), # processing_time
gr.update(), # highlighted_table
gr.update(), # explanation_box
)
# Process the input
result = process_input(table_name, question, answer)
result_dict = json.loads(result)
# Stop the counter
processing = False
counter_thread.join()
# Re-enable the dropdown and process button, update processing time, and return the result
yield (
gr.update(interactive=True), # table_dropdown
gr.update(interactive=True), # process_button
f"Processed in {counter} seconds", # processing_time
gr.update(value=result_dict['highlighted_table']), # highlighted_table
gr.update(value=result_dict.get('Explanation', '')), # explanation_box
{
"Row Citations": result_dict['row_citations'],
"Column Citations": result_dict['column_citations']
}
)
table_dropdown.change(update_table_data,
inputs=[table_dropdown],
outputs=[original_table, question_box, answer_box, process_button])
process_button.click(process_and_disable,
inputs=[table_dropdown, question_box, answer_box],
outputs=[table_dropdown, process_button, processing_time, highlighted_table, explanation_box])
reset_button.click(reset_app,
inputs=[],
outputs=[table_dropdown, original_table, question_box, answer_box, highlighted_table, explanation_box, process_button, processing_time])
# Launch the interface
iface.launch(share=True)