Spaces:
Sleeping
Sleeping
File size: 4,524 Bytes
31b6e27 b78468a 31b6e27 66df48d 31b6e27 66df48d 31b6e27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import pickle
import json
import dotenv
import gradio as gr
import numpy as np
import random
from typarse import BaseParser
from core import get_one_embedding, Chunk, Dataset
from openai import OpenAI
from prompts import get_initial_messages
# random.seed(42)
class Parser(BaseParser):
data_path: str = "data4k.pkl"
questions_path: str = "questions.json"
def cosine_similarity(query: np.ndarray, embeddings: np.ndarray) -> np.ndarray:
dot_product = np.dot(embeddings, query)
query_norm = np.linalg.norm(query)
embeddings_norm = np.linalg.norm(embeddings, axis=1)
return dot_product / (query_norm * embeddings_norm)
def rank_chunks(
client: OpenAI,
question: str,
dataset: Dataset,
model: str = "text-embedding-3-small",
) -> list[Chunk]:
embeddings = dataset.embeddings
chunk_metadata = dataset.chunks
q_embedding = get_one_embedding(client, question, model)
similarities = cosine_similarity(q_embedding, embeddings)
sorted_indices = np.argsort(similarities)[::-1]
return [chunk_metadata[i] for i in sorted_indices]
if __name__ == "__main__":
dotenv.load_dotenv()
args = Parser()
with open(args.data_path, "rb") as f:
data: Dataset = pickle.load(f)
with open(args.questions_path, "r") as f:
questions = json.load(f)
select_questions = random.sample(questions, 3)
select_questions = [
"Which guest worked at Abercrombie and Fitch?",
"Who failed making pastries as a teenager?",
] + select_questions
def get_answer(api_key: str, query: str) -> tuple[str, str]:
client = OpenAI(api_key=api_key)
sorted_chunks = rank_chunks(client, query, data)
best_chunk = sorted_chunks[0]
print(f"Looking at chunk from video {best_chunk.title}")
messages = get_initial_messages(query, best_chunk)
completion = client.chat.completions.create(
model="gpt-4o",
messages=messages,
)
context = f"Looking at the video titled {best_chunk.title}"
answer = completion.choices[0].message.content
answer = answer if "<|UNKNOWN|>" not in answer else "Couldn't find the answer."
return answer, context
def get_answer_better(api_key: str, query: str) -> str:
client = OpenAI(api_key=api_key)
print(f"Looking for answer to question: {query}")
sorted_chunks = rank_chunks(client, query, data)
for chunk in sorted_chunks:
print(f"Looking at chunk from video {chunk.title}")
context = f"Looking at the video titled {chunk.title}"
yield None, context
messages = get_initial_messages(query, chunk)
completion = client.chat.completions.create(
model="gpt-4o",
messages=messages,
)
res = completion.choices[0].message.content
if "<|UNKNOWN|>" not in res:
yield res, context
break
else:
yield "Not sure, still looking", context
def trivia_app(api_key: str, query: str, use_multiple: bool) -> tuple[str, str]:
if use_multiple:
print("Using multiple chunks")
yield from get_answer_better(api_key, query)
else:
print("Using single chunk")
yield get_answer(api_key, query)
with gr.Blocks() as interface:
gr.Markdown("# Trivia Question Answering App")
with gr.Row():
with gr.Column():
api_key_box = gr.Textbox(
lines=1, placeholder="Enter your OpenAI API key here...", type="password"
)
question_box = gr.Textbox(
lines=2, placeholder="Enter your trivia question here..."
)
answer_button = gr.Button("Get Answer")
examples = gr.Examples(
select_questions, label="Example Questions", inputs=[question_box]
)
use_multiple = gr.Checkbox(
label="Search across multiple chunks", key="better"
)
with gr.Column():
answer_box = gr.Markdown("The answer will appear here...")
context_box = gr.Textbox(label="Context")
answer_button.click(
fn=trivia_app,
inputs=[api_key_box, question_box, use_multiple],
outputs=[answer_box, context_box],
)
interface.launch()
|