File size: 3,600 Bytes
23c1edb
 
 
 
 
ca38751
 
 
 
23c1edb
f24bed6
23c1edb
 
 
4e86ef1
23c1edb
 
 
 
4e86ef1
23c1edb
 
 
 
e030ac0
c17ba77
23c1edb
 
 
 
 
 
 
 
 
c17ba77
23c1edb
 
 
 
 
69beb29
23c1edb
 
 
 
 
 
 
ca38751
23c1edb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca38751
23c1edb
c17ba77
23c1edb
 
5c4d0f2
23c1edb
5c4d0f2
c17ba77
23c1edb
c17ba77
23c1edb
 
 
4e86ef1
23c1edb
 
 
 
 
 
 
 
 
 
4e86ef1
23c1edb
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TapexTokenizer, BartForConditionalGeneration
import pandas as pd
import torch

import numpy as np
import time
import os
#import pkg_resources

'''
# Get a list of installed packages and their versions
installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}

# Print the list of packages
for package, version in installed_packages.items():
    print(f"{package}=={version}")
'''

# Load the chatbot model
chatbot_model_name = "microsoft/DialoGPT-medium" #"gpt2"
chatbot_tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)


# Load the SQL Model
#wikisql take longer to process
#model_name = "microsoft/tapex-large-finetuned-wikisql"  # You can change this to any other model from the list above
#model_name = "microsoft/tapex-base-finetuned-wikisql"
#model_name = "microsoft/tapex-base-finetuned-wtq"
model_name = "microsoft/tapex-large-finetuned-wtq"
#model_name = "google/tapas-base-finetuned-wtq"
sql_tokenizer = TapexTokenizer.from_pretrained(model_name)
sql_model = BartForConditionalGeneration.from_pretrained(model_name)

data = {
    "year": [1896, 1900, 1904, 2004, 2008, 2012],
    "city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
}
table = pd.DataFrame.from_dict(data)

chat_history_ids = None
bot_input_ids = None


def chatbot_response(user_message):

    global chat_history_ids
    global bot_input_ids
    # Check if the user input is a question
    is_question = "?" in user_message

    if is_question:  
        # If the user input is a question, use TAPEx for question-answering
        #inputs = user_query
        encoding = sql_tokenizer(table=table, query=user_message, return_tensors="pt")
        outputs = sql_model.generate(**encoding)
        response = sql_tokenizer.batch_decode(outputs, skip_special_tokens=True)
    else:
        # Generate chatbot response using the chatbot model
        '''
        inputs = chatbot_tokenizer.encode("User: " + user_message, return_tensors="pt")
        outputs = chatbot_model.generate(inputs, max_length=100, num_return_sequences=1)
        response = chatbot_tokenizer.decode(outputs[0], skip_special_tokens=True)
        '''
        # encode the new user input, add the eos_token and return a tensor in Pytorch
        new_user_input_ids = chatbot_tokenizer.encode(user_message + chatbot_tokenizer.eos_token, return_tensors='pt')
    
        # append the new user input tokens to the chat history
        if chat_history_ids is not None:
            bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
        else:
            bot_input_ids = new_user_input_ids
        
        # generated a response while limiting the total chat history to 1000 tokens, 
        chat_history_ids = chatbot_model.generate(bot_input_ids, max_length=1000, pad_token_id=chatbot_tokenizer.eos_token_id)

        response = chatbot_tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
    
    return response

# Define the chatbot and SQL execution interfaces using Gradio
chatbot_interface = gr.Interface(
    fn=chatbot_response,
    inputs=gr.Textbox(prompt="You:"),
    outputs=gr.Textbox(),
    live=True,
    capture_session=True,
    title="ST Chatbot",
    description="Type your message in the box above, and the chatbot will respond.",
)

# Launch the Gradio interface
if __name__ == "__main__":
    chatbot_interface.launch()