simonraj commited on
Commit
eebea6d
1 Parent(s): 52b7aa8

Upload 5 files

Browse files
Files changed (5) hide show
  1. .gitignore +48 -0
  2. app.py +169 -0
  3. app2.py +192 -0
  4. chinook.db +0 -0
  5. requirements.txt +11 -0
.gitignore ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python related
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.pyc
5
+ *.pyo
6
+ *.pyd
7
+ .Python
8
+ build/
9
+ develop-eggs/
10
+ dist/
11
+ downloads/
12
+ eggs/
13
+ .eggs/
14
+ lib/
15
+ lib64/
16
+ parts/
17
+ sdist/
18
+ var/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environment
24
+ venv/
25
+ ENV/
26
+
27
+ # IDEs and editors
28
+ .vscode/
29
+ .idea/
30
+ *.swp
31
+ *.bak
32
+ *.sublime-workspace
33
+
34
+ # OS generated files
35
+ .DS_Store
36
+ Thumbs.db
37
+
38
+ # Jupyter Notebook
39
+ .ipynb_checkpoints
40
+
41
+ # pytest
42
+ .pytest_cache/
43
+
44
+ # mypy
45
+ .mypy_cache/
46
+
47
+ #env virables:
48
+ .env
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import gradio as gr
4
+ from dotenv import load_dotenv
5
+ from langchain_community.utilities import SQLDatabase
6
+ from langchain_openai import ChatOpenAI
7
+ from langchain.chains import create_sql_query_chain
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.prompts import ChatPromptTemplate
10
+ from langchain_core.runnables import RunnablePassthrough
11
+ from langchain_core.output_parsers.openai_tools import PydanticToolsParser
12
+ from langchain_core.pydantic_v1 import BaseModel, Field
13
+ from typing import List
14
+ import sqlite3
15
+
16
+ # Load environment variables from .env file
17
+ load_dotenv()
18
+
19
+ # Set up the database connection
20
+ db_path = os.path.join(os.path.dirname(__file__), "chinook.db")
21
+ db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
22
+
23
+ # Function to get table info
24
+ def get_table_info(db_path):
25
+ conn = sqlite3.connect(db_path)
26
+ cursor = conn.cursor()
27
+
28
+ # Get all table names
29
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
30
+ tables = cursor.fetchall()
31
+
32
+ table_info = {}
33
+ for table in tables:
34
+ table_name = table[0]
35
+ cursor.execute(f"PRAGMA table_info({table_name})")
36
+ columns = cursor.fetchall()
37
+ column_names = [column[1] for column in columns]
38
+ table_info[table_name] = column_names
39
+
40
+ conn.close()
41
+ return table_info
42
+
43
+ # Get table info
44
+ table_info = get_table_info(db_path)
45
+
46
+ # Format table info for display
47
+ def format_table_info(table_info):
48
+ info_str = f"Total number of tables: {len(table_info)}\n\n"
49
+ info_str += "Tables and their columns:\n\n"
50
+ for table, columns in table_info.items():
51
+ info_str += f"{table}:\n"
52
+ for column in columns:
53
+ info_str += f" - {column}\n"
54
+ info_str += "\n"
55
+ return info_str
56
+
57
+ # Initialize the language model
58
+ llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
59
+
60
+ class Table(BaseModel):
61
+ """Table in SQL database."""
62
+ name: str = Field(description="Name of table in SQL database.")
63
+
64
+ # Create the table selection prompt
65
+ table_names = "\n".join(db.get_usable_table_names())
66
+ system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
67
+ The tables are:
68
+
69
+ {table_names}
70
+
71
+ Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""
72
+
73
+ table_prompt = ChatPromptTemplate.from_messages([
74
+ ("system", system),
75
+ ("human", "{input}"),
76
+ ])
77
+
78
+ llm_with_tools = llm.bind_tools([Table])
79
+ output_parser = PydanticToolsParser(tools=[Table])
80
+
81
+ table_chain = table_prompt | llm_with_tools | output_parser
82
+
83
+ # Function to get table names from the output
84
+ def get_table_names(output: List[Table]) -> List[str]:
85
+ return [table.name for table in output]
86
+
87
+ # Create the SQL query chain
88
+ query_chain = create_sql_query_chain(llm, db)
89
+
90
+ # Combine table selection and query generation
91
+ full_chain = (
92
+ RunnablePassthrough.assign(
93
+ table_names_to_use=lambda x: get_table_names(table_chain.invoke({"input": x["question"]}))
94
+ )
95
+ | query_chain
96
+ )
97
+
98
+ # Function to strip markdown formatting from SQL query
99
+ def strip_markdown(text):
100
+ # Remove code block formatting
101
+ text = re.sub(r'```sql\s*|\s*```', '', text)
102
+ # Remove any leading/trailing whitespace
103
+ return text.strip()
104
+
105
+ # Function to execute SQL query
106
+ def execute_query(query: str) -> str:
107
+ try:
108
+ # Strip markdown formatting before executing
109
+ clean_query = strip_markdown(query)
110
+ result = db.run(clean_query)
111
+ return str(result)
112
+ except Exception as e:
113
+ return f"Error executing query: {str(e)}"
114
+
115
+ # Create the answer generation prompt
116
+ answer_prompt = ChatPromptTemplate.from_messages([
117
+ ("system", """Given the following user question, corresponding SQL query, and SQL result, answer the user question.
118
+ If there was an error in executing the SQL query, please explain the error and suggest a correction.
119
+ Do not include any SQL code formatting or markdown in your response.
120
+
121
+ Here is the database schema for reference:
122
+ {table_info}"""),
123
+ ("human", "Question: {question}\nSQL Query: {query}\nSQL Result: {result}\nAnswer:")
124
+ ])
125
+
126
+ # Assemble the final chain
127
+ chain = (
128
+ RunnablePassthrough.assign(query=lambda x: full_chain.invoke(x))
129
+ .assign(result=lambda x: execute_query(x["query"]))
130
+ | answer_prompt
131
+ | llm
132
+ | StrOutputParser()
133
+ )
134
+
135
+ # Function to process user input and generate response
136
+ def process_input(message, history, table_info_str):
137
+ response = chain.invoke({"question": message, "table_info": table_info_str})
138
+ return response
139
+
140
+ # Formatted table info
141
+ formatted_table_info = format_table_info(table_info)
142
+
143
+ # Create Gradio interface
144
+ iface = gr.ChatInterface(
145
+ fn=process_input,
146
+ title="SQL Q&A Chatbot for Chinook Database",
147
+ description="Ask questions about the Chinook music store database and get answers!",
148
+ examples=[
149
+ ["Who are the top 5 artists with the most albums in the database?"],
150
+ ["What is the total sales amount for each country?"],
151
+ ["Which employee has made the highest total sales, and what is the amount?"],
152
+ ["What are the top 10 longest tracks in the database, and who are their artists?"],
153
+ ["How many customers are there in each country, and what is the total sales for each?"]
154
+ ],
155
+ additional_inputs=[
156
+ gr.Textbox(
157
+ label="Database Schema",
158
+ value=formatted_table_info,
159
+ lines=10,
160
+ max_lines=20,
161
+ interactive=False
162
+ )
163
+ ],
164
+ theme="soft"
165
+ )
166
+
167
+ # Launch the interface
168
+ if __name__ == "__main__":
169
+ iface.launch()
app2.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import getpass
4
+ from contextlib import contextmanager
5
+ from typing import List
6
+ from operator import itemgetter
7
+
8
+ from sqlalchemy import create_engine, text, inspect
9
+ from sqlalchemy.orm import sessionmaker
10
+ from dotenv import load_dotenv
11
+
12
+ from langchain_community.utilities import SQLDatabase
13
+ from langchain_openai import ChatOpenAI
14
+ from langchain_core.output_parsers.openai_tools import PydanticToolsParser
15
+ from langchain.chains import create_sql_query_chain
16
+ from langchain_core.output_parsers import StrOutputParser
17
+ from langchain_core.prompts import ChatPromptTemplate
18
+ from langchain_core.runnables import RunnablePassthrough
19
+ from langchain_core.pydantic_v1 import BaseModel, Field
20
+
21
+ # Load environment variables from .env file
22
+ load_dotenv()
23
+
24
+ # Set environment variables for API keys
25
+ if not os.environ.get("OPENAI_API_KEY"):
26
+ os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API key: ")
27
+
28
+ if not os.environ.get("LANGCHAIN_API_KEY"):
29
+ os.environ["LANGCHAIN_API_KEY"] = getpass.getpass("Enter your LangChain API key: ")
30
+ os.environ["LANGCHAIN_TRACING_V2"] = "true"
31
+
32
+ # Setup SQLite Database
33
+ db_path = os.path.join(os.path.dirname(__file__), "chinook.db")
34
+ engine = create_engine(f"sqlite:///{db_path}")
35
+ Session = sessionmaker(bind=engine)
36
+
37
+ db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
38
+ print(db.dialect)
39
+ print(db.get_usable_table_names())
40
+
41
+ with Session() as session:
42
+ result = session.execute(text("SELECT * FROM artists LIMIT 10;")).fetchall()
43
+ print(result)
44
+
45
+ # Initialize LLM
46
+ llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
47
+
48
+ class Table(BaseModel):
49
+ """Table in SQL database."""
50
+ name: str = Field(description="Name of table in SQL database.")
51
+
52
+ # Function to get schema information
53
+ def get_schema_info():
54
+ inspector = inspect(engine)
55
+ schema_info = {}
56
+ for table_name in inspector.get_table_names():
57
+ columns = inspector.get_columns(table_name)
58
+ schema_info[table_name] = [(column["name"], str(column["type"])) for column in columns]
59
+ return schema_info
60
+
61
+ # Provide schema info to LLM
62
+ schema_info = get_schema_info()
63
+ formatted_schema_info = "\n".join(
64
+ f"Table: {table}\nColumns: {', '.join([f'{col[0]} ({col[1]})' for col in cols])}"
65
+ for table, cols in schema_info.items()
66
+ )
67
+
68
+ system = f"""You are an expert in querying SQL databases. The database schema is as follows:
69
+
70
+ {formatted_schema_info}
71
+
72
+ Given an input question, create a syntactically correct SQL query to run, then look at the results of the query and return the answer to the input question.
73
+ Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite.
74
+ You can order the results to return the most informative data in the database. Never query for all columns from a table.
75
+ You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
76
+ Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist.
77
+ Also, pay attention to which column is in which table. Use the following format:
78
+
79
+ SQLQuery: """
80
+
81
+
82
+ table_names = "\n".join(db.get_usable_table_names())
83
+ system_prompt = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
84
+ The tables are:
85
+
86
+ {table_names}
87
+
88
+ Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""
89
+
90
+ prompt = ChatPromptTemplate.from_messages(
91
+ [
92
+ ("system", system_prompt),
93
+ ("human", "{input}"),
94
+ ]
95
+ )
96
+
97
+ llm_with_tools = llm.bind_tools([Table])
98
+ output_parser = PydanticToolsParser(tools=[Table])
99
+
100
+ table_chain = prompt | llm_with_tools | output_parser
101
+
102
+ # Function to get table names from the output
103
+ def get_table_names(output: List[Table]) -> List[str]:
104
+ return [table.name for table in output]
105
+
106
+ # Create the SQL query chain
107
+ query_chain = create_sql_query_chain(llm, db)
108
+
109
+ # Combine table selection and query generation
110
+ full_chain = (
111
+ RunnablePassthrough.assign(
112
+ table_names_to_use=lambda x: get_table_names(table_chain.invoke({"input": x["question"]}))
113
+ )
114
+ | query_chain
115
+ )
116
+
117
+ # Function to strip markdown formatting from SQL query
118
+ def strip_markdown(text):
119
+ # Remove code block formatting
120
+ text = re.sub(r'```sql\s*|\s*```', '', text)
121
+ # Remove any leading/trailing whitespace
122
+ return text.strip()
123
+
124
+ # Function to execute SQL query
125
+ @contextmanager
126
+ def get_db_session():
127
+ session = Session()
128
+ try:
129
+ yield session
130
+ finally:
131
+ session.close()
132
+
133
+ def execute_sql_query(query: str) -> str:
134
+ try:
135
+ with get_db_session() as session:
136
+ # Strip markdown formatting before executing
137
+ clean_query = strip_markdown(query)
138
+ result = session.execute(text(clean_query)).fetchall()
139
+ return str(result)
140
+ except Exception as e:
141
+ return f"Error executing query: {str(e)}"
142
+
143
+ # Create the answer generation prompt
144
+ answer_prompt = ChatPromptTemplate.from_messages([
145
+ ("system", """Given the following user question, corresponding SQL query, and SQL result, answer the user question.
146
+ If there was an error in executing the SQL query, please explain the error and suggest a correction.
147
+ Do not include any SQL code formatting or markdown in your response."""),
148
+ ("human", "Question: {question}\nSQL Query: {query}\nSQL Result: {result}\nAnswer:")
149
+ ])
150
+
151
+
152
+ # Assemble the final chain
153
+ chain = (
154
+ RunnablePassthrough.assign(query=lambda x: full_chain.invoke(x))
155
+ .assign(result=lambda x: execute_sql_query(x["query"]))
156
+ | answer_prompt
157
+ | llm
158
+ | StrOutputParser()
159
+ )
160
+
161
+ # Unit test function
162
+ def unit_test():
163
+ print("Running unit test...")
164
+
165
+ # Example query
166
+ response = chain.invoke({"question": "How many employees are there?"})
167
+ print("Final Answer:", response)
168
+
169
+ print("Unit test completed.")
170
+
171
+ # Main function
172
+ def main():
173
+ # Print schema information
174
+ print("Database Schema Information:")
175
+ print(formatted_schema_info)
176
+
177
+ # Run unit test
178
+ unit_test()
179
+
180
+ # Continuously ask the user for queries until "quit" is entered
181
+ while True:
182
+ user_question = input("Please enter your query (or type 'quit' to exit): ")
183
+ if user_question.lower() == 'quit':
184
+ print("Exiting the program.")
185
+ break
186
+
187
+ # Process user's query
188
+ response = chain.invoke({"question": user_question})
189
+ print("Final Answer:", response)
190
+
191
+ if __name__ == "__main__":
192
+ main()
chinook.db ADDED
Binary file (885 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain
2
+ langchain-community
3
+ langchain-core
4
+ langchain-openai
5
+ langgraph
6
+ openai
7
+ faiss-cpu
8
+ SQLAlchemy
9
+ python-dotenv
10
+ gradio
11
+ langsmith