File size: 5,586 Bytes
07ffad3
 
 
 
 
 
 
 
 
 
 
 
 
95140c0
07ffad3
 
 
 
 
 
 
 
 
 
 
 
 
95140c0
07ffad3
95140c0
07ffad3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95140c0
07ffad3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95140c0
07ffad3
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.quantization import quantize_embeddings
import faiss
from usearch.index import Index
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread

token = os.environ["HF_TOKEN"]
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it",
                                             # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                                             torch_dtype=torch.float16,
                                             token=token)
tok = AutoTokenizer.from_pretrained("google/gemma-7b-it",token=token)
device = torch.device('cuda')
model = model.to(device)

# Load titles and texts
title_text_dataset = load_dataset(
    "mixedbread-ai/wikipedia-data-en-2023-11", split="train", num_proc=4
).select_columns(["title", "text"])

# Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it.
int8_view = Index.restore("https://huggingface.co/spaces/sentence-transformers/quantized-retrieval/resolve/main/wikipedia_int8_usearch_1m.index", view=True)
binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary(
    "https://huggingface.co/spaces/sentence-transformers/quantized-retrieval/resolve/main/wikipedia_ubinary_faiss_1m.index"
)

# Load the SentenceTransformer model for embedding the queries
model = SentenceTransformer(
    "mixedbread-ai/mxbai-embed-large-v1",
    prompts={
        "retrieval": "Represent this sentence for searching relevant passages: ",
    },
    default_prompt_name="retrieval",
)


def search(
    query, top_k: int = 10, rescore_multiplier: int = 1, use_approx: bool = False
):
    # 1. Embed the query as float32
    query_embedding = model.encode(query)

    # 2. Quantize the query to ubinary
    query_embedding_ubinary = quantize_embeddings(
        query_embedding.reshape(1, -1), "ubinary"
    )

    # 3. Search the binary index (either exact or approximate)
    index = binary_index
    _scores, binary_ids = index.search(
        query_embedding_ubinary, top_k * rescore_multiplier
    )
    binary_ids = binary_ids[0]

    # 4. Load the corresponding int8 embeddings
    int8_embeddings = int8_view[binary_ids].astype(int)

    # 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings
    scores = query_embedding @ int8_embeddings.T

    # 6. Sort the scores and return the top_k
    indices = scores.argsort()[::-1][:top_k]
    top_k_indices = binary_ids[indices]
    top_k_scores = scores[indices]
    top_k_titles, top_k_texts = zip(
        *[
            (title_text_dataset[idx]["title"], title_text_dataset[idx]["text"])
            for idx in top_k_indices.tolist()
        ]
    )
    df = {
            "Score": [round(value, 2) for value in top_k_scores],
            "Title": top_k_titles,
            "Text": top_k_texts,
        }

    return df

def prepare_prompt(query, df):
    prompt = f"Query: {query}\nContinue to answer the query by using the Search Results:\n"
    for data in df :
        title = data["Title"]
        text = data["Text"]
        prompt+=f"Title: {title}, Text: {text}\n"
    return prompt

@spaces.GPU
def talk(message, history):
    df = search(message)
    message = prepare_prompt(message,df)
    resources = "\nRESOURCES:\n"
    for title in df["Title"][:3] :
        resources+=f"[{title}](https://huggingface.co/spaces/not-lain/RAG), "
    chat = []
    for item in history:
        chat.append({"role": "user", "content": item[0]})
        if item[1] is not None:
            cleaned_past = item[1].split("\nRESOURCES:\n")[0]
            chat.append({"role": "assistant", "content": cleaned_past})
    chat.append({"role": "user", "content": message})
    messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    # Tokenize the messages string
    model_inputs = tok([messages], return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(
        tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        top_k=1000,
        temperature=0.75,
        num_beams=1,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # Initialize an empty string to store the generated text
    partial_text = ""
    for new_text in streamer:
        partial_text += new_text
        yield partial_text
    partial_text+= resources
    yield partial_text





TITLE = "RAG"

DESCRIPTION = """
## Resources used to build this project
* https://huggingface.co/learn/cookbook/rag_with_hugging_face_gemma_mongodb
* https://huggingface.co/spaces/sentence-transformers/quantized-retrieval
## Retrival paramaters
```python
top_k: int = 10, rescore_multiplier: int = 1, use_approx: bool = False
```
## Models
the models used in this space are : 
* google/gemma-7b-it
* mixedbread-ai/wikipedia-data-en-2023-11
"""

demo = gr.ChatInterface(fn=talk, 
                        chatbot=gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False),
                        theme="Soft",
                        examples=[["what is machine learning"]],
                        title="Text Streaming")
demo.launch()