File size: 2,579 Bytes
ec9ef8b 4e86ef1 54210ca 69beb29 4e86ef1 b5d991e 4e86ef1 b5d991e 4e86ef1 f24bed6 4e86ef1 69beb29 4e86ef1 d002017 d087072 d002017 436b052 c8d5ecf e030ac0 54210ca e030ac0 54210ca 69beb29 c8d5ecf 4e86ef1 e030ac0 04883bf 4e86ef1 abad0fd 5fb63f2 69beb29 4004cf7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TapexTokenizer, BartForConditionalGeneration
import pandas as pd
#import torch
#import pkg_resources
'''
# Get a list of installed packages and their versions
installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}
# Print the list of packages
for package, version in installed_packages.items():
print(f"{package}=={version}")
'''
# Load the chatbot model
chatbot_model_name = "gpt2" #"microsoft/DialoGPT-medium"
chatbot_tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
# Load the SQL Model
#wikisql take longer to process
#model_name = "microsoft/tapex-large-finetuned-wikisql" # You can change this to any other model from the list above
#model_name = "microsoft/tapex-base-finetuned-wikisql"
model_name = "microsoft/tapex-large-finetuned-wtq"
#model_name = "microsoft/tapex-base-finetuned-wtq"
sql_tokenizer = TapexTokenizer.from_pretrained(model_name)
sql_model = BartForConditionalGeneration.from_pretrained(model_name)
data = {
"year": [1896, 1900, 1904, 2004, 2008, 2012],
"city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
}
table = pd.DataFrame.from_dict(data)
def chatbot_response(user_message):
# Check if the user input is a question
is_question = "?" in user_message
if is_question:
# If the user input is a question, use TAPEx for question-answering
inputs = user_query
encoding = sql_tokenizer(table=table, query=inputs, return_tensors="pt")
outputs = sql_model.generate(**encoding)
response = sql_tokenizer.batch_decode(outputs, skip_special_tokens=True)
else:
# Generate chatbot response using the chatbot model
inputs = chatbot_tokenizer.encode("User: " + user_message, return_tensors="pt")
outputs = chatbot_model.generate(inputs, max_length=100, num_return_sequences=1)
response = chatbot_tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Define the chatbot and SQL execution interfaces using Gradio
chatbot_interface = gr.Interface(
fn=chatbot_response,
inputs=gr.Textbox(prompt="You:"),
outputs=gr.Textbox(),
live=True,
capture_session=True,
title="ST Chatbot",
description="Type your message in the box above, and the chatbot will respond.",
)
# Launch the Gradio interface
if __name__ == "__main__":
chatbot_interface.launch()
|