gradio / app.py
teaevo's picture
Update app.py
8233187
raw
history blame
9.18 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TapexTokenizer, BartForConditionalGeneration
import pandas as pd
import gradio as gr
import numpy as np
import time
import os
import random
#import pyodbc
'''
import pkg_resources
# Get a list of installed packages and their versions
installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}
# Print the list of packages
for package, version in installed_packages.items():
print(f"{package}=={version}")
'''
'''
# Replace the connection parameters with your SQL Server information
server = 'your_server'
database = 'your_database'
username = 'your_username'
password = 'your_password'
driver = 'SQL Server' # This depends on the ODBC driver installed on your system
# Create the connection string
connection_string = f'DRIVER={{{driver}}};SERVER={server};DATABASE={database};UID={username};PWD={password}'
# Connect to the SQL Server
conn = pyodbc.connect(connection_string)
#============================================================================
# Replace "your_query" with your SQL query to fetch data from the database
query = 'SELECT * FROM your_table_name'
# Use pandas to read data from the SQL Server and store it in a DataFrame
df = pd.read_sql_query(query, conn)
# Close the SQL connection
conn.close()
'''
# Create a sample DataFrame with 3,000 records and 20 columns
num_records = 3000
num_columns = 20
data = {
f"column_{i}": np.random.randint(0, 100, num_records) for i in range(num_columns)
}
# Randomize the year and city columns
years = list(range(2000, 2023)) # Range of years
cities = ["New York", "Los Angeles", "Chicago", "Houston", "Miami"] # List of cities
data["year"] = [random.choice(years) for _ in range(num_records)]
data["city"] = [random.choice(cities) for _ in range(num_records)]
table = pd.DataFrame(data)
#table = pd.read_csv(csv_file.name, delimiter=",")
#table.fillna(0, inplace=True)
#table = table.astype(str)
data = {
"year": [1896, 1900, 1904, 2004, 2008, 2012],
"city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
}
#table = pd.DataFrame.from_dict(data)
# Load the chatbot model
chatbot_model_name = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
#cmax_token_limit = tokenizer.max_model_input_sizes[chatbot_model_name]
#print(f"Chat bot Maximum token limit for {chatbot_model_name}: {cmax_token_limit}")
# Load the SQL Model
sql_model_name = "microsoft/tapex-large-finetuned-wtq"
sql_tokenizer = TapexTokenizer.from_pretrained(sql_model_name)
sql_model = BartForConditionalGeneration.from_pretrained(sql_model_name)
#stokenizer = AutoTokenizer.from_pretrained(sql_model_name)
#max_token_limit = stokenizer.max_model_input_sizes[sql_model_name]
#print(f"SQL Maximum token limit for {sql_model_name}: {max_token_limit}")
#sql_response = None
conversation_history = []
def chat(input, history=[]):
#global sql_response
# Check if the user input is a question
#is_question = "?" in input
'''
if is_question:
sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
sql_outputs = sql_model.generate(**sql_encoding)
sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
else:
'''
# tokenize the new input sentence
new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
# append the new user input tokens to the chat history
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
# generate a response
history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
# convert the tokens to text, and then split the responses into the right format
response = tokenizer.decode(history[0]).split("<|endoftext|>")
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
return response, history
def sqlquery(input): #, history=[]):
global conversation_history
#======================================================================
batch_size = 10 # Number of records in each batch
num_records = 3000 # Total number of records in the dataset
for start_idx in range(0, num_records, batch_size):
end_idx = min(start_idx + batch_size, num_records)
# Get a batch of records
batch_data = table[start_idx:end_idx]
batch_responses = []
for idx, record in enumerate(batch_data):
# Maintain conversation context by appending history
if conversation_history:
history = "\n".join(conversation_history)
input_text = history + "\nUser: " + record["question"]
else:
input_text = "User: " + record["question"]
# Tokenize the input text
tokenized_input = sql_tokenizer.encode(input_text, return_tensors="pt")
# Perform inference
with torch.no_grad():
output = sql_model.generate(
input_ids=tokenized_input,
max_length=1024,
pad_token_id=sql_tokenizer.eos_token_id,
)
# Decode the output and process the response
response = sql_tokenizer.decode(output[0], skip_special_tokens=True)
batch_responses.append(response)
# Update conversation history
conversation_history.append("User: " + record["question"])
conversation_history.append("Bot: " + response)
# ==========================================================================
'''
inputs = [input]
sql_encoding = sql_tokenizer(table=table, query=input, return_tensors="pt")
sql_outputs = sql_model.generate(**sql_encoding)
sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
#history.append((input, sql_response))
conversation_history.append(("User", input))
conversation_history.append(("Bot", sql_response))
'''
# Build conversation string
#conversation = "\n".join([f"User: {user_msg}\nBot: {resp_msg}" for user_msg, resp_msg in conversation_history])
conversation = "\n".join([f"{sender}: {msg}" for sender, msg in conversation_history])
return conversation
#return sql_response
#return sql_response, history
'''
html = "<div class='chatbot'>"
for user_msg, resp_msg in conversation_history:
html += f"<div class='user_msg'>{user_msg}</div>"
html += f"<div class='resp_msg'>{resp_msg}</div>"
html += "</div>"
return html
'''
chat_interface = gr.Interface(
fn=chat,
theme="default",
css=".footer {display:none !important}",
inputs=["text", "state"],
outputs=["chatbot", "state"],
title="ST Chatbot",
description="Type your message in the box above, and the chatbot will respond.",
)
sql_interface = gr.Interface(
fn=sqlquery,
theme="default",
css=".footer {display:none !important}",
inputs=gr.Textbox(prompt="You:"),
outputs=gr.Textbox(),
#inputs=["text", "state"],
#outputs=["chatbot", "state"],
#live=True,
#capture_session=True,
title="ST SQL Chat",
description="Type your message in the box above, and the chatbot will respond.",
)
'''
iface = gr.Interface(sqlquery, "text", "html", css="""
.chatbox {display:flex;flex-direction:column}
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
.user_msg {background-color:cornflowerblue;color:white;align-self:start}
.resp_msg {background-color:lightgray;align-self:self-end}
""", allow_screenshot=False, allow_flagging=False)
'''
combine_interface = gr.TabbedInterface(
interface_list=[
sql_interface,
chat_interface
],
tab_names=['SQL Chat' ,'Chatbot'],
)
if __name__ == '__main__':
combine_interface.launch()
#iface.launch(debug=True)
'''
batch_size = 10 # Number of records in each batch
num_records = 3000 # Total number of records in the dataset
for start_idx in range(0, num_records, batch_size):
end_idx = min(start_idx + batch_size, num_records)
# Get a batch of records
batch_data = dataset[start_idx:end_idx] # Replace with your dataset
# Tokenize the batch
tokenized_batch = tokenizer.batch_encode_plus(
batch_data, padding=True, truncation=True, return_tensors="pt"
)
# Perform inference
with torch.no_grad():
output = model.generate(
input_ids=tokenized_batch["input_ids"],
max_length=1024,
pad_token_id=tokenizer.eos_token_id,
)
# Decode the output and process the responses
responses = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output]
# Process responses and maintain conversation context
# ...
'''