alonsosilva's picture
Add app
a539bdc
# from dotenv import find_dotenv, load_dotenv
# _ = load_dotenv(find_dotenv())
import solara
import polars as pl
df = pl.read_csv(
"https://drive.google.com/uc?export=download&id=1uD3h7xYxr9EoZ0Ggoh99JtQXa3AxtxyU"
)
import string
df = df.with_columns(
pl.Series("Album", [string.capwords(album) for album in df["Album"]])
)
df = df.with_columns(pl.Series("Song", [string.capwords(song) for song in df["Song"]]))
df = df.with_columns(pl.col("Lyrics").fill_null("None"))
df = df.with_columns(
text=pl.lit("# ")
+ pl.col("Album")
+ pl.lit(": ")
+ pl.col("Song")
+ pl.lit("\n\n")
+ pl.col("Lyrics")
)
import shutil
import lancedb
shutil.rmtree("test_lancedb", ignore_errors=True)
db = lancedb.connect("test_lancedb")
from lancedb.embeddings import get_registry
embeddings = (
get_registry()
.get("sentence-transformers")
.create(name="TaylorAI/gte-tiny", device="cpu")
)
from lancedb.pydantic import LanceModel, Vector
class Songs(LanceModel):
Song: str
Lyrics: str
Album: str
Artist: str
text: str = embeddings.SourceField()
vector: Vector(embeddings.ndims()) = embeddings.VectorField()
table = db.create_table("Songs", schema=Songs)
table.add(data=df)
import os
from typing import Optional
from langchain_community.chat_models import ChatOpenAI
class ChatOpenRouter(ChatOpenAI):
openai_api_base: str
openai_api_key: str
model_name: str
def __init__(
self,
model_name: str,
openai_api_key: Optional[str] = None,
openai_api_base: str = "https://openrouter.ai/api/v1",
**kwargs,
):
openai_api_key = os.getenv("OPENROUTER_API_KEY")
super().__init__(
openai_api_base=openai_api_base,
openai_api_key=openai_api_key,
model_name=model_name,
**kwargs,
)
llm_openrouter = ChatOpenRouter(model_name="meta-llama/llama-3.1-405b-instruct", temperature=0.1)
def get_relevant_texts(query, table=table):
results = (
table.search(query)
.limit(5)
.to_polars()
)
return " ".join([results["text"][i] + "\n\n---\n\n" for i in range(5)])
def generate_prompt(query, table=table):
return (
"Answer the question based only on the following context:\n\n"
+ get_relevant_texts(query, table)
+ "\n\nQuestion: "
+ query
)
def generate_response(query, table=table):
prompt = generate_prompt(query, table)
response = llm_openrouter.invoke(input=prompt)
return response.content
import kuzu
shutil.rmtree("test_kuzudb", ignore_errors=True)
db = kuzu.Database("test_kuzudb")
conn = kuzu.Connection(db)
# Create schema
conn.execute("CREATE NODE TABLE ARTIST(name STRING, PRIMARY KEY (name))")
conn.execute("CREATE NODE TABLE ALBUM(name STRING, PRIMARY KEY (name))")
conn.execute("CREATE NODE TABLE SONG(ID SERIAL, name STRING, lyrics STRING, PRIMARY KEY(ID))")
conn.execute("CREATE REL TABLE IN_ALBUM(FROM SONG TO ALBUM)")
conn.execute("CREATE REL TABLE FROM_ARTIST(FROM ALBUM TO ARTIST)");
# Insert nodes
for artist in df["Artist"].unique():
conn.execute(f"CREATE (artist:ARTIST {{name: '{artist}'}})")
for album in df["Album"].unique():
conn.execute(f"""CREATE (album:ALBUM {{name: "{album}"}})""")
for song, lyrics in df.select(["Song", "text"]).unique().rows():
replaced_lyrics = lyrics.replace('"', "'")
conn.execute(
f"""CREATE (song:SONG {{name: "{song}", lyrics: "{replaced_lyrics}"}})"""
)
# Insert edges
for song, album, lyrics in df.select(["Song", "Album", "text"]).rows():
replaced_lyrics = lyrics.replace('"', "'")
conn.execute(
f"""
MATCH (song:SONG), (album:ALBUM)
WHERE song.name = "{song}" AND song.lyrics = "{replaced_lyrics}" AND album.name = "{album}"
CREATE (song)-[:IN_ALBUM]->(album)
"""
)
for album, artist in df.select(["Album", "Artist"]).unique().rows():
conn.execute(
f"""
MATCH (album:ALBUM), (artist:ARTIST) WHERE album.name = "{album}" AND artist.name = "{artist}"
CREATE (album)-[:FROM_ARTIST]->(artist)
"""
)
response = conn.execute(
"""
MATCH (a:ALBUM {name: 'The Black Album'})<-[:IN_ALBUM]-(s:SONG) RETURN s.name
"""
)
df_response = response.get_as_pl()
from langchain_community.graphs import KuzuGraph
graph = KuzuGraph(db)
def generate_kuzu_prompt(user_query):
return """Task: Generate Kùzu Cypher statement to query a graph database.
Instructions:
Generate the Kùzu dialect of Cypher with the following rules in mind:
1. Do not omit the relationship pattern. Always use `()-[]->()` instead of `()->()`.
2. Do not include triple backticks ``` in your response. Return only Cypher.
3. Do not return any notes or comments in your response.
Use only the provided relationship types and properties in the schema.
Do not use any other relationship types or properties that are not provided.
Schema:\n""" + graph.get_schema + """\nExample:
The question is:\n"Which songs does the load album have?"
MATCH (a:ALBUM {name: 'Load'})<-[:IN_ALBUM]-(s:SONG) RETURN s.name
Note: Do not include any explanations or apologies in your responses.
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
Do not include any text except the generated Cypher statement.
The question is:\n""" + user_query
def generate_final_prompt(query,cypher_query,col_name,_values):
return f"""You are an assistant that helps to form nice and human understandable answers.
The information part contains the provided information that you must use to construct an answer.
The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
Make the answer sound as a response to the question. Do not mention that you based the result on the given information.
Here is an example:
Question: Which managers own Neo4j stocks?
Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC]
Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks.
Follow this example when generating answers.
If the provided information is empty, say that you don't know the answer.
Query:\n{cypher_query}
Information:
[{col_name}: {_values}]
Question: {query}
Helpful Answer:
"""
def generate_kg_response(query):
prompt = generate_kuzu_prompt(query)
cypher_query_response = llm_openrouter.invoke(input=prompt)
cypher_query = cypher_query_response.content
response = conn.execute(
f"""
{cypher_query}
"""
)
df = response.get_as_pl()
col_name = df.columns[0]
_values = df[col_name].to_list()
final_prompt = generate_final_prompt(query,cypher_query,col_name,_values)
final_response = llm_openrouter.invoke(input=final_prompt)
final_response = final_response.content
return final_response, cypher_query
def get_classification(query):
prompt = "Answer only YES or NO. Is the question '" + query + "' related to the content of a song?"
response = llm_openrouter.invoke(input=prompt)
return response.content
query = solara.reactive("How many songs does the black album have?")
@solara.component
def Page():
with solara.Column(margin=10):
solara.Markdown("# Metallica Song Finder Graph RAG")
solara.InputText("Enter some query:", query, continuous_update=False)
if query.value != "":
query_class = get_classification(query.value)
if query_class == 'YES' or query_class == 'YES.':
df_results = table.search(query.value).limit(5).to_polars()
df_results = df_results.select(['Song', 'Album', '_distance', 'Lyrics', 'Artist'])
response = generate_response(query.value)
solara.Markdown("## Answer:")
solara.Markdown(response)
solara.Markdown("## Context:")
solara.DataFrame(df_results, items_per_page=5)
else:
response, cypher_query = generate_kg_response(query.value)
solara.Markdown("## Answer:")
solara.Markdown(response)
solara.Markdown("## Cypher query:")
solara.Markdown(cypher_query)