teaevo commited on
Commit
5782838
1 Parent(s): b948232

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -35
app.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from transformers import TapexTokenizer, BartForConditionalGeneration
4
  import pandas as pd
5
- import torch
6
 
7
  import numpy as np
8
  import time
@@ -42,47 +42,41 @@ def predict(input, history=[]):
42
  # Check if the user input is a question
43
  is_question = "?" in input
44
 
 
45
  if is_question:
46
  sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
47
- #sql_outputs = sql_model.generate(**sql_encoding)
48
- #sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
49
-
50
- # append the new user input tokens to the chat history
51
- bot_input_ids = torch.cat([torch.LongTensor(history), sql_encoding], dim=-1)
52
-
53
- # generate a response
54
- history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
55
-
56
- # convert the tokens to text, and then split the responses into the right format
57
- response = sql_tokenizer.decode(history[0]).split("<|endoftext|>")
58
- response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
59
-
60
- '''
61
- bot_input_ids = torch.cat([torch.LongTensor(history), sql_encoding], dim=-1)
62
- history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
63
- response = sql_tokenizer.decode(history[0]).split("<|endoftext|>")
64
- response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
65
- '''
66
  else:
67
- # tokenize the new input sentence
68
- new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
69
-
70
- # append the new user input tokens to the chat history
71
- bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
72
-
73
- # generate a response
74
- history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
 
 
 
 
 
 
75
 
76
- # convert the tokens to text, and then split the responses into the right format
77
- response = tokenizer.decode(history[0]).split("<|endoftext|>")
78
- response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
79
-
80
  return response, history
81
 
82
 
83
- import gradio as gr
 
 
 
 
84
 
85
- interface = gr.Interface(
 
 
 
86
  fn=predict,
87
  theme="default",
88
  css=".footer {display:none !important}",
@@ -92,5 +86,24 @@ interface = gr.Interface(
92
  description="Type your message in the box above, and the chatbot will respond.",
93
  )
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  if __name__ == '__main__':
96
- interface.launch()
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from transformers import TapexTokenizer, BartForConditionalGeneration
4
  import pandas as pd
5
+ import gradio as gr
6
 
7
  import numpy as np
8
  import time
 
42
  # Check if the user input is a question
43
  is_question = "?" in input
44
 
45
+ '''
46
  if is_question:
47
  sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
48
+ sql_outputs = sql_model.generate(**sql_encoding)
49
+ sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
50
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  else:
52
+ '''
53
+
54
+ # tokenize the new input sentence
55
+ new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
56
+
57
+ # append the new user input tokens to the chat history
58
+ bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
59
+
60
+ # generate a response
61
+ history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
62
+
63
+ # convert the tokens to text, and then split the responses into the right format
64
+ response = tokenizer.decode(history[0]).split("<|endoftext|>")
65
+ response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
66
 
 
 
 
 
67
  return response, history
68
 
69
 
70
+ def sqlquery(input):
71
+
72
+ sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
73
+ sql_outputs = sql_model.generate(**sql_encoding)
74
+ sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
75
 
76
+ return sql_response
77
+
78
+
79
+ chat_interface = gr.Interface(
80
  fn=predict,
81
  theme="default",
82
  css=".footer {display:none !important}",
 
86
  description="Type your message in the box above, and the chatbot will respond.",
87
  )
88
 
89
+ sql_interface = gr.Interface(
90
+ fn=sqlquery,
91
+ theme="default",
92
+ inputs=gr.Textbox(prompt="You:"),
93
+ outputs=gr.Textbox(),
94
+ live=True,
95
+ capture_session=True,
96
+ title="ST SQL Chat",
97
+ description="Type your message in the box above, and the chatbot will respond.",
98
+ )
99
+
100
+ combine_interface = gr.TabbedInterface(
101
+ interface_list=[
102
+ chat_interface,
103
+ sql_interface
104
+ ],
105
+ tab_names=['Chatbot' ,'SQL Chat'],
106
+ )
107
+
108
  if __name__ == '__main__':
109
+ combine_interface.launch()