|
|
|
|
|
|
|
import solara |
|
|
|
import polars as pl |
|
|
|
df = pl.read_csv( |
|
"https://drive.google.com/uc?export=download&id=1uD3h7xYxr9EoZ0Ggoh99JtQXa3AxtxyU" |
|
) |
|
|
|
import string |
|
|
|
df = df.with_columns( |
|
pl.Series("Album", [string.capwords(album) for album in df["Album"]]) |
|
) |
|
df = df.with_columns(pl.Series("Song", [string.capwords(song) for song in df["Song"]])) |
|
df = df.with_columns(pl.col("Lyrics").fill_null("None")) |
|
|
|
df = df.with_columns( |
|
text=pl.lit("# ") |
|
+ pl.col("Album") |
|
+ pl.lit(": ") |
|
+ pl.col("Song") |
|
+ pl.lit("\n\n") |
|
+ pl.col("Lyrics") |
|
) |
|
|
|
import shutil |
|
import lancedb |
|
|
|
shutil.rmtree("test_lancedb", ignore_errors=True) |
|
db = lancedb.connect("test_lancedb") |
|
|
|
from lancedb.embeddings import get_registry |
|
|
|
embeddings = ( |
|
get_registry() |
|
.get("sentence-transformers") |
|
.create(name="TaylorAI/gte-tiny", device="cpu") |
|
) |
|
|
|
from lancedb.pydantic import LanceModel, Vector |
|
|
|
|
|
class Songs(LanceModel): |
|
Song: str |
|
Lyrics: str |
|
Album: str |
|
Artist: str |
|
text: str = embeddings.SourceField() |
|
vector: Vector(embeddings.ndims()) = embeddings.VectorField() |
|
|
|
table = db.create_table("Songs", schema=Songs) |
|
table.add(data=df) |
|
|
|
import os |
|
from typing import Optional |
|
|
|
from langchain_community.chat_models import ChatOpenAI |
|
|
|
class ChatOpenRouter(ChatOpenAI): |
|
openai_api_base: str |
|
openai_api_key: str |
|
model_name: str |
|
|
|
def __init__( |
|
self, |
|
model_name: str, |
|
openai_api_key: Optional[str] = None, |
|
openai_api_base: str = "https://openrouter.ai/api/v1", |
|
**kwargs, |
|
): |
|
openai_api_key = os.getenv("OPENROUTER_API_KEY") |
|
super().__init__( |
|
openai_api_base=openai_api_base, |
|
openai_api_key=openai_api_key, |
|
model_name=model_name, |
|
**kwargs, |
|
) |
|
|
|
llm_openrouter = ChatOpenRouter(model_name="meta-llama/llama-3.1-405b-instruct", temperature=0.1) |
|
|
|
def get_relevant_texts(query, table=table): |
|
results = ( |
|
table.search(query) |
|
.limit(5) |
|
.to_polars() |
|
) |
|
return " ".join([results["text"][i] + "\n\n---\n\n" for i in range(5)]) |
|
|
|
def generate_prompt(query, table=table): |
|
return ( |
|
"Answer the question based only on the following context:\n\n" |
|
+ get_relevant_texts(query, table) |
|
+ "\n\nQuestion: " |
|
+ query |
|
) |
|
|
|
def generate_response(query, table=table): |
|
prompt = generate_prompt(query, table) |
|
response = llm_openrouter.invoke(input=prompt) |
|
return response.content |
|
|
|
import kuzu |
|
|
|
shutil.rmtree("test_kuzudb", ignore_errors=True) |
|
db = kuzu.Database("test_kuzudb") |
|
conn = kuzu.Connection(db) |
|
|
|
conn.execute("CREATE NODE TABLE ARTIST(name STRING, PRIMARY KEY (name))") |
|
conn.execute("CREATE NODE TABLE ALBUM(name STRING, PRIMARY KEY (name))") |
|
conn.execute("CREATE NODE TABLE SONG(ID SERIAL, name STRING, lyrics STRING, PRIMARY KEY(ID))") |
|
conn.execute("CREATE REL TABLE IN_ALBUM(FROM SONG TO ALBUM)") |
|
conn.execute("CREATE REL TABLE FROM_ARTIST(FROM ALBUM TO ARTIST)"); |
|
|
|
|
|
for artist in df["Artist"].unique(): |
|
conn.execute(f"CREATE (artist:ARTIST {{name: '{artist}'}})") |
|
|
|
for album in df["Album"].unique(): |
|
conn.execute(f"""CREATE (album:ALBUM {{name: "{album}"}})""") |
|
|
|
for song, lyrics in df.select(["Song", "text"]).unique().rows(): |
|
replaced_lyrics = lyrics.replace('"', "'") |
|
conn.execute( |
|
f"""CREATE (song:SONG {{name: "{song}", lyrics: "{replaced_lyrics}"}})""" |
|
) |
|
|
|
|
|
for song, album, lyrics in df.select(["Song", "Album", "text"]).rows(): |
|
replaced_lyrics = lyrics.replace('"', "'") |
|
conn.execute( |
|
f""" |
|
MATCH (song:SONG), (album:ALBUM) |
|
WHERE song.name = "{song}" AND song.lyrics = "{replaced_lyrics}" AND album.name = "{album}" |
|
CREATE (song)-[:IN_ALBUM]->(album) |
|
""" |
|
) |
|
|
|
for album, artist in df.select(["Album", "Artist"]).unique().rows(): |
|
conn.execute( |
|
f""" |
|
MATCH (album:ALBUM), (artist:ARTIST) WHERE album.name = "{album}" AND artist.name = "{artist}" |
|
CREATE (album)-[:FROM_ARTIST]->(artist) |
|
""" |
|
) |
|
|
|
response = conn.execute( |
|
""" |
|
MATCH (a:ALBUM {name: 'The Black Album'})<-[:IN_ALBUM]-(s:SONG) RETURN s.name |
|
""" |
|
) |
|
|
|
df_response = response.get_as_pl() |
|
|
|
from langchain_community.graphs import KuzuGraph |
|
|
|
graph = KuzuGraph(db) |
|
|
|
def generate_kuzu_prompt(user_query): |
|
return """Task: Generate Kùzu Cypher statement to query a graph database. |
|
|
|
Instructions: |
|
Generate the Kùzu dialect of Cypher with the following rules in mind: |
|
1. Do not omit the relationship pattern. Always use `()-[]->()` instead of `()->()`. |
|
2. Do not include triple backticks ``` in your response. Return only Cypher. |
|
3. Do not return any notes or comments in your response. |
|
|
|
|
|
Use only the provided relationship types and properties in the schema. |
|
Do not use any other relationship types or properties that are not provided. |
|
|
|
Schema:\n""" + graph.get_schema + """\nExample: |
|
The question is:\n"Which songs does the load album have?" |
|
MATCH (a:ALBUM {name: 'Load'})<-[:IN_ALBUM]-(s:SONG) RETURN s.name |
|
|
|
Note: Do not include any explanations or apologies in your responses. |
|
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement. |
|
Do not include any text except the generated Cypher statement. |
|
|
|
The question is:\n""" + user_query |
|
|
|
|
|
def generate_final_prompt(query,cypher_query,col_name,_values): |
|
return f"""You are an assistant that helps to form nice and human understandable answers. |
|
The information part contains the provided information that you must use to construct an answer. |
|
The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it. |
|
Make the answer sound as a response to the question. Do not mention that you based the result on the given information. |
|
Here is an example: |
|
|
|
Question: Which managers own Neo4j stocks? |
|
Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC] |
|
Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks. |
|
|
|
Follow this example when generating answers. |
|
If the provided information is empty, say that you don't know the answer. |
|
Query:\n{cypher_query} |
|
Information: |
|
[{col_name}: {_values}] |
|
|
|
Question: {query} |
|
Helpful Answer: |
|
""" |
|
|
|
def generate_kg_response(query): |
|
prompt = generate_kuzu_prompt(query) |
|
cypher_query_response = llm_openrouter.invoke(input=prompt) |
|
cypher_query = cypher_query_response.content |
|
response = conn.execute( |
|
f""" |
|
{cypher_query} |
|
""" |
|
) |
|
df = response.get_as_pl() |
|
col_name = df.columns[0] |
|
_values = df[col_name].to_list() |
|
final_prompt = generate_final_prompt(query,cypher_query,col_name,_values) |
|
final_response = llm_openrouter.invoke(input=final_prompt) |
|
final_response = final_response.content |
|
return final_response, cypher_query |
|
|
|
def get_classification(query): |
|
prompt = "Answer only YES or NO. Is the question '" + query + "' related to the content of a song?" |
|
response = llm_openrouter.invoke(input=prompt) |
|
return response.content |
|
|
|
query = solara.reactive("How many songs does the black album have?") |
|
@solara.component |
|
def Page(): |
|
with solara.Column(margin=10): |
|
solara.Markdown("# Metallica Song Finder Graph RAG") |
|
solara.InputText("Enter some query:", query, continuous_update=False) |
|
if query.value != "": |
|
query_class = get_classification(query.value) |
|
if query_class == 'YES' or query_class == 'YES.': |
|
df_results = table.search(query.value).limit(5).to_polars() |
|
df_results = df_results.select(['Song', 'Album', '_distance', 'Lyrics', 'Artist']) |
|
response = generate_response(query.value) |
|
solara.Markdown("## Answer:") |
|
solara.Markdown(response) |
|
solara.Markdown("## Context:") |
|
solara.DataFrame(df_results, items_per_page=5) |
|
else: |
|
response, cypher_query = generate_kg_response(query.value) |
|
solara.Markdown("## Answer:") |
|
solara.Markdown(response) |
|
solara.Markdown("## Cypher query:") |
|
solara.Markdown(cypher_query) |
|
|