Spaces:
Runtime error
Runtime error
File size: 5,570 Bytes
07ffad3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.quantization import quantize_embeddings
import faiss
from usearch.index import Index
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
token = os.environ["HF_TOKEN"]
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it",
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
torch_dtype=torch.float16,
token=token)
tok = AutoTokenizer.from_pretrained("google/gemma-7b-it",token=token)
device = torch.device('cuda')
model = model.to(device)
# Load titles and texts
title_text_dataset = load_dataset(
"mixedbread-ai/wikipedia-data-en-2023-11", split="train", num_proc=4
).select_columns(["title", "text"])
# Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it.
int8_view = Index.restore("wikipedia_int8_usearch_50m.index", view=True)
binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary(
"wikipedia_ubinary_faiss_50m.index"
)
binary_ivf: faiss.IndexBinaryIVF = faiss.read_index_binary(
"wikipedia_ubinary_ivf_faiss_50m.index"
)
# Load the SentenceTransformer model for embedding the queries
model = SentenceTransformer(
"mixedbread-ai/mxbai-embed-large-v1",
prompts={
"retrieval": "Represent this sentence for searching relevant passages: ",
},
default_prompt_name="retrieval",
)
def search(
query, top_k: int = 10, rescore_multiplier: int = 1, use_approx: bool = False
):
# 1. Embed the query as float32
query_embedding = model.encode(query)
# 2. Quantize the query to ubinary
query_embedding_ubinary = quantize_embeddings(
query_embedding.reshape(1, -1), "ubinary"
)
# 3. Search the binary index (either exact or approximate)
index = binary_ivf if use_approx else binary_index
_scores, binary_ids = index.search(
query_embedding_ubinary, top_k * rescore_multiplier
)
binary_ids = binary_ids[0]
# 4. Load the corresponding int8 embeddings
int8_embeddings = int8_view[binary_ids].astype(int)
# 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings
scores = query_embedding @ int8_embeddings.T
# 6. Sort the scores and return the top_k
indices = scores.argsort()[::-1][:top_k]
top_k_indices = binary_ids[indices]
top_k_scores = scores[indices]
top_k_titles, top_k_texts = zip(
*[
(title_text_dataset[idx]["title"], title_text_dataset[idx]["text"])
for idx in top_k_indices.tolist()
]
)
df = {
"Score": [round(value, 2) for value in top_k_scores],
"Title": top_k_titles,
"Text": top_k_texts,
}
return df
def prepare_prompt(query, df):
prompt = f"Query: {query}\nContinue to answer the query by using the Search Results:\n"
for data in df :
title = data["Title"]
text = data["Text"]
prompt+=f"Title: {title}, Text: {text}\n"
return prompt
@spaces.GPU
def talk(message, history):
df = search(message)
message = prepare_prompt(message,df)
resources = "\nRESOURCES:\n"
for title in df["Title"][:3] :
resources+=f"[{title}](https://huggingface.co/spaces/not-lain/RAG), "
chat = []
for item in history:
chat.append({"role": "user", "content": item[0]})
if item[1] is not None:
cleaned_past = item[1].split("\nRESOURCES:\n")[0]
chat.append({"role": "assistant", "content": cleaned_past})
chat.append({"role": "user", "content": message})
messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
# Tokenize the messages string
model_inputs = tok([messages], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(
tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.95,
top_k=1000,
temperature=0.75,
num_beams=1,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Initialize an empty string to store the generated text
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
partial_text+= resources
yield partial_text
TITLE = "RAG"
DESCRIPTION = """
## Resources used to build this project
* https://huggingface.co/learn/cookbook/rag_with_hugging_face_gemma_mongodb
* https://huggingface.co/spaces/sentence-transformers/quantized-retrieval
## Retrival paramaters
```python
top_k: int = 10, rescore_multiplier: int = 1, use_approx: bool = False
```
## Models
the models used in this space are :
* google/gemma-7b-it
* mixedbread-ai/wikipedia-data-en-2023-11
"""
demo = gr.ChatInterface(fn=talk,
chatbot=gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False),
theme="Soft",
examples=[["Write me a poem about Machine Learning."]],
title="Text Streaming")
demo.launch()
|