teaevo commited on
Commit
37e8e73
1 Parent(s): 5758bb4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -64,6 +64,7 @@ sql_tokenizer = TapexTokenizer.from_pretrained(sql_model_name)
64
  sql_model = BartForConditionalGeneration.from_pretrained(sql_model_name)
65
 
66
  #sql_response = None
 
67
 
68
  def predict(input, history=[]):
69
 
@@ -96,12 +97,18 @@ def predict(input, history=[]):
96
  return response, history
97
 
98
 
99
- def sqlquery(input):
100
-
101
- sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
 
102
  sql_outputs = sql_model.generate(**sql_encoding)
103
  sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
104
 
 
 
 
 
 
105
  return sql_response
106
 
107
 
@@ -118,7 +125,8 @@ chat_interface = gr.Interface(
118
  sql_interface = gr.Interface(
119
  fn=sqlquery,
120
  theme="default",
121
- inputs=gr.Textbox(prompt="You:"),
 
122
  outputs=gr.Textbox(),
123
  live=True,
124
  capture_session=True,
 
64
  sql_model = BartForConditionalGeneration.from_pretrained(sql_model_name)
65
 
66
  #sql_response = None
67
+ conversation_history = []
68
 
69
  def predict(input, history=[]):
70
 
 
97
  return response, history
98
 
99
 
100
+ def sqlquery(input, conversation_history):
101
+
102
+ input_text = " ".join(conversation_history) + " " + input
103
+ sql_encoding = sql_tokenizer(table=table, query=input_text + sql_tokenizer.eos_token, return_tensors="pt")
104
  sql_outputs = sql_model.generate(**sql_encoding)
105
  sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
106
 
107
+ global conversation_history
108
+
109
+ # Maintain the conversation history
110
+ conversation_history.append(user_input)
111
+
112
  return sql_response
113
 
114
 
 
125
  sql_interface = gr.Interface(
126
  fn=sqlquery,
127
  theme="default",
128
+ #inputs=gr.Textbox(prompt="You:"),
129
+ inputs=["text", "state"],
130
  outputs=gr.Textbox(),
131
  live=True,
132
  capture_session=True,