File size: 4,266 Bytes
07ffad3
1b7e4b0
 
 
07ffad3
 
 
 
 
1b7e4b0
 
 
07ffad3
 
1b7e4b0
 
 
 
 
 
 
 
07ffad3
1b7e4b0
07ffad3
1b7e4b0
 
07ffad3
1b7e4b0
 
07ffad3
1b7e4b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07ffad3
1b7e4b0
 
 
 
 
 
 
 
 
 
07ffad3
 
 
 
1b7e4b0
 
07ffad3
1b7e4b0
 
07ffad3
 
 
 
 
 
 
 
 
 
 
1b7e4b0
 
07ffad3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b7e4b0
07ffad3
 
 
 
 
 
 
1b7e4b0
 
07ffad3
 
 
1b7e4b0
07ffad3
 
1b7e4b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07ffad3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
from datasets import load_dataset, Dataset

# import faiss
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
from ragatouille import RAGPretrainedModel
from datasets import load_dataset


token = os.environ["HF_TOKEN"]
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-7b-it",
    # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    torch_dtype=torch.float16,
    token=token,
)
tok = AutoTokenizer.from_pretrained("google/gemma-7b-it", token=token)
device = torch.device("cuda")
model = model.to(device)
RAG = RAGPretrainedModel.from_pretrained("mixedbread-ai/mxbai-colbert-v1")

# prepare data
# since data is too big we will only select the first 3K lines

dataset = load_dataset(
    "wikimedia/wikipedia", "20231101.en", split="train", streaming=True
)
# init data
data = Dataset.from_dict({})
i = 0
for i, entry in enumerate(dataset):
    # each entry has the following columns
    # ['id', 'url', 'title', 'text']
    data.add_item(entry)
    if i == 3000:
        break
# free memory
del dataset  # we keep data

# index data
documents = data["text"]
RAG.index(documents, index_name="wikipedia", use_faiss=True)
# free memory
del documents

def search(query, k: int = 5):
    results = RAG.search(query, k=k)
    # results are ordered according to their score
    # results has the following keys
    #
    # {'content' : 'retrieved content'
    # 'score' : score[float]
    # 'rank' : "results are sorted using score and each is given a rank, also can be called place, 1 2 3 4 ..."
    # 'document_id' : "no clue man i just got here"
    # 'passage_id' :  "or original row number"
    # }
    #
    return [result["passage_id"] for result in results]


def prepare_prompt(query, indexes,data = data):
    prompt = (
        f"Query: {query}\nContinue to answer the query by using the Search Results:\n"
    )
    titles = []
    urls = []
    for i in indexes:
        title = entry["title"][i]
        text = entry["text"][i]
        url = entry["url"][i]
        titles.append(title)
        urls.append(url)
        prompt += f"Title: {title}, Text: {text}\n"
    return prompt, (titles,urls)


@spaces.GPU
def talk(message, history):
    indexes = search(message)
    message,metadata = prepare_prompt(message, indexes)
    resources = "\nRESOURCES:\n"
    for title,url in metadata:
        resources += f"[{title}]({url}),  "
    chat = []
    for item in history:
        chat.append({"role": "user", "content": item[0]})
        if item[1] is not None:
            cleaned_past = item[1].split("\nRESOURCES:\n")[0]
            chat.append({"role": "assistant", "content": cleaned_past})
    chat.append({"role": "user", "content": message})
    messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    # Tokenize the messages string
    model_inputs = tok([messages], return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(
        tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        top_k=1000,
        temperature=0.75,
        num_beams=1,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # Initialize an empty string to store the generated text
    partial_text = ""
    for new_text in streamer:
        partial_text += new_text
        yield partial_text
    partial_text += resources
    yield partial_text


TITLE = "RAG"

DESCRIPTION = """
## Resources used to build this project
* https://huggingface.co/mixedbread-ai/mxbai-colbert-large-v1
* me ๐Ÿ˜Ž
## Models
the models used in this space are : 
* google/gemma-7b-it
* mixedbread-ai/mxbai-colbert-v1
"""

demo = gr.ChatInterface(
    fn=talk,
    chatbot=gr.Chatbot(
        show_label=True,
        show_share_button=True,
        show_copy_button=True,
        likeable=True,
        layout="bubble",
        bubble_full_width=False,
    ),
    theme="Soft",
    examples=[["what is machine learning"]],
    title=TITLE,
    description=DESCRIPTION,
)
demo.launch()