hot-ones-trivia / run_trivia.py
RedTachyon's picture
Initial commit from GitHub repository without history
31b6e27
raw
history blame
4.24 kB
import pickle
import json
import dotenv
import gradio as gr
import numpy as np
import random
from typarse import BaseParser
from core import get_one_embedding, Chunk, Dataset
from openai import OpenAI
from prompts import get_initial_messages
random.seed(42)
class Parser(BaseParser):
data_path: str = "data4k.pkl"
questions_path: str = "questions.json"
def cosine_similarity(query: np.ndarray, embeddings: np.ndarray) -> np.ndarray:
dot_product = np.dot(embeddings, query)
query_norm = np.linalg.norm(query)
embeddings_norm = np.linalg.norm(embeddings, axis=1)
return dot_product / (query_norm * embeddings_norm)
def rank_chunks(
client: OpenAI,
question: str,
dataset: Dataset,
model: str = "text-embedding-3-small",
) -> list[Chunk]:
embeddings = dataset.embeddings
chunk_metadata = dataset.chunks
q_embedding = get_one_embedding(client, question, model)
similarities = cosine_similarity(q_embedding, embeddings)
sorted_indices = np.argsort(similarities)[::-1]
return [chunk_metadata[i] for i in sorted_indices]
if __name__ == "__main__":
dotenv.load_dotenv()
args = Parser()
client = OpenAI()
with open(args.data_path, "rb") as f:
data: Dataset = pickle.load(f)
with open(args.questions_path, "r") as f:
questions = json.load(f)
select_questions = random.sample(questions, 3)
select_questions = [
"Which guest worked at Abercrombie and Fitch?",
"Who failed making pastries as a teenager?",
] + select_questions
def get_answer(query: str) -> tuple[str, str]:
sorted_chunks = rank_chunks(client, query, data)
best_chunk = sorted_chunks[0]
print(f"Looking at chunk from video {best_chunk.title}")
messages = get_initial_messages(query, best_chunk)
completion = client.chat.completions.create(
model="gpt-4o",
messages=messages,
)
context = f"Looking at the video titled {best_chunk.title}"
answer = completion.choices[0].message.content
answer = answer if "<|UNKNOWN|>" not in answer else "Couldn't find the answer."
return answer, context
def get_answer_better(query: str) -> str:
print(f"Looking for answer to question: {query}")
sorted_chunks = rank_chunks(client, query, data)
for chunk in sorted_chunks:
print(f"Looking at chunk from video {chunk.title}")
context = f"Looking at the video titled {chunk.title}"
yield None, context
messages = get_initial_messages(query, chunk)
completion = client.chat.completions.create(
model="gpt-4o",
messages=messages,
)
res = completion.choices[0].message.content
if "<|UNKNOWN|>" not in res:
yield res, context
break
else:
yield "Not sure, still looking", context
def trivia_app(query: str, use_multiple: bool) -> tuple[str, str]:
if use_multiple:
print("Using multiple chunks")
yield from get_answer_better(query)
else:
print("Using single chunk")
yield get_answer(query)
with gr.Blocks() as interface:
gr.Markdown("# Trivia Question Answering App")
with gr.Row():
with gr.Column():
question_box = gr.Textbox(
lines=2, placeholder="Enter your trivia question here..."
)
answer_button = gr.Button("Get Answer")
examples = gr.Examples(
select_questions, label="Example Questions", inputs=[question_box]
)
use_multiple = gr.Checkbox(
label="Search across multiple chunks", key="better"
)
with gr.Column():
answer_box = gr.Markdown("The answer will appear here...")
context_box = gr.Textbox(label="Context")
answer_button.click(
fn=trivia_app,
inputs=[question_box, use_multiple],
outputs=[answer_box, context_box],
)
interface.launch()