import gradio as gr from datasets import load_dataset import os import spaces from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import torch from threading import Thread from sentence_transformers import SentenceTransformer from datasets import load_dataset import time token = os.environ["HF_TOKEN"] model = AutoModelForCausalLM.from_pretrained( "google/gemma-7b-it", # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, torch_dtype=torch.float16, token=token, ) tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it", token=token) device = torch.device("cuda") model = model.to(device) RAG = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") TOP_K = 1 HEADER = "\n# RESOURCES:\n" # prepare data # since data is too big we will only select the first 3K lines data = load_dataset("not-lain/wikipedia-small-3000-embedded", split="train") # index dataset data.add_faiss_index("embedding") def search(query: str, k: int = TOP_K): embedded_query = RAG.encode(query) scores, retrieved_examples = data.get_nearest_examples( "embedding", embedded_query, k=k ) return retrieved_examples def prepare_prompt(query, retrieved_examples): prompt = ( f"Query: {query}\nContinue to answer the query in short sentences by using the Search Results:\n" ) urls = [] titles = retrieved_examples["title"][::-1] texts = retrieved_examples["text"][::-1] urls = retrieved_examples["url"][::-1] titles = titles[::-1] for i in range(TOP_K): prompt += f"* {texts[i]}\n" return prompt, zip(titles, urls) @spaces.GPU(duration=150) def talk(message, history): print("history, ", history) print("message ", message) print("searching dataset ...") retrieved_examples = search(message) print("preparing prompt ...") message, metadata = prepare_prompt(message, retrieved_examples) resources = HEADER print("preparing metadata ...") for title, url in metadata: resources += f"[{title}]({url}), " print("preparing chat template ...") chat = [] for item in history: chat.append({"role": "user", "content": item[0]}) cleaned_past = item[1].split(HEADER)[0] chat.append({"role": "assistant", "content": cleaned_past}) chat.append({"role": "user", "content": message}) messages = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True ) print("chat template prepared, ", messages) print("tokenizing input ...") # Tokenize the messages string model_inputs = tokenizer([messages], return_tensors="pt").to(device) streamer = TextIteratorStreamer( tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( model_inputs, streamer=streamer, max_new_tokens=1024, do_sample=True, top_p=0.95, top_k=1000, temperature=0.75, num_beams=1, ) print("initializing thread ...") t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() time.sleep(1) # Initialize an empty string to store the generated text partial_text = "" i = 0 while t.is_alive(): try: for new_text in streamer: if new_text is not None: partial_text += new_text yield partial_text except Exception as e: print(f"retry number {i}\n LOGS:\n") i+=1 print(e, e.args) partial_text += resources yield partial_text TITLE = "# RAG" DESCRIPTION = """ A rag pipeline with a chatbot feature Resources used to build this project : * embedding model : https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1 * dataset : https://huggingface.co/datasets/not-lain/wikipedia-small-3000-embedded (used mxbai-colbert-large-v1 to create the embedding column ) * faiss docs : https://huggingface.co/docs/datasets/v2.18.0/en/package_reference/main_classes#datasets.Dataset.add_faiss_index * chatbot : https://huggingface.co/google/gemma-7b-it If you want to support my work consider clicking on the heart react button ❤️🤗 """ demo = gr.ChatInterface( fn=talk, chatbot=gr.Chatbot( show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False, ), theme="Soft", examples=[["what's anarchy ?"]], title=TITLE, description=DESCRIPTION, ) demo.launch(debug=True)