|
from sqlalchemy import ( |
|
create_engine, |
|
MetaData, |
|
Table, |
|
Column, |
|
String, |
|
Integer, |
|
select, |
|
insert, |
|
text |
|
) |
|
import logging |
|
import sys |
|
from llama_index.core import SQLDatabase, VectorStoreIndex |
|
from llama_index.core.query_engine import NLSQLTableQueryEngine, RetrieverQueryEngine |
|
from llama_index.llms.ollama import Ollama |
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
|
from llama_index.core import Settings |
|
from llama_index.core.retrievers import NLSQLRetriever |
|
from llama_index.core.indices.struct_store.sql_query import ( |
|
SQLTableRetrieverQueryEngine, |
|
) |
|
from llama_index.core.objects import ( |
|
SQLTableNodeMapping, |
|
ObjectIndex, |
|
SQLTableSchema, |
|
) |
|
from rich import print |
|
|
|
|
|
llm = Ollama(model="phi3", request_timeout=360.0) |
|
|
|
Settings.llm = llm |
|
|
|
Settings.embed_model = HuggingFaceEmbedding( |
|
model_name="BAAI/bge-small-en-v1.5" |
|
) |
|
|
|
|
|
|
|
engine = create_engine("sqlite:///:memory:") |
|
metadata_obj = MetaData() |
|
|
|
table_name = "city_stats" |
|
city_stats_table = Table( |
|
table_name, |
|
metadata_obj, |
|
Column("city_name", String(16), primary_key=True), |
|
Column("population", Integer), |
|
Column("country", String(16), nullable=False), |
|
) |
|
metadata_obj.create_all(engine) |
|
|
|
|
|
|
|
sql_database = SQLDatabase(engine, include_tables=["city_stats"]) |
|
|
|
rows = [ |
|
{"city_name": "Toronto", "population": 2930000, "country": "Canada"}, |
|
{"city_name": "Tokyo", "population": 13960000, "country": "Japan"}, |
|
{ |
|
"city_name": "Chicago", |
|
"population": 2679000, |
|
"country": "United States", |
|
}, |
|
{"city_name": "Seoul", "population": 9776000, "country": "South Korea"}, |
|
] |
|
for row in rows: |
|
stmt = insert(city_stats_table).values(**row) |
|
with engine.begin() as connection: |
|
cursor = connection.execute(stmt) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(stream=sys.stdout, level=logging.INFO) |
|
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
So far our text-to-SQL capability is packaged in a |
|
query engine and consists of both retrieval and synthesis. |
|
|
|
You can use the SQL retriever on its own. |
|
We show you some different parameters you can try, |
|
and also show how to plug it into our RetrieverQueryEngine |
|
to get roughly the same results. |
|
""" |
|
|
|
|
|
|
|
nl_sql_retriever = NLSQLRetriever( |
|
sql_database, tables=["city_stats"], return_raw=True |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
# Plug into our RetrieverQueryEngine |
|
|
|
We compose our SQL Retriever with our standard RetrieverQueryEngine |
|
to synthesize a response. The result is roughly similar |
|
to our packaged Text-to-SQL query engines. |
|
""" |
|
|
|
query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever) |
|
|
|
response = query_engine.query( |
|
"Return the top 5 cities (along with their populations) with the highest population." |
|
) |
|
|
|
print( |
|
f"[bold chartreuse1 on grey7]> Response: {str(response)}[/bold chartreuse1 on grey7]\n") |
|
"""" |
|
> Response: Tokyo - 13,960,000 |
|
Seoul - 9,776,000 |
|
Toronto - 2,930,000 |
|
Chicago - 2,679,000 |
|
""" |
|
|