File size: 6,730 Bytes
eebea6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
import re
import getpass
from contextlib import contextmanager
from typing import List
from operator import itemgetter

from sqlalchemy import create_engine, text, inspect
from sqlalchemy.orm import sessionmaker
from dotenv import load_dotenv

from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers.openai_tools import PydanticToolsParser
from langchain.chains import create_sql_query_chain
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.pydantic_v1 import BaseModel, Field

# Load environment variables from .env file
load_dotenv()

# Set environment variables for API keys
if not os.environ.get("OPENAI_API_KEY"):
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API key: ")

if not os.environ.get("LANGCHAIN_API_KEY"):
    os.environ["LANGCHAIN_API_KEY"] = getpass.getpass("Enter your LangChain API key: ")
    os.environ["LANGCHAIN_TRACING_V2"] = "true"

# Setup SQLite Database
db_path = os.path.join(os.path.dirname(__file__), "chinook.db")
engine = create_engine(f"sqlite:///{db_path}")
Session = sessionmaker(bind=engine)

db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
print(db.dialect)
print(db.get_usable_table_names())

with Session() as session:
    result = session.execute(text("SELECT * FROM artists LIMIT 10;")).fetchall()
    print(result)

# Initialize LLM
llm = ChatOpenAI(model="gpt-3.5-turbo-0125")

class Table(BaseModel):
    """Table in SQL database."""
    name: str = Field(description="Name of table in SQL database.")

# Function to get schema information
def get_schema_info():
    inspector = inspect(engine)
    schema_info = {}
    for table_name in inspector.get_table_names():
        columns = inspector.get_columns(table_name)
        schema_info[table_name] = [(column["name"], str(column["type"])) for column in columns]
    return schema_info

# Provide schema info to LLM
schema_info = get_schema_info()
formatted_schema_info = "\n".join(
    f"Table: {table}\nColumns: {', '.join([f'{col[0]} ({col[1]})' for col in cols])}"
    for table, cols in schema_info.items()
)

system = f"""You are an expert in querying SQL databases. The database schema is as follows:



{formatted_schema_info}



Given an input question, create a syntactically correct SQL query to run, then look at the results of the query and return the answer to the input question.

Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite.

You can order the results to return the most informative data in the database. Never query for all columns from a table.

You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.

Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist.

Also, pay attention to which column is in which table. Use the following format:



SQLQuery: """


table_names = "\n".join(db.get_usable_table_names())
system_prompt = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \

The tables are:



{table_names}



Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("human", "{input}"),
    ]
)

llm_with_tools = llm.bind_tools([Table])
output_parser = PydanticToolsParser(tools=[Table])

table_chain = prompt | llm_with_tools | output_parser

# Function to get table names from the output
def get_table_names(output: List[Table]) -> List[str]:
    return [table.name for table in output]

# Create the SQL query chain
query_chain = create_sql_query_chain(llm, db)

# Combine table selection and query generation
full_chain = (
    RunnablePassthrough.assign(
        table_names_to_use=lambda x: get_table_names(table_chain.invoke({"input": x["question"]}))
    )
    | query_chain
)

# Function to strip markdown formatting from SQL query
def strip_markdown(text):
    # Remove code block formatting
    text = re.sub(r'```sql\s*|\s*```', '', text)
    # Remove any leading/trailing whitespace
    return text.strip()

# Function to execute SQL query
@contextmanager
def get_db_session():
    session = Session()
    try:
        yield session
    finally:
        session.close()

def execute_sql_query(query: str) -> str:
    try:
        with get_db_session() as session:
            # Strip markdown formatting before executing
            clean_query = strip_markdown(query)
            result = session.execute(text(clean_query)).fetchall()
        return str(result)
    except Exception as e:
        return f"Error executing query: {str(e)}"

# Create the answer generation prompt
answer_prompt = ChatPromptTemplate.from_messages([
    ("system", """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

    If there was an error in executing the SQL query, please explain the error and suggest a correction.

    Do not include any SQL code formatting or markdown in your response."""),
    ("human", "Question: {question}\nSQL Query: {query}\nSQL Result: {result}\nAnswer:")
])


# Assemble the final chain
chain = (
    RunnablePassthrough.assign(query=lambda x: full_chain.invoke(x))
    .assign(result=lambda x: execute_sql_query(x["query"]))
    | answer_prompt
    | llm
    | StrOutputParser()
)

# Unit test function
def unit_test():
    print("Running unit test...")

    # Example query
    response = chain.invoke({"question": "How many employees are there?"})
    print("Final Answer:", response)

    print("Unit test completed.")

# Main function
def main():
    # Print schema information
    print("Database Schema Information:")
    print(formatted_schema_info)
    
    # Run unit test
    unit_test()

    # Continuously ask the user for queries until "quit" is entered
    while True:
        user_question = input("Please enter your query (or type 'quit' to exit): ")
        if user_question.lower() == 'quit':
            print("Exiting the program.")
            break

        # Process user's query
        response = chain.invoke({"question": user_question})
        print("Final Answer:", response)

if __name__ == "__main__":
    main()