File size: 8,226 Bytes
35d31f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import gradio as gr
import json
from bs4 import BeautifulSoup
from matsa import MATSA, InputInstance
import imgkit
import tempfile
import time
import threading

TABLE_FOLDER = "./tables_folder/MATSA_fetaqa.json"
# Load data from JSON file
def load_data():
    with open(TABLE_FOLDER, 'r') as json_file:
        return json.load(json_file)

# Global variable to store the loaded data
TABLE_DATA = load_data()

def get_table_names():
    return [f"tab_{i+1}" for i in range(len(TABLE_DATA))]

def html_to_image(html_content):
    with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_img:
        imgkit.from_string(html_content, temp_img.name)
        return temp_img.name

def highlight_table(html_table, row_ids, col_ids):
    soup = BeautifulSoup(html_table, 'html.parser')
    row_sel = []

    # Highlight rows
    for row_id in row_ids:
        row = soup.find('tr', id=row_id)
        if row:
            row_sel.append(row)

    for col_id in col_ids:
        col_index = int(col_id.split('-')[1]) - 1  # Convert col-1 to index 0, col-2 to index 1, etc.
        for row in soup.find_all('tr'):
            cells = row.find_all(['td', 'th'])
            if row in row_sel:
                if col_index < len(cells):
                    cells[col_index]['style'] = 'background-color: rgba(173, 216, 230, 0.7);'
            else:
                if col_index < len(cells):
                    cells[col_index]['style'] = 'background-color: rgba(211, 211, 211, 0.6);'  

    return str(soup)

def load_table_data(table_name):
    if not table_name:
        return None, "", ""
    index = int(table_name.split('_')[1]) - 1
    data = TABLE_DATA[index]
    
    html_content = data['html_table']
    question = data.get("question", "") #data['question'] 
    if question is None:
        question = ""
    answer = data['answer_statement']
    
    image_path = html_to_image(html_content)
    
    return image_path, question, answer

def process_input(table_name, question, answer):
    if not table_name:
        return "Please select a table from the dropdown."
    
    # Get the data for the selected table
    index = int(table_name.split('_')[1]) - 1
    data = TABLE_DATA[index]
    
    html_content = data['html_table']

    print("html_content: ", html_content)
    print("question: ", question)
    print("answer: ", answer)
    
    # Initialize MATSA
    matsa_agent = MATSA()
    
    # Create input instance
    instance = InputInstance(html_table=html_content, question=question, answer=answer)
    
    # Apply MATSA pipeline
    # formatted_table = matsa_agent.table_formatting_agent(instance.html_table)
    augmented_table = matsa_agent.description_augmentation_agent(instance.html_table)
    print("augmented_table: ", augmented_table)
    fact_list = matsa_agent.answer_decomposition_agent(instance.answer)
    print("fact_list: ", fact_list)
    attributed_table, _, _ = matsa_agent.semantic_retreival_agent(augmented_table, fact_list)
    print("attributed_table: ", attributed_table)
    attribution_fxn = matsa_agent.sufficiency_attribution_agent(fact_list, attributed_table)
    print("attribution_fxn: ", attribution_fxn)
    
    # Get row and column attributions
    row_attribution_set = attribution_fxn["Row Citations"]
    col_attribution_set = attribution_fxn["Column Citations"]
    explnation = attribution_fxn.get("Explanation", "")
    print("row_attribution_set: ", row_attribution_set)
    print("col_attribution_set: ", col_attribution_set)
    print("Explanation: ", attribution_fxn.get("Explanation", ""))
    
    # Convert string representations to lists
    if isinstance(row_attribution_set, str):
        row_ids = eval(row_attribution_set)
    else:
        row_ids = row_attribution_set
    
    if isinstance(col_attribution_set, str):
        col_ids = eval(col_attribution_set)
    else:
        col_ids = col_attribution_set
    
    # Highlight the table
    highlighted_table = highlight_table(instance.html_table, row_ids, col_ids)
    
    result = {
        "highlighted_table": highlighted_table,
        "facts": attribution_fxn.get("List of Facts", []),
        "row_citations": row_attribution_set,
        "column_citations": col_attribution_set,
        "Explanation": explnation
    }
    
    return json.dumps(result)

# Define Gradio interface
with gr.Blocks() as iface:
    gr.Markdown("# MATSA: Table Question Answering with Attribution")
    gr.Markdown("Select a table from dropdown load table image, question, and answer.")
    gr.Markdown("Attributions are provided as per answer. You may change the question/answer as per your need.")
    
    table_dropdown = gr.Dropdown(choices=get_table_names(), label="Select Table")
    original_table = gr.Image(type="filepath", label="Original Table")
    question_box = gr.Textbox(label="Question")
    answer_box = gr.Textbox(label="Answer")
    
    gr.Markdown("Click 'Process' to see the highlighted relevant parts. Click 'Reset' to start over.")
    
    process_button = gr.Button("Process")
    reset_button = gr.Button("Reset")
    processing_time = gr.Textbox(label="Processing Time", value="0 seconds")
    highlighted_table = gr.HTML(label="Highlighted Table")
    explanation_box = gr.Textbox(label="Explanation")

    def update_table_data(table_name):
        image_path, question, answer = load_table_data(table_name)
        return image_path, question, answer, gr.update(interactive=True)
    
    def reset_app():
        return (
            gr.update(value="", interactive=True),  # table_dropdown
            None,  # original_table
            "",  # question_box
            "",  # answer_box
            "",  # highlighted_table
            "",  # explanation_box
            gr.update(interactive=True),  # process_button
            "0 seconds",  # processing_time
        )
    
    def process_and_disable(table_name, question, answer):
        processing = True
        counter = 0
        
        def update_counter():
            nonlocal counter
            while processing:
                counter += 1
                yield counter
                time.sleep(1)
        
        counter_thread = threading.Thread(target=update_counter)
        counter_thread.start()
        
        # Disable the dropdown and process button during processing
        yield (
            gr.update(interactive=False),  # table_dropdown
            gr.update(interactive=False),  # process_button
            gr.update(value="Processing..."),  # processing_time
            gr.update(),  # highlighted_table
            gr.update(),  # explanation_box
        )
        
        # Process the input
        result = process_input(table_name, question, answer)
        result_dict = json.loads(result)
        
        # Stop the counter
        processing = False
        counter_thread.join()
        
        # Re-enable the dropdown and process button, update processing time, and return the result
        yield (
            gr.update(interactive=True),  # table_dropdown
            gr.update(interactive=True),  # process_button
            f"Processed in {counter} seconds",  # processing_time
            gr.update(value=result_dict['highlighted_table']),  # highlighted_table
            gr.update(value=result_dict.get('Explanation', '')),  # explanation_box
            {
                "Row Citations": result_dict['row_citations'],
                "Column Citations": result_dict['column_citations']
            }
        )
    
    table_dropdown.change(update_table_data, 
                          inputs=[table_dropdown], 
                          outputs=[original_table, question_box, answer_box, process_button])
    
    process_button.click(process_and_disable, 
                         inputs=[table_dropdown, question_box, answer_box], 
                         outputs=[table_dropdown, process_button, processing_time, highlighted_table, explanation_box])
    
    reset_button.click(reset_app, 
                       inputs=[], 
                       outputs=[table_dropdown, original_table, question_box, answer_box, highlighted_table, explanation_box, process_button, processing_time])

# Launch the interface
iface.launch(share=True)