File size: 3,098 Bytes
23c1edb
f65b03e
9d6743e
 
 
 
 
 
 
 
 
 
 
 
ca38751
9d6743e
 
 
 
e030ac0
9d6743e
 
 
 
 
 
829e215
 
 
9d6743e
 
 
 
 
 
c17ba77
829e215
 
f65b03e
9d6743e
829e215
9d6743e
8a4fd9e
69beb29
0deb7d9
 
0279b82
42fd7c5
0279b82
42fd7c5
59c5eff
0279b82
0deb7d9
 
 
 
0279b82
0deb7d9
 
 
cee9f1f
0deb7d9
 
 
 
 
 
 
 
 
 
f65b03e
23c1edb
b92090c
f65b03e
4e86ef1
f65b03e
 
 
 
830c2c9
 
9d6743e
 
23c1edb
4e86ef1
f65b03e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TapexTokenizer, BartForConditionalGeneration
import pandas as pd
import torch

import numpy as np
import time
import os
#import pkg_resources

'''
# Get a list of installed packages and their versions
installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}

# Print the list of packages
for package, version in installed_packages.items():
    print(f"{package}=={version}")
'''

# Load the chatbot model
chatbot_model_name = "microsoft/DialoGPT-medium" 
tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)

# Load the SQL Model
sql_model_name = "microsoft/tapex-large-finetuned-wtq"
sql_tokenizer = TapexTokenizer.from_pretrained(sql_model_name)
sql_model = BartForConditionalGeneration.from_pretrained(sql_model_name)

data = {
    "year": [1896, 1900, 1904, 2004, 2008, 2012],
    "city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
}
table = pd.DataFrame.from_dict(data)

sql_response = None

def predict(input, history=[]):

    global sql_response
    # Check if the user input is a question
    is_question = "?" in input

    if is_question: 
        sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
        sql_outputs = sql_model.generate(**sql_encoding)
        response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)

        history.append(response)
        
        '''
        bot_input_ids = torch.cat([torch.LongTensor(history), sql_encoding], dim=-1)
        history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
        response = sql_tokenizer.decode(history[0]).split("<|endoftext|>")
        response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
        '''
    else:
        # tokenize the new input sentence
        new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
    
        # append the new user input tokens to the chat history
        bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
          
        # generate a response
        history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
    
        # convert the tokens to text, and then split the responses into the right format
        response = tokenizer.decode(history[0]).split("<|endoftext|>")
        response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]  # convert to tuples of list
        
    return response, history


import gradio as gr

interface = gr.Interface(
    fn=predict,
    theme="default",
    css=".footer {display:none !important}",
    inputs=["text", "state"],
    outputs=["chatbot", "state"],
    title="ST Chatbot",
    description="Type your message in the box above, and the chatbot will respond.",
)

if __name__ == '__main__':
    interface.launch()