teaevo commited on
Commit
0bb8f5d
1 Parent(s): b54b3e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -23
app.py CHANGED
@@ -64,7 +64,7 @@ sql_tokenizer = TapexTokenizer.from_pretrained(sql_model_name)
64
  sql_model = BartForConditionalGeneration.from_pretrained(sql_model_name)
65
 
66
  #sql_response = None
67
- #conversation_history = []
68
 
69
  def chat(input, history=[]):
70
 
@@ -97,35 +97,24 @@ def chat(input, history=[]):
97
  return response, history
98
 
99
 
100
- def sqlquery(input, history=[]):
101
 
102
  #input_text = " ".join(conversation_history) + " " + input
103
  sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
104
- #sql_outputs = sql_model.generate(**sql_encoding)
105
- #sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
106
 
107
- sql_input_ids = sql_encoding["input_ids"]
108
- # append the new user input tokens to the chat history
109
- bot_input_ids = torch.cat([torch.LongTensor(history), sql_input_ids], dim=-1)
110
-
111
- # generate a response
112
- history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
113
-
114
- # convert the tokens to text, and then split the responses into the right format
115
- response = sql_tokenizer.batch_decode(history[0]).split("<|endoftext|>")
116
- response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
117
 
118
- '''
119
  global conversation_history
120
 
121
  # Maintain the conversation history
122
- conversation_history.append("User: " + input)
123
- conversation_history.append("Bot: " + " ".join(sql_response) )
124
 
125
  output = " ".join(conversation_history)
126
  return output
127
- '''
128
- return response, history
129
 
130
 
131
  chat_interface = gr.Interface(
@@ -141,10 +130,8 @@ chat_interface = gr.Interface(
141
  sql_interface = gr.Interface(
142
  fn=sqlquery,
143
  theme="default",
144
- #inputs=gr.Textbox(prompt="You:"),
145
- #outputs=gr.Textbox(),
146
- inputs=["text", "state"],
147
- outputs=["chatbot", "state"],
148
  live=True,
149
  capture_session=True,
150
  title="ST SQL Chat",
 
64
  sql_model = BartForConditionalGeneration.from_pretrained(sql_model_name)
65
 
66
  #sql_response = None
67
+ conversation_history = []
68
 
69
  def chat(input, history=[]):
70
 
 
97
  return response, history
98
 
99
 
100
+ def sqlquery(input):
101
 
102
  #input_text = " ".join(conversation_history) + " " + input
103
  sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
104
+ sql_outputs = sql_model.generate(**sql_encoding)
105
+ sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
106
 
 
 
 
 
 
 
 
 
 
 
107
 
 
108
  global conversation_history
109
 
110
  # Maintain the conversation history
111
+ conversation_history.append("User: " + input + "<|endoftext|>")
112
+ conversation_history.append("Bot: " + " ".join(sql_response) + "<|endoftext|>" )
113
 
114
  output = " ".join(conversation_history)
115
  return output
116
+
117
+ #return sql_response
118
 
119
 
120
  chat_interface = gr.Interface(
 
130
  sql_interface = gr.Interface(
131
  fn=sqlquery,
132
  theme="default",
133
+ inputs=gr.Textbox(prompt="You:"),
134
+ outputs=gr.Textbox(),
 
 
135
  live=True,
136
  capture_session=True,
137
  title="ST SQL Chat",