Spaces:
Sleeping
Sleeping
import pickle | |
import json | |
import dotenv | |
import gradio as gr | |
import numpy as np | |
import random | |
from typarse import BaseParser | |
from core import get_one_embedding, Chunk, Dataset | |
from openai import OpenAI | |
from prompts import get_initial_messages | |
random.seed(42) | |
class Parser(BaseParser): | |
data_path: str = "data4k.pkl" | |
questions_path: str = "questions.json" | |
def cosine_similarity(query: np.ndarray, embeddings: np.ndarray) -> np.ndarray: | |
dot_product = np.dot(embeddings, query) | |
query_norm = np.linalg.norm(query) | |
embeddings_norm = np.linalg.norm(embeddings, axis=1) | |
return dot_product / (query_norm * embeddings_norm) | |
def rank_chunks( | |
client: OpenAI, | |
question: str, | |
dataset: Dataset, | |
model: str = "text-embedding-3-small", | |
) -> list[Chunk]: | |
embeddings = dataset.embeddings | |
chunk_metadata = dataset.chunks | |
q_embedding = get_one_embedding(client, question, model) | |
similarities = cosine_similarity(q_embedding, embeddings) | |
sorted_indices = np.argsort(similarities)[::-1] | |
return [chunk_metadata[i] for i in sorted_indices] | |
if __name__ == "__main__": | |
dotenv.load_dotenv() | |
args = Parser() | |
with open(args.data_path, "rb") as f: | |
data: Dataset = pickle.load(f) | |
with open(args.questions_path, "r") as f: | |
questions = json.load(f) | |
select_questions = random.sample(questions, 3) | |
select_questions = [ | |
"Which guest worked at Abercrombie and Fitch?", | |
"Who failed making pastries as a teenager?", | |
] + select_questions | |
def get_answer(api_key: str, query: str) -> tuple[str, str]: | |
client = OpenAI(api_key=api_key) | |
sorted_chunks = rank_chunks(client, query, data) | |
best_chunk = sorted_chunks[0] | |
print(f"Looking at chunk from video {best_chunk.title}") | |
messages = get_initial_messages(query, best_chunk) | |
completion = client.chat.completions.create( | |
model="gpt-4o", | |
messages=messages, | |
) | |
context = f"Looking at the video titled {best_chunk.title}" | |
return completion.choices[0].message.content, context | |
def get_answer_better(api_key: str, query: str) -> str: | |
client = OpenAI(api_key=api_key) | |
print(f"Looking for answer to question: {query}") | |
sorted_chunks = rank_chunks(client, query, data) | |
for chunk in sorted_chunks: | |
print(f"Looking at chunk from video {chunk.title}") | |
context = f"Looking at the video titled {chunk.title}" | |
yield None, context | |
messages = get_initial_messages(query, chunk) | |
completion = client.chat.completions.create( | |
model="gpt-4o", | |
messages=messages, | |
) | |
res = completion.choices[0].message.content | |
if "<|UNKNOWN|>" not in res: | |
yield res, context | |
break | |
else: | |
yield "Not sure, still looking", context | |
def trivia_app(api_key: str, query: str, use_multiple: bool) -> tuple[str, str]: | |
if use_multiple: | |
print("Using multiple chunks") | |
yield from get_answer_better(api_key, query) | |
else: | |
print("Using single chunk") | |
yield get_answer(api_key, query) | |
with gr.Blocks() as interface: | |
gr.Markdown("# Trivia Question Answering App") | |
with gr.Row(): | |
with gr.Column(): | |
api_key_box = gr.Textbox( | |
lines=1, placeholder="Enter your OpenAI API key here...", type="password" | |
) | |
question_box = gr.Textbox( | |
lines=2, placeholder="Enter your trivia question here..." | |
) | |
answer_button = gr.Button("Get Answer") | |
examples = gr.Examples( | |
select_questions, label="Example Questions", inputs=[question_box] | |
) | |
use_multiple = gr.Checkbox( | |
label="Search across multiple chunks", key="better" | |
) | |
with gr.Column(): | |
answer_box = gr.Markdown("The answer will appear here...") | |
context_box = gr.Textbox(label="Context") | |
answer_button.click( | |
fn=trivia_app, | |
inputs=[api_key_box, question_box, use_multiple], | |
outputs=[answer_box, context_box], | |
) | |
interface.launch() | |