File size: 4,410 Bytes
31b6e27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import pickle
import json

import dotenv
import gradio as gr
import numpy as np
import random
from typarse import BaseParser

from core import get_one_embedding, Chunk, Dataset

from openai import OpenAI

from prompts import get_initial_messages

random.seed(42)

class Parser(BaseParser):
    data_path: str = "data4k.pkl"
    questions_path: str = "questions.json"

def cosine_similarity(query: np.ndarray, embeddings: np.ndarray) -> np.ndarray:
    dot_product = np.dot(embeddings, query)
    query_norm = np.linalg.norm(query)
    embeddings_norm = np.linalg.norm(embeddings, axis=1)
    return dot_product / (query_norm * embeddings_norm)

def rank_chunks(
    client: OpenAI,
    question: str,
    dataset: Dataset,
    model: str = "text-embedding-3-small",
) -> list[Chunk]:
    embeddings = dataset.embeddings
    chunk_metadata = dataset.chunks

    q_embedding = get_one_embedding(client, question, model)
    similarities = cosine_similarity(q_embedding, embeddings)

    sorted_indices = np.argsort(similarities)[::-1]
    return [chunk_metadata[i] for i in sorted_indices]

if __name__ == "__main__":
    dotenv.load_dotenv()

    args = Parser()

    with open(args.data_path, "rb") as f:
        data: Dataset = pickle.load(f)

    with open(args.questions_path, "r") as f:
        questions = json.load(f)

    select_questions = random.sample(questions, 3)

    select_questions = [
        "Which guest worked at Abercrombie and Fitch?",
        "Who failed making pastries as a teenager?",
    ] + select_questions

    def get_answer(api_key: str, query: str) -> tuple[str, str]:
        client = OpenAI(api_key=api_key)
        sorted_chunks = rank_chunks(client, query, data)

        best_chunk = sorted_chunks[0]
        print(f"Looking at chunk from video {best_chunk.title}")
        messages = get_initial_messages(query, best_chunk)

        completion = client.chat.completions.create(
            model="gpt-4o",
            messages=messages,
        )

        context = f"Looking at the video titled {best_chunk.title}"

        return completion.choices[0].message.content, context

    def get_answer_better(api_key: str, query: str) -> str:
        client = OpenAI(api_key=api_key)
        print(f"Looking for answer to question: {query}")
        sorted_chunks = rank_chunks(client, query, data)

        for chunk in sorted_chunks:
            print(f"Looking at chunk from video {chunk.title}")
            context = f"Looking at the video titled {chunk.title}"

            yield None, context

            messages = get_initial_messages(query, chunk)

            completion = client.chat.completions.create(
                model="gpt-4o",
                messages=messages,
            )

            res = completion.choices[0].message.content

            if "<|UNKNOWN|>" not in res:
                yield res, context
                break
            else:
                yield "Not sure, still looking", context

    def trivia_app(api_key: str, query: str, use_multiple: bool) -> tuple[str, str]:
        if use_multiple:
            print("Using multiple chunks")
            yield from get_answer_better(api_key, query)
        else:
            print("Using single chunk")
            yield get_answer(api_key, query)

    with gr.Blocks() as interface:
        gr.Markdown("# Trivia Question Answering App")
        with gr.Row():
            with gr.Column():
                api_key_box = gr.Textbox(
                    lines=1, placeholder="Enter your OpenAI API key here...", type="password"
                )
                question_box = gr.Textbox(
                    lines=2, placeholder="Enter your trivia question here..."
                )
                answer_button = gr.Button("Get Answer")
                examples = gr.Examples(
                    select_questions, label="Example Questions", inputs=[question_box]
                )
                use_multiple = gr.Checkbox(
                    label="Search across multiple chunks", key="better"
                )
            with gr.Column():
                answer_box = gr.Markdown("The answer will appear here...")
                context_box = gr.Textbox(label="Context")

        answer_button.click(
            fn=trivia_app,
            inputs=[api_key_box, question_box, use_multiple],
            outputs=[answer_box, context_box],
        )

    interface.launch()