Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@ import torch
|
|
2 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
from transformers import TapexTokenizer, BartForConditionalGeneration
|
4 |
import pandas as pd
|
5 |
-
import
|
6 |
|
7 |
import numpy as np
|
8 |
import time
|
@@ -42,47 +42,41 @@ def predict(input, history=[]):
|
|
42 |
# Check if the user input is a question
|
43 |
is_question = "?" in input
|
44 |
|
|
|
45 |
if is_question:
|
46 |
sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
# append the new user input tokens to the chat history
|
51 |
-
bot_input_ids = torch.cat([torch.LongTensor(history), sql_encoding], dim=-1)
|
52 |
-
|
53 |
-
# generate a response
|
54 |
-
history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
|
55 |
-
|
56 |
-
# convert the tokens to text, and then split the responses into the right format
|
57 |
-
response = sql_tokenizer.decode(history[0]).split("<|endoftext|>")
|
58 |
-
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
|
59 |
-
|
60 |
-
'''
|
61 |
-
bot_input_ids = torch.cat([torch.LongTensor(history), sql_encoding], dim=-1)
|
62 |
-
history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
|
63 |
-
response = sql_tokenizer.decode(history[0]).split("<|endoftext|>")
|
64 |
-
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
|
65 |
-
'''
|
66 |
else:
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
-
# convert the tokens to text, and then split the responses into the right format
|
77 |
-
response = tokenizer.decode(history[0]).split("<|endoftext|>")
|
78 |
-
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
|
79 |
-
|
80 |
return response, history
|
81 |
|
82 |
|
83 |
-
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
|
|
|
|
|
|
|
86 |
fn=predict,
|
87 |
theme="default",
|
88 |
css=".footer {display:none !important}",
|
@@ -92,5 +86,24 @@ interface = gr.Interface(
|
|
92 |
description="Type your message in the box above, and the chatbot will respond.",
|
93 |
)
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
if __name__ == '__main__':
|
96 |
-
|
|
|
2 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
from transformers import TapexTokenizer, BartForConditionalGeneration
|
4 |
import pandas as pd
|
5 |
+
import gradio as gr
|
6 |
|
7 |
import numpy as np
|
8 |
import time
|
|
|
42 |
# Check if the user input is a question
|
43 |
is_question = "?" in input
|
44 |
|
45 |
+
'''
|
46 |
if is_question:
|
47 |
sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
|
48 |
+
sql_outputs = sql_model.generate(**sql_encoding)
|
49 |
+
sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
|
50 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
else:
|
52 |
+
'''
|
53 |
+
|
54 |
+
# tokenize the new input sentence
|
55 |
+
new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
|
56 |
+
|
57 |
+
# append the new user input tokens to the chat history
|
58 |
+
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
|
59 |
+
|
60 |
+
# generate a response
|
61 |
+
history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
|
62 |
+
|
63 |
+
# convert the tokens to text, and then split the responses into the right format
|
64 |
+
response = tokenizer.decode(history[0]).split("<|endoftext|>")
|
65 |
+
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
|
66 |
|
|
|
|
|
|
|
|
|
67 |
return response, history
|
68 |
|
69 |
|
70 |
+
def sqlquery(input):
|
71 |
+
|
72 |
+
sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
|
73 |
+
sql_outputs = sql_model.generate(**sql_encoding)
|
74 |
+
sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
|
75 |
|
76 |
+
return sql_response
|
77 |
+
|
78 |
+
|
79 |
+
chat_interface = gr.Interface(
|
80 |
fn=predict,
|
81 |
theme="default",
|
82 |
css=".footer {display:none !important}",
|
|
|
86 |
description="Type your message in the box above, and the chatbot will respond.",
|
87 |
)
|
88 |
|
89 |
+
sql_interface = gr.Interface(
|
90 |
+
fn=sqlquery,
|
91 |
+
theme="default",
|
92 |
+
inputs=gr.Textbox(prompt="You:"),
|
93 |
+
outputs=gr.Textbox(),
|
94 |
+
live=True,
|
95 |
+
capture_session=True,
|
96 |
+
title="ST SQL Chat",
|
97 |
+
description="Type your message in the box above, and the chatbot will respond.",
|
98 |
+
)
|
99 |
+
|
100 |
+
combine_interface = gr.TabbedInterface(
|
101 |
+
interface_list=[
|
102 |
+
chat_interface,
|
103 |
+
sql_interface
|
104 |
+
],
|
105 |
+
tab_names=['Chatbot' ,'SQL Chat'],
|
106 |
+
)
|
107 |
+
|
108 |
if __name__ == '__main__':
|
109 |
+
combine_interface.launch()
|