Spaces:
Paused
Paused
# TODO: return all pages used to form answer | |
# TODO: question samples | |
# TEST: with and without GPU instance | |
# TODO: visual questions on page image (in same app)? | |
# expose more parameters | |
import torch | |
from llama_index.llms.huggingface import HuggingFaceLLM | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.core import SimpleDirectoryReader | |
from llama_index.core import VectorStoreIndex, SummaryIndex | |
from llama_index.core.prompts import PromptTemplate | |
from llama_index.core import Settings | |
from PIL import Image | |
import gradio as gr | |
CHEAPMODE = torch.cuda.is_available() | |
# LLM = "HuggingFaceH4/zephyr-7b-alpha" if not CHEAPMODE else "microsoft/phi-2" | |
config = { | |
# "LLM": "meta-llama/Meta-Llama-3-8B", | |
"LLM": "microsoft/phi-2", | |
# "LLM": "HuggingFaceH4/zephyr-7b-alpha", | |
"embeddings": "BAAI/bge-small-en-v1.5", | |
"similarity_top_k": 2, | |
"context_window": 4048, | |
"max_new_tokens": 200, | |
"temperature": 0.7, | |
"top_k": 5, | |
"top_p": 0.95, | |
"chunk_size": 512, | |
"chunk_overlap": 50, | |
} | |
def center_element(el): | |
return f"<div style='text-align: center;'>{el}</div>" | |
title = "Ask my thesis: Intelligent Automation for AI-Driven Document Understanding" | |
title = center_element(title) | |
description = """Chat with the thesis manuscript by asking questions and receive answers with reference to the page. | |
<div class="center"> | |
<a href="https://jordy-vl.github.io/assets/phdthesis/VanLandeghem_Jordy_PhD-thesis.pdf"> | |
<img src="https://ideogram.ai/api/images/direct/cc3Um6ClQkWJpVdXx6pWVA.png" | |
title="Thesis.pdf" alt="Ideogram image generated with prompt engineering" width="500" class="center"/></a> | |
</div> Click the visual above to be redirected to the PDF of the manuscript. | |
Technology used: [Llama-index](https://www.llamaindex.ai/), OS LLMs from HuggingFace | |
Spoiler: a quickly hacked together RAG application with a >1B LLM and online vector store can be quite slow on a 290 page document ⏳ (10s+) | |
""" | |
description = center_element(description) | |
def messages_to_prompt(messages): | |
prompt = "" | |
for message in messages: | |
if message.role == "system": | |
m = "You are an expert in the research field of document understanding, bayesian deep learning and neural networks." | |
prompt += f"<|system|>\n{m}</s>\n" | |
elif message.role == "user": | |
prompt += f"<|user|>\n{message.content}</s>\n" | |
elif message.role == "assistant": | |
prompt += f"<|assistant|>\n{message.content}</s>\n" | |
# ensure we start with a system prompt, insert blank if needed | |
if not prompt.startswith("<|system|>\n"): | |
prompt = "<|system|>\n</s>\n" + prompt | |
# add final assistant prompt | |
prompt = prompt + "<|assistant|>\n" | |
return prompt | |
def load_RAG_pipeline(config): | |
# LLM | |
quantization_config = None # dirty fix for CPU/GPU support | |
if torch.cuda.is_available(): | |
from transformers import BitsAndBytesConfig | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
) | |
llm = HuggingFaceLLM( | |
model_name=config["LLM"], | |
tokenizer_name=config["LLM"], | |
query_wrapper_prompt=PromptTemplate("<|system|>\n</s>\n<|user|>\n{query_str}</s>\n<|assistant|>\n"), | |
context_window=config["context_window"], | |
max_new_tokens=config["max_new_tokens"], | |
model_kwargs={"quantization_config": quantization_config}, | |
# tokenizer_kwargs={}, | |
generate_kwargs={"temperature": config["temperature"], "top_k": config["top_k"], "top_p": config["top_p"]}, | |
messages_to_prompt=messages_to_prompt, | |
device_map="auto", | |
) | |
# Llama-index | |
Settings.llm = llm | |
Settings.embed_model = HuggingFaceEmbedding(model_name=config["embeddings"]) | |
print(Settings) | |
Settings.chunk_size = config["chunk_size"] | |
Settings.chunk_overlap = config["chunk_overlap"] | |
# raw data | |
documents = SimpleDirectoryReader("assets/txts").load_data() | |
vector_index = VectorStoreIndex.from_documents(documents) | |
# summary_index = SummaryIndex.from_documents(documents) | |
# vector_index.persist(persist_dir="vectors") | |
# https://docs.llamaindex.ai/en/v0.10.17/understanding/storing/storing.html | |
query_engine = vector_index.as_query_engine(response_mode="compact", similarity_top_k=config["similarity_top_k"]) | |
return query_engine | |
default_query_engine = load_RAG_pipeline(config) | |
# These are placeholder functions to simulate the behavior of the RAG setup. | |
# You would need to implement these with the actual logic to retrieve and generate answers based on the document. | |
def get_answer(question, config, query_engine=default_query_engine): | |
# Here you should implement the logic to generate an answer based on the question and the document. | |
# For example, you could use a machine learning model for RAG. | |
# answer = "This is a placeholder answer." | |
# https://docs.llamaindex.ai/en/stable/module_guides/supporting_modules/settings/#setting-local-configurations | |
# if temperature or nucleus sampling or max_tokens != as in config, recall query engine | |
response = query_engine.query(question) | |
print(f"A: {response}") | |
return response | |
def get_answer_page(response): | |
# Implement logic to retrieve the page number or an image of the page with the answer. | |
# best image | |
best_match = response.source_nodes[0].metadata["file_path"] | |
answer_page = int(best_match[-8:-4]) | |
image = Image.open(best_match.replace("txt", "png")) | |
return image, f"Navigate to page {answer_page}" | |
# Create the gr.Interface function | |
def ask_my_thesis( | |
question, | |
LLM=config["LLM"], | |
embeddings=config["embeddings"], | |
similarity_top_k=config["similarity_top_k"], | |
context_window=config["context_window"], | |
max_new_tokens=config["max_new_tokens"], | |
temperature=config["temperature"], | |
top_k=config["top_k"], | |
top_p=config["top_p"], | |
chunk_size=config["chunk_size"], | |
chunk_overlap=config["chunk_overlap"], | |
): | |
print(f"Got Q: {question}") | |
query_engine = default_query_engine | |
# if any change in kwargs | |
# Check if any of the kwargs have changed | |
if ( | |
temperature != config["temperature"] | |
or top_p != config["top_p"] | |
or max_new_tokens != config["max_new_tokens"] | |
or LLM != config["LLM"] | |
or embeddings != config["embeddings"] | |
or similarity_top_k != config["similarity_top_k"] | |
or context_window != config["context_window"] | |
or top_k != config["top_k"] | |
or chunk_size != config["chunk_size"] | |
or chunk_overlap != config["chunk_overlap"] | |
): | |
# Update the config dictionary with the new values | |
config["temperature"] = temperature | |
config["top_p"] = top_p | |
config["max_new_tokens"] = max_new_tokens | |
# config["LLM"] = LLM | |
# config["embeddings"] = embeddings | |
config["similarity_top_k"] = similarity_top_k | |
config["context_window"] = context_window | |
config["top_k"] = top_k | |
config["chunk_size"] = chunk_size | |
config["chunk_overlap"] = chunk_overlap | |
query_engine = load_RAG_pipeline(config) | |
answer = get_answer(question, config, query_engine=query_engine) | |
image, answer_page = get_answer_page(answer) | |
return answer.response, image, answer_page | |
# Set up the interface options based on the design in the image. | |
output_image = gr.Image(label="Answer Page") | |
# examples | |
examples = [ | |
["What model is state-of-the-art on DUDE?"], | |
["Why is knowledge distillation interesting?"], | |
["What is ANLS?"], | |
] | |
# Define additional Gradio input elements | |
additional_inputs = [ | |
# gr.Input("text", label="Question"), | |
# gr.Input("text", label="LLM", value=config["LLM"]), | |
# gr.Input("text", label="Embeddings", value=config["embeddings"]), | |
gr.Slider(1, 5, value=config["similarity_top_k"], label="Similarity Top K"), | |
gr.Slider(512, 8048, value=config["context_window"], label="Context Window"), | |
gr.Slider(20, 250, value=config["max_new_tokens"], label="Max New Tokens"), | |
gr.Slider(0, 1, value=config["temperature"], label="Temperature"), | |
gr.Slider(1, 10, value=config["top_k"], label="Top K"), | |
gr.Slider(0, 1, value=config["top_p"], label="Nucleus Sampling"), | |
gr.Slider(128, 4024, value=config["chunk_size"], label="Chunk Size"), | |
gr.Slider(0, 200, value=config["chunk_overlap"], label="Chunk Overlap"), | |
] | |
iface = gr.Interface( | |
fn=ask_my_thesis, | |
inputs=[gr.Textbox(label="Question", placeholder="Type your question here...")], | |
additional_inputs=additional_inputs, | |
outputs=[gr.Textbox(label="Answer"), output_image, gr.Label()], | |
examples=examples, | |
title=title, | |
description=description, | |
allow_flagging="auto", | |
cache_examples=True, | |
) | |
# https://github.com/gradio-app/gradio/issues/4309 | |
# https://discuss.huggingface.co/t/add-background-image/16381/4 background image | |
# Start the application. | |
if __name__ == "__main__": | |
iface.launch() | |