File size: 4,213 Bytes
23c1edb
f65b03e
9d6743e
 
5782838
9d6743e
 
 
 
d7a34dd
 
 
9d6743e
 
 
 
 
ca38751
9d6743e
 
 
 
e030ac0
d7a34dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d6743e
 
 
 
 
 
829e215
 
 
9d6743e
d7a34dd
829e215
f65b03e
9d6743e
d7a34dd
9d6743e
d7a34dd
69beb29
5782838
0deb7d9
 
5782838
 
 
0deb7d9
5782838
 
 
 
 
 
 
 
 
 
 
 
 
 
0deb7d9
f65b03e
23c1edb
b92090c
5782838
 
 
 
 
4e86ef1
5782838
 
 
 
f65b03e
 
 
830c2c9
 
9d6743e
 
23c1edb
4e86ef1
5782838
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f65b03e
5782838
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TapexTokenizer, BartForConditionalGeneration
import pandas as pd
import gradio as gr

import numpy as np
import time
import os

import pyodbc

#import pkg_resources

'''
# Get a list of installed packages and their versions
installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}

# Print the list of packages
for package, version in installed_packages.items():
    print(f"{package}=={version}")
'''

'''
# Replace the connection parameters with your SQL Server information
server = 'your_server'
database = 'your_database'
username = 'your_username'
password = 'your_password'
driver = 'SQL Server'  # This depends on the ODBC driver installed on your system

# Create the connection string
connection_string = f'DRIVER={{{driver}}};SERVER={server};DATABASE={database};UID={username};PWD={password}'

# Connect to the SQL Server
conn = pyodbc.connect(connection_string)

#============================================================================
# Replace "your_query" with your SQL query to fetch data from the database
query = 'SELECT * FROM your_table_name'

# Use pandas to read data from the SQL Server and store it in a DataFrame
df = pd.read_sql_query(query, conn)

# Close the SQL connection
conn.close()
'''

data = {
    "year": [1896, 1900, 1904, 2004, 2008, 2012],
    "city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
}
table = pd.DataFrame.from_dict(data)


# Load the chatbot model
chatbot_model_name = "microsoft/DialoGPT-medium" 
tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)

# Load the SQL Model
sql_model_name = "microsoft/tapex-large-finetuned-wtq"
sql_tokenizer = TapexTokenizer.from_pretrained(sql_model_name)
sql_model = BartForConditionalGeneration.from_pretrained(sql_model_name)

#sql_response = None

def predict(input, history=[]):

    #global sql_response
    # Check if the user input is a question
    #is_question = "?" in input

    '''
    if is_question: 
        sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
        sql_outputs = sql_model.generate(**sql_encoding)
        sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)

    else:
    '''
    
    # tokenize the new input sentence
    new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
    
    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
      
    # generate a response
    history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()

    # convert the tokens to text, and then split the responses into the right format
    response = tokenizer.decode(history[0]).split("<|endoftext|>")
    response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]  # convert to tuples of list
    
    return response, history


def sqlquery(input):

    sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
    sql_outputs = sql_model.generate(**sql_encoding)
    sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)

    return sql_response


chat_interface = gr.Interface(
    fn=predict,
    theme="default",
    css=".footer {display:none !important}",
    inputs=["text", "state"],
    outputs=["chatbot", "state"],
    title="ST Chatbot",
    description="Type your message in the box above, and the chatbot will respond.",
)

sql_interface = gr.Interface(
    fn=sqlquery,
    theme="default",
    inputs=gr.Textbox(prompt="You:"),
    outputs=gr.Textbox(),
    live=True,
    capture_session=True,
    title="ST SQL Chat",
    description="Type your message in the box above, and the chatbot will respond.",
)

combine_interface = gr.TabbedInterface(
    interface_list=[
        chat_interface,
        sql_interface
    ],
    tab_names=['Chatbot' ,'SQL Chat'],
)

if __name__ == '__main__':
    combine_interface.launch()