Spaces:
Runtime error
Runtime error
File size: 4,321 Bytes
94b55f0 602d806 94b55f0 602d806 94b55f0 602d806 94b55f0 602d806 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import gradio as gr
from pdf2image import convert_from_path
import torch
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor
imoort os
from custom_colbert.models.paligemma_colbert_architecture import ColPali
from custom_colbert.trainer.retrieval_evaluator import CustomEvaluator
def process_images(processor, images, max_length: int = 50):
texts_doc = ["Describe the image."] * len(images)
images = [image.convert("RGB") for image in images]
batch_doc = processor(
text=texts_doc,
images=images,
return_tensors="pt",
padding="longest",
max_length=max_length + processor.image_seq_length,
)
return batch_doc
def process_queries(processor, queries, mock_image, max_length: int = 50):
texts_query = []
for query in queries:
query = f"Question: {query}<unused0><unused0><unused0><unused0><unused0>"
texts_query.append(query)
batch_query = processor(
images=[mock_image.convert("RGB")] * len(texts_query),
# NOTE: the image is not used in batch_query but it is required for calling the processor
text=texts_query,
return_tensors="pt",
padding="longest",
max_length=max_length + processor.image_seq_length,
)
del batch_query["pixel_values"]
batch_query["input_ids"] = batch_query["input_ids"][..., processor.image_seq_length :]
batch_query["attention_mask"] = batch_query["attention_mask"][..., processor.image_seq_length :]
return batch_query
def search(query: str, ds, images) -> str:
qs = []
with torch.no_grad():
batch_query = process_queries(processor, [query], mock_image)
batch_query = {k: v.to(device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
# run evaluation
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
scores = retriever_evaluator.evaluate(qs, ds)
return f"The most relevant page is {scores.argmax(axis=1)}", images[scores.argmax(axis=1)]
# return f"Query: {query}, most relevant page: 1, {len(ds)}", images[1]
def index(file):
"""Example script to run inference with ColPali"""
images = []
for f in file:
images.extend(convert_from_path(f))
# run inference - docs
dataloader = DataLoader(
images,
batch_size=4,
shuffle=False,
collate_fn=lambda x: process_images(processor, x),
)
ds = ["test", "double test"]
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return f"Uploaded and converted {len(images)} pages", ds, images
COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
# Load model
model_name = "coldoc/colpali-3b-mix-448"
token = os.environ.get("HF_TOKEN")
model = ColPali.from_pretrained("google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cuda", token=token).eval()
model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name, token=token)
device = model.device
mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
with gr.Blocks() as demo:
gr.Markdown("# PDF to π€ Dataset")
gr.Markdown("## 1οΈβ£ Upload PDFs")
file = gr.File(file_types=["pdf"], file_count="multiple")
gr.Markdown("## 2οΈβ£ Convert the PDFs and upload")
convert_button = gr.Button("π Convert and upload")
message = gr.Textbox("Files not yet uploaded")
embeds = gr.State()
imgs = gr.State()
# Define the actions
convert_button.click(
index,
inputs=[file],
outputs=[message, embeds, imgs]
)
gr.Markdown("## 3οΈβ£ Search")
query = gr.Textbox(placeholder="Enter your query here")
search_button = gr.Button("π Search")
message2 = gr.Textbox("Query not yet set")
output_img = gr.Image()
search_button.click(
search, inputs=[query, embeds, imgs],
outputs=[message2, output_img]
)
if __name__ == "__main__":
demo.queue(max_size=10).launch(debug=True) |