File size: 11,126 Bytes
f6f97d8
 
7faa846
f6f97d8
 
 
 
 
 
 
2afcc05
f6f97d8
 
 
 
 
 
03f8ee8
 
 
 
 
 
 
 
 
 
f6f97d8
9654ebc
f6f97d8
7494e0f
9654ebc
f6f97d8
9654ebc
f6f97d8
9654ebc
 
 
f6f97d8
 
 
 
 
 
 
 
 
 
 
 
 
 
2ef23da
 
 
 
 
f6f97d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9654ebc
f6f97d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9611943
 
 
 
 
 
 
 
 
 
 
 
 
 
f6f97d8
e75ae3a
ab8d087
e75ae3a
f6f97d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9654ebc
 
f6f97d8
 
 
 
 
 
 
35c4230
f6f97d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35c4230
b44c042
f6f97d8
b44c042
f6f97d8
 
b44c042
f6f97d8
 
 
 
ab8d087
cf22e71
81f1d37
0eca8c1
cf22e71
6e02b9f
35c4230
e727295
f6f97d8
 
89029fd
 
 
 
f6f97d8
 
3474cf6
f6f97d8
 
 
 
9654ebc
b44c042
2afcc05
7913a66
9654ebc
f6f97d8
b44c042
f6f97d8
 
 
 
 
 
7faa846
f6f97d8
3474cf6
7faa846
b44c042
7faa846
f6f97d8
 
9611943
9654ebc
89029fd
63c9137
9654ebc
 
9611943
1ed6866
9611943
 
e727295
9611943
 
 
 
 
eed742c
9654ebc
 
9611943
 
b44c042
9611943
756edf5
f6f97d8
 
7faa846
f6f97d8
9654ebc
 
9611943
 
 
9654ebc
 
 
 
eed742c
4568f13
9654ebc
 
 
 
b44c042
f6f97d8
 
 
9654ebc
 
4568f13
9654ebc
f6f97d8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import json
import os
import uuid
import pandas as pd
import streamlit as st
import argparse
import traceback
from typing import Dict
import requests
from utils.utils import load_data_split
from utils.normalizer import post_process_sql
from nsql.database import NeuralDB
from nsql.nsql_exec import NSQLExecutor
from nsql.nsql_exec_python import NPythonExecutor
from generation.generator import Generator
import time

st.set_page_config(
    page_title="Binder Demo",
    page_icon="πŸ”—",
    layout="wide",
    initial_sidebar_state="expanded",
    menu_items={
        'About': "Check out our [website](https://lm-code-binder.github.io/) for more details!"
    }
)

ROOT_DIR = os.path.join(os.path.dirname(__file__), "./")
# todo: Add more binder questions, need careful cherry-picks
EXAMPLE_TABLES = {
    "Estonia men's national volleyball team": (558, "what is the number of players from france?"),
    # 'how old is kert toobal'
    "Highest mountain peaks of California": (5, "which is the lowest mountain?"),
    # 'which mountain is in the most north place?'
    "2010–11 UAB Blazers men's basketball team": (1, "how many players come from alabama?"),
    # 'how many players are born after 1996?'
    "Nissan SR20DET": (438, "which car has power more than 170 kw?"),
    # ''
}


@st.cache
def load_data():
    return load_data_split("missing_squall", "validation")


@st.cache
def get_key():
    # print the public IP of the demo machine
    ip = requests.get('https://checkip.amazonaws.com').text.strip()
    print(ip)

    URL = "http://54.242.37.195:8080/api/predict"
    # The springboard machine we built to protect the key, 20217 is the birthday of Tianbao's girlfriend
    # we will only let the demo machine have the access to the keys

    one_key = requests.post(url=URL, json={"data": "Hi, binder server. Give me a key!"}).json()['data'][0]
    return one_key


def read_markdown(path):
    with open(path, "r") as f:
        output = f.read()
    st.markdown(output, unsafe_allow_html=True)


def generate_binder_program(_args, _generator, _data_item):
    n_shots = _args.n_shots
    few_shot_prompt = _generator.build_few_shot_prompt_from_file(
        file_path=_args.prompt_file,
        n_shots=n_shots
    )
    generate_prompt = _generator.build_generate_prompt(
        data_item=_data_item,
        generate_type=(_args.generate_type,)
    )
    prompt = few_shot_prompt + "\n\n" + generate_prompt

    # Ensure the input length fit Codex max input tokens by shrinking the n_shots
    max_prompt_tokens = _args.max_api_total_tokens - _args.max_generation_tokens
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=os.path.join(ROOT_DIR, "utils", "gpt2"))
    while len(tokenizer.tokenize(prompt)) >= max_prompt_tokens:
        n_shots -= 1
        assert n_shots >= 0
        few_shot_prompt = _generator.build_few_shot_prompt_from_file(
            file_path=_args.prompt_file,
            n_shots=n_shots
        )
        prompt = few_shot_prompt + "\n\n" + generate_prompt

    response_dict = _generator.generate_one_pass(
        prompts=[("0", prompt)],  # the "0" is the place taker, take effect only when there are multi threads
        verbose=_args.verbose
    )
    print(response_dict)
    return response_dict["0"][0][0]


def remove_row_id(table):
    new_table = {"header": [], "rows": []}
    header: list = table['header']
    rows = table['rows']

    if not 'row_id' in header:
        return table

    new_table['header'] = header[1:]
    new_table['rows'] = [row[1:] for row in rows]

    return new_table


# Set up
import nltk

nltk.download('punkt')
parser = argparse.ArgumentParser()

parser.add_argument('--prompt_file', type=str, default='templates/prompts/prompt_wikitq_v3.txt')
# Binder program generation options
parser.add_argument('--prompt_style', type=str, default='create_table_select_3_full_table',
                    choices=['create_table_select_3_full_table',
                             'create_table_select_full_table',
                             'create_table_select_3',
                             'create_table',
                             'create_table_select_3_full_table_w_all_passage_image',
                             'create_table_select_3_full_table_w_gold_passage_image',
                             'no_table'])
parser.add_argument('--generate_type', type=str, default='nsql',
                    choices=['nsql', 'sql', 'answer', 'npython', 'python'])
parser.add_argument('--n_shots', type=int, default=14)
parser.add_argument('--seed', type=int, default=42)

# Codex options
# todo: Allow adjusting Codex parameters
parser.add_argument('--engine', type=str, default="code-davinci-002")
parser.add_argument('--max_generation_tokens', type=int, default=512)
parser.add_argument('--max_api_total_tokens', type=int, default=8001)
parser.add_argument('--temperature', type=float, default=0.)
parser.add_argument('--sampling_n', type=int, default=1)
parser.add_argument('--top_p', type=float, default=1.0)
parser.add_argument('--stop_tokens', type=str, default='\n\n',
                    help='Split stop tokens by ||')
parser.add_argument('--qa_retrieve_pool_file', type=str, default='templates/qa_retrieve_pool.json')

# debug options
parser.add_argument('-v', '--verbose', action='store_false')
args = parser.parse_args()
keys = [get_key()]

# The title
st.markdown("# Binder Playground")

# Demo description
read_markdown('resources/demo_description.md')

# Upload tables/Switch tables

st.markdown('### Try Binder!')
col1, _ = st.columns(2)
with col1:
    selected_table_title = st.selectbox(
        "Select an example table (We use WikiTQ examples for this demo. But task inputs can include free-form texts and images as well)",
        (
            "Estonia men's national volleyball team",
            "Highest mountain peaks of California",
            "2010–11 UAB Blazers men's basketball team",
            "Nissan SR20DET",
        )
    )

# Here we just use ourselves'
data_items = load_data()
data_item = data_items[EXAMPLE_TABLES[selected_table_title][0]]
table = data_item['table']
header, rows, title = table['header'], table['rows'], table['page_title']
db = NeuralDB(
    [{"title": title, "table": table}])  # todo: try to cache this db instead of re-creating it again and again.
df = db.get_table_df()
st.markdown("Title: {}".format(title))
st.dataframe(df)

# Let user input the question
with col1:
    selected_language = st.selectbox(
        "Select a target Binder program",
        ("Binder-SQL", "Binder-Python"),
    )
if selected_language == 'Binder-SQL':
    args.prompt_file = 'templates/prompts/prompt_wikitq_v3.txt'
    args.generate_type = 'nsql'
elif selected_language == 'Binder-Python':
    args.prompt_file = 'templates/prompts/prompt_wikitq_python_simplified_v4.txt'
    args.generate_type = 'npython'
else:
    raise ValueError(f'{selected_language} language is not supported.')

question = st.text_input(
    "Ask a question about the table:",
    value=EXAMPLE_TABLES[selected_table_title][1],
)

button = st.button("Run Binder!")
if not button:
    st.stop()

# Print the question we just input.
st.subheader("Question")
st.markdown("{}".format(question))

# Generate Binder Program
generator = Generator(args, keys=keys)
with st.spinner("Generating Binder program to solve the question...will be finished in 10s, please refresh the page if not"):
    binder_program = generate_binder_program(args, generator,
                                             {"question": question, "table": db.get_table_df(), "title": title})

# Do execution
st.subheader("Binder program")
if selected_language == 'Binder-SQL':
    # Post process
    binder_program = post_process_sql(binder_program, df, selected_table_title, True)
    st.markdown('```sql\n' + binder_program + '\n```')
    executor = NSQLExecutor(args, keys=keys)
elif selected_language == 'Binder-Python':
    st.code(binder_program, language='python')
    executor = NPythonExecutor(args, keys=keys)
    db = db.get_table_df()
else:
    raise ValueError(f'{selected_language} language is not supported.')
try:
    stamp = '{}'.format(uuid.uuid4())
    os.makedirs('tmp_for_vis/', exist_ok=True)
    with st.spinner("Executing... will be finished in 30s, please refresh the page if not"):
        exec_answer = executor.nsql_exec(stamp, binder_program, db)
    if selected_language == 'Binder-SQL':
        with open("tmp_for_vis/{}_tmp_for_vis_steps.txt".format(stamp), "r") as f:
            steps = json.load(f)
        for i, step in enumerate(steps):
            col1, _, _ = st.columns([7, 1, 2])
            with col1:
                st.markdown(f'**Step #{i + 1}**')
            col1, col1_25, col1_5, col2, col3 = st.columns([4, 1, 2, 1, 2])
            with col1:
                st.markdown('```sql\n' + step + '\n```')
            with col1_25:
                st.markdown("executes\non")
            with col1_5:
                if i == len(steps) - 1:
                    st.markdown("Full table")
                else:
                    with open("tmp_for_vis/{}_result_step_{}_input.txt".format(stamp, i), "r") as f:
                        sub_tables_input = json.load(f)
                    for sub_table in sub_tables_input:
                        sub_table_to_print = remove_row_id(sub_table)
                        st.table(pd.DataFrame(sub_table_to_print['rows'], columns=sub_table_to_print['header']))
            with col2:
                st.markdown('$\\rightarrow$')
                if i == len(steps) - 1:
                    # The final step
                    st.markdown("{} Interpreter".format(selected_language.replace("Binder-", "")))
                else:
                    st.markdown("GPT3 Codex")
            with st.spinner('...'):
                time.sleep(1)
            with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, i), "r") as f:
                result_in_this_step = json.load(f)
            with col3:
                if isinstance(result_in_this_step, Dict):

                    rows = remove_row_id(result_in_this_step)["rows"]
                    header = remove_row_id(result_in_this_step)["header"]
                    if isinstance(header, list):
                        for idx in range(len(header)):
                            if header[idx].startswith('col_'):
                                header[idx] = step
                    st.table(pd.DataFrame(rows, columns=header))
                    # hard code here, use use_container_width after the huggingface update their streamlit version
                else:
                    st.markdown(result_in_this_step)
            with st.spinner('...'):
                time.sleep(1)
    elif selected_language == 'Binder-Python':
        pass
    if isinstance(exec_answer, list) and len(exec_answer) == 1:
        exec_answer = exec_answer[0]
    # st.subheader(f'Execution answer')
    st.text('')
    st.markdown(f"**Execution answer:** {exec_answer}")
    # todo: Remove tmp files
except Exception as e:
    traceback.print_exc()