Update app.py
Browse files
app.py
CHANGED
@@ -97,13 +97,23 @@ def chat(input, history=[]):
|
|
97 |
return response, history
|
98 |
|
99 |
|
100 |
-
def sqlquery(input):
|
101 |
|
102 |
#input_text = " ".join(conversation_history) + " " + input
|
103 |
sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
|
104 |
-
sql_outputs = sql_model.generate(**sql_encoding)
|
105 |
-
sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
'''
|
108 |
global conversation_history
|
109 |
|
@@ -114,7 +124,7 @@ def sqlquery(input):
|
|
114 |
output = " ".join(conversation_history)
|
115 |
return output
|
116 |
'''
|
117 |
-
return
|
118 |
|
119 |
|
120 |
chat_interface = gr.Interface(
|
@@ -131,8 +141,9 @@ sql_interface = gr.Interface(
|
|
131 |
fn=sqlquery,
|
132 |
theme="default",
|
133 |
#inputs=gr.Textbox(prompt="You:"),
|
134 |
-
|
135 |
-
|
|
|
136 |
live=True,
|
137 |
capture_session=True,
|
138 |
title="ST SQL Chat",
|
|
|
97 |
return response, history
|
98 |
|
99 |
|
100 |
+
def sqlquery(input, history=[]):
|
101 |
|
102 |
#input_text = " ".join(conversation_history) + " " + input
|
103 |
sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
|
104 |
+
#sql_outputs = sql_model.generate(**sql_encoding)
|
105 |
+
#sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
|
106 |
|
107 |
+
# append the new user input tokens to the chat history
|
108 |
+
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
|
109 |
+
|
110 |
+
# generate a response
|
111 |
+
history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
|
112 |
+
|
113 |
+
# convert the tokens to text, and then split the responses into the right format
|
114 |
+
response = sql_tokenizer.decode(history[0]).split("<|endoftext|>")
|
115 |
+
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
|
116 |
+
|
117 |
'''
|
118 |
global conversation_history
|
119 |
|
|
|
124 |
output = " ".join(conversation_history)
|
125 |
return output
|
126 |
'''
|
127 |
+
return response, history
|
128 |
|
129 |
|
130 |
chat_interface = gr.Interface(
|
|
|
141 |
fn=sqlquery,
|
142 |
theme="default",
|
143 |
#inputs=gr.Textbox(prompt="You:"),
|
144 |
+
#outputs=gr.Textbox(),
|
145 |
+
inputs=["text", "state"],
|
146 |
+
outputs=["chatbot", "state"],
|
147 |
live=True,
|
148 |
capture_session=True,
|
149 |
title="ST SQL Chat",
|