not-lain's picture
🌘wπŸŒ–
95140c0
raw
history blame
5.59 kB
import gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.quantization import quantize_embeddings
import faiss
from usearch.index import Index
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
token = os.environ["HF_TOKEN"]
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it",
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
torch_dtype=torch.float16,
token=token)
tok = AutoTokenizer.from_pretrained("google/gemma-7b-it",token=token)
device = torch.device('cuda')
model = model.to(device)
# Load titles and texts
title_text_dataset = load_dataset(
"mixedbread-ai/wikipedia-data-en-2023-11", split="train", num_proc=4
).select_columns(["title", "text"])
# Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it.
int8_view = Index.restore("https://huggingface.co/spaces/sentence-transformers/quantized-retrieval/resolve/main/wikipedia_int8_usearch_1m.index", view=True)
binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary(
"https://huggingface.co/spaces/sentence-transformers/quantized-retrieval/resolve/main/wikipedia_ubinary_faiss_1m.index"
)
# Load the SentenceTransformer model for embedding the queries
model = SentenceTransformer(
"mixedbread-ai/mxbai-embed-large-v1",
prompts={
"retrieval": "Represent this sentence for searching relevant passages: ",
},
default_prompt_name="retrieval",
)
def search(
query, top_k: int = 10, rescore_multiplier: int = 1, use_approx: bool = False
):
# 1. Embed the query as float32
query_embedding = model.encode(query)
# 2. Quantize the query to ubinary
query_embedding_ubinary = quantize_embeddings(
query_embedding.reshape(1, -1), "ubinary"
)
# 3. Search the binary index (either exact or approximate)
index = binary_index
_scores, binary_ids = index.search(
query_embedding_ubinary, top_k * rescore_multiplier
)
binary_ids = binary_ids[0]
# 4. Load the corresponding int8 embeddings
int8_embeddings = int8_view[binary_ids].astype(int)
# 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings
scores = query_embedding @ int8_embeddings.T
# 6. Sort the scores and return the top_k
indices = scores.argsort()[::-1][:top_k]
top_k_indices = binary_ids[indices]
top_k_scores = scores[indices]
top_k_titles, top_k_texts = zip(
*[
(title_text_dataset[idx]["title"], title_text_dataset[idx]["text"])
for idx in top_k_indices.tolist()
]
)
df = {
"Score": [round(value, 2) for value in top_k_scores],
"Title": top_k_titles,
"Text": top_k_texts,
}
return df
def prepare_prompt(query, df):
prompt = f"Query: {query}\nContinue to answer the query by using the Search Results:\n"
for data in df :
title = data["Title"]
text = data["Text"]
prompt+=f"Title: {title}, Text: {text}\n"
return prompt
@spaces.GPU
def talk(message, history):
df = search(message)
message = prepare_prompt(message,df)
resources = "\nRESOURCES:\n"
for title in df["Title"][:3] :
resources+=f"[{title}](https://huggingface.co/spaces/not-lain/RAG), "
chat = []
for item in history:
chat.append({"role": "user", "content": item[0]})
if item[1] is not None:
cleaned_past = item[1].split("\nRESOURCES:\n")[0]
chat.append({"role": "assistant", "content": cleaned_past})
chat.append({"role": "user", "content": message})
messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
# Tokenize the messages string
model_inputs = tok([messages], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(
tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.95,
top_k=1000,
temperature=0.75,
num_beams=1,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Initialize an empty string to store the generated text
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
partial_text+= resources
yield partial_text
TITLE = "RAG"
DESCRIPTION = """
## Resources used to build this project
* https://huggingface.co/learn/cookbook/rag_with_hugging_face_gemma_mongodb
* https://huggingface.co/spaces/sentence-transformers/quantized-retrieval
## Retrival paramaters
```python
top_k: int = 10, rescore_multiplier: int = 1, use_approx: bool = False
```
## Models
the models used in this space are :
* google/gemma-7b-it
* mixedbread-ai/wikipedia-data-en-2023-11
"""
demo = gr.ChatInterface(fn=talk,
chatbot=gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False),
theme="Soft",
examples=[["what is machine learning"]],
title="Text Streaming")
demo.launch()