teaevo commited on
Commit
2714773
1 Parent(s): 0343785

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -33
app.py CHANGED
@@ -1,45 +1,150 @@
 
 
 
 
1
  import gradio as gr
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TapasForQuestionAnswering, TapasTokenizer
3
 
4
- # Load the models and tokenizers
5
- tapas_model_name = "microsoft/tapex-large-finetuned-wtq"
6
- dialogpt_model_name = "microsoft/DialoGPT-medium"
7
 
8
- tapas_tokenizer = TapasTokenizer.from_pretrained(tapas_model_name)
9
- tapas_model = BartForConditionalGeneration.from_pretrained(tapas_model_name)
10
 
11
- dialogpt_tokenizer = AutoTokenizer.from_pretrained(dialogpt_model_name)
12
- dialogpt_model = AutoModelForSeqCausalLM.from_pretrained(dialogpt_model_name)
13
 
14
- def answer_table_question(table, question):
15
- encoding = tapas_tokenizer(table=table, query=question, return_tensors="pt")
16
- outputs = tapas_model.generate(**encoding)
17
- response = tapas_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
18
- return response
19
 
20
- def generate_dialog_response(prompt, conversation_history):
21
- bot_input = dialogpt_tokenizer.encode(prompt + dialogpt_tokenizer.eos_token, return_tensors="pt")
22
- chat_history_ids = dialogpt_model.generate(bot_input, max_length=1000, pad_token_id=dialogpt_tokenizer.eos_token_id)
23
- response = dialogpt_tokenizer.decode(chat_history_ids[:, bot_input.shape[-1]:][0], skip_special_tokens=True)
24
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- def chatbot_interface(user_input, table=gr.inputs.Textbox()):
27
  global conversation_history
28
 
29
  conversation_history.append(user_input)
 
 
 
 
 
30
 
31
- # Check if user asks a question related to the table
32
- if "table" in user_input:
33
- question = user_input
34
- answer = answer_table_question(table, question)
35
- conversation_history.append(answer)
36
- return "Bot (TAPAS): " + answer
37
- else:
38
- dialog_prompt = "User: " + " ".join(conversation_history) + "\nBot:"
39
- response = generate_dialog_response(dialog_prompt, conversation_history)
40
- conversation_history.append(response)
41
- return "Bot (DialoGPT): " + response
42
 
43
- conversation_history = []
44
- iface = gr.Interface(fn=chatbot_interface, inputs=["text", "text"], outputs="text", live=True)
45
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from transformers import TapexTokenizer, BartForConditionalGeneration
4
+ import pandas as pd
5
  import gradio as gr
 
6
 
7
+ import numpy as np
8
+ import time
9
+ import os
10
 
11
+ #import pyodbc
 
12
 
13
+ #import pkg_resources
 
14
 
15
+ '''
16
+ # Get a list of installed packages and their versions
17
+ installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}
 
 
18
 
19
+ # Print the list of packages
20
+ for package, version in installed_packages.items():
21
+ print(f"{package}=={version}")
22
+ '''
23
+
24
+ '''
25
+ # Replace the connection parameters with your SQL Server information
26
+ server = 'your_server'
27
+ database = 'your_database'
28
+ username = 'your_username'
29
+ password = 'your_password'
30
+ driver = 'SQL Server' # This depends on the ODBC driver installed on your system
31
+
32
+ # Create the connection string
33
+ connection_string = f'DRIVER={{{driver}}};SERVER={server};DATABASE={database};UID={username};PWD={password}'
34
+
35
+ # Connect to the SQL Server
36
+ conn = pyodbc.connect(connection_string)
37
+
38
+ #============================================================================
39
+ # Replace "your_query" with your SQL query to fetch data from the database
40
+ query = 'SELECT * FROM your_table_name'
41
+
42
+ # Use pandas to read data from the SQL Server and store it in a DataFrame
43
+ df = pd.read_sql_query(query, conn)
44
+
45
+ # Close the SQL connection
46
+ conn.close()
47
+ '''
48
+
49
+ data = {
50
+ "year": [1896, 1900, 1904, 2004, 2008, 2012],
51
+ "city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
52
+ }
53
+ table = pd.DataFrame.from_dict(data)
54
+
55
+
56
+ # Load the chatbot model
57
+ chatbot_model_name = "microsoft/DialoGPT-medium"
58
+ tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
59
+ model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
60
+
61
+ # Load the SQL Model
62
+ sql_model_name = "microsoft/tapex-large-finetuned-wtq"
63
+ sql_tokenizer = TapexTokenizer.from_pretrained(sql_model_name)
64
+ sql_model = BartForConditionalGeneration.from_pretrained(sql_model_name)
65
+
66
+ #sql_response = None
67
+ conversation_history = []
68
+
69
+ def predict(input, conversation_history): #history=[]):
70
+
71
+ #global sql_response
72
+ # Check if the user input is a question
73
+ #is_question = "?" in input
74
+
75
+ '''
76
+ if is_question:
77
+ sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
78
+ sql_outputs = sql_model.generate(**sql_encoding)
79
+ sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
80
+
81
+ else:
82
+ '''
83
 
 
84
  global conversation_history
85
 
86
  conversation_history.append(user_input)
87
+
88
+ bot_input = dialogpt_tokenizer.encode(input + " ".join(conversation_history)+ tokenizer.eos_token, return_tensors="pt")
89
+ chat_history_ids = model.generate(bot_input, max_length=1000, pad_token_id=tokenizer.eos_token_id)
90
+ response = tokenizer.decode(chat_history_ids[:, bot_input.shape[-1]:][0], skip_special_tokens=True)
91
+ #return response
92
 
93
+ '''
94
+ # tokenize the new input sentence
95
+ new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
96
+
97
+ # append the new user input tokens to the chat history
98
+ bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
99
+
100
+ # generate a response
101
+ history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
 
 
102
 
103
+ # convert the tokens to text, and then split the responses into the right format
104
+ response = tokenizer.decode(history[0]).split("<|endoftext|>")
105
+ response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
106
+ '''
107
+
108
+ return response #, history
109
+
110
+
111
+ def sqlquery(input):
112
+
113
+ sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
114
+ sql_outputs = sql_model.generate(**sql_encoding)
115
+ sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
116
+
117
+ return sql_response
118
+
119
+
120
+ chat_interface = gr.Interface(
121
+ fn=predict,
122
+ theme="default",
123
+ css=".footer {display:none !important}",
124
+ inputs=["text", "state"],
125
+ outputs="text", #["chatbot", "state"],
126
+ title="ST Chatbot",
127
+ description="Type your message in the box above, and the chatbot will respond.",
128
+ )
129
+
130
+ sql_interface = gr.Interface(
131
+ fn=sqlquery,
132
+ theme="default",
133
+ inputs=gr.Textbox(prompt="You:"),
134
+ outputs=gr.Textbox(),
135
+ live=True,
136
+ capture_session=True,
137
+ title="ST SQL Chat",
138
+ description="Type your message in the box above, and the chatbot will respond.",
139
+ )
140
+
141
+ combine_interface = gr.TabbedInterface(
142
+ interface_list=[
143
+ chat_interface,
144
+ sql_interface
145
+ ],
146
+ tab_names=['Chatbot' ,'SQL Chat'],
147
+ )
148
+
149
+ if __name__ == '__main__':
150
+ combine_interface.launch()