Update app.py
Browse files
app.py
CHANGED
@@ -64,7 +64,7 @@ sql_tokenizer = TapexTokenizer.from_pretrained(sql_model_name)
|
|
64 |
sql_model = BartForConditionalGeneration.from_pretrained(sql_model_name)
|
65 |
|
66 |
#sql_response = None
|
67 |
-
|
68 |
|
69 |
def chat(input, history=[]):
|
70 |
|
@@ -97,35 +97,24 @@ def chat(input, history=[]):
|
|
97 |
return response, history
|
98 |
|
99 |
|
100 |
-
def sqlquery(input
|
101 |
|
102 |
#input_text = " ".join(conversation_history) + " " + input
|
103 |
sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
|
104 |
-
|
105 |
-
|
106 |
|
107 |
-
sql_input_ids = sql_encoding["input_ids"]
|
108 |
-
# append the new user input tokens to the chat history
|
109 |
-
bot_input_ids = torch.cat([torch.LongTensor(history), sql_input_ids], dim=-1)
|
110 |
-
|
111 |
-
# generate a response
|
112 |
-
history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
|
113 |
-
|
114 |
-
# convert the tokens to text, and then split the responses into the right format
|
115 |
-
response = sql_tokenizer.batch_decode(history[0]).split("<|endoftext|>")
|
116 |
-
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
|
117 |
|
118 |
-
'''
|
119 |
global conversation_history
|
120 |
|
121 |
# Maintain the conversation history
|
122 |
-
conversation_history.append("User: " + input)
|
123 |
-
conversation_history.append("Bot: " + " ".join(sql_response) )
|
124 |
|
125 |
output = " ".join(conversation_history)
|
126 |
return output
|
127 |
-
|
128 |
-
return
|
129 |
|
130 |
|
131 |
chat_interface = gr.Interface(
|
@@ -141,10 +130,8 @@ chat_interface = gr.Interface(
|
|
141 |
sql_interface = gr.Interface(
|
142 |
fn=sqlquery,
|
143 |
theme="default",
|
144 |
-
|
145 |
-
|
146 |
-
inputs=["text", "state"],
|
147 |
-
outputs=["chatbot", "state"],
|
148 |
live=True,
|
149 |
capture_session=True,
|
150 |
title="ST SQL Chat",
|
|
|
64 |
sql_model = BartForConditionalGeneration.from_pretrained(sql_model_name)
|
65 |
|
66 |
#sql_response = None
|
67 |
+
conversation_history = []
|
68 |
|
69 |
def chat(input, history=[]):
|
70 |
|
|
|
97 |
return response, history
|
98 |
|
99 |
|
100 |
+
def sqlquery(input):
|
101 |
|
102 |
#input_text = " ".join(conversation_history) + " " + input
|
103 |
sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
|
104 |
+
sql_outputs = sql_model.generate(**sql_encoding)
|
105 |
+
sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
|
|
108 |
global conversation_history
|
109 |
|
110 |
# Maintain the conversation history
|
111 |
+
conversation_history.append("User: " + input + "<|endoftext|>")
|
112 |
+
conversation_history.append("Bot: " + " ".join(sql_response) + "<|endoftext|>" )
|
113 |
|
114 |
output = " ".join(conversation_history)
|
115 |
return output
|
116 |
+
|
117 |
+
#return sql_response
|
118 |
|
119 |
|
120 |
chat_interface = gr.Interface(
|
|
|
130 |
sql_interface = gr.Interface(
|
131 |
fn=sqlquery,
|
132 |
theme="default",
|
133 |
+
inputs=gr.Textbox(prompt="You:"),
|
134 |
+
outputs=gr.Textbox(),
|
|
|
|
|
135 |
live=True,
|
136 |
capture_session=True,
|
137 |
title="ST SQL Chat",
|