Spaces:
Runtime error
Runtime error
File size: 4,612 Bytes
07ffad3 42df98c 1b7e4b0 07ffad3 eaca477 1b7e4b0 43ae797 07ffad3 1b7e4b0 eaca477 1b7e4b0 07ffad3 eaca477 7d9a21e ec493d8 1b7e4b0 07ffad3 5ea07e3 1b7e4b0 eaca477 b0b771c 1b7e4b0 42df98c eaca477 e4b2161 eaca477 1b7e4b0 eaca477 1b7e4b0 8577be5 07ffad3 1b7e4b0 eaca477 23e06d0 fdd8ddb 07ffad3 cc1edc1 07ffad3 e82c570 eaca477 e82c570 eaca477 ec493d8 e82c570 eaca477 1b7e4b0 e82c570 07ffad3 ec493d8 07ffad3 42df98c 3afa221 07ffad3 eaca477 07ffad3 eaca477 1b7e4b0 07ffad3 3afa221 07ffad3 43ae797 3afa221 89b12a6 8577be5 89b12a6 e4f812c fdd8ddb 07ffad3 18b530b 07ffad3 e4b2161 31630cf 18b530b 42df98c afb805b 07ffad3 18b530b ef4a283 8577be5 ef4a283 18b530b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import gradio as gr
from datasets import load_dataset
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import time
token = os.environ["HF_TOKEN"]
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-7b-it",
# torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
torch_dtype=torch.float16,
token=token,
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it", token=token)
device = torch.device("cuda")
model = model.to(device)
RAG = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
TOP_K = 1
HEADER = "\n# RESOURCES:\n"
# prepare data
# since data is too big we will only select the first 3K lines
data = load_dataset("not-lain/wikipedia-small-3000-embedded", split="train")
# index dataset
data.add_faiss_index("embedding")
def search(query: str, k: int = TOP_K):
embedded_query = RAG.encode(query)
scores, retrieved_examples = data.get_nearest_examples(
"embedding", embedded_query, k=k
)
return retrieved_examples
def prepare_prompt(query, retrieved_examples):
prompt = (
f"Query: {query}\nContinue to answer the query in short sentences by using the Search Results:\n"
)
urls = []
titles = retrieved_examples["title"][::-1]
texts = retrieved_examples["text"][::-1]
urls = retrieved_examples["url"][::-1]
titles = titles[::-1]
for i in range(TOP_K):
prompt += f"* {texts[i]}\n"
return prompt, zip(titles, urls)
@spaces.GPU(duration=150)
def talk(message, history):
print("history, ", history)
print("message ", message)
print("searching dataset ...")
retrieved_examples = search(message)
print("preparing prompt ...")
message, metadata = prepare_prompt(message, retrieved_examples)
resources = HEADER
print("preparing metadata ...")
for title, url in metadata:
resources += f"[{title}]({url}), "
print("preparing chat template ...")
chat = []
for item in history:
chat.append({"role": "user", "content": item[0]})
cleaned_past = item[1].split(HEADER)[0]
chat.append({"role": "assistant", "content": cleaned_past})
chat.append({"role": "user", "content": message})
messages = tokenizer.apply_chat_template(
chat, tokenize=False, add_generation_prompt=True
)
print("chat template prepared, ", messages)
print("tokenizing input ...")
# Tokenize the messages string
model_inputs = tokenizer([messages], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.95,
top_k=1000,
temperature=0.75,
num_beams=1,
)
print("initializing thread ...")
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
time.sleep(1)
# Initialize an empty string to store the generated text
partial_text = ""
i = 0
while t.is_alive():
try:
for new_text in streamer:
if new_text is not None:
partial_text += new_text
yield partial_text
except Exception as e:
print(f"retry number {i}\n LOGS:\n")
i+=1
print(e, e.args)
partial_text += resources
yield partial_text
TITLE = "# RAG"
DESCRIPTION = """
A rag pipeline with a chatbot feature
Resources used to build this project :
* embedding model : https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1
* dataset : https://huggingface.co/datasets/not-lain/wikipedia-small-3000-embedded (used mxbai-colbert-large-v1 to create the embedding column )
* faiss docs : https://huggingface.co/docs/datasets/v2.18.0/en/package_reference/main_classes#datasets.Dataset.add_faiss_index
* chatbot : https://huggingface.co/google/gemma-7b-it
If you want to support my work consider clicking on the heart react button ❤️🤗
"""
demo = gr.ChatInterface(
fn=talk,
chatbot=gr.Chatbot(
show_label=True,
show_share_button=True,
show_copy_button=True,
likeable=True,
layout="bubble",
bubble_full_width=False,
),
theme="Soft",
examples=[["what's anarchy ?"]],
title=TITLE,
description=DESCRIPTION,
)
demo.launch(debug=True)
|