teaevo commited on
Commit
9a55232
·
1 Parent(s): 6651d18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -6
app.py CHANGED
@@ -97,13 +97,23 @@ def chat(input, history=[]):
97
  return response, history
98
 
99
 
100
- def sqlquery(input):
101
 
102
  #input_text = " ".join(conversation_history) + " " + input
103
  sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
104
- sql_outputs = sql_model.generate(**sql_encoding)
105
- sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
106
 
 
 
 
 
 
 
 
 
 
 
107
  '''
108
  global conversation_history
109
 
@@ -114,7 +124,7 @@ def sqlquery(input):
114
  output = " ".join(conversation_history)
115
  return output
116
  '''
117
- return sql_response
118
 
119
 
120
  chat_interface = gr.Interface(
@@ -131,8 +141,9 @@ sql_interface = gr.Interface(
131
  fn=sqlquery,
132
  theme="default",
133
  #inputs=gr.Textbox(prompt="You:"),
134
- inputs=gr.Textbox(prompt="You:"),
135
- outputs=gr.Textbox(),
 
136
  live=True,
137
  capture_session=True,
138
  title="ST SQL Chat",
 
97
  return response, history
98
 
99
 
100
+ def sqlquery(input, history=[]):
101
 
102
  #input_text = " ".join(conversation_history) + " " + input
103
  sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
104
+ #sql_outputs = sql_model.generate(**sql_encoding)
105
+ #sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
106
 
107
+ # append the new user input tokens to the chat history
108
+ bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
109
+
110
+ # generate a response
111
+ history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
112
+
113
+ # convert the tokens to text, and then split the responses into the right format
114
+ response = sql_tokenizer.decode(history[0]).split("<|endoftext|>")
115
+ response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
116
+
117
  '''
118
  global conversation_history
119
 
 
124
  output = " ".join(conversation_history)
125
  return output
126
  '''
127
+ return response, history
128
 
129
 
130
  chat_interface = gr.Interface(
 
141
  fn=sqlquery,
142
  theme="default",
143
  #inputs=gr.Textbox(prompt="You:"),
144
+ #outputs=gr.Textbox(),
145
+ inputs=["text", "state"],
146
+ outputs=["chatbot", "state"],
147
  live=True,
148
  capture_session=True,
149
  title="ST SQL Chat",