teaevo commited on
Commit
4e86ef1
·
1 Parent(s): b5d991e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -8
app.py CHANGED
@@ -1,22 +1,30 @@
1
  import gradio as gr
 
2
  from transformers import TapexTokenizer, BartForConditionalGeneration
3
  import pandas as pd
4
- import pkg_resources
5
 
 
6
  # Get a list of installed packages and their versions
7
  installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}
8
 
9
  # Print the list of packages
10
  for package, version in installed_packages.items():
11
  print(f"{package}=={version}")
 
12
 
 
 
 
 
 
 
 
13
  #wikisql take longer to process
14
  #model_name = "microsoft/tapex-large-finetuned-wikisql" # You can change this to any other model from the list above
15
  #model_name = "microsoft/tapex-base-finetuned-wikisql"
16
-
17
  model_name = "microsoft/tapex-large-finetuned-wtq"
18
  #model_name = "microsoft/tapex-base-finetuned-wtq"
19
-
20
  tokenizer = TapexTokenizer.from_pretrained(model_name)
21
  model = BartForConditionalGeneration.from_pretrained(model_name)
22
 
@@ -27,26 +35,48 @@ data = {
27
  table = pd.DataFrame.from_dict(data)
28
 
29
  def chatbot_response(user_message):
 
 
 
 
 
 
 
 
30
 
31
- #inputs = tokenizer.encode("User: " + user_message, return_tensors="pt")
32
- inputs = user_message
33
  encoding = tokenizer(table=table, query=inputs, return_tensors="pt")
34
  outputs = model.generate(**encoding)
35
  response = tokenizer.batch_decode(outputs, skip_special_tokens=True)
36
 
37
  return response
38
 
39
- # Define the chatbot interface using Gradio
40
- iface = gr.Interface(
41
  fn=chatbot_response,
42
  inputs=gr.Textbox(prompt="You:"),
43
  outputs=gr.Textbox(),
44
  live=True,
45
  capture_session=True,
 
 
 
 
 
 
 
 
 
 
 
46
  title="ST SQL Chatbot",
47
  description="Type your message in the box above, and the chatbot will respond.",
48
  )
49
 
 
 
 
50
  # Launch the Gradio interface
51
  if __name__ == "__main__":
52
- iface.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from transformers import TapexTokenizer, BartForConditionalGeneration
4
  import pandas as pd
5
+ #import pkg_resources
6
 
7
+ '''
8
  # Get a list of installed packages and their versions
9
  installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}
10
 
11
  # Print the list of packages
12
  for package, version in installed_packages.items():
13
  print(f"{package}=={version}")
14
+ '''
15
 
16
+ # Load the chatbot model
17
+ chatbot_model_name = "gpt2"
18
+ chatbot_tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
19
+ chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
20
+
21
+
22
+ # Load the SQL Model
23
  #wikisql take longer to process
24
  #model_name = "microsoft/tapex-large-finetuned-wikisql" # You can change this to any other model from the list above
25
  #model_name = "microsoft/tapex-base-finetuned-wikisql"
 
26
  model_name = "microsoft/tapex-large-finetuned-wtq"
27
  #model_name = "microsoft/tapex-base-finetuned-wtq"
 
28
  tokenizer = TapexTokenizer.from_pretrained(model_name)
29
  model = BartForConditionalGeneration.from_pretrained(model_name)
30
 
 
35
  table = pd.DataFrame.from_dict(data)
36
 
37
  def chatbot_response(user_message):
38
+ # Generate chatbot response using the chatbot model
39
+ inputs = chatbot_tokenizer.encode("User: " + user_message, return_tensors="pt")
40
+ outputs = chatbot_model.generate(inputs, max_length=100, num_return_sequences=1)
41
+ response = chatbot_tokenizer.decode(outputs[0], skip_special_tokens=True)
42
+
43
+ return response
44
+
45
+ def sql_response(user_query):
46
 
47
+ #inputs = tokenizer.encode("User: " + user_query, return_tensors="pt")
48
+ inputs = user_query
49
  encoding = tokenizer(table=table, query=inputs, return_tensors="pt")
50
  outputs = model.generate(**encoding)
51
  response = tokenizer.batch_decode(outputs, skip_special_tokens=True)
52
 
53
  return response
54
 
55
+ # Define the chatbot and SQL execution interfaces using Gradio
56
+ chatbot_interface = gr.Interface(
57
  fn=chatbot_response,
58
  inputs=gr.Textbox(prompt="You:"),
59
  outputs=gr.Textbox(),
60
  live=True,
61
  capture_session=True,
62
+ title="Chatbot",
63
+ description="Type your message in the box above, and the chatbot will respond.",
64
+ )
65
+
66
+ # Define the chatbot interface using Gradio
67
+ sql_interface = gr.Interface(
68
+ fn=sql_response,
69
+ inputs=gr.Textbox(prompt="You:"),
70
+ outputs=gr.Textbox(),
71
+ live=True,
72
+ capture_session=True,
73
  title="ST SQL Chatbot",
74
  description="Type your message in the box above, and the chatbot will respond.",
75
  )
76
 
77
+ # Combine the chatbot and SQL execution interfaces
78
+ combined_interface = gr.Interface([chatbot_interface, sql_interface], layout="horizontal")
79
+
80
  # Launch the Gradio interface
81
  if __name__ == "__main__":
82
+ combined_interface.launch()