ask_my_thesis / app.py
jordyvl's picture
GPU enabled - small bug fix for LLM
847325c
# TODO: return all pages used to form answer
# TODO: question samples
# TEST: with and without GPU instance
# TODO: visual questions on page image (in same app)?
# expose more parameters
import torch
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import SimpleDirectoryReader
from llama_index.core import VectorStoreIndex, SummaryIndex
from llama_index.core.prompts import PromptTemplate
from llama_index.core import Settings
from PIL import Image
import gradio as gr
CHEAPMODE = torch.cuda.is_available()
# LLM = "HuggingFaceH4/zephyr-7b-alpha" if not CHEAPMODE else "microsoft/phi-2"
config = {
# "LLM": "meta-llama/Meta-Llama-3-8B",
# "LLM": "google/gemma-2b",
# "LLM": "microsoft/phi-2",
"LLM": "HuggingFaceH4/zephyr-7b-alpha",
"embeddings": "BAAI/bge-small-en-v1.5",
"similarity_top_k": 2,
"context_window": 2048,
"max_new_tokens": 200,
"temperature": 0.7,
"top_k": 5,
"top_p": 0.95,
"chunk_size": 512,
"chunk_overlap": 50,
}
def center_element(el):
return f"<div style='text-align: center;'>{el}</div>"
title = "Ask my thesis: Intelligent Automation for AI-Driven Document Understanding"
title = center_element(title)
description = """Chat with the thesis manuscript by asking questions and receive answers with reference to the page.
<div class="center">
<a href="https://jordy-vl.github.io/assets/phdthesis/VanLandeghem_Jordy_PhD-thesis.pdf">
<img src="https://ideogram.ai/api/images/direct/cc3Um6ClQkWJpVdXx6pWVA.png"
title="Thesis.pdf" alt="Ideogram image generated with prompt engineering" width="500" class="center"/></a>
</div> Click the visual above to be redirected to the PDF of the manuscript.
Technology used: [Llama-index](https://www.llamaindex.ai/), OS LLMs from HuggingFace
Spoiler: a quickly hacked together RAG application with a >1B LLM and online vector store can be quite slow on a 290 page document ⏳ (10s+)
"""
description = center_element(description)
def messages_to_prompt(messages):
prompt = ""
for message in messages:
if message.role == "system":
m = "You are an expert in the research field of document understanding, bayesian deep learning and neural networks."
prompt += f"<|system|>\n{m}</s>\n"
elif message.role == "user":
prompt += f"<|user|>\n{message.content}</s>\n"
elif message.role == "assistant":
prompt += f"<|assistant|>\n{message.content}</s>\n"
# ensure we start with a system prompt, insert blank if needed
if not prompt.startswith("<|system|>\n"):
prompt = "<|system|>\n</s>\n" + prompt
# add final assistant prompt
prompt = prompt + "<|assistant|>\n"
return prompt
def load_RAG_pipeline(config):
# LLM
quantization_config = None # dirty fix for CPU/GPU support
if torch.cuda.is_available():
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
llm = HuggingFaceLLM(
model_name=config["LLM"],
tokenizer_name=config["LLM"],
query_wrapper_prompt=PromptTemplate("<|system|>\n</s>\n<|user|>\n{query_str}</s>\n<|assistant|>\n"),
context_window=config["context_window"],
max_new_tokens=config["max_new_tokens"],
model_kwargs={"quantization_config": quantization_config},
# tokenizer_kwargs={},
generate_kwargs={"temperature": config["temperature"], "top_k": config["top_k"], "top_p": config["top_p"]},
messages_to_prompt=messages_to_prompt,
device_map="auto",
)
# Llama-index
Settings.llm = llm
Settings.embed_model = HuggingFaceEmbedding(model_name=config["embeddings"])
print(Settings)
Settings.chunk_size = config["chunk_size"]
Settings.chunk_overlap = config["chunk_overlap"]
# raw data
documents = SimpleDirectoryReader("assets/txts").load_data()
vector_index = VectorStoreIndex.from_documents(documents)
# summary_index = SummaryIndex.from_documents(documents)
# vector_index.persist(persist_dir="vectors")
# https://docs.llamaindex.ai/en/v0.10.17/understanding/storing/storing.html
query_engine = vector_index.as_query_engine(response_mode="compact", similarity_top_k=config["similarity_top_k"])
return query_engine
default_query_engine = load_RAG_pipeline(config)
# These are placeholder functions to simulate the behavior of the RAG setup.
# You would need to implement these with the actual logic to retrieve and generate answers based on the document.
def get_answer(question, query_engine=default_query_engine):
# Here you should implement the logic to generate an answer based on the question and the document.
# For example, you could use a machine learning model for RAG.
# answer = "This is a placeholder answer."
# https://docs.llamaindex.ai/en/stable/module_guides/supporting_modules/settings/#setting-local-configurations
response = query_engine.query(question)
print(f"A: {response}")
return response
def get_answer_page(response):
# Implement logic to retrieve the page number or an image of the page with the answer.
# best image
best_match = response.source_nodes[0].metadata["file_path"]
answer_page = int(best_match[-8:-4])
image = Image.open(best_match.replace("txt", "png"))
return image, f"Navigate to page {answer_page}"
# Create the gr.Interface function
def ask_my_thesis(
question,
similarity_top_k=config["similarity_top_k"],
context_window=config["context_window"],
max_new_tokens=config["max_new_tokens"],
temperature=config["temperature"],
top_k=config["top_k"],
top_p=config["top_p"],
chunk_size=config["chunk_size"],
chunk_overlap=config["chunk_overlap"],
):
print(f"Got Q: {question}")
query_engine = default_query_engine
# if any change in kwargs
# Check if any of the kwargs have changed
if (
temperature != config["temperature"]
or top_p != config["top_p"]
or max_new_tokens != config["max_new_tokens"]
# or LLM != config["LLM"]
# or embeddings != config["embeddings"]
or similarity_top_k != config["similarity_top_k"]
or context_window != config["context_window"]
or top_k != config["top_k"]
or chunk_size != config["chunk_size"]
or chunk_overlap != config["chunk_overlap"]
):
# Update the config dictionary with the new values
config["temperature"] = temperature
config["top_p"] = top_p
config["max_new_tokens"] = max_new_tokens
# config["LLM"] = LLM
# config["embeddings"] = embeddings
config["similarity_top_k"] = similarity_top_k
config["context_window"] = context_window
config["top_k"] = top_k
config["chunk_size"] = chunk_size
config["chunk_overlap"] = chunk_overlap
query_engine = load_RAG_pipeline(config)
answer = get_answer(question, query_engine=query_engine)
image, answer_page = get_answer_page(answer)
return answer.response, image, answer_page
# Set up the interface options based on the design in the image.
output_image = gr.Image(label="Answer Page")
# examples
examples = [
["What model is state-of-the-art on DUDE?"],
["Why is knowledge distillation interesting?"],
["What is ANLS?"],
]
# Define additional Gradio input elements
additional_inputs = [
# gr.Input("text", label="Question"),
# gr.Input("text", label="LLM", value=config["LLM"]),
# gr.Input("text", label="Embeddings", value=config["embeddings"]),
gr.Slider(1, 5, value=config["similarity_top_k"], label="Similarity Top K", step=1),
gr.Slider(512, 8048, value=config["context_window"], label="Context Window"),
gr.Slider(20, 500, value=config["max_new_tokens"], label="Max New Tokens"),
gr.Slider(0, 1, value=config["temperature"], label="Temperature"),
gr.Slider(1, 10, value=config["top_k"], label="Top K", step=1),
gr.Slider(0, 1, value=config["top_p"], label="Nucleus Sampling"),
gr.Slider(128, 4024, value=config["chunk_size"], label="Chunk Size"),
gr.Slider(0, 200, value=config["chunk_overlap"], label="Chunk Overlap"),
]
iface = gr.Interface(
fn=ask_my_thesis,
inputs=[gr.Textbox(label="Question", placeholder="Type your question here...")],
additional_inputs=additional_inputs,
outputs=[gr.Textbox(label="Answer"), output_image, gr.Label()],
examples=examples,
title=title,
description=description,
allow_flagging="auto",
cache_examples=True,
)
# https://github.com/gradio-app/gradio/issues/4309
# https://discuss.huggingface.co/t/add-background-image/16381/4 background image
# Start the application.
if __name__ == "__main__":
iface.launch()