song-finder-bot / app.py
alonsosilva's picture
Add app
72e2472
#from dotenv import find_dotenv, load_dotenv
#_ = load_dotenv(find_dotenv())
import solara
# Clean up all the directories used in this notebook
import shutil
shutil.rmtree("./data", ignore_errors=True)
import polars as pl
df = pl.read_csv(
"https://drive.google.com/uc?export=download&id=1uD3h7xYxr9EoZ0Ggoh99JtQXa3AxtxyU"
)
import string
df = df.with_columns(
pl.Series("Album", [string.capwords(album) for album in df["Album"]])
)
df = df.with_columns(pl.Series("Song", [string.capwords(song) for song in df["Song"]]))
df = df.with_columns(pl.col("Lyrics").fill_null("None"))
df = df.with_columns(
text=pl.lit("# ")
+ pl.col("Album")
+ pl.lit(": ")
+ pl.col("Song")
+ pl.lit("\n\n")
+ pl.col("Lyrics")
# text = pl.col("Lyrics")
)
import lancedb
db = lancedb.connect("data/")
from lancedb.embeddings import get_registry
embeddings = (
get_registry()
.get("sentence-transformers")
.create(name="TaylorAI/gte-tiny", device="cpu")
)
from lancedb.pydantic import LanceModel, Vector
class Songs(LanceModel):
Song: str
Lyrics: str
Album: str
Artist: str
text: str = embeddings.SourceField()
vector: Vector(embeddings.ndims()) = embeddings.VectorField()
table = db.create_table("Songs", schema=Songs)
table.add(data=df)
import os
from typing import Optional
from langchain_community.chat_models import ChatOpenAI
class ChatOpenRouter(ChatOpenAI):
openai_api_base: str
openai_api_key: str
model_name: str
def __init__(
self,
model_name: str,
openai_api_key: Optional[str] = None,
openai_api_base: str = "https://openrouter.ai/api/v1",
**kwargs,
):
openai_api_key = os.getenv("OPENROUTER_API_KEY")
super().__init__(
openai_api_base=openai_api_base,
openai_api_key=openai_api_key,
model_name=model_name,
**kwargs,
)
llm_openrouter = ChatOpenRouter(model_name="meta-llama/llama-3.1-405b-instruct")
def get_relevant_texts(query, table):
results = (
table.search(query)
.limit(5)
.to_polars()
)
return " ".join([results["text"][i] + "\n\n---\n\n" for i in range(5)])
def generate_prompt(query, table):
return (
"Answer the question based only on the following context:\n\n"
+ get_relevant_texts(query, table)
+ "\n\nQuestion: "
+ query
)
def generate_response(query, table):
prompt = generate_prompt(query, table)
response = llm_openrouter.invoke(input=prompt)
return response.content
query = solara.reactive("Which song is about a boy who is having nightmares?")
@solara.component
def Page():
with solara.Column(margin=10):
solara.Markdown("# Metallica Song Finder Bot")
solara.InputText("Enter some query:", query, continuous_update=False)
if query.value != "":
df_results = table.search(query.value).limit(5).to_polars()
df_results = df_results.select(['Song', 'Album', '_distance', 'Lyrics', 'Artist'])
solara.Markdown("## Answer:")
solara.Markdown(generate_response(query.value, table))
solara.Markdown("## Context:")
solara.DataFrame(df_results, items_per_page=5)