Spaces:
Running
Running
import os | |
import gradio as gr | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_community.utilities.sql_database import SQLDatabase | |
from langchain_community.agent_toolkits import create_sql_agent | |
from langchain_openai import AzureChatOpenAI | |
ccms_db_loc = 'ccms.db' | |
ccms_db = SQLDatabase.from_uri(f"sqlite:///{ccms_db_loc}") | |
gpt4o_azure = AzureChatOpenAI( | |
model_name='gpt-4o-mini', | |
api_key=os.environ["AZURE_OPENAI_KEY"], | |
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"], | |
api_version="2024-02-01", | |
temperature=0 | |
) | |
context = ccms_db.get_context() | |
database_schema = context['table_info'] | |
system_message = f"""You are a SQLite expert agent designed to interact with a SQLite database. | |
Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer. | |
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.. | |
You can order the results by a relevant column to return the most interesting examples in the database. | |
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers. | |
You have access to tools for interacting with the database. | |
Only use the given tools. Only use the information returned by the tools to construct your final answer. | |
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. | |
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. | |
If the question does not seem related to the database, just return "I don't know" as the answer. | |
Only use the following tables: | |
{database_schema} | |
""" | |
full_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", system_message), | |
("human", '{input}'), | |
MessagesPlaceholder("agent_scratchpad") | |
] | |
) | |
sqlite_agent = create_sql_agent( | |
llm=gpt4o_azure, | |
db=ccms_db, | |
prompt=full_prompt, | |
agent_type="openai-tools", | |
agent_executor_kwargs={'handle_parsing_errors':True}, | |
max_iterations=10, | |
verbose=True | |
) | |
def predict(user_input): | |
try: | |
response = sqlite_agent.invoke(user_input) | |
prediction = response['output'] | |
except Exception as e: | |
prediction = e | |
return prediction | |
# UI | |
textbox = gr.Textbox(placeholder="Enter your query here", lines=6) | |
schema = 'The schema for the database is presented below: \n <img src="https://cdn-uploads.huggingface.co/production/uploads/64118e60756b9e455c7eddd6/S1alVt_D88qatd-N4Dkjd.png" > \n<img src="https://cdn-uploads.huggingface.co/production/uploads/64118e60756b9e455c7eddd6/81ggHEjrt6wFrMyXJtHVS.png" > (Source: https://github.com/shrivastavasatyam/Credit-Card-Management-System)' | |
demo = gr.Interface( | |
inputs=textbox, fn=predict, outputs="text", | |
title="Query a Credit Card Database", | |
description="This web API presents an interface to ask questions on information stored in a credit card database.", | |
article=schema, | |
examples=[ | |
["Who are the top 5 merchants by total transactions?", ""], | |
["Which are the top 5 cities with the highest spend and what is their percentage contribution to overall spends?", ""], | |
["Which is the highest spend month and amount for each card type?", ""], | |
["Which was the city with the lowest percentage spend for the Gold card type?", ""], | |
["What was the percentage contribution of spends by females for each card type?", ""], | |
["Which city has the highest spend to transaction ratio on weekends?", ""], | |
["Which was the city to reach 500 transactions the fastest?", ""] | |
], | |
cache_examples=False, | |
theme=gr.themes.Base(), | |
concurrency_limit=8 | |
) | |
demo.queue() | |
demo.launch(auth=("demouser", os.getenv('PASSWD')), ssr_mode=False) |