manu commited on
Commit
602d806
Β·
verified Β·
1 Parent(s): 75ef360

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pdf2image import convert_from_path
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from torch.utils.data import DataLoader
7
+ from tqdm import tqdm
8
+ from transformers import AutoProcessor
9
+
10
+ from custom_colbert.models.paligemma_colbert_architecture import ColPali
11
+ from custom_colbert.trainer.retrieval_evaluator import CustomEvaluator
12
+
13
+
14
+ def process_images(processor, images, max_length: int = 50):
15
+ texts_doc = ["Describe the image."] * len(images)
16
+ images = [image.convert("RGB") for image in images]
17
+
18
+ batch_doc = processor(
19
+ text=texts_doc,
20
+ images=images,
21
+ return_tensors="pt",
22
+ padding="longest",
23
+ max_length=max_length + processor.image_seq_length,
24
+ )
25
+ return batch_doc
26
+
27
+
28
+ def process_queries(processor, queries, mock_image, max_length: int = 50):
29
+ texts_query = []
30
+ for query in queries:
31
+ query = f"Question: {query}<unused0><unused0><unused0><unused0><unused0>"
32
+ texts_query.append(query)
33
+
34
+ batch_query = processor(
35
+ images=[mock_image.convert("RGB")] * len(texts_query),
36
+ # NOTE: the image is not used in batch_query but it is required for calling the processor
37
+ text=texts_query,
38
+ return_tensors="pt",
39
+ padding="longest",
40
+ max_length=max_length + processor.image_seq_length,
41
+ )
42
+ del batch_query["pixel_values"]
43
+
44
+ batch_query["input_ids"] = batch_query["input_ids"][..., processor.image_seq_length :]
45
+ batch_query["attention_mask"] = batch_query["attention_mask"][..., processor.image_seq_length :]
46
+ return batch_query
47
+
48
+
49
+ def search(query: str, ds, images) -> str:
50
+ qs = []
51
+ with torch.no_grad():
52
+ batch_query = process_queries(processor, [query], mock_image)
53
+ batch_query = {k: v.to(device) for k, v in batch_query.items()}
54
+ embeddings_query = model(**batch_query)
55
+ qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
56
+
57
+ # run evaluation
58
+ retriever_evaluator = CustomEvaluator(is_multi_vector=True)
59
+ scores = retriever_evaluator.evaluate(qs, ds)
60
+
61
+ return f"The most relevant page is {scores.argmax(axis=1)}", images[scores.argmax(axis=1)]
62
+ # return f"Query: {query}, most relevant page: 1, {len(ds)}", images[1]
63
+
64
+
65
+ def index(file):
66
+ """Example script to run inference with ColPali"""
67
+ images = []
68
+ for f in file:
69
+ images.extend(convert_from_path(f))
70
+
71
+ # run inference - docs
72
+ dataloader = DataLoader(
73
+ images,
74
+ batch_size=4,
75
+ shuffle=False,
76
+ collate_fn=lambda x: process_images(processor, x),
77
+ )
78
+ ds = ["test", "double test"]
79
+ for batch_doc in tqdm(dataloader):
80
+ with torch.no_grad():
81
+ batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
82
+ embeddings_doc = model(**batch_doc)
83
+ ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
84
+ return f"Uploaded and converted {len(images)} pages", ds, images
85
+
86
+
87
+ COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
88
+ # Load model
89
+ model_name = "coldoc/colpali-3b-mix-448"
90
+ model = ColPali.from_pretrained("google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cuda").eval()
91
+ model.load_adapter(model_name)
92
+ processor = AutoProcessor.from_pretrained(model_name)
93
+ device = model.device
94
+ mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
95
+
96
+ with gr.Blocks() as demo:
97
+ gr.Markdown("# PDF to πŸ€— Dataset")
98
+ gr.Markdown("## 1️⃣ Upload PDFs")
99
+ file = gr.File(file_types=["pdf"], file_count="multiple")
100
+
101
+ gr.Markdown("## 2️⃣ Convert the PDFs and upload")
102
+ convert_button = gr.Button("πŸ”„ Convert and upload")
103
+ message = gr.Textbox("Files not yet uploaded")
104
+ embeds = gr.State()
105
+ imgs = gr.State()
106
+
107
+ # Define the actions
108
+ convert_button.click(
109
+ index,
110
+ inputs=[file],
111
+ outputs=[message, embeds, imgs]
112
+ )
113
+
114
+ gr.Markdown("## 3️⃣ Search")
115
+ query = gr.Textbox(placeholder="Enter your query here")
116
+ search_button = gr.Button("πŸ” Search")
117
+ message2 = gr.Textbox("Query not yet set")
118
+ output_img = gr.Image()
119
+
120
+ search_button.click(
121
+ search, inputs=[query, embeds, imgs],
122
+ outputs=[message2, output_img]
123
+ )
124
+
125
+
126
+ if __name__ == "__main__":
127
+ demo.queue(max_size=10).launch(debug=True)