gradio / app.py
teaevo's picture
Update app.py
875202b
raw
history blame
3.84 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TapexTokenizer, BartForConditionalGeneration
import pandas as pd
import torch
import numpy as np
import time
import os
#import pkg_resources
'''
# Get a list of installed packages and their versions
installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}
# Print the list of packages
for package, version in installed_packages.items():
print(f"{package}=={version}")
'''
# Load the chatbot model
chatbot_model_name = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
# Load the SQL Model
sql_model_name = "microsoft/tapex-large-finetuned-wtq"
sql_tokenizer = TapexTokenizer.from_pretrained(sql_model_name)
sql_model = BartForConditionalGeneration.from_pretrained(sql_model_name)
data = {
"year": [1896, 1900, 1904, 2004, 2008, 2012],
"city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
}
table = pd.DataFrame.from_dict(data)
sql_response = None
def predict(input, history=[]):
global sql_response
# Check if the user input is a question
is_question = "?" in input
if is_question:
sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
sql_outputs = sql_model.generate(**sql_encoding)
sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
# Append the SQL model's response to the history
sql_response_ids = tokenizer.encode(sql_response + tokenizer.eos_token, return_tensors='pt')
history.extend(sql_response_ids[0].tolist()) # Add SQL response token IDs to history
'''
bot_input_ids = torch.cat([torch.LongTensor(history), sql_encoding], dim=-1)
history = sql_model.generate(bot_input_ids, max_length=1000, pad_token_id=sql_tokenizer.eos_token_id).tolist()
response = sql_tokenizer.decode(history[0]).split("<|endoftext|>")
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
'''
else:
# tokenize the new input sentence
new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
sql_outputs = sql_model.generate(**sql_encoding)
sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
# Append the SQL model's response to the history
sql_response_ids = tokenizer.encode(sql_response + tokenizer.eos_token, return_tensors='pt')
history.extend(sql_response_ids[0].tolist()) # Add SQL response token IDs to history
# append the new user input tokens to the chat history
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
# generate a response
history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
# convert the tokens to text, and then split the responses into the right format
response = tokenizer.decode(history[0]).split("<|endoftext|>")
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
return response, history
import gradio as gr
interface = gr.Interface(
fn=predict,
theme="default",
css=".footer {display:none !important}",
inputs=["text", "state"],
outputs=["chatbot", "state"],
title="ST Chatbot",
description="Type your message in the box above, and the chatbot will respond.",
)
if __name__ == '__main__':
interface.launch()