File size: 3,967 Bytes
ec9ef8b
4e86ef1
54210ca
 
5fb63f2
4e86ef1
b5d991e
4e86ef1
b5d991e
 
 
 
 
 
4e86ef1
f24bed6
4e86ef1
5fb63f2
4e86ef1
 
 
 
 
d002017
d087072
d002017
436b052
 
c8d5ecf
 
e030ac0
54210ca
 
 
 
 
e030ac0
9a7d447
7f91b7d
c8d5ecf
 
 
 
54210ca
4e86ef1
5fb63f2
 
 
c8d5ecf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e86ef1
c8d5ecf
 
4e86ef1
 
 
54210ca
4e86ef1
 
c8d5ecf
 
 
54210ca
e030ac0
 
4e86ef1
 
e030ac0
 
 
 
 
04883bf
4e86ef1
 
 
 
 
 
7e20218
4e86ef1
 
 
4f6e66f
e030ac0
 
f24bed6
abad0fd
5fb63f2
 
 
7e20218
4004cf7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TapexTokenizer, BartForConditionalGeneration
import pandas as pd
import torch
#import pkg_resources

'''
# Get a list of installed packages and their versions
installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}

# Print the list of packages
for package, version in installed_packages.items():
    print(f"{package}=={version}")
'''

# Load the chatbot model
chatbot_model_name = "microsoft/DialoGPT-medium"
chatbot_tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)


# Load the SQL Model
#wikisql take longer to process
#model_name = "microsoft/tapex-large-finetuned-wikisql"  # You can change this to any other model from the list above
#model_name = "microsoft/tapex-base-finetuned-wikisql"
model_name = "microsoft/tapex-large-finetuned-wtq"
#model_name = "microsoft/tapex-base-finetuned-wtq"
sql_tokenizer = TapexTokenizer.from_pretrained(model_name)
sql_model = BartForConditionalGeneration.from_pretrained(model_name)

data = {
    "year": [1896, 1900, 1904, 2004, 2008, 2012],
    "city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
}
table = pd.DataFrame.from_dict(data)

bot_input_ids = None

tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")


def chatbot_response(user_message):
    # Generate chatbot response using the chatbot model
    #inputs = chatbot_tokenizer.encode("User: " + user_message, return_tensors="pt")
    #outputs = chatbot_model.generate(inputs, max_length=100, num_return_sequences=1)
    #response = chatbot_tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = None
    
    # Let's chat for 5 lines
    for step in range(1):
        # encode the new user input, add the eos_token and return a tensor in Pytorch
        new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')
    
        # append the new user input tokens to the chat history
        bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
    
        # generated a response while limiting the total chat history to 1000 tokens, 
        chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
    
        # pretty print last ouput tokens from bot
        #print("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))

        response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
        
    return response

def sql_response(user_query):
    
    #inputs = tokenizer.encode("User: " + user_query, return_tensors="pt")
    inputs = user_query
    encoding = sql_tokenizer(table=table, query=inputs, return_tensors="pt")
    outputs = sql_model.generate(**encoding)
    response = sql_tokenizer.batch_decode(outputs, skip_special_tokens=True)
    
    return response

# Define the chatbot and SQL execution interfaces using Gradio
chatbot_interface = gr.Interface(
    fn=chatbot_response,
    inputs=gr.Textbox(prompt="You:"),
    outputs=gr.Textbox(),
    live=True,
    capture_session=True,
    title="ST Chatbot",
    description="Type your message in the box above, and the chatbot will respond.",
)

# Define the chatbot interface using Gradio
sql_interface = gr.Interface(
    fn=sql_response,
    inputs=gr.Textbox(prompt="Enter your SQL Qus:"),
    outputs=gr.Textbox(),
    live=True,
    capture_session=True,
    title="ST SQL Chatbot",
    description="Type your message in the box above, and the chatbot will respond.",
)

# Launch the Gradio interface
if __name__ == "__main__":
    chatbot_interface.launch()

sql_interface.launch()