import gradio as gr from datasets import load_dataset import os import spaces from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import torch from threading import Thread from sentence_transformers import SentenceTransformer from datasets import load_dataset token = os.environ["HF_TOKEN"] model = AutoModelForCausalLM.from_pretrained( "google/gemma-7b-it", # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, torch_dtype=torch.float16, token=token, ) tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it", token=token) device = torch.device("cuda") model = model.to(device) RAG = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") TOP_K = 3 # prepare data # since data is too big we will only select the first 3K lines data = load_dataset("not-lain/wikipedia-small-3000-embedded", split="train") # index dataset data.add_faiss_index("embedding") @spaces.GPU def search(query: str, k: int = TOP_K): embedded_query = model.encode(query) scores, retrieved_examples = data.get_nearest_examples( "embedding", embedded_query, k=k ) return retrieved_examples def prepare_prompt(query, retrieved_examples): prompt = ( f"Query: {query}\nContinue to answer the query by using the Search Results:\n" ) urls = [] titles = retrieved_examples["title"][::-1] texts = retrieved_examples["text"][::-1] urls = retrieved_examples["url"][::-1] titles = titles[::-1] for i in range(TOP_K): prompt += f"Title: {titles[i]}, Text: {texts[i]}\n" return prompt, (titles, urls) @spaces.GPU def talk(message, history): retrieved_examples = search(message) message, metadata = prepare_prompt(message, retrieved_examples) resources = "\nRESOURCES:\n" for title, url in metadata: resources += f"[{title}]({url}), " chat = [] for item in history: chat.append({"role": "user", "content": item[0]}) if item[1] is not None: cleaned_past = item[1].split("\nRESOURCES:\n")[0] chat.append({"role": "assistant", "content": cleaned_past}) chat.append({"role": "user", "content": message}) messages = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True ) # Tokenize the messages string model_inputs = tokenizer([messages], return_tensors="pt").to(device) streamer = TextIteratorStreamer( tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( model_inputs, streamer=streamer, max_new_tokens=1024, do_sample=True, top_p=0.95, top_k=1000, temperature=0.75, num_beams=1, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() # Initialize an empty string to store the generated text partial_text = "" for new_text in streamer: partial_text += new_text yield partial_text partial_text += resources yield partial_text TITLE = "RAG" DESCRIPTION = """ ## Resources used to build this project * embedding model : https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1 * dataset : https://huggingface.co/datasets/not-lain/wikipedia-small-3000-embedded (used mxbai-colbert-large-v1 to create the embedding column ) * faiss docs : https://huggingface.co/docs/datasets/v2.18.0/en/package_reference/main_classes#datasets.Dataset.add_faiss_index * chatbot : google/gemma-7b-it If you want to support my work please click on the heart react button ❤️🤗 psst, I am still open for work if please reach me out at https://not-lain.github.io/ """ demo = gr.ChatInterface( fn=talk, chatbot=gr.Chatbot( show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False, ), theme="Soft", examples=[["what is machine learning"]], title=TITLE, description=DESCRIPTION, ) demo.launch()