Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
|
|
2 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
from transformers import TapexTokenizer, BartForConditionalGeneration
|
4 |
import pandas as pd
|
5 |
-
|
6 |
#import pkg_resources
|
7 |
|
8 |
'''
|
@@ -15,7 +15,7 @@ for package, version in installed_packages.items():
|
|
15 |
'''
|
16 |
|
17 |
# Load the chatbot model
|
18 |
-
chatbot_model_name = "
|
19 |
chatbot_tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
|
20 |
chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
|
21 |
|
@@ -24,8 +24,9 @@ chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
|
|
24 |
#wikisql take longer to process
|
25 |
#model_name = "microsoft/tapex-large-finetuned-wikisql" # You can change this to any other model from the list above
|
26 |
#model_name = "microsoft/tapex-base-finetuned-wikisql"
|
27 |
-
model_name = "microsoft/tapex-large-finetuned-wtq"
|
28 |
#model_name = "microsoft/tapex-base-finetuned-wtq"
|
|
|
|
|
29 |
sql_tokenizer = TapexTokenizer.from_pretrained(model_name)
|
30 |
sql_model = BartForConditionalGeneration.from_pretrained(model_name)
|
31 |
|
@@ -35,7 +36,11 @@ data = {
|
|
35 |
}
|
36 |
table = pd.DataFrame.from_dict(data)
|
37 |
|
|
|
|
|
38 |
def chatbot_response(user_message):
|
|
|
|
|
39 |
# Check if the user input is a question
|
40 |
is_question = "?" in user_message
|
41 |
|
@@ -47,10 +52,24 @@ def chatbot_response(user_message):
|
|
47 |
response = sql_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
48 |
else:
|
49 |
# Generate chatbot response using the chatbot model
|
|
|
50 |
inputs = chatbot_tokenizer.encode("User: " + user_message, return_tensors="pt")
|
51 |
outputs = chatbot_model.generate(inputs, max_length=100, num_return_sequences=1)
|
52 |
response = chatbot_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
return response
|
55 |
|
56 |
# Define the chatbot and SQL execution interfaces using Gradio
|
|
|
2 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
3 |
from transformers import TapexTokenizer, BartForConditionalGeneration
|
4 |
import pandas as pd
|
5 |
+
import torch
|
6 |
#import pkg_resources
|
7 |
|
8 |
'''
|
|
|
15 |
'''
|
16 |
|
17 |
# Load the chatbot model
|
18 |
+
chatbot_model_name = "microsoft/DialoGPT-medium" #"gpt2"
|
19 |
chatbot_tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
|
20 |
chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
|
21 |
|
|
|
24 |
#wikisql take longer to process
|
25 |
#model_name = "microsoft/tapex-large-finetuned-wikisql" # You can change this to any other model from the list above
|
26 |
#model_name = "microsoft/tapex-base-finetuned-wikisql"
|
|
|
27 |
#model_name = "microsoft/tapex-base-finetuned-wtq"
|
28 |
+
#model_name = "microsoft/tapex-large-finetuned-wtq"
|
29 |
+
model_name = "google/tapas-base-finetuned-wtq"
|
30 |
sql_tokenizer = TapexTokenizer.from_pretrained(model_name)
|
31 |
sql_model = BartForConditionalGeneration.from_pretrained(model_name)
|
32 |
|
|
|
36 |
}
|
37 |
table = pd.DataFrame.from_dict(data)
|
38 |
|
39 |
+
new_chat = True
|
40 |
+
|
41 |
def chatbot_response(user_message):
|
42 |
+
|
43 |
+
global new_chat
|
44 |
# Check if the user input is a question
|
45 |
is_question = "?" in user_message
|
46 |
|
|
|
52 |
response = sql_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
53 |
else:
|
54 |
# Generate chatbot response using the chatbot model
|
55 |
+
'''
|
56 |
inputs = chatbot_tokenizer.encode("User: " + user_message, return_tensors="pt")
|
57 |
outputs = chatbot_model.generate(inputs, max_length=100, num_return_sequences=1)
|
58 |
response = chatbot_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
59 |
+
'''
|
60 |
+
# encode the new user input, add the eos_token and return a tensor in Pytorch
|
61 |
+
new_user_input_ids = chatbot_tokenizer.encode(input(">> User:") + chatbot_tokenizer.eos_token, return_tensors='pt')
|
62 |
+
|
63 |
+
# append the new user input tokens to the chat history
|
64 |
+
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if new_chat is False else new_user_input_ids
|
65 |
+
|
66 |
+
# generated a response while limiting the total chat history to 1000 tokens,
|
67 |
+
chat_history_ids = chatbot_model.generate(bot_input_ids, max_length=1000, pad_token_id=chatbot_tokenizer.eos_token_id)
|
68 |
+
|
69 |
+
response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
|
70 |
+
|
71 |
+
new_chat = False
|
72 |
+
|
73 |
return response
|
74 |
|
75 |
# Define the chatbot and SQL execution interfaces using Gradio
|