Spaces:
Runtime error
Runtime error
File size: 3,172 Bytes
94b55f0 602d806 3649694 602d806 b5297f4 602d806 5dfd724 602d806 b5297f4 5dfd724 602d806 5dfd724 602d806 5dfd724 602d806 b5297f4 5dfd724 602d806 5dfd724 602d806 654c2e1 602d806 5dfd724 602d806 d5db6a5 94b55f0 3649694 5dfd724 3649694 602d806 5dfd724 602d806 dad1e49 5dfd724 602d806 5dfd724 5923654 5dfd724 602d806 5dfd724 602d806 5dfd724 602d806 5dfd724 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import os
import gradio as gr
import torch
from pdf2image import convert_from_path
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor
import spaces
from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
@spaces.GPU
def search(query: str, ds, images):
qs = []
with torch.no_grad():
batch_query = process_queries(processor, [query], mock_image)
batch_query = {k: v.to(device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
# run evaluation
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
scores = retriever_evaluator.evaluate(qs, ds)
best_page = int(scores.argmax(axis=1).item())
return f"The most relevant page is {best_page}", images[best_page]
@spaces.GPU
def index(file, ds):
"""Example script to run inference with ColPali"""
images = []
for f in file:
images.extend(convert_from_path(f))
# run inference - docs
dataloader = DataLoader(
images,
batch_size=4,
shuffle=False,
collate_fn=lambda x: process_images(processor, x),
)
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return f"Uploaded and converted {len(images)} pages", ds, images
COLORS = ["#4285f4", "#db4437", "#f4b400", "#0f9d58", "#e48ef1"]
# Load model
model_name = "vidore/colpali"
token = os.environ.get("HF_TOKEN")
model = ColPali.from_pretrained(
"google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cuda", token=token
).eval()
model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name, token=token)
device = model.device
mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models ππ")
gr.Markdown("## 1οΈβ£ Upload PDFs")
file = gr.File(file_types=["pdf"], file_count="multiple")
gr.Markdown("## 2οΈβ£ Convert the PDFs and upload")
convert_button = gr.Button("π Convert and upload")
message = gr.Textbox("Files not yet uploaded")
embeds = gr.State(value=[])
imgs = gr.State(value=[])
# Define the actions
convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
gr.Markdown("## 3οΈβ£ Search")
query = gr.Textbox(placeholder="Enter your query here")
search_button = gr.Button("π Search")
message2 = gr.Textbox("Query not yet set")
output_img = gr.Image()
search_button.click(search, inputs=[query, embeds, imgs], outputs=[message2, output_img])
if __name__ == "__main__":
demo.queue(max_size=10).launch(debug=True) |