RAG_Test / Test_RAG.py
JiakaiDu's picture
Upload folder using huggingface_hub
bd9c220 verified
raw
history blame
33 kB
import os
os.environ["GIT_CLONE_PROTECTION_ACTIVE"] = "false"
from pathlib import Path
import requests
import shutil
import io
from pathlib import Path
import openvino as ov
import torch
from transformers import (
TextIteratorStreamer,
StoppingCriteria,
StoppingCriteriaList,
)
from llm_config import (
SUPPORTED_EMBEDDING_MODELS,
SUPPORTED_RERANK_MODELS,
SUPPORTED_LLM_MODELS,
)
from huggingface_hub import login
config_shared_path = Path("../../utils/llm_config.py")
config_dst_path = Path("llm_config.py")
text_example_en_path = Path("text_example_en.pdf")
text_example_cn_path = Path("text_example_cn.pdf")
text_example_en = "https://github.com/openvinotoolkit/openvino_notebooks/files/15039728/Platform.Brief_Intel.vPro.with.Intel.Core.Ultra_Final.pdf"
text_example_cn = "https://github.com/openvinotoolkit/openvino_notebooks/files/15039713/Platform.Brief_Intel.vPro.with.Intel.Core.Ultra_Final_CH.pdf"
if not config_dst_path.exists():
if config_shared_path.exists():
try:
os.symlink(config_shared_path, config_dst_path)
except Exception:
shutil.copy(config_shared_path, config_dst_path)
else:
r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/llm_config.py")
with open("llm_config.py", "w", encoding="utf-8") as f:
f.write(r.text)
elif not os.path.islink(config_dst_path):
print("LLM config will be updated")
if config_shared_path.exists():
shutil.copy(config_shared_path, config_dst_path)
else:
r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/llm_config.py")
with open("llm_config.py", "w", encoding="utf-8") as f:
f.write(r.text)
if not text_example_en_path.exists():
r = requests.get(url=text_example_en)
content = io.BytesIO(r.content)
with open("text_example_en.pdf", "wb") as f:
f.write(content.read())
if not text_example_cn_path.exists():
r = requests.get(url=text_example_cn)
content = io.BytesIO(r.content)
with open("text_example_cn.pdf", "wb") as f:
f.write(content.read())
model_language = "English"
llm_model_id= "llama-3.2-3b-instruct" #"llama-3-8b-instruct"
llm_model_configuration = SUPPORTED_LLM_MODELS[model_language][llm_model_id]
print(f"Selected LLM model {llm_model_id}")
prepare_int4_model = True # Prepare INT4 model
prepare_int8_model = False # Do not prepare INT8 model
prepare_fp16_model = False # Do not prepare FP16 model
enable_awq = False
# Get the token from the environment variable
hf_token = os.getenv("HUGGINGFACE_TOKEN")
if hf_token is None:
raise ValueError(
"HUGGINGFACE_TOKEN environment variable not set. "
"Please set it in your environment variables or repository secrets."
)
# Log in to Hugging Face Hub
login(token=hf_token)
pt_model_id = llm_model_configuration["model_id"]
# pt_model_name = llm_model_id.value.split("-")[0]
fp16_model_dir = Path(llm_model_id) / "FP16"
int8_model_dir = Path(llm_model_id) / "INT8_compressed_weights"
int4_model_dir = Path(llm_model_id) / "INT4_compressed_weights"
def convert_to_fp16():
if (fp16_model_dir / "openvino_model.xml").exists():
return
remote_code = llm_model_configuration.get("remote_code", False)
export_command_base = "optimum-cli export openvino --model {} --task text-generation-with-past --weight-format fp16".format(pt_model_id)
if remote_code:
export_command_base += " --trust-remote-code"
export_command = export_command_base + " " + str(fp16_model_dir)
def convert_to_int8():
if (int8_model_dir / "openvino_model.xml").exists():
return
int8_model_dir.mkdir(parents=True, exist_ok=True)
remote_code = llm_model_configuration.get("remote_code", False)
export_command_base = "optimum-cli export openvino --model {} --task text-generation-with-past --weight-format int8".format(pt_model_id)
if remote_code:
export_command_base += " --trust-remote-code"
export_command = export_command_base + " " + str(int8_model_dir)
def convert_to_int4():
compression_configs = {
"zephyr-7b-beta": {
"sym": True,
"group_size": 64,
"ratio": 0.6,
},
"mistral-7b": {
"sym": True,
"group_size": 64,
"ratio": 0.6,
},
"minicpm-2b-dpo": {
"sym": True,
"group_size": 64,
"ratio": 0.6,
},
"gemma-2b-it": {
"sym": True,
"group_size": 64,
"ratio": 0.6,
},
"notus-7b-v1": {
"sym": True,
"group_size": 64,
"ratio": 0.6,
},
"neural-chat-7b-v3-1": {
"sym": True,
"group_size": 64,
"ratio": 0.6,
},
"llama-2-chat-7b": {
"sym": True,
"group_size": 128,
"ratio": 0.8,
},
"llama-3-8b-instruct": {
"sym": True,
"group_size": 128,
"ratio": 0.8,
},
"gemma-7b-it": {
"sym": True,
"group_size": 128,
"ratio": 0.8,
},
"chatglm2-6b": {
"sym": True,
"group_size": 128,
"ratio": 0.72,
},
"qwen-7b-chat": {"sym": True, "group_size": 128, "ratio": 0.6},
"red-pajama-3b-chat": {
"sym": False,
"group_size": 128,
"ratio": 0.5,
},
"default": {
"sym": False,
"group_size": 128,
"ratio": 0.8,
},
}
model_compression_params = compression_configs.get(llm_model_id, compression_configs["default"])
if (int4_model_dir / "openvino_model.xml").exists():
return
remote_code = llm_model_configuration.get("remote_code", False)
export_command_base = "optimum-cli export openvino --model {} --task text-generation-with-past --weight-format int4".format(pt_model_id)
int4_compression_args = " --group-size {} --ratio {}".format(model_compression_params["group_size"], model_compression_params["ratio"])
if model_compression_params["sym"]:
int4_compression_args += " --sym"
print("updated")
if enable_awq:
int4_compression_args += " --awq --dataset wikitext2 --num-samples 128"
export_command_base += int4_compression_args
if remote_code:
export_command_base += " --trust-remote-code"
# export_command = export_command_base + " " + str(int4_model_dir)
if prepare_fp16_model:
convert_to_fp16()
if prepare_int8_model:
convert_to_int8()
if prepare_int4_model:
convert_to_int4()
fp16_weights = fp16_model_dir / "openvino_model.bin"
int8_weights = int8_model_dir / "openvino_model.bin"
int4_weights = int4_model_dir / "openvino_model.bin"
if fp16_weights.exists():
print(f"Size of FP16 model is {fp16_weights.stat().st_size / 1024 / 1024:.2f} MB")
for precision, compressed_weights in zip([8, 4], [int8_weights, int4_weights]):
if compressed_weights.exists():
print(f"Size of model with INT{precision} compressed weights is {compressed_weights.stat().st_size / 1024 / 1024:.2f} MB")
if compressed_weights.exists() and fp16_weights.exists():
print(f"Compression rate for INT{precision} model: {fp16_weights.stat().st_size / compressed_weights.stat().st_size:.3f}")
embedding_model_id = 'bge-small-en-v1.5' #'bge-small-en-v1.5', 'bge-large-en-v1.5', 'bge-m3'), value='bge-small-en-v1.5'
embedding_model_configuration = SUPPORTED_EMBEDDING_MODELS[model_language][embedding_model_id]
print(f"Selected {embedding_model_id} model")
export_command_base = "optimum-cli export openvino --model {} --task feature-extraction".format(embedding_model_configuration["model_id"])
export_command = export_command_base + " " + str(embedding_model_id)
rerank_model_id = "bge-reranker-v2-m3" #'bge-reranker-v2-m3', 'bge-reranker-large', 'bge-reranker-base')
rerank_model_configuration = SUPPORTED_RERANK_MODELS[rerank_model_id]
print(f"Selected {rerank_model_id} model")
export_command_base = "optimum-cli export openvino --model {} --task text-classification".format(rerank_model_configuration["model_id"])
export_command = export_command_base + " " + str(rerank_model_id)
embedding_device = "CPU"
USING_NPU = embedding_device == "NPU"
npu_embedding_dir = embedding_model_id + "-npu"
npu_embedding_path = Path(npu_embedding_dir) / "openvino_model.xml"
if USING_NPU and not Path(npu_embedding_dir).exists():
r = requests.get(
url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py",
)
with open("notebook_utils.py", "w") as f:
f.write(r.text)
import notebook_utils as utils
shutil.copytree(embedding_model_id, npu_embedding_dir)
utils.optimize_bge_embedding(Path(embedding_model_id) / "openvino_model.xml", npu_embedding_path)
rerank_device = "CPU"
llm_device = "CPU"
from langchain_community.embeddings import OpenVINOBgeEmbeddings
embedding_model_name = npu_embedding_dir if USING_NPU else embedding_model_id
batch_size = 1 if USING_NPU else 4
embedding_model_kwargs = {"device": embedding_device, "compile": False}
encode_kwargs = {
"mean_pooling": embedding_model_configuration["mean_pooling"],
"normalize_embeddings": embedding_model_configuration["normalize_embeddings"],
"batch_size": batch_size,
}
embedding = OpenVINOBgeEmbeddings(
model_name_or_path="BAAI/bge-small-en-v1.5",
model_kwargs=embedding_model_kwargs,
encode_kwargs=encode_kwargs,
)
if USING_NPU:
embedding.ov_model.reshape(1, 512)
embedding.ov_model.compile()
text = "This is a test document."
embedding_result = embedding.embed_query(text)
embedding_result[:3]
from langchain_community.document_compressors.openvino_rerank import OpenVINOReranker
rerank_model_name = rerank_model_id
rerank_model_kwargs = {"device": rerank_device}
rerank_top_n = 2
reranker = OpenVINOReranker(
model_name_or_path="BAAI/bge-reranker-v2-m3",
model_kwargs=rerank_model_kwargs,
top_n=rerank_top_n,
)
model_to_run = "INT4"
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
if model_to_run == "INT4":
model_dir = int4_model_dir
elif model_to_run == "INT8":
model_dir = int8_model_dir
else:
model_dir = fp16_model_dir
print(f"Loading model from {model_dir}")
ov_config = {"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "CACHE_DIR": ""}
if "GPU" in llm_device and "qwen2-7b-instruct" in llm_model_id:
ov_config["GPU_ENABLE_SDPA_OPTIMIZATION"] = "NO"
# On a GPU device a model is executed in FP16 precision. For red-pajama-3b-chat model there known accuracy
# issues caused by this, which we avoid by setting precision hint to "f32".
if llm_model_id == "red-pajama-3b-chat" and "GPU" in core.available_devices and llm_device in ["GPU", "AUTO"]:
ov_config["INFERENCE_PRECISION_HINT"] = "f32"
llm = HuggingFacePipeline.from_model_id(
model_id="meta-llama/Llama-3.2-3B-Instruct", #“meta-llama/Meta-Llama-3-8B"
task="text-generation",
backend="openvino",
model_kwargs={
"device": llm_device,
"ov_config": ov_config,
"trust_remote_code": True,
},
pipeline_kwargs={"max_new_tokens": 2},
)
llm.invoke("2 + 2 =")
import re
from typing import List
from langchain.text_splitter import (
CharacterTextSplitter,
RecursiveCharacterTextSplitter,
MarkdownTextSplitter,
)
from langchain.document_loaders import (
CSVLoader,
EverNoteLoader,
PyPDFLoader,
TextLoader,
UnstructuredEPubLoader,
UnstructuredHTMLLoader,
UnstructuredMarkdownLoader,
UnstructuredODTLoader,
UnstructuredPowerPointLoader,
UnstructuredWordDocumentLoader,
)
class ChineseTextSplitter(CharacterTextSplitter):
def __init__(self, pdf: bool = False, **kwargs):
super().__init__(**kwargs)
self.pdf = pdf
def split_text(self, text: str) -> List[str]:
if self.pdf:
text = re.sub(r"\n{3,}", "\n", text)
text = text.replace("\n\n", "")
sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))')
sent_list = []
for ele in sent_sep_pattern.split(text):
if sent_sep_pattern.match(ele) and sent_list:
sent_list[-1] += ele
elif ele:
sent_list.append(ele)
return sent_list
TEXT_SPLITERS = {
"Character": CharacterTextSplitter,
"RecursiveCharacter": RecursiveCharacterTextSplitter,
"Markdown": MarkdownTextSplitter,
"Chinese": ChineseTextSplitter,
}
LOADERS = {
".csv": (CSVLoader, {}),
".doc": (UnstructuredWordDocumentLoader, {}),
".docx": (UnstructuredWordDocumentLoader, {}),
".enex": (EverNoteLoader, {}),
".epub": (UnstructuredEPubLoader, {}),
".html": (UnstructuredHTMLLoader, {}),
".md": (UnstructuredMarkdownLoader, {}),
".odt": (UnstructuredODTLoader, {}),
".pdf": (PyPDFLoader, {}),
".ppt": (UnstructuredPowerPointLoader, {}),
".pptx": (UnstructuredPowerPointLoader, {}),
".txt": (TextLoader, {"encoding": "utf8"}),
}
chinese_examples = [
["英特尔®酷睿™ Ultra处理器可以降低多少功耗?"],
["相比英特尔之前的移动处理器产品,英特尔®酷睿™ Ultra处理器的AI推理性能提升了多少?"],
["英特尔博锐® Enterprise系统提供哪些功能?"],
]
english_examples = [
["How much power consumption can Intel® Core™ Ultra Processors help save?"],
["Compared to Intel’s previous mobile processor, what is the advantage of Intel® Core™ Ultra Processors for Artificial Intelligence?"],
["What can Intel vPro® Enterprise systems offer?"],
]
if model_language == "English":
# text_example_path = "text_example_en.pdf"
text_example_path = ['Supervisors-Guide-Accurate-Timekeeping_AH edits.docx','Salary-vs-Hourly-Guide_AH edits.docx','Employee-Guide-Accurate-Timekeeping_AH edits.docx','Eller Overtime Guidelines.docx','Eller FLSA information 9.2024_AH edits.docx','Accurate Timekeeping Supervisors 12.2.20_AH edits.docx']
else:
text_example_path = "text_example_cn.pdf"
examples = chinese_examples if (model_language == "Chinese") else english_examples
from langchain.prompts import PromptTemplate
from langchain_community.vectorstores import FAISS
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.docstore.document import Document
from langchain.retrievers import ContextualCompressionRetriever
from threading import Thread
import gradio as gr
stop_tokens = llm_model_configuration.get("stop_tokens")
rag_prompt_template = llm_model_configuration["rag_prompt_template"]
class StopOnTokens(StoppingCriteria):
def __init__(self, token_ids):
self.token_ids = token_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_id in self.token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
if stop_tokens is not None:
if isinstance(stop_tokens[0], str):
stop_tokens = llm.pipeline.tokenizer.convert_tokens_to_ids(stop_tokens)
stop_tokens = [StopOnTokens(stop_tokens)]
def load_single_document(file_path: str) -> List[Document]:
"""
helper for loading a single document
Params:
file_path: document path
Returns:
documents loaded
"""
ext = "." + file_path.rsplit(".", 1)[-1]
if ext in LOADERS:
loader_class, loader_args = LOADERS[ext]
loader = loader_class(file_path, **loader_args)
return loader.load()
raise ValueError(f"File does not exist '{ext}'")
def default_partial_text_processor(partial_text: str, new_text: str):
"""
helper for updating partially generated answer, used by default
Params:
partial_text: text buffer for storing previosly generated text
new_text: text update for the current step
Returns:
updated text string
"""
partial_text += new_text
return partial_text
text_processor = llm_model_configuration.get("partial_text_processor", default_partial_text_processor)
def create_vectordb(
docs, spliter_name, chunk_size, chunk_overlap, vector_search_top_k, vector_rerank_top_n, run_rerank, search_method, score_threshold, progress=gr.Progress()
):
"""
Initialize a vector database
Params:
doc: orignal documents provided by user
spliter_name: spliter method
chunk_size: size of a single sentence chunk
chunk_overlap: overlap size between 2 chunks
vector_search_top_k: Vector search top k
vector_rerank_top_n: Search rerank top n
run_rerank: whether run reranker
search_method: top k search method
score_threshold: score threshold when selecting 'similarity_score_threshold' method
"""
global db
global retriever
global combine_docs_chain
global rag_chain
if vector_rerank_top_n > vector_search_top_k:
gr.Warning("Search top k must >= Rerank top n")
documents = []
for doc in docs:
if type(doc) is not str:
doc = doc.name
documents.extend(load_single_document(doc))
text_splitter = TEXT_SPLITERS[spliter_name](chunk_size=chunk_size, chunk_overlap=chunk_overlap)
texts = text_splitter.split_documents(documents)
db = FAISS.from_documents(texts, embedding)
if search_method == "similarity_score_threshold":
search_kwargs = {"k": vector_search_top_k, "score_threshold": score_threshold}
else:
search_kwargs = {"k": vector_search_top_k}
retriever = db.as_retriever(search_kwargs=search_kwargs, search_type=search_method)
if run_rerank:
reranker.top_n = vector_rerank_top_n
retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever)
prompt = PromptTemplate.from_template(rag_prompt_template)
combine_docs_chain = create_stuff_documents_chain(llm, prompt)
rag_chain = create_retrieval_chain(retriever, combine_docs_chain)
return "Vector database is Ready"
def update_retriever(vector_search_top_k, vector_rerank_top_n, run_rerank, search_method, score_threshold):
"""
Update retriever
Params:
vector_search_top_k: Vector search top k
vector_rerank_top_n: Search rerank top n
run_rerank: whether run reranker
search_method: top k search method
score_threshold: score threshold when selecting 'similarity_score_threshold' method
"""
global db
global retriever
global combine_docs_chain
global rag_chain
if vector_rerank_top_n > vector_search_top_k:
gr.Warning("Search top k must >= Rerank top n")
if search_method == "similarity_score_threshold":
search_kwargs = {"k": vector_search_top_k, "score_threshold": score_threshold}
else:
search_kwargs = {"k": vector_search_top_k}
retriever = db.as_retriever(search_kwargs=search_kwargs, search_type=search_method)
if run_rerank:
retriever = ContextualCompressionRetriever(base_compressor=reranker, base_retriever=retriever)
reranker.top_n = vector_rerank_top_n
rag_chain = create_retrieval_chain(retriever, combine_docs_chain)
return "Vector database is Ready"
def user(message, history):
"""
callback function for updating user messages in interface on submit button click
Params:
message: current message
history: conversation history
Returns:
None
"""
# Append the user's message to the conversation history
return "", history + [[message, ""]]
def bot(history, temperature, top_p, top_k, repetition_penalty, hide_full_prompt, do_rag):
"""
callback function for running chatbot on submit button click
Params:
history: conversation history
temperature: parameter for control the level of creativity in AI-generated text.
By adjusting the `temperature`, you can influence the AI model's probability distribution, making the text more focused or diverse.
top_p: parameter for control the range of tokens considered by the AI model based on their cumulative probability.
top_k: parameter for control the range of tokens considered by the AI model based on their cumulative probability, selecting number of tokens with highest probability.
repetition_penalty: parameter for penalizing tokens based on how frequently they occur in the text.
hide_full_prompt: whether to show searching results in promopt.
do_rag: whether do RAG when generating texts.
"""
streamer = TextIteratorStreamer(
llm.pipeline.tokenizer,
timeout=60.0,
skip_prompt=hide_full_prompt,
skip_special_tokens=True,
)
llm.pipeline._forward_params = dict(
max_new_tokens=512,
temperature=temperature,
do_sample=temperature > 0.0,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
streamer=streamer,
)
if stop_tokens is not None:
llm.pipeline._forward_params["stopping_criteria"] = StoppingCriteriaList(stop_tokens)
if do_rag:
t1 = Thread(target=rag_chain.invoke, args=({"input": history[-1][0]},))
else:
input_text = rag_prompt_template.format(input=history[-1][0], context="")
t1 = Thread(target=llm.invoke, args=(input_text,))
t1.start()
# Initialize an empty string to store the generated text
partial_text = ""
for new_text in streamer:
partial_text = text_processor(partial_text, new_text)
history[-1][1] = partial_text
yield history
def request_cancel():
llm.pipeline.model.request.cancel()
def clear_files():
return "Vector Store is Not ready"
# initialize the vector store with example document
create_vectordb(
text_example_path, #changed
"RecursiveCharacter",
chunk_size=400,
chunk_overlap=50,
vector_search_top_k=10,
vector_rerank_top_n=2,
run_rerank=True,
search_method="similarity_score_threshold",
score_threshold=0.5,
)
with gr.Blocks(
theme=gr.themes.Soft(),
css=".disclaimer {font-variant-caps: all-small-caps;}",
) as demo:
gr.Markdown("""<h1><center>QA over Document</center></h1>""")
gr.Markdown(f"""<center>Powered by OpenVINO and {llm_model_id} </center>""")
with gr.Row():
with gr.Column(scale=1):
docs = gr.File(
label="Step 1: Load text files",
value=text_example_path, #changed
file_count="multiple",
file_types=[
".csv",
".doc",
".docx",
".enex",
".epub",
".html",
".md",
".odt",
".pdf",
".ppt",
".pptx",
".txt",
],
)
load_docs = gr.Button("Step 2: Build Vector Store", variant="primary")
db_argument = gr.Accordion("Vector Store Configuration", open=False)
with db_argument:
spliter = gr.Dropdown(
["Character", "RecursiveCharacter", "Markdown", "Chinese"],
value="RecursiveCharacter",
label="Text Spliter",
info="Method used to splite the documents",
multiselect=False,
)
chunk_size = gr.Slider(
label="Chunk size",
value=400,
minimum=50,
maximum=2000,
step=50,
interactive=True,
info="Size of sentence chunk",
)
chunk_overlap = gr.Slider(
label="Chunk overlap",
value=50,
minimum=0,
maximum=400,
step=10,
interactive=True,
info=("Overlap between 2 chunks"),
)
langchain_status = gr.Textbox(
label="Vector Store Status",
value="Vector Store is Ready",
interactive=False,
)
do_rag = gr.Checkbox(
value=True,
label="RAG is ON",
interactive=True,
info="Whether to do RAG for generation",
)
with gr.Accordion("Generation Configuration", open=False):
with gr.Row():
with gr.Column():
with gr.Row():
temperature = gr.Slider(
label="Temperature",
value=0.1,
minimum=0.0,
maximum=1.0,
step=0.1,
interactive=True,
info="Higher values produce more diverse outputs",
)
with gr.Column():
with gr.Row():
top_p = gr.Slider(
label="Top-p (nucleus sampling)",
value=1.0,
minimum=0.0,
maximum=1,
step=0.01,
interactive=True,
info=(
"Sample from the smallest possible set of tokens whose cumulative probability "
"exceeds top_p. Set to 1 to disable and sample from all tokens."
),
)
with gr.Column():
with gr.Row():
top_k = gr.Slider(
label="Top-k",
value=50,
minimum=0.0,
maximum=200,
step=1,
interactive=True,
info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
)
with gr.Column():
with gr.Row():
repetition_penalty = gr.Slider(
label="Repetition Penalty",
value=1.1,
minimum=1.0,
maximum=2.0,
step=0.1,
interactive=True,
info="Penalize repetition — 1.0 to disable.",
)
with gr.Column(scale=4):
chatbot = gr.Chatbot(
height=800,
label="Step 3: Input Query",
)
with gr.Row():
with gr.Column():
with gr.Row():
msg = gr.Textbox(
label="QA Message Box",
placeholder="Chat Message Box",
show_label=False,
container=False,
)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit", variant="primary")
stop = gr.Button("Stop")
clear = gr.Button("Clear")
gr.Examples(examples, inputs=msg, label="Click on any example and press the 'Submit' button")
retriever_argument = gr.Accordion("Retriever Configuration", open=True)
with retriever_argument:
with gr.Row():
with gr.Row():
do_rerank = gr.Checkbox(
value=True,
label="Rerank searching result",
interactive=True,
)
hide_context = gr.Checkbox(
value=True,
label="Hide searching result in prompt",
interactive=True,
)
with gr.Row():
search_method = gr.Dropdown(
["similarity_score_threshold", "similarity", "mmr"],
value="similarity_score_threshold",
label="Searching Method",
info="Method used to search vector store",
multiselect=False,
interactive=True,
)
with gr.Row():
score_threshold = gr.Slider(
0.01,
0.99,
value=0.5,
step=0.01,
label="Similarity Threshold",
info="Only working for 'similarity score threshold' method",
interactive=True,
)
with gr.Row():
vector_rerank_top_n = gr.Slider(
1,
10,
value=2,
step=1,
label="Rerank top n",
info="Number of rerank results",
interactive=True,
)
with gr.Row():
vector_search_top_k = gr.Slider(
1,
50,
value=10,
step=1,
label="Search top k",
info="Search top k must >= Rerank top n",
interactive=True,
)
docs.clear(clear_files, outputs=[langchain_status], queue=False)
load_docs.click(
create_vectordb,
inputs=[docs, spliter, chunk_size, chunk_overlap, vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
outputs=[langchain_status],
queue=False,
)
submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot,
[chatbot, temperature, top_p, top_k, repetition_penalty, hide_context, do_rag],
chatbot,
queue=True,
)
submit_click_event = submit.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot,
[chatbot, temperature, top_p, top_k, repetition_penalty, hide_context, do_rag],
chatbot,
queue=True,
)
stop.click(
fn=request_cancel,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(lambda: None, None, chatbot, queue=False)
vector_search_top_k.release(
update_retriever,
[vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
outputs=[langchain_status],
)
vector_rerank_top_n.release(
update_retriever,
inputs=[vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
outputs=[langchain_status],
)
do_rerank.change(
update_retriever,
inputs=[vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
outputs=[langchain_status],
)
search_method.change(
update_retriever,
inputs=[vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
outputs=[langchain_status],
)
score_threshold.change(
update_retriever,
inputs=[vector_search_top_k, vector_rerank_top_n, do_rerank, search_method, score_threshold],
outputs=[langchain_status],
)
demo.queue()
# if you are launching remotely, specify server_name and server_port
# demo.launch(server_port=8082)
# if you have any issue to launch on your platform, you can pass share=True to launch method:
demo.launch(share=True)
# it creates a publicly shareable link for the interface. Read more in the docs: https://gradio.app/docs/
# demo.launch()