Spaces:
Runtime error
Runtime error
import gradio as gr | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
from sentence_transformers.quantization import quantize_embeddings | |
import faiss | |
from usearch.index import Index | |
import os | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import torch | |
from threading import Thread | |
token = os.environ["HF_TOKEN"] | |
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it", | |
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
torch_dtype=torch.float16, | |
token=token) | |
tok = AutoTokenizer.from_pretrained("google/gemma-7b-it",token=token) | |
device = torch.device('cuda') | |
model = model.to(device) | |
# Load titles and texts | |
title_text_dataset = load_dataset( | |
"mixedbread-ai/wikipedia-data-en-2023-11", split="train", num_proc=4 | |
).select_columns(["title", "text"]) | |
# Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it. | |
int8_view = Index.restore("https://huggingface.co/spaces/sentence-transformers/quantized-retrieval/resolve/main/wikipedia_int8_usearch_1m.index", view=True) | |
binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary( | |
"https://huggingface.co/spaces/sentence-transformers/quantized-retrieval/resolve/main/wikipedia_ubinary_faiss_1m.index" | |
) | |
# Load the SentenceTransformer model for embedding the queries | |
model = SentenceTransformer( | |
"mixedbread-ai/mxbai-embed-large-v1", | |
prompts={ | |
"retrieval": "Represent this sentence for searching relevant passages: ", | |
}, | |
default_prompt_name="retrieval", | |
) | |
def search( | |
query, top_k: int = 10, rescore_multiplier: int = 1, use_approx: bool = False | |
): | |
# 1. Embed the query as float32 | |
query_embedding = model.encode(query) | |
# 2. Quantize the query to ubinary | |
query_embedding_ubinary = quantize_embeddings( | |
query_embedding.reshape(1, -1), "ubinary" | |
) | |
# 3. Search the binary index (either exact or approximate) | |
index = binary_index | |
_scores, binary_ids = index.search( | |
query_embedding_ubinary, top_k * rescore_multiplier | |
) | |
binary_ids = binary_ids[0] | |
# 4. Load the corresponding int8 embeddings | |
int8_embeddings = int8_view[binary_ids].astype(int) | |
# 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings | |
scores = query_embedding @ int8_embeddings.T | |
# 6. Sort the scores and return the top_k | |
indices = scores.argsort()[::-1][:top_k] | |
top_k_indices = binary_ids[indices] | |
top_k_scores = scores[indices] | |
top_k_titles, top_k_texts = zip( | |
*[ | |
(title_text_dataset[idx]["title"], title_text_dataset[idx]["text"]) | |
for idx in top_k_indices.tolist() | |
] | |
) | |
df = { | |
"Score": [round(value, 2) for value in top_k_scores], | |
"Title": top_k_titles, | |
"Text": top_k_texts, | |
} | |
return df | |
def prepare_prompt(query, df): | |
prompt = f"Query: {query}\nContinue to answer the query by using the Search Results:\n" | |
for data in df : | |
title = data["Title"] | |
text = data["Text"] | |
prompt+=f"Title: {title}, Text: {text}\n" | |
return prompt | |
def talk(message, history): | |
df = search(message) | |
message = prepare_prompt(message,df) | |
resources = "\nRESOURCES:\n" | |
for title in df["Title"][:3] : | |
resources+=f"[{title}](https://huggingface.co/spaces/not-lain/RAG), " | |
chat = [] | |
for item in history: | |
chat.append({"role": "user", "content": item[0]}) | |
if item[1] is not None: | |
cleaned_past = item[1].split("\nRESOURCES:\n")[0] | |
chat.append({"role": "assistant", "content": cleaned_past}) | |
chat.append({"role": "user", "content": message}) | |
messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | |
# Tokenize the messages string | |
model_inputs = tok([messages], return_tensors="pt").to(device) | |
streamer = TextIteratorStreamer( | |
tok, timeout=10., skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
model_inputs, | |
streamer=streamer, | |
max_new_tokens=1024, | |
do_sample=True, | |
top_p=0.95, | |
top_k=1000, | |
temperature=0.75, | |
num_beams=1, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
# Initialize an empty string to store the generated text | |
partial_text = "" | |
for new_text in streamer: | |
partial_text += new_text | |
yield partial_text | |
partial_text+= resources | |
yield partial_text | |
TITLE = "RAG" | |
DESCRIPTION = """ | |
## Resources used to build this project | |
* https://huggingface.co/learn/cookbook/rag_with_hugging_face_gemma_mongodb | |
* https://huggingface.co/spaces/sentence-transformers/quantized-retrieval | |
## Retrival paramaters | |
```python | |
top_k: int = 10, rescore_multiplier: int = 1, use_approx: bool = False | |
``` | |
## Models | |
the models used in this space are : | |
* google/gemma-7b-it | |
* mixedbread-ai/wikipedia-data-en-2023-11 | |
""" | |
demo = gr.ChatInterface(fn=talk, | |
chatbot=gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False), | |
theme="Soft", | |
examples=[["what is machine learning"]], | |
title="Text Streaming") | |
demo.launch() | |