Spaces:
Sleeping
Sleeping
import os | |
import re | |
import getpass | |
from contextlib import contextmanager | |
from typing import List | |
from operator import itemgetter | |
from sqlalchemy import create_engine, text, inspect | |
from sqlalchemy.orm import sessionmaker | |
from dotenv import load_dotenv | |
from langchain_community.utilities import SQLDatabase | |
from langchain_openai import ChatOpenAI | |
from langchain_core.output_parsers.openai_tools import PydanticToolsParser | |
from langchain.chains import create_sql_query_chain | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.pydantic_v1 import BaseModel, Field | |
# Load environment variables from .env file | |
load_dotenv() | |
# Set environment variables for API keys | |
if not os.environ.get("OPENAI_API_KEY"): | |
os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API key: ") | |
if not os.environ.get("LANGCHAIN_API_KEY"): | |
os.environ["LANGCHAIN_API_KEY"] = getpass.getpass("Enter your LangChain API key: ") | |
os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
# Setup SQLite Database | |
db_path = os.path.join(os.path.dirname(__file__), "chinook.db") | |
engine = create_engine(f"sqlite:///{db_path}") | |
Session = sessionmaker(bind=engine) | |
db = SQLDatabase.from_uri(f"sqlite:///{db_path}") | |
print(db.dialect) | |
print(db.get_usable_table_names()) | |
with Session() as session: | |
result = session.execute(text("SELECT * FROM artists LIMIT 10;")).fetchall() | |
print(result) | |
# Initialize LLM | |
llm = ChatOpenAI(model="gpt-3.5-turbo-0125") | |
class Table(BaseModel): | |
"""Table in SQL database.""" | |
name: str = Field(description="Name of table in SQL database.") | |
# Function to get schema information | |
def get_schema_info(): | |
inspector = inspect(engine) | |
schema_info = {} | |
for table_name in inspector.get_table_names(): | |
columns = inspector.get_columns(table_name) | |
schema_info[table_name] = [(column["name"], str(column["type"])) for column in columns] | |
return schema_info | |
# Provide schema info to LLM | |
schema_info = get_schema_info() | |
formatted_schema_info = "\n".join( | |
f"Table: {table}\nColumns: {', '.join([f'{col[0]} ({col[1]})' for col in cols])}" | |
for table, cols in schema_info.items() | |
) | |
system = f"""You are an expert in querying SQL databases. The database schema is as follows: | |
{formatted_schema_info} | |
Given an input question, create a syntactically correct SQL query to run, then look at the results of the query and return the answer to the input question. | |
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. | |
You can order the results to return the most informative data in the database. Never query for all columns from a table. | |
You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers. | |
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. | |
Also, pay attention to which column is in which table. Use the following format: | |
SQLQuery: """ | |
table_names = "\n".join(db.get_usable_table_names()) | |
system_prompt = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \ | |
The tables are: | |
{table_names} | |
Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed.""" | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", system_prompt), | |
("human", "{input}"), | |
] | |
) | |
llm_with_tools = llm.bind_tools([Table]) | |
output_parser = PydanticToolsParser(tools=[Table]) | |
table_chain = prompt | llm_with_tools | output_parser | |
# Function to get table names from the output | |
def get_table_names(output: List[Table]) -> List[str]: | |
return [table.name for table in output] | |
# Create the SQL query chain | |
query_chain = create_sql_query_chain(llm, db) | |
# Combine table selection and query generation | |
full_chain = ( | |
RunnablePassthrough.assign( | |
table_names_to_use=lambda x: get_table_names(table_chain.invoke({"input": x["question"]})) | |
) | |
| query_chain | |
) | |
# Function to strip markdown formatting from SQL query | |
def strip_markdown(text): | |
# Remove code block formatting | |
text = re.sub(r'```sql\s*|\s*```', '', text) | |
# Remove any leading/trailing whitespace | |
return text.strip() | |
# Function to execute SQL query | |
def get_db_session(): | |
session = Session() | |
try: | |
yield session | |
finally: | |
session.close() | |
def execute_sql_query(query: str) -> str: | |
try: | |
with get_db_session() as session: | |
# Strip markdown formatting before executing | |
clean_query = strip_markdown(query) | |
result = session.execute(text(clean_query)).fetchall() | |
return str(result) | |
except Exception as e: | |
return f"Error executing query: {str(e)}" | |
# Create the answer generation prompt | |
answer_prompt = ChatPromptTemplate.from_messages([ | |
("system", """Given the following user question, corresponding SQL query, and SQL result, answer the user question. | |
If there was an error in executing the SQL query, please explain the error and suggest a correction. | |
Do not include any SQL code formatting or markdown in your response."""), | |
("human", "Question: {question}\nSQL Query: {query}\nSQL Result: {result}\nAnswer:") | |
]) | |
# Assemble the final chain | |
chain = ( | |
RunnablePassthrough.assign(query=lambda x: full_chain.invoke(x)) | |
.assign(result=lambda x: execute_sql_query(x["query"])) | |
| answer_prompt | |
| llm | |
| StrOutputParser() | |
) | |
# Unit test function | |
def unit_test(): | |
print("Running unit test...") | |
# Example query | |
response = chain.invoke({"question": "How many employees are there?"}) | |
print("Final Answer:", response) | |
print("Unit test completed.") | |
# Main function | |
def main(): | |
# Print schema information | |
print("Database Schema Information:") | |
print(formatted_schema_info) | |
# Run unit test | |
unit_test() | |
# Continuously ask the user for queries until "quit" is entered | |
while True: | |
user_question = input("Please enter your query (or type 'quit' to exit): ") | |
if user_question.lower() == 'quit': | |
print("Exiting the program.") | |
break | |
# Process user's query | |
response = chain.invoke({"question": user_question}) | |
print("Final Answer:", response) | |
if __name__ == "__main__": | |
main() | |