Spaces:
Build error
Build error
import gradio as gr | |
import json | |
from bs4 import BeautifulSoup | |
from matsa import MATSA, InputInstance | |
import imgkit | |
import tempfile | |
import time | |
import threading | |
TABLE_FOLDER = "./tables_folder/MATSA_fetaqa.json" | |
# Load data from JSON file | |
def load_data(): | |
with open(TABLE_FOLDER, 'r') as json_file: | |
return json.load(json_file) | |
# Global variable to store the loaded data | |
TABLE_DATA = load_data() | |
def get_table_names(): | |
return [f"tab_{i+1}" for i in range(len(TABLE_DATA))] | |
def html_to_image(html_content): | |
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_img: | |
imgkit.from_string(html_content, temp_img.name) | |
return temp_img.name | |
def highlight_table(html_table, row_ids, col_ids): | |
soup = BeautifulSoup(html_table, 'html.parser') | |
row_sel = [] | |
# Highlight rows | |
for row_id in row_ids: | |
row = soup.find('tr', id=row_id) | |
if row: | |
row_sel.append(row) | |
for col_id in col_ids: | |
col_index = int(col_id.split('-')[1]) - 1 # Convert col-1 to index 0, col-2 to index 1, etc. | |
for row in soup.find_all('tr'): | |
cells = row.find_all(['td', 'th']) | |
if row in row_sel: | |
if col_index < len(cells): | |
cells[col_index]['style'] = 'background-color: rgba(173, 216, 230, 0.7);' | |
else: | |
if col_index < len(cells): | |
cells[col_index]['style'] = 'background-color: rgba(211, 211, 211, 0.6);' | |
return str(soup) | |
def load_table_data(table_name): | |
if not table_name: | |
return None, "", "" | |
index = int(table_name.split('_')[1]) - 1 | |
data = TABLE_DATA[index] | |
html_content = data['html_table'] | |
question = data.get("question", "") #data['question'] | |
if question is None: | |
question = "" | |
answer = data['answer_statement'] | |
image_path = html_to_image(html_content) | |
return image_path, question, answer | |
def process_input(table_name, question, answer): | |
if not table_name: | |
return "Please select a table from the dropdown." | |
# Get the data for the selected table | |
index = int(table_name.split('_')[1]) - 1 | |
data = TABLE_DATA[index] | |
html_content = data['html_table'] | |
print("html_content: ", html_content) | |
print("question: ", question) | |
print("answer: ", answer) | |
# Initialize MATSA | |
matsa_agent = MATSA() | |
# Create input instance | |
instance = InputInstance(html_table=html_content, question=question, answer=answer) | |
# Apply MATSA pipeline | |
# formatted_table = matsa_agent.table_formatting_agent(instance.html_table) | |
augmented_table = matsa_agent.description_augmentation_agent(instance.html_table) | |
print("augmented_table: ", augmented_table) | |
fact_list = matsa_agent.answer_decomposition_agent(instance.answer) | |
print("fact_list: ", fact_list) | |
attributed_table, _, _ = matsa_agent.semantic_retreival_agent(augmented_table, fact_list) | |
print("attributed_table: ", attributed_table) | |
attribution_fxn = matsa_agent.sufficiency_attribution_agent(fact_list, attributed_table) | |
print("attribution_fxn: ", attribution_fxn) | |
# Get row and column attributions | |
row_attribution_set = attribution_fxn["Row Citations"] | |
col_attribution_set = attribution_fxn["Column Citations"] | |
explnation = attribution_fxn.get("Explanation", "") | |
print("row_attribution_set: ", row_attribution_set) | |
print("col_attribution_set: ", col_attribution_set) | |
print("Explanation: ", attribution_fxn.get("Explanation", "")) | |
# Convert string representations to lists | |
if isinstance(row_attribution_set, str): | |
row_ids = eval(row_attribution_set) | |
else: | |
row_ids = row_attribution_set | |
if isinstance(col_attribution_set, str): | |
col_ids = eval(col_attribution_set) | |
else: | |
col_ids = col_attribution_set | |
# Highlight the table | |
highlighted_table = highlight_table(instance.html_table, row_ids, col_ids) | |
result = { | |
"highlighted_table": highlighted_table, | |
"facts": attribution_fxn.get("List of Facts", []), | |
"row_citations": row_attribution_set, | |
"column_citations": col_attribution_set, | |
"Explanation": explnation | |
} | |
return json.dumps(result) | |
# Define Gradio interface | |
with gr.Blocks() as iface: | |
gr.Markdown("# MATSA: Table Question Answering with Attribution") | |
gr.Markdown("Select a table from dropdown load table image, question, and answer.") | |
gr.Markdown("Attributions are provided as per answer. You may change the question/answer as per your need.") | |
table_dropdown = gr.Dropdown(choices=get_table_names(), label="Select Table") | |
original_table = gr.Image(type="filepath", label="Original Table") | |
question_box = gr.Textbox(label="Question") | |
answer_box = gr.Textbox(label="Answer") | |
gr.Markdown("Click 'Process' to see the highlighted relevant parts. Click 'Reset' to start over.") | |
process_button = gr.Button("Process") | |
reset_button = gr.Button("Reset") | |
processing_time = gr.Textbox(label="Processing Time", value="0 seconds") | |
highlighted_table = gr.HTML(label="Highlighted Table") | |
explanation_box = gr.Textbox(label="Explanation") | |
def update_table_data(table_name): | |
image_path, question, answer = load_table_data(table_name) | |
return image_path, question, answer, gr.update(interactive=True) | |
def reset_app(): | |
return ( | |
gr.update(value="", interactive=True), # table_dropdown | |
None, # original_table | |
"", # question_box | |
"", # answer_box | |
"", # highlighted_table | |
"", # explanation_box | |
gr.update(interactive=True), # process_button | |
"0 seconds", # processing_time | |
) | |
def process_and_disable(table_name, question, answer): | |
processing = True | |
counter = 0 | |
def update_counter(): | |
nonlocal counter | |
while processing: | |
counter += 1 | |
yield counter | |
time.sleep(1) | |
counter_thread = threading.Thread(target=update_counter) | |
counter_thread.start() | |
# Disable the dropdown and process button during processing | |
yield ( | |
gr.update(interactive=False), # table_dropdown | |
gr.update(interactive=False), # process_button | |
gr.update(value="Processing..."), # processing_time | |
gr.update(), # highlighted_table | |
gr.update(), # explanation_box | |
) | |
# Process the input | |
result = process_input(table_name, question, answer) | |
result_dict = json.loads(result) | |
# Stop the counter | |
processing = False | |
counter_thread.join() | |
# Re-enable the dropdown and process button, update processing time, and return the result | |
yield ( | |
gr.update(interactive=True), # table_dropdown | |
gr.update(interactive=True), # process_button | |
f"Processed in {counter} seconds", # processing_time | |
gr.update(value=result_dict['highlighted_table']), # highlighted_table | |
gr.update(value=result_dict.get('Explanation', '')), # explanation_box | |
{ | |
"Row Citations": result_dict['row_citations'], | |
"Column Citations": result_dict['column_citations'] | |
} | |
) | |
table_dropdown.change(update_table_data, | |
inputs=[table_dropdown], | |
outputs=[original_table, question_box, answer_box, process_button]) | |
process_button.click(process_and_disable, | |
inputs=[table_dropdown, question_box, answer_box], | |
outputs=[table_dropdown, process_button, processing_time, highlighted_table, explanation_box]) | |
reset_button.click(reset_app, | |
inputs=[], | |
outputs=[table_dropdown, original_table, question_box, answer_box, highlighted_table, explanation_box, process_button, processing_time]) | |
# Launch the interface | |
iface.launch(share=True) | |