AdrienB134's picture
gedet
20f229d
raw
history blame
8.13 kB
import os
import spaces
import gradio as gr
import torch
from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import (
process_images,
process_queries,
)
from pdf2image import convert_from_path
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import re
import time
from PIL import Image
import torch
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
@spaces.GPU
def model_inference(
images, text,
):
# print(type(images))
# print(images[0])
# images = Image.open(images[0][0])
# print(images)
# print(type(images))
images = [{"type": "image", "image": Image.open(image[0])} for image in images]
images.append({"type": "text", "text": text})
print(images)
# model = Qwen2VLForConditionalGeneration.from_pretrained(
# "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
# )
#We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
attn_implementation="flash_attention_2",
trust_remote_code=True,
torch_dtype="auto").cuda().eval()
# default processer
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
# The default range for the number of visual tokens per image in the model is 4-16384. You can set min_pixels and max_pixels according to your needs, such as a token count range of 256-1280, to balance speed and memory usage.
# min_pixels = 256*28*28
# max_pixels = 1280*28*28
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
messages = [
{
"role": "user",
"content": images,
}
]
# Preparation for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
del model
del processor
torch.cuda.empty_cache()
return output_text[0]
@spaces.GPU
def search(query: str, ds, images, k):
# Load colpali model
model_name = "vidore/colpali-v1.2"
token = os.environ.get("HF_TOKEN")
model = ColPali.from_pretrained(
"vidore/colpaligemma-3b-pt-448-base", torch_dtype=torch.bfloat16, device_map="cuda", token = token).eval()
model.load_adapter(model_name)
model = model.eval()
processor = AutoProcessor.from_pretrained(model_name, token = token)
mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
model.to(device)
qs = []
with torch.no_grad():
batch_query = process_queries(processor, [query], mock_image)
batch_query = {k: v.to(device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
scores = retriever_evaluator.evaluate(qs, ds)
top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]
results = []
for idx in top_k_indices:
results.append((images[idx])) #, f"Page {idx}"
del model
del processor
torch.cuda.empty_cache()
print("done")
return results
def index(files, ds):
print("Converting files")
images = convert_files(files)
print(f"Files converted with {len(images)} images.")
return index_gpu(images, ds)
def convert_files(files):
images = []
for f in files:
images.extend(convert_from_path(f, thread_count=4))
if len(images) >= 150:
raise gr.Error("The number of images in the dataset should be less than 150.")
return images
@spaces.GPU
def index_gpu(images, ds):
"""Example script to run inference with ColPali"""
# Load colpali model
model_name = "vidore/colpali-v1.2"
token = os.environ.get("HF_TOKEN")
model = ColPali.from_pretrained(
"vidore/colpaligemma-3b-pt-448-base", torch_dtype=torch.bfloat16, device_map="cuda", token = token).eval()
model.load_adapter(model_name)
model = model.eval()
processor = AutoProcessor.from_pretrained(model_name, token = token)
mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
# run inference - docs
dataloader = DataLoader(
images,
batch_size=4,
shuffle=False,
collate_fn=lambda x: process_images(processor, x),
)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
model.to(device)
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
del model
del processor
torch.cuda.empty_cache()
print("done")
return f"Uploaded and converted {len(images)} pages", ds, images
def get_example():
return [
[["RAPPORT_DEVELOPPEMENT_DURABLE_2019.pdf"], "Quels sont les 4 axes majeurs des achats?"],
[["RAPPORT_DEVELOPPEMENT_DURABLE_2019.pdf"], "Quelles sont les actions entreprise en Afrique du Sud?"],
[["RAPPORT_DEVELOPPEMENT_DURABLE_2019.pdf"], "fais moi un tableau markdown de la répartition homme femme"],
]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# ColPali + Qwen2VL 2B: Efficient Document Retrieval with Vision Language Models 📚")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("## 1️⃣ Upload PDFs")
file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDFs")
message = gr.Textbox("Files not yet uploaded", label="Status")
convert_button = gr.Button("🔄 Index documents")
embeds = gr.State(value=[])
imgs = gr.State(value=[])
img_chunk = gr.State(value=[])
with gr.Column(scale=3):
gr.Markdown("## 2️⃣ Search")
query = gr.Textbox(placeholder="Enter your query here", label="Query")
k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)
search_button = gr.Button("🔍 Search", variant="primary")
with gr.Row():
gr.Examples(
examples=get_example(),
inputs=[file, query],
)
# Define the actions
output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery])
answer_button = gr.Button("Answer", variant="primary")
output = gr.Markdown(label="Output")
answer_button.click(model_inference, inputs=[output_gallery, query], outputs=output)
if __name__ == "__main__":
demo.queue(max_size=10).launch(debug=True)