File size: 5,235 Bytes
2714773
 
 
 
5782838
9d6743e
2714773
 
 
d7a34dd
2714773
d7a34dd
2714773
9d6743e
2714773
 
 
ca38751
2714773
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
978fd4d
2714773
978fd4d
2714773
 
 
6651d18
2714773
 
 
 
 
 
 
 
 
5758bb4
2714773
 
 
 
 
 
 
 
0f46be8
2714773
 
 
 
5758bb4
2714773
 
9a55232
37e8e73
9bb7983
 
9a55232
 
2714773
9a55232
 
 
 
 
 
 
 
 
 
978fd4d
b15f680
37e8e73
 
9bb7983
928fc23
95934ef
 
 
978fd4d
9a55232
2714773
 
 
978fd4d
5758bb4
 
 
 
 
 
2714773
 
 
 
 
37e8e73
9a55232
 
 
2714773
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TapexTokenizer, BartForConditionalGeneration
import pandas as pd
import gradio as gr

import numpy as np
import time
import os

#import pyodbc

#import pkg_resources

'''
# Get a list of installed packages and their versions
installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}

# Print the list of packages
for package, version in installed_packages.items():
    print(f"{package}=={version}")
'''

'''
# Replace the connection parameters with your SQL Server information
server = 'your_server'
database = 'your_database'
username = 'your_username'
password = 'your_password'
driver = 'SQL Server'  # This depends on the ODBC driver installed on your system

# Create the connection string
connection_string = f'DRIVER={{{driver}}};SERVER={server};DATABASE={database};UID={username};PWD={password}'

# Connect to the SQL Server
conn = pyodbc.connect(connection_string)

#============================================================================
# Replace "your_query" with your SQL query to fetch data from the database
query = 'SELECT * FROM your_table_name'

# Use pandas to read data from the SQL Server and store it in a DataFrame
df = pd.read_sql_query(query, conn)

# Close the SQL connection
conn.close()
'''

data = {
    "year": [1896, 1900, 1904, 2004, 2008, 2012],
    "city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
}
table = pd.DataFrame.from_dict(data)


# Load the chatbot model
chatbot_model_name = "microsoft/DialoGPT-medium" 
tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)

# Load the SQL Model
sql_model_name = "microsoft/tapex-large-finetuned-wtq"
sql_tokenizer = TapexTokenizer.from_pretrained(sql_model_name)
sql_model = BartForConditionalGeneration.from_pretrained(sql_model_name)

#sql_response = None
#conversation_history = []

def chat(input, history=[]):

    #global sql_response
    # Check if the user input is a question
    #is_question = "?" in input

    '''
    if is_question: 
        sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
        sql_outputs = sql_model.generate(**sql_encoding)
        sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)

    else:
    '''
    
    # tokenize the new input sentence
    new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
    
    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
      
    # generate a response
    history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()

    # convert the tokens to text, and then split the responses into the right format
    response = tokenizer.decode(history[0]).split("<|endoftext|>")
    response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]  # convert to tuples of list
    
    return response, history


def sqlquery(input, history=[]):
    
    #input_text = " ".join(conversation_history) + " " + input
    sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
    #sql_outputs = sql_model.generate(**sql_encoding)
    #sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)

     # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
      
    # generate a response
    history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()

     # convert the tokens to text, and then split the responses into the right format
    response = sql_tokenizer.decode(history[0]).split("<|endoftext|>")
    response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]  # convert to tuples of list
    
    '''
    global conversation_history
    
    # Maintain the conversation history
    conversation_history.append("User: " + input)
    conversation_history.append("Bot: " + " ".join(sql_response) )

    output = " ".join(conversation_history)
    return output
    '''
    return response, history


chat_interface = gr.Interface(
    fn=chat,
    theme="default",
    css=".footer {display:none !important}",
    inputs=["text", "state"],
    outputs=["chatbot", "state"],
    title="ST Chatbot",
    description="Type your message in the box above, and the chatbot will respond.",
)

sql_interface = gr.Interface(
    fn=sqlquery,
    theme="default",
    #inputs=gr.Textbox(prompt="You:"),
    #outputs=gr.Textbox(),
    inputs=["text", "state"],
    outputs=["chatbot", "state"],
    live=True,
    capture_session=True,
    title="ST SQL Chat",
    description="Type your message in the box above, and the chatbot will respond.",
)

combine_interface = gr.TabbedInterface(
    interface_list=[
        chat_interface,
        sql_interface
    ],
    tab_names=['Chatbot' ,'SQL Chat'],
)

if __name__ == '__main__':
    combine_interface.launch()