Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import spaces | |
import gradio as gr | |
import torch | |
from colpali_engine.models.paligemma_colbert_architecture import ColPali | |
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator | |
from colpali_engine.utils.colpali_processing_utils import ( | |
process_images, | |
process_queries, | |
) | |
from pdf2image import convert_from_path | |
from PIL import Image | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor | |
from qwen_vl_utils import process_vision_info | |
import re | |
import time | |
from PIL import Image | |
import torch | |
import subprocess | |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
def model_inference( | |
images, text, | |
): | |
# print(type(images)) | |
# print(images[0]) | |
# images = Image.open(images[0][0]) | |
# print(images) | |
# print(type(images)) | |
images = [{"type": "image", "image": Image.open(image[0])} for image in images] | |
images.append({"type": "text", "text": text}) | |
print(images) | |
# model = Qwen2VLForConditionalGeneration.from_pretrained( | |
# "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto" | |
# ) | |
#We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios. | |
model = Qwen2VLForConditionalGeneration.from_pretrained( | |
"Qwen/Qwen2-VL-2B-Instruct", | |
attn_implementation="flash_attention_2", | |
trust_remote_code=True, | |
torch_dtype="auto").cuda().eval() | |
# default processer | |
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") | |
# The default range for the number of visual tokens per image in the model is 4-16384. You can set min_pixels and max_pixels according to your needs, such as a token count range of 256-1280, to balance speed and memory usage. | |
# min_pixels = 256*28*28 | |
# max_pixels = 1280*28*28 | |
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels) | |
messages = [ | |
{ | |
"role": "user", | |
"content": images, | |
} | |
] | |
# Preparation for inference | |
text = processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
image_inputs, video_inputs = process_vision_info(messages) | |
inputs = processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
padding=True, | |
return_tensors="pt", | |
) | |
inputs = inputs.to("cuda") | |
# Inference: Generation of the output | |
generated_ids = model.generate(**inputs, max_new_tokens=128) | |
generated_ids_trimmed = [ | |
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
] | |
output_text = processor.batch_decode( | |
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
) | |
del model | |
del processor | |
torch.cuda.empty_cache() | |
return output_text[0] | |
def search(query: str, ds, images, k): | |
# Load colpali model | |
model_name = "vidore/colpali-v1.2" | |
token = os.environ.get("HF_TOKEN") | |
model = ColPali.from_pretrained( | |
"vidore/colpaligemma-3b-pt-448-base", torch_dtype=torch.bfloat16, device_map="cuda", token = token).eval() | |
model.load_adapter(model_name) | |
model = model.eval() | |
processor = AutoProcessor.from_pretrained(model_name, token = token) | |
mock_image = Image.new("RGB", (448, 448), (255, 255, 255)) | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
if device != model.device: | |
model.to(device) | |
qs = [] | |
with torch.no_grad(): | |
batch_query = process_queries(processor, [query], mock_image) | |
batch_query = {k: v.to(device) for k, v in batch_query.items()} | |
embeddings_query = model(**batch_query) | |
qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) | |
retriever_evaluator = CustomEvaluator(is_multi_vector=True) | |
scores = retriever_evaluator.evaluate(qs, ds) | |
top_k_indices = scores.argsort(axis=1)[0][-k:][::-1] | |
results = [] | |
for idx in top_k_indices: | |
results.append((images[idx])) #, f"Page {idx}" | |
del model | |
del processor | |
torch.cuda.empty_cache() | |
print("done") | |
return results | |
def index(files, ds): | |
print("Converting files") | |
images = convert_files(files) | |
print(f"Files converted with {len(images)} images.") | |
return index_gpu(images, ds) | |
def convert_files(files): | |
images = [] | |
for f in files: | |
images.extend(convert_from_path(f, thread_count=4)) | |
if len(images) >= 150: | |
raise gr.Error("The number of images in the dataset should be less than 150.") | |
return images | |
def index_gpu(images, ds): | |
"""Example script to run inference with ColPali""" | |
# Load colpali model | |
model_name = "vidore/colpali-v1.2" | |
token = os.environ.get("HF_TOKEN") | |
model = ColPali.from_pretrained( | |
"vidore/colpaligemma-3b-pt-448-base", torch_dtype=torch.bfloat16, device_map="cuda", token = token).eval() | |
model.load_adapter(model_name) | |
model = model.eval() | |
processor = AutoProcessor.from_pretrained(model_name, token = token) | |
mock_image = Image.new("RGB", (448, 448), (255, 255, 255)) | |
# run inference - docs | |
dataloader = DataLoader( | |
images, | |
batch_size=4, | |
shuffle=False, | |
collate_fn=lambda x: process_images(processor, x), | |
) | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
if device != model.device: | |
model.to(device) | |
for batch_doc in tqdm(dataloader): | |
with torch.no_grad(): | |
batch_doc = {k: v.to(device) for k, v in batch_doc.items()} | |
embeddings_doc = model(**batch_doc) | |
ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) | |
del model | |
del processor | |
torch.cuda.empty_cache() | |
print("done") | |
return f"Uploaded and converted {len(images)} pages", ds, images | |
def get_example(): | |
return [ | |
[["RAPPORT_DEVELOPPEMENT_DURABLE_2019.pdf"], "Quels sont les 4 axes majeurs des achats?"], | |
[["RAPPORT_DEVELOPPEMENT_DURABLE_2019.pdf"], "Quelles sont les actions entreprise en Afrique du Sud?"], | |
[["RAPPORT_DEVELOPPEMENT_DURABLE_2019.pdf"], "fais moi un tableau markdown de la répartition homme femme"], | |
] | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# ColPali + Qwen2VL 2B: Efficient Document Retrieval with Vision Language Models 📚") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
gr.Markdown("## 1️⃣ Upload PDFs") | |
file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDFs") | |
message = gr.Textbox("Files not yet uploaded", label="Status") | |
convert_button = gr.Button("🔄 Index documents") | |
embeds = gr.State(value=[]) | |
imgs = gr.State(value=[]) | |
img_chunk = gr.State(value=[]) | |
with gr.Column(scale=3): | |
gr.Markdown("## 2️⃣ Search") | |
query = gr.Textbox(placeholder="Enter your query here", label="Query") | |
k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5) | |
search_button = gr.Button("🔍 Search", variant="primary") | |
with gr.Row(): | |
gr.Examples( | |
examples=get_example(), | |
inputs=[file, query], | |
) | |
# Define the actions | |
output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True) | |
convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs]) | |
search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery]) | |
answer_button = gr.Button("Answer", variant="primary") | |
output = gr.Markdown(label="Output") | |
answer_button.click(model_inference, inputs=[output_gallery, query], outputs=output) | |
if __name__ == "__main__": | |
demo.queue(max_size=10).launch(debug=True) |