Spaces:
Running
Running
#from dotenv import find_dotenv, load_dotenv | |
#_ = load_dotenv(find_dotenv()) | |
import solara | |
# Clean up all the directories used in this notebook | |
import shutil | |
shutil.rmtree("./data", ignore_errors=True) | |
import polars as pl | |
df = pl.read_csv( | |
"https://drive.google.com/uc?export=download&id=1uD3h7xYxr9EoZ0Ggoh99JtQXa3AxtxyU" | |
) | |
import string | |
df = df.with_columns( | |
pl.Series("Album", [string.capwords(album) for album in df["Album"]]) | |
) | |
df = df.with_columns(pl.Series("Song", [string.capwords(song) for song in df["Song"]])) | |
df = df.with_columns(pl.col("Lyrics").fill_null("None")) | |
df = df.with_columns( | |
text=pl.lit("# ") | |
+ pl.col("Album") | |
+ pl.lit(": ") | |
+ pl.col("Song") | |
+ pl.lit("\n\n") | |
+ pl.col("Lyrics") | |
# text = pl.col("Lyrics") | |
) | |
import lancedb | |
db = lancedb.connect("data/") | |
from lancedb.embeddings import get_registry | |
embeddings = ( | |
get_registry() | |
.get("sentence-transformers") | |
.create(name="TaylorAI/gte-tiny", device="cpu") | |
) | |
from lancedb.pydantic import LanceModel, Vector | |
class Songs(LanceModel): | |
Song: str | |
Lyrics: str | |
Album: str | |
Artist: str | |
text: str = embeddings.SourceField() | |
vector: Vector(embeddings.ndims()) = embeddings.VectorField() | |
table = db.create_table("Songs", schema=Songs) | |
table.add(data=df) | |
import os | |
from typing import Optional | |
from langchain_community.chat_models import ChatOpenAI | |
class ChatOpenRouter(ChatOpenAI): | |
openai_api_base: str | |
openai_api_key: str | |
model_name: str | |
def __init__( | |
self, | |
model_name: str, | |
openai_api_key: Optional[str] = None, | |
openai_api_base: str = "https://openrouter.ai/api/v1", | |
**kwargs, | |
): | |
openai_api_key = os.getenv("OPENROUTER_API_KEY") | |
super().__init__( | |
openai_api_base=openai_api_base, | |
openai_api_key=openai_api_key, | |
model_name=model_name, | |
**kwargs, | |
) | |
llm_openrouter = ChatOpenRouter(model_name="meta-llama/llama-3.1-405b-instruct") | |
def get_relevant_texts(query, table): | |
results = ( | |
table.search(query) | |
.limit(5) | |
.to_polars() | |
) | |
return " ".join([results["text"][i] + "\n\n---\n\n" for i in range(5)]) | |
def generate_prompt(query, table): | |
return ( | |
"Answer the question based only on the following context:\n\n" | |
+ get_relevant_texts(query, table) | |
+ "\n\nQuestion: " | |
+ query | |
) | |
def generate_response(query, table): | |
prompt = generate_prompt(query, table) | |
response = llm_openrouter.invoke(input=prompt) | |
return response.content | |
query = solara.reactive("Which song is about a boy who is having nightmares?") | |
def Page(): | |
with solara.Column(margin=10): | |
solara.Markdown("# Metallica Song Finder Bot") | |
solara.InputText("Enter some query:", query, continuous_update=False) | |
if query.value != "": | |
df_results = table.search(query.value).limit(5).to_polars() | |
df_results = df_results.select(['Song', 'Album', '_distance', 'Lyrics', 'Artist']) | |
solara.Markdown("## Answer:") | |
solara.Markdown(generate_response(query.value, table)) | |
solara.Markdown("## Context:") | |
solara.DataFrame(df_results, items_per_page=5) | |