Draichi's picture
chore: organize the code into `/backup, /db, /notebooks, /regression-models` and `/SQL`
b4f5da6 unverified
from sqlalchemy import (
create_engine,
MetaData,
Table,
Column,
String,
Integer,
select,
insert,
text
)
import logging
import sys
from llama_index.core import SQLDatabase, VectorStoreIndex
from llama_index.core.query_engine import NLSQLTableQueryEngine, RetrieverQueryEngine
from llama_index.llms.ollama import Ollama
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings
from llama_index.core.retrievers import NLSQLRetriever
from llama_index.core.indices.struct_store.sql_query import (
SQLTableRetrieverQueryEngine,
)
from llama_index.core.objects import (
SQLTableNodeMapping,
ObjectIndex,
SQLTableSchema,
)
from rich import print
llm = Ollama(model="phi3", request_timeout=360.0)
Settings.llm = llm
Settings.embed_model = HuggingFaceEmbedding(
model_name="BAAI/bge-small-en-v1.5"
)
# Start Create Database Schema
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
table_name = "city_stats"
city_stats_table = Table(
table_name,
metadata_obj,
Column("city_name", String(16), primary_key=True),
Column("population", Integer),
Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)
# Finish Create Database Schema
# Start Define SQL Database
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
rows = [
{"city_name": "Toronto", "population": 2930000, "country": "Canada"},
{"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
{
"city_name": "Chicago",
"population": 2679000,
"country": "United States",
},
{"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
stmt = insert(city_stats_table).values(**row)
with engine.begin() as connection:
cursor = connection.execute(stmt)
# view current table
# stmt = select(
# city_stats_table.c.city_name,
# city_stats_table.c.population,
# city_stats_table.c.country,
# ).select_from(city_stats_table)
# with engine.connect() as connection:
# results = connection.execute(stmt).fetchall()
# print(f"results: {results}")
# Finish Define SQL Database
# --------------------------------
# Part 1: Text-to-SQL Query Engine
# query_engine = NLSQLTableQueryEngine(
# sql_database=sql_database, tables=["city_stats"], llm=llm
# )
# query_str = "Which city has the highest population?"
# response = query_engine.query(query_str)
# print("\n-----------Part 1---------------\n")
# print(f"Question: {query_str}\n")
# print(f"Response: {response}")
# --------------------------------
# Part 2: Query-Time Retrieval of Tables for Text-to-SQL
# set Logging to DEBUG for more detailed outputs
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
# manually set context text
# city_stats_text = (
# "This table gives information regarding the population and country of a"
# " given city.\nThe user will query with codewords, where 'foo' corresponds"
# " to population and 'bar' corresponds to city."
# )
# table_node_mapping = SQLTableNodeMapping(sql_database)
# table_schema_objs = [
# (SQLTableSchema(table_name="city_stats", context_str=city_stats_text))
# ] # add a SQLTableSchema for each table
# obj_index = ObjectIndex.from_objects(
# table_schema_objs,
# table_node_mapping,
# VectorStoreIndex,
# )
# query_engine = SQLTableRetrieverQueryEngine(
# sql_database, obj_index.as_retriever(similarity_top_k=1, llm=llm)
# )
# query_str = "Which city has the highest population?"
# response = query_engine.query(query_str)
# print("[dark_magenta on grey7]\n\n---------Part 2-----------------\n[/dark_magenta on grey7]")
# print(f"[chartreuse3 on grey7]Question: {query_str}[/chartreuse3 on grey7]\n")
# print(
# f"[bold chartreuse1 on grey7]Response: {response}[/bold chartreuse1 on grey7]\n")
# print(
# f"[chartreuse3 on grey7]metadata: {response.metadata}[/chartreuse3 on grey7]\n")
# --------------------------------
# Part 3: Text-to-SQL Retriever
"""
So far our text-to-SQL capability is packaged in a
query engine and consists of both retrieval and synthesis.
You can use the SQL retriever on its own.
We show you some different parameters you can try,
and also show how to plug it into our RetrieverQueryEngine
to get roughly the same results.
"""
# default retrieval (return_raw=True)
nl_sql_retriever = NLSQLRetriever(
sql_database, tables=["city_stats"], return_raw=True
)
# results = nl_sql_retriever.retrieve(
# "Return the top 5 cities (along with their populations) with the highest population."
# )
# print(len(results))
# for n in results:
# print(
# f"[bold chartreuse1 on grey7]> n: {n.metadata}[/bold chartreuse1 on grey7]\n")
"""
# Plug into our RetrieverQueryEngine
We compose our SQL Retriever with our standard RetrieverQueryEngine
to synthesize a response. The result is roughly similar
to our packaged Text-to-SQL query engines.
"""
query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever)
response = query_engine.query(
"Return the top 5 cities (along with their populations) with the highest population."
)
print(
f"[bold chartreuse1 on grey7]> Response: {str(response)}[/bold chartreuse1 on grey7]\n")
""""
> Response: Tokyo - 13,960,000
Seoul - 9,776,000
Toronto - 2,930,000
Chicago - 2,679,000
"""