Spaces:
Runtime error
Runtime error
import gradio as gr | |
from datasets import load_dataset, Dataset | |
# import faiss | |
import os | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import torch | |
from threading import Thread | |
from ragatouille import RAGPretrainedModel | |
from datasets import load_dataset | |
token = os.environ["HF_TOKEN"] | |
model = AutoModelForCausalLM.from_pretrained( | |
"google/gemma-7b-it", | |
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
torch_dtype=torch.float16, | |
token=token, | |
) | |
tok = AutoTokenizer.from_pretrained("google/gemma-7b-it", token=token) | |
device = torch.device("cuda") | |
model = model.to(device) | |
RAG = RAGPretrainedModel.from_pretrained("mixedbread-ai/mxbai-colbert-v1") | |
# prepare data | |
# since data is too big we will only select the first 3K lines | |
dataset = load_dataset( | |
"wikimedia/wikipedia", "20231101.en", split="train", streaming=True | |
) | |
# init data | |
data = Dataset.from_dict({}) | |
i = 0 | |
for i, entry in enumerate(dataset): | |
# each entry has the following columns | |
# ['id', 'url', 'title', 'text'] | |
data = data.add_item(entry) | |
if i == 3000: | |
break | |
# free memory | |
del dataset # we keep data | |
# index data | |
documents = data["text"] | |
RAG.index(documents, index_name="wikipedia", use_faiss=True) | |
# free memory | |
del documents | |
def search(query, k: int = 5): | |
results = RAG.search(query, k=k) | |
# results are ordered according to their score | |
# results has the following keys | |
# | |
# {'content' : 'retrieved content' | |
# 'score' : score[float] | |
# 'rank' : "results are sorted using score and each is given a rank, also can be called place, 1 2 3 4 ..." | |
# 'document_id' : "no clue man i just got here" | |
# 'passage_id' : "or original row number" | |
# } | |
# | |
return [result["passage_id"] for result in results] | |
def prepare_prompt(query, indexes,data = data): | |
prompt = ( | |
f"Query: {query}\nContinue to answer the query by using the Search Results:\n" | |
) | |
titles = [] | |
urls = [] | |
for i in indexes: | |
title = entry["title"][i] | |
text = entry["text"][i] | |
url = entry["url"][i] | |
titles.append(title) | |
urls.append(url) | |
prompt += f"Title: {title}, Text: {text}\n" | |
return prompt, (titles,urls) | |
def talk(message, history): | |
indexes = search(message) | |
message,metadata = prepare_prompt(message, indexes) | |
resources = "\nRESOURCES:\n" | |
for title,url in metadata: | |
resources += f"[{title}]({url}), " | |
chat = [] | |
for item in history: | |
chat.append({"role": "user", "content": item[0]}) | |
if item[1] is not None: | |
cleaned_past = item[1].split("\nRESOURCES:\n")[0] | |
chat.append({"role": "assistant", "content": cleaned_past}) | |
chat.append({"role": "user", "content": message}) | |
messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | |
# Tokenize the messages string | |
model_inputs = tok([messages], return_tensors="pt").to(device) | |
streamer = TextIteratorStreamer( | |
tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True | |
) | |
generate_kwargs = dict( | |
model_inputs, | |
streamer=streamer, | |
max_new_tokens=1024, | |
do_sample=True, | |
top_p=0.95, | |
top_k=1000, | |
temperature=0.75, | |
num_beams=1, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
# Initialize an empty string to store the generated text | |
partial_text = "" | |
for new_text in streamer: | |
partial_text += new_text | |
yield partial_text | |
partial_text += resources | |
yield partial_text | |
TITLE = "RAG" | |
DESCRIPTION = """ | |
## Resources used to build this project | |
* https://huggingface.co/mixedbread-ai/mxbai-colbert-large-v1 | |
* me π | |
## Models | |
the models used in this space are : | |
* google/gemma-7b-it | |
* mixedbread-ai/mxbai-colbert-v1 | |
""" | |
demo = gr.ChatInterface( | |
fn=talk, | |
chatbot=gr.Chatbot( | |
show_label=True, | |
show_share_button=True, | |
show_copy_button=True, | |
likeable=True, | |
layout="bubble", | |
bubble_full_width=False, | |
), | |
theme="Soft", | |
examples=[["what is machine learning"]], | |
title=TITLE, | |
description=DESCRIPTION, | |
) | |
demo.launch() | |