Spaces:
Running
on
T4
Running
on
T4
thomasht86
commited on
Commit
β’
b7897bb
1
Parent(s):
bb4f59a
Upload folder using huggingface_hub
Browse files- .DS_Store +0 -0
- .env.example +2 -1
- .gitignore +3 -1
- README.md +20 -1
- backend/colpali.py +25 -7
- frontend/app.py +46 -13
- frontend/layout.py +95 -14
- globals.css +22 -1
- main.py +106 -19
- output.css +77 -1
- prepare_feed_deploy.py +977 -0
- pyproject.toml +10 -1
- static/.DS_Store +0 -0
- uv.lock +0 -0
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
.env.example
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
VESPA_APP_URL=https://abcde.z.vespa-app.cloud
|
2 |
HF_TOKEN=hf_xxxxxxxxxx
|
3 |
-
VESPA_CLOUD_SECRET_TOKEN=vespa_cloud_xxxxxxxx
|
|
|
|
1 |
VESPA_APP_URL=https://abcde.z.vespa-app.cloud
|
2 |
HF_TOKEN=hf_xxxxxxxxxx
|
3 |
+
VESPA_CLOUD_SECRET_TOKEN=vespa_cloud_xxxxxxxx
|
4 |
+
GEMINI_API_KEY=
|
.gitignore
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
.sesskey
|
2 |
.venv/
|
3 |
__pycache__/
|
|
|
4 |
.python-version
|
5 |
.env
|
6 |
template/
|
7 |
*.json
|
8 |
-
output/
|
|
|
|
1 |
.sesskey
|
2 |
.venv/
|
3 |
__pycache__/
|
4 |
+
ipynb_checkpoints/
|
5 |
.python-version
|
6 |
.env
|
7 |
template/
|
8 |
*.json
|
9 |
+
output/
|
10 |
+
pdfs/
|
README.md
CHANGED
@@ -27,7 +27,7 @@ preload_from_hub:
|
|
27 |
|
28 |
# Visual Retrieval ColPali
|
29 |
|
30 |
-
#
|
31 |
|
32 |
First, install `uv`:
|
33 |
|
@@ -35,6 +35,25 @@ First, install `uv`:
|
|
35 |
curl -LsSf https://astral.sh/uv/install.sh | sh
|
36 |
```
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
Then, in this directory, run:
|
39 |
|
40 |
```bash
|
|
|
27 |
|
28 |
# Visual Retrieval ColPali
|
29 |
|
30 |
+
# Prepare data and Vespa application
|
31 |
|
32 |
First, install `uv`:
|
33 |
|
|
|
35 |
curl -LsSf https://astral.sh/uv/install.sh | sh
|
36 |
```
|
37 |
|
38 |
+
Then, run:
|
39 |
+
|
40 |
+
```bash
|
41 |
+
uv sync --extra dev --extra feed
|
42 |
+
```
|
43 |
+
|
44 |
+
Convert the `prepare_feed_deploy.py` to notebook to:
|
45 |
+
|
46 |
+
```bash
|
47 |
+
jupytext --to notebook prepare_feed_deploy.py
|
48 |
+
```
|
49 |
+
|
50 |
+
And launch a Jupyter instance, see https://docs.astral.sh/uv/guides/integration/jupyter/ for recommended approach.
|
51 |
+
|
52 |
+
Open and follow the `prepare_feed_deploy.ipynb` notebook to prepare the data and deploy the Vespa application.
|
53 |
+
|
54 |
+
# Developing on the web app
|
55 |
+
|
56 |
+
|
57 |
Then, in this directory, run:
|
58 |
|
59 |
```bash
|
backend/colpali.py
CHANGED
@@ -170,13 +170,13 @@ def gen_similarity_maps(
|
|
170 |
if vespa_sim_maps:
|
171 |
print("Using provided similarity maps")
|
172 |
# A sim map looks like this:
|
173 |
-
# "
|
174 |
# {
|
175 |
# "address": {
|
176 |
# "patch": "0",
|
177 |
# "querytoken": "0"
|
178 |
# },
|
179 |
-
# "value":
|
180 |
# },
|
181 |
# ... and so on.
|
182 |
# Now turn these into a tensor of same shape as previous similarity map
|
@@ -189,7 +189,7 @@ def gen_similarity_maps(
|
|
189 |
)
|
190 |
)
|
191 |
for idx, vespa_sim_map in enumerate(vespa_sim_maps):
|
192 |
-
for cell in vespa_sim_map["
|
193 |
patch = int(cell["address"]["patch"])
|
194 |
# if dummy model then just use 1024 as the image_seq_length
|
195 |
|
@@ -359,7 +359,7 @@ async def query_vespa_default(
|
|
359 |
start = time.perf_counter()
|
360 |
response: VespaQueryResponse = await session.query(
|
361 |
body={
|
362 |
-
"yql": "select id,title,url,
|
363 |
"ranking": "default",
|
364 |
"query": query,
|
365 |
"timeout": timeout,
|
@@ -392,7 +392,7 @@ async def query_vespa_bm25(
|
|
392 |
start = time.perf_counter()
|
393 |
response: VespaQueryResponse = await session.query(
|
394 |
body={
|
395 |
-
"yql": "select id,title,url,
|
396 |
"ranking": "bm25",
|
397 |
"query": query,
|
398 |
"timeout": timeout,
|
@@ -472,7 +472,7 @@ async def query_vespa_nearest_neighbor(
|
|
472 |
**query_tensors,
|
473 |
"presentation.timing": True,
|
474 |
# if we use rank({nn_string}, userQuery()), dynamic summary doesn't work, see https://github.com/vespa-engine/vespa/issues/28704
|
475 |
-
"yql": f"select id,title,snippet,text,url,
|
476 |
"ranking.profile": "retrieval-and-rerank",
|
477 |
"timeout": timeout,
|
478 |
"hits": hits,
|
@@ -492,6 +492,24 @@ def is_special_token(token: str) -> bool:
|
|
492 |
return True
|
493 |
return False
|
494 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
495 |
|
496 |
async def get_result_from_query(
|
497 |
app: Vespa,
|
@@ -538,7 +556,7 @@ def add_sim_maps_to_result(
|
|
538 |
imgs: List[str] = []
|
539 |
vespa_sim_maps: List[str] = []
|
540 |
for single_result in result["root"]["children"]:
|
541 |
-
img = single_result["fields"]["
|
542 |
if img:
|
543 |
imgs.append(img)
|
544 |
vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
|
|
|
170 |
if vespa_sim_maps:
|
171 |
print("Using provided similarity maps")
|
172 |
# A sim map looks like this:
|
173 |
+
# "quantized": [
|
174 |
# {
|
175 |
# "address": {
|
176 |
# "patch": "0",
|
177 |
# "querytoken": "0"
|
178 |
# },
|
179 |
+
# "value": 12, # score in range [-128, 127]
|
180 |
# },
|
181 |
# ... and so on.
|
182 |
# Now turn these into a tensor of same shape as previous similarity map
|
|
|
189 |
)
|
190 |
)
|
191 |
for idx, vespa_sim_map in enumerate(vespa_sim_maps):
|
192 |
+
for cell in vespa_sim_map["quantized"]["cells"]:
|
193 |
patch = int(cell["address"]["patch"])
|
194 |
# if dummy model then just use 1024 as the image_seq_length
|
195 |
|
|
|
359 |
start = time.perf_counter()
|
360 |
response: VespaQueryResponse = await session.query(
|
361 |
body={
|
362 |
+
"yql": "select id,title,url,blur_image,page_number,snippet,text,summaryfeatures from pdf_page where userQuery();",
|
363 |
"ranking": "default",
|
364 |
"query": query,
|
365 |
"timeout": timeout,
|
|
|
392 |
start = time.perf_counter()
|
393 |
response: VespaQueryResponse = await session.query(
|
394 |
body={
|
395 |
+
"yql": "select id,title,url,blur_image,page_number,snippet,text,summaryfeatures from pdf_page where userQuery();",
|
396 |
"ranking": "bm25",
|
397 |
"query": query,
|
398 |
"timeout": timeout,
|
|
|
472 |
**query_tensors,
|
473 |
"presentation.timing": True,
|
474 |
# if we use rank({nn_string}, userQuery()), dynamic summary doesn't work, see https://github.com/vespa-engine/vespa/issues/28704
|
475 |
+
"yql": f"select id,title,snippet,text,url,blur_image,page_number,summaryfeatures from pdf_page where {nn_string} or userQuery()",
|
476 |
"ranking.profile": "retrieval-and-rerank",
|
477 |
"timeout": timeout,
|
478 |
"hits": hits,
|
|
|
492 |
return True
|
493 |
return False
|
494 |
|
495 |
+
async def get_full_image_from_vespa(
|
496 |
+
app: Vespa,
|
497 |
+
id: str) -> str:
|
498 |
+
async with app.asyncio(connections=1, total_timeout=120) as session:
|
499 |
+
start = time.perf_counter()
|
500 |
+
response: VespaQueryResponse = await session.query(
|
501 |
+
body={
|
502 |
+
"yql": f"select full_image from pdf_page where id contains \"{id}\"",
|
503 |
+
"ranking": "unranked",
|
504 |
+
"presentation.timing": True,
|
505 |
+
},
|
506 |
+
)
|
507 |
+
assert response.is_successful(), response.json
|
508 |
+
stop = time.perf_counter()
|
509 |
+
print(
|
510 |
+
f"Getting image from Vespa took: {stop - start} s, vespa said searchtime was {response.json.get('timing', {}).get('searchtime', -1)} s"
|
511 |
+
)
|
512 |
+
return response.json["root"]["children"][0]["fields"]["full_image"]
|
513 |
|
514 |
async def get_result_from_query(
|
515 |
app: Vespa,
|
|
|
556 |
imgs: List[str] = []
|
557 |
vespa_sim_maps: List[str] = []
|
558 |
for single_result in result["root"]["children"]:
|
559 |
+
img = single_result["fields"]["blur_image"]
|
560 |
if img:
|
561 |
imgs.append(img)
|
562 |
vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
|
frontend/app.py
CHANGED
@@ -131,9 +131,13 @@ def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"):
|
|
131 |
|
132 |
def SampleQueries():
|
133 |
sample_queries = [
|
134 |
-
"
|
135 |
-
"
|
136 |
-
"
|
|
|
|
|
|
|
|
|
137 |
]
|
138 |
|
139 |
query_badges = []
|
@@ -193,21 +197,23 @@ def Search(request, search_results=[]):
|
|
193 |
)
|
194 |
return Div(
|
195 |
Div(
|
196 |
-
SearchBox(query_value=query_value, ranking_value=ranking_value),
|
197 |
Div(
|
198 |
-
|
199 |
-
|
|
|
|
|
|
|
|
|
200 |
),
|
201 |
cls="grid",
|
202 |
),
|
203 |
-
cls="grid",
|
204 |
)
|
205 |
|
206 |
|
207 |
-
def LoadingMessage():
|
208 |
return Div(
|
209 |
Lucide(icon="loader-circle", cls="size-5 mr-1.5 animate-spin"),
|
210 |
-
Span(
|
211 |
cls="p-10 text-muted-foreground flex items-center justify-center",
|
212 |
id="loading-indicator",
|
213 |
)
|
@@ -250,7 +256,7 @@ def SearchResult(results: list, query_id: Optional[str] = None):
|
|
250 |
result_items = []
|
251 |
for idx, result in enumerate(results):
|
252 |
fields = result["fields"] # Extract the 'fields' part of each result
|
253 |
-
|
254 |
|
255 |
# Filter sim_map fields that are words with 4 or more characters
|
256 |
sim_map_fields = {
|
@@ -286,7 +292,7 @@ def SearchResult(results: list, query_id: Optional[str] = None):
|
|
286 |
"Reset",
|
287 |
variant="outline",
|
288 |
size="sm",
|
289 |
-
data_image_src=
|
290 |
cls="reset-button pointer-events-auto font-mono text-xs h-5 rounded-none px-2",
|
291 |
)
|
292 |
|
@@ -312,7 +318,11 @@ def SearchResult(results: list, query_id: Optional[str] = None):
|
|
312 |
Div(
|
313 |
Div(
|
314 |
Img(
|
315 |
-
src=
|
|
|
|
|
|
|
|
|
316 |
alt=fields["title"],
|
317 |
cls="result-image w-full h-full object-contain",
|
318 |
),
|
@@ -350,12 +360,35 @@ def SearchResult(results: list, query_id: Optional[str] = None):
|
|
350 |
),
|
351 |
cls="bg-background px-3 py-5 hidden md:block",
|
352 |
),
|
353 |
-
cls="grid grid-cols-1 md:grid-cols-2 col-span-2",
|
354 |
)
|
355 |
)
|
|
|
356 |
return Div(
|
357 |
*result_items,
|
358 |
image_swapping,
|
359 |
id="search-results",
|
360 |
cls="grid grid-cols-2 gap-px bg-border",
|
361 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
def SampleQueries():
|
133 |
sample_queries = [
|
134 |
+
"Proportion of female new hires 2021-2023?",
|
135 |
+
"Total amount of performance-based pay awarded in 2023?",
|
136 |
+
"What is the percentage distribution of employees with performance-based pay relative to the limit in 2023?",
|
137 |
+
"What is the breakdown of management costs by investment strategy in 2023?",
|
138 |
+
"2023 profit loss portfolio",
|
139 |
+
"net cash flow operating activities",
|
140 |
+
"fund currency basket returns",
|
141 |
]
|
142 |
|
143 |
query_badges = []
|
|
|
197 |
)
|
198 |
return Div(
|
199 |
Div(
|
|
|
200 |
Div(
|
201 |
+
SearchBox(query_value=query_value, ranking_value=ranking_value),
|
202 |
+
Div(
|
203 |
+
LoadingMessage(),
|
204 |
+
id="search-results", # This will be replaced by the search results
|
205 |
+
),
|
206 |
+
cls="grid",
|
207 |
),
|
208 |
cls="grid",
|
209 |
),
|
|
|
210 |
)
|
211 |
|
212 |
|
213 |
+
def LoadingMessage(display_text="Retrieving search results"):
|
214 |
return Div(
|
215 |
Lucide(icon="loader-circle", cls="size-5 mr-1.5 animate-spin"),
|
216 |
+
Span(display_text, cls="text-base text-center"),
|
217 |
cls="p-10 text-muted-foreground flex items-center justify-center",
|
218 |
id="loading-indicator",
|
219 |
)
|
|
|
256 |
result_items = []
|
257 |
for idx, result in enumerate(results):
|
258 |
fields = result["fields"] # Extract the 'fields' part of each result
|
259 |
+
blur_image_base64 = f"data:image/jpeg;base64,{fields['blur_image']}"
|
260 |
|
261 |
# Filter sim_map fields that are words with 4 or more characters
|
262 |
sim_map_fields = {
|
|
|
292 |
"Reset",
|
293 |
variant="outline",
|
294 |
size="sm",
|
295 |
+
data_image_src=blur_image_base64,
|
296 |
cls="reset-button pointer-events-auto font-mono text-xs h-5 rounded-none px-2",
|
297 |
)
|
298 |
|
|
|
318 |
Div(
|
319 |
Div(
|
320 |
Img(
|
321 |
+
src=blur_image_base64,
|
322 |
+
hx_get=f"/full_image?id={fields['id']}",
|
323 |
+
style="filter: blur(5px);",
|
324 |
+
hx_trigger="load",
|
325 |
+
hx_swap="outerHTML",
|
326 |
alt=fields["title"],
|
327 |
cls="result-image w-full h-full object-contain",
|
328 |
),
|
|
|
360 |
),
|
361 |
cls="bg-background px-3 py-5 hidden md:block",
|
362 |
),
|
363 |
+
cls="grid grid-cols-1 md:grid-cols-2 col-span-2 border-t",
|
364 |
)
|
365 |
)
|
366 |
+
|
367 |
return Div(
|
368 |
*result_items,
|
369 |
image_swapping,
|
370 |
id="search-results",
|
371 |
cls="grid grid-cols-2 gap-px bg-border",
|
372 |
)
|
373 |
+
|
374 |
+
|
375 |
+
def ChatResult(query_id: str, query: str):
|
376 |
+
return Div(
|
377 |
+
Div("Chat", cls="text-xl font-semibold p-3"),
|
378 |
+
Div(
|
379 |
+
Div(
|
380 |
+
Div(
|
381 |
+
LoadingMessage(display_text="Waiting for response..."),
|
382 |
+
cls="bg-muted/80 dark:bg-muted/40 text-black dark:text-white p-2 rounded-md",
|
383 |
+
hx_ext="sse",
|
384 |
+
sse_connect=f"/get-message?query_id={query_id}&query={quote_plus(query)}",
|
385 |
+
sse_swap="message",
|
386 |
+
sse_close="close",
|
387 |
+
hx_swap="innerHTML",
|
388 |
+
),
|
389 |
+
),
|
390 |
+
id="chat-messages",
|
391 |
+
cls="overflow-auto min-h-0 grid items-end px-3",
|
392 |
+
),
|
393 |
+
cls="h-full grid grid-rows-[auto_1fr_auto] min-h-0 gap-3",
|
394 |
+
)
|
frontend/layout.py
CHANGED
@@ -1,15 +1,96 @@
|
|
1 |
-
from fasthtml.components import Div, Img, Nav, Title
|
2 |
-
from fasthtml.xtend import A
|
3 |
from lucide_fasthtml import Lucide
|
4 |
from shad4fast import Button, Separator
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def Logo():
|
8 |
return Div(
|
9 |
-
Img(
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
)
|
14 |
|
15 |
|
@@ -38,23 +119,23 @@ def Links():
|
|
38 |
),
|
39 |
Separator(orientation="vertical"),
|
40 |
ThemeToggle(),
|
41 |
-
cls=
|
42 |
)
|
43 |
|
44 |
|
45 |
def Layout(*c, **kwargs):
|
46 |
return (
|
47 |
-
Title(
|
48 |
Body(
|
49 |
Header(
|
50 |
A(Logo(), href="/"),
|
51 |
Links(),
|
52 |
-
cls=
|
53 |
-
),
|
54 |
-
Main(
|
55 |
-
*c, **kwargs,
|
56 |
-
cls='flex-1 h-full'
|
57 |
),
|
58 |
-
|
|
|
|
|
59 |
),
|
|
|
|
|
60 |
)
|
|
|
1 |
+
from fasthtml.components import Body, Div, Header, Img, Nav, Title
|
2 |
+
from fasthtml.xtend import A, Script
|
3 |
from lucide_fasthtml import Lucide
|
4 |
from shad4fast import Button, Separator
|
5 |
|
6 |
+
script = Script(
|
7 |
+
"""
|
8 |
+
document.addEventListener("DOMContentLoaded", function () {
|
9 |
+
const main = document.querySelector('main');
|
10 |
+
const aside = document.querySelector('aside');
|
11 |
+
const body = document.body;
|
12 |
+
|
13 |
+
if (main && aside && main.nextElementSibling === aside) {
|
14 |
+
// Main + Aside layout
|
15 |
+
body.classList.add('grid-cols-[minmax(0,_4fr)_minmax(0,_1fr)]');
|
16 |
+
aside.classList.remove('hidden');
|
17 |
+
} else if (main) {
|
18 |
+
// Only Main layout (full width)
|
19 |
+
body.classList.add('grid-cols-[1fr]');
|
20 |
+
}
|
21 |
+
});
|
22 |
+
"""
|
23 |
+
)
|
24 |
+
|
25 |
+
overlay_scrollbars = Script(
|
26 |
+
"""
|
27 |
+
(function () {
|
28 |
+
const { OverlayScrollbars } = OverlayScrollbarsGlobal;
|
29 |
+
|
30 |
+
function getPreferredTheme() {
|
31 |
+
return localStorage.theme === 'dark' || (!('theme' in localStorage) && window.matchMedia('(prefers-color-scheme: dark)').matches)
|
32 |
+
? 'dark'
|
33 |
+
: 'light';
|
34 |
+
}
|
35 |
+
|
36 |
+
function applyOverlayScrollbars(element, scrollbarTheme) {
|
37 |
+
// Destroy existing OverlayScrollbars instance if it exists
|
38 |
+
const instance = OverlayScrollbars(element);
|
39 |
+
if (instance) {
|
40 |
+
instance.destroy();
|
41 |
+
}
|
42 |
+
|
43 |
+
// Reinitialize OverlayScrollbars with the new theme
|
44 |
+
OverlayScrollbars(element, {
|
45 |
+
scrollbars: {
|
46 |
+
theme: scrollbarTheme,
|
47 |
+
visibility: 'auto',
|
48 |
+
autoHide: 'leave',
|
49 |
+
autoHideDelay: 800
|
50 |
+
}
|
51 |
+
});
|
52 |
+
}
|
53 |
+
|
54 |
+
function updateScrollbarTheme() {
|
55 |
+
const isDarkMode = getPreferredTheme() === 'dark';
|
56 |
+
const scrollbarTheme = isDarkMode ? 'os-theme-light' : 'os-theme-dark'; // Light theme in dark mode, dark theme in light mode
|
57 |
+
|
58 |
+
const mainElement = document.querySelector('main');
|
59 |
+
const chatMessagesElement = document.querySelector('#chat-messages'); // Select the chat message container by ID
|
60 |
+
|
61 |
+
if (mainElement) {
|
62 |
+
applyOverlayScrollbars(mainElement, scrollbarTheme);
|
63 |
+
}
|
64 |
+
|
65 |
+
if (chatMessagesElement) {
|
66 |
+
applyOverlayScrollbars(chatMessagesElement, scrollbarTheme);
|
67 |
+
}
|
68 |
+
}
|
69 |
+
|
70 |
+
// Apply the correct theme immediately when the page loads
|
71 |
+
updateScrollbarTheme();
|
72 |
+
|
73 |
+
// Observe changes in the 'dark' class on the <html> element
|
74 |
+
const observer = new MutationObserver(updateScrollbarTheme);
|
75 |
+
observer.observe(document.documentElement, { attributes: true, attributeFilter: ['class'] });
|
76 |
+
})();
|
77 |
+
"""
|
78 |
+
)
|
79 |
+
|
80 |
|
81 |
def Logo():
|
82 |
return Div(
|
83 |
+
Img(
|
84 |
+
src="https://assets.vespa.ai/logos/vespa-logo-black.svg",
|
85 |
+
alt="Vespa Logo",
|
86 |
+
cls="h-full dark:hidden",
|
87 |
+
),
|
88 |
+
Img(
|
89 |
+
src="https://assets.vespa.ai/logos/vespa-logo-white.svg",
|
90 |
+
alt="Vespa Logo Dark Mode",
|
91 |
+
cls="h-full hidden dark:block",
|
92 |
+
),
|
93 |
+
cls="h-[27px]",
|
94 |
)
|
95 |
|
96 |
|
|
|
119 |
),
|
120 |
Separator(orientation="vertical"),
|
121 |
ThemeToggle(),
|
122 |
+
cls="flex items-center space-x-3",
|
123 |
)
|
124 |
|
125 |
|
126 |
def Layout(*c, **kwargs):
|
127 |
return (
|
128 |
+
Title("Visual Retrieval ColPali"),
|
129 |
Body(
|
130 |
Header(
|
131 |
A(Logo(), href="/"),
|
132 |
Links(),
|
133 |
+
cls="min-h-[55px] h-[55px] w-full flex items-center justify-between px-4",
|
|
|
|
|
|
|
|
|
134 |
),
|
135 |
+
*c,
|
136 |
+
**kwargs,
|
137 |
+
cls="grid grid-rows-[55px_1fr] min-h-0",
|
138 |
),
|
139 |
+
script,
|
140 |
+
overlay_scrollbars,
|
141 |
)
|
globals.css
CHANGED
@@ -183,4 +183,25 @@
|
|
183 |
width: 100%;
|
184 |
height: 100%;
|
185 |
z-index: 10;
|
186 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
width: 100%;
|
184 |
height: 100%;
|
185 |
z-index: 10;
|
186 |
+
}
|
187 |
+
|
188 |
+
header {
|
189 |
+
grid-column: 1/-1;
|
190 |
+
}
|
191 |
+
|
192 |
+
main {
|
193 |
+
overflow: auto;
|
194 |
+
}
|
195 |
+
|
196 |
+
aside {
|
197 |
+
overflow: auto;
|
198 |
+
}
|
199 |
+
|
200 |
+
.scroll-container {
|
201 |
+
padding-right: 10px;
|
202 |
+
}
|
203 |
+
|
204 |
+
.question-message {
|
205 |
+
background-color: #61D790;
|
206 |
+
color: #2E2F27;
|
207 |
+
}
|
main.py
CHANGED
@@ -1,22 +1,25 @@
|
|
1 |
import asyncio
|
|
|
|
|
2 |
from concurrent.futures import ThreadPoolExecutor
|
3 |
from functools import partial
|
4 |
|
5 |
from fasthtml.common import *
|
6 |
from shad4fast import *
|
7 |
from vespa.application import Vespa
|
8 |
-
import time
|
9 |
|
|
|
10 |
from backend.colpali import (
|
11 |
-
get_result_from_query,
|
12 |
-
get_query_embeddings_and_token_map,
|
13 |
add_sim_maps_to_result,
|
|
|
|
|
14 |
is_special_token,
|
|
|
15 |
)
|
16 |
-
from backend.vespa_app import get_vespa_app
|
17 |
-
from backend.cache import LRUCache
|
18 |
from backend.modelmanager import ModelManager
|
|
|
19 |
from frontend.app import (
|
|
|
20 |
Home,
|
21 |
Search,
|
22 |
SearchBox,
|
@@ -25,7 +28,10 @@ from frontend.app import (
|
|
25 |
SimMapButtonReady,
|
26 |
)
|
27 |
from frontend.layout import Layout
|
28 |
-
import
|
|
|
|
|
|
|
29 |
|
30 |
highlight_js_theme_link = Link(id="highlight-theme", rel="stylesheet", href="")
|
31 |
highlight_js_theme = Script(src="/static/js/highlightjs-theme.js")
|
@@ -35,15 +41,27 @@ highlight_js = HighlightJS(
|
|
35 |
light="github",
|
36 |
)
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
app, rt = fast_app(
|
40 |
-
htmlkw={"cls": "h-full"},
|
41 |
pico=False,
|
42 |
hdrs=(
|
43 |
ShadHead(tw_cdn=False, theme_handle=True),
|
44 |
highlight_js,
|
45 |
highlight_js_theme_link,
|
46 |
highlight_js_theme,
|
|
|
|
|
|
|
47 |
),
|
48 |
)
|
49 |
vespa_app: Vespa = get_vespa_app()
|
@@ -53,6 +71,16 @@ task_cache = LRUCache(
|
|
53 |
max_size=1000
|
54 |
) # Map from query_id to boolean value - False if not all results are ready.
|
55 |
thread_pool = ThreadPoolExecutor()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
|
58 |
@app.on_event("startup")
|
@@ -72,7 +100,7 @@ def serve_static(filepath: str):
|
|
72 |
|
73 |
@rt("/")
|
74 |
def get():
|
75 |
-
return Layout(Home())
|
76 |
|
77 |
|
78 |
@rt("/search")
|
@@ -86,16 +114,18 @@ def get(request):
|
|
86 |
if not query_value:
|
87 |
# Show SearchBox and a message for missing query
|
88 |
return Layout(
|
89 |
-
|
90 |
-
SearchBox(query_value=query_value, ranking_value=ranking_value),
|
91 |
Div(
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
95 |
),
|
96 |
-
cls="
|
97 |
-
)
|
98 |
-
cls="grid",
|
99 |
)
|
100 |
)
|
101 |
# Generate a unique query_id based on the query and ranking value
|
@@ -107,7 +137,12 @@ def get(request):
|
|
107 |
# search_results = get_results_children(result)
|
108 |
# return Layout(Search(request, search_results))
|
109 |
# Show the loading message if a query is provided
|
110 |
-
return Layout(
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
|
113 |
@rt("/fetch_results")
|
@@ -215,15 +250,67 @@ async def get_sim_map(query_id: str, idx: int, token: str):
|
|
215 |
sim_map_b64 = search_results[idx]["fields"].get(sim_map_key, None)
|
216 |
if sim_map_b64 is None:
|
217 |
return SimMapButtonPoll(query_id=query_id, idx=idx, token=token)
|
218 |
-
sim_map_img_src = f"data:image/
|
219 |
return SimMapButtonReady(
|
220 |
query_id=query_id, idx=idx, token=token, img_src=sim_map_img_src
|
221 |
)
|
222 |
|
223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
@rt("/app")
|
225 |
def get():
|
226 |
-
return Layout(Div(P(f"Connected to Vespa at {vespa_app.url}"), cls="p-4"))
|
227 |
|
228 |
|
229 |
if __name__ == "__main__":
|
|
|
1 |
import asyncio
|
2 |
+
import hashlib
|
3 |
+
import time
|
4 |
from concurrent.futures import ThreadPoolExecutor
|
5 |
from functools import partial
|
6 |
|
7 |
from fasthtml.common import *
|
8 |
from shad4fast import *
|
9 |
from vespa.application import Vespa
|
|
|
10 |
|
11 |
+
from backend.cache import LRUCache
|
12 |
from backend.colpali import (
|
|
|
|
|
13 |
add_sim_maps_to_result,
|
14 |
+
get_query_embeddings_and_token_map,
|
15 |
+
get_result_from_query,
|
16 |
is_special_token,
|
17 |
+
get_full_image_from_vespa,
|
18 |
)
|
|
|
|
|
19 |
from backend.modelmanager import ModelManager
|
20 |
+
from backend.vespa_app import get_vespa_app
|
21 |
from frontend.app import (
|
22 |
+
ChatResult,
|
23 |
Home,
|
24 |
Search,
|
25 |
SearchBox,
|
|
|
28 |
SimMapButtonReady,
|
29 |
)
|
30 |
from frontend.layout import Layout
|
31 |
+
import google.generativeai as genai
|
32 |
+
from PIL import Image
|
33 |
+
import io
|
34 |
+
import base64
|
35 |
|
36 |
highlight_js_theme_link = Link(id="highlight-theme", rel="stylesheet", href="")
|
37 |
highlight_js_theme = Script(src="/static/js/highlightjs-theme.js")
|
|
|
41 |
light="github",
|
42 |
)
|
43 |
|
44 |
+
overlayscrollbars_link = Link(
|
45 |
+
rel="stylesheet",
|
46 |
+
href="https://cdnjs.cloudflare.com/ajax/libs/overlayscrollbars/2.10.0/styles/overlayscrollbars.min.css",
|
47 |
+
type="text/css",
|
48 |
+
)
|
49 |
+
overlayscrollbars_js = Script(
|
50 |
+
src="https://cdnjs.cloudflare.com/ajax/libs/overlayscrollbars/2.10.0/browser/overlayscrollbars.browser.es5.min.js"
|
51 |
+
)
|
52 |
+
sselink = Script(src="https://unpkg.com/htmx-ext-sse@2.2.1/sse.js")
|
53 |
|
54 |
app, rt = fast_app(
|
55 |
+
htmlkw={"cls": "grid h-full"},
|
56 |
pico=False,
|
57 |
hdrs=(
|
58 |
ShadHead(tw_cdn=False, theme_handle=True),
|
59 |
highlight_js,
|
60 |
highlight_js_theme_link,
|
61 |
highlight_js_theme,
|
62 |
+
overlayscrollbars_link,
|
63 |
+
overlayscrollbars_js,
|
64 |
+
sselink,
|
65 |
),
|
66 |
)
|
67 |
vespa_app: Vespa = get_vespa_app()
|
|
|
71 |
max_size=1000
|
72 |
) # Map from query_id to boolean value - False if not all results are ready.
|
73 |
thread_pool = ThreadPoolExecutor()
|
74 |
+
# Gemini config
|
75 |
+
|
76 |
+
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
|
77 |
+
GEMINI_SYSTEM_PROMPT = """If the user query is a question, try your best to answer it based on the provided images.
|
78 |
+
If the user query is not an obvious question, reply with 'No question detected.'. Your response should be HTML formatted.
|
79 |
+
This means that newlines will be replaced with <br> tags, bold text will be enclosed in <b> tags, and so on.
|
80 |
+
"""
|
81 |
+
gemini_model = genai.GenerativeModel(
|
82 |
+
"gemini-1.5-flash-8b", system_instruction=GEMINI_SYSTEM_PROMPT
|
83 |
+
)
|
84 |
|
85 |
|
86 |
@app.on_event("startup")
|
|
|
100 |
|
101 |
@rt("/")
|
102 |
def get():
|
103 |
+
return Layout(Main(Home()))
|
104 |
|
105 |
|
106 |
@rt("/search")
|
|
|
114 |
if not query_value:
|
115 |
# Show SearchBox and a message for missing query
|
116 |
return Layout(
|
117 |
+
Main(
|
|
|
118 |
Div(
|
119 |
+
SearchBox(query_value=query_value, ranking_value=ranking_value),
|
120 |
+
Div(
|
121 |
+
P(
|
122 |
+
"No query provided. Please enter a query.",
|
123 |
+
cls="text-center text-muted-foreground",
|
124 |
+
),
|
125 |
+
cls="p-10",
|
126 |
),
|
127 |
+
cls="grid",
|
128 |
+
)
|
|
|
129 |
)
|
130 |
)
|
131 |
# Generate a unique query_id based on the query and ranking value
|
|
|
137 |
# search_results = get_results_children(result)
|
138 |
# return Layout(Search(request, search_results))
|
139 |
# Show the loading message if a query is provided
|
140 |
+
return Layout(
|
141 |
+
Main(Search(request), data_overlayscrollbars_initialize=True, cls="border-t"),
|
142 |
+
Aside(
|
143 |
+
ChatResult(query_id=query_id, query=query_value), cls="border-t border-l"
|
144 |
+
),
|
145 |
+
) # Show SearchBox and Loading message initially
|
146 |
|
147 |
|
148 |
@rt("/fetch_results")
|
|
|
250 |
sim_map_b64 = search_results[idx]["fields"].get(sim_map_key, None)
|
251 |
if sim_map_b64 is None:
|
252 |
return SimMapButtonPoll(query_id=query_id, idx=idx, token=token)
|
253 |
+
sim_map_img_src = f"data:image/png;base64,{sim_map_b64}"
|
254 |
return SimMapButtonReady(
|
255 |
query_id=query_id, idx=idx, token=token, img_src=sim_map_img_src
|
256 |
)
|
257 |
|
258 |
|
259 |
+
@app.get("/full_image")
|
260 |
+
async def full_image(id: str):
|
261 |
+
"""
|
262 |
+
Endpoint to get the full quality image for a given result id.
|
263 |
+
"""
|
264 |
+
image_data = await get_full_image_from_vespa(vespa_app, id)
|
265 |
+
|
266 |
+
# Decode the base64 image data
|
267 |
+
# image_data = base64.b64decode(image_data)
|
268 |
+
image_data = "data:image/jpeg;base64," + image_data
|
269 |
+
|
270 |
+
return Img(
|
271 |
+
src=image_data,
|
272 |
+
alt="something",
|
273 |
+
cls="result-image w-full h-full object-contain",
|
274 |
+
)
|
275 |
+
|
276 |
+
|
277 |
+
async def message_generator(query_id: str, query: str):
|
278 |
+
result = None
|
279 |
+
while result is None:
|
280 |
+
result = result_cache.get(query_id)
|
281 |
+
await asyncio.sleep(0.5)
|
282 |
+
search_results = get_results_children(result)
|
283 |
+
images = [result["fields"]["blur_image"] for result in search_results]
|
284 |
+
# from b64 to PIL image
|
285 |
+
images = [Image.open(io.BytesIO(base64.b64decode(img))) for img in images]
|
286 |
+
|
287 |
+
# If newlines are present in the response, the connection will be closed.
|
288 |
+
def replace_newline_with_br(text):
|
289 |
+
return text.replace("\n", "<br>")
|
290 |
+
|
291 |
+
response_text = ""
|
292 |
+
async for chunk in await gemini_model.generate_content_async(
|
293 |
+
images + ["\n\n Query: ", query], stream=True
|
294 |
+
):
|
295 |
+
if chunk.text:
|
296 |
+
response_text += chunk.text
|
297 |
+
response_text = replace_newline_with_br(response_text)
|
298 |
+
yield f"event: message\ndata: {response_text}\n\n"
|
299 |
+
await asyncio.sleep(0.5)
|
300 |
+
yield "event: close\ndata: \n\n"
|
301 |
+
|
302 |
+
|
303 |
+
@app.get("/get-message")
|
304 |
+
async def get_message(query_id: str, query: str):
|
305 |
+
return StreamingResponse(
|
306 |
+
message_generator(query_id=query_id, query=query),
|
307 |
+
media_type="text/event-stream",
|
308 |
+
)
|
309 |
+
|
310 |
+
|
311 |
@rt("/app")
|
312 |
def get():
|
313 |
+
return Layout(Main(Div(P(f"Connected to Vespa at {vespa_app.url}"), cls="p-4")))
|
314 |
|
315 |
|
316 |
if __name__ == "__main__":
|
output.css
CHANGED
@@ -927,6 +927,10 @@ body {
|
|
927 |
max-height: 100vh;
|
928 |
}
|
929 |
|
|
|
|
|
|
|
|
|
930 |
.min-h-\[55px\] {
|
931 |
min-height: 55px;
|
932 |
}
|
@@ -1096,6 +1100,22 @@ body {
|
|
1096 |
grid-template-columns: repeat(2, minmax(0, 1fr));
|
1097 |
}
|
1098 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1099 |
.flex-col {
|
1100 |
flex-direction: column;
|
1101 |
}
|
@@ -1112,10 +1132,18 @@ body {
|
|
1112 |
align-content: flex-start;
|
1113 |
}
|
1114 |
|
|
|
|
|
|
|
|
|
1115 |
.items-center {
|
1116 |
align-items: center;
|
1117 |
}
|
1118 |
|
|
|
|
|
|
|
|
|
1119 |
.justify-center {
|
1120 |
justify-content: center;
|
1121 |
}
|
@@ -1136,6 +1164,10 @@ body {
|
|
1136 |
gap: 0.5rem;
|
1137 |
}
|
1138 |
|
|
|
|
|
|
|
|
|
1139 |
.gap-4 {
|
1140 |
gap: 1rem;
|
1141 |
}
|
@@ -1200,6 +1232,10 @@ body {
|
|
1200 |
margin-bottom: calc(0.5rem * var(--tw-space-y-reverse));
|
1201 |
}
|
1202 |
|
|
|
|
|
|
|
|
|
1203 |
.self-stretch {
|
1204 |
align-self: stretch;
|
1205 |
}
|
@@ -1252,6 +1288,11 @@ body {
|
|
1252 |
border-width: 2px;
|
1253 |
}
|
1254 |
|
|
|
|
|
|
|
|
|
|
|
1255 |
.border-b {
|
1256 |
border-bottom-width: 1px;
|
1257 |
}
|
@@ -1493,6 +1534,10 @@ body {
|
|
1493 |
padding-top: 1rem;
|
1494 |
}
|
1495 |
|
|
|
|
|
|
|
|
|
1496 |
.text-left {
|
1497 |
text-align: left;
|
1498 |
}
|
@@ -1577,6 +1622,11 @@ body {
|
|
1577 |
letter-spacing: 0.025em;
|
1578 |
}
|
1579 |
|
|
|
|
|
|
|
|
|
|
|
1580 |
.text-card-foreground {
|
1581 |
color: hsl(var(--card-foreground));
|
1582 |
}
|
@@ -1993,6 +2043,27 @@ body {
|
|
1993 |
z-index: 10;
|
1994 |
}
|
1995 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1996 |
:root:has(.data-\[state\=open\]\:no-bg-scroll[data-state="open"]) {
|
1997 |
overflow: hidden;
|
1998 |
}
|
@@ -2537,6 +2608,11 @@ body {
|
|
2537 |
--tw-gradient-to: #d1d5db var(--tw-gradient-to-position);
|
2538 |
}
|
2539 |
|
|
|
|
|
|
|
|
|
|
|
2540 |
.dark\:hover\:border-white:hover:where(.dark, .dark *) {
|
2541 |
--tw-border-opacity: 1;
|
2542 |
border-color: rgb(255 255 255 / var(--tw-border-opacity));
|
@@ -2610,4 +2686,4 @@ body {
|
|
2610 |
|
2611 |
.\[\&_tr\]\:border-b tr {
|
2612 |
border-bottom-width: 1px;
|
2613 |
-
}
|
|
|
927 |
max-height: 100vh;
|
928 |
}
|
929 |
|
930 |
+
.min-h-0 {
|
931 |
+
min-height: 0px;
|
932 |
+
}
|
933 |
+
|
934 |
.min-h-\[55px\] {
|
935 |
min-height: 55px;
|
936 |
}
|
|
|
1100 |
grid-template-columns: repeat(2, minmax(0, 1fr));
|
1101 |
}
|
1102 |
|
1103 |
+
.grid-cols-\[1fr\] {
|
1104 |
+
grid-template-columns: 1fr;
|
1105 |
+
}
|
1106 |
+
|
1107 |
+
.grid-cols-\[minmax\(0\2c _4fr\)_minmax\(0\2c _1fr\)\] {
|
1108 |
+
grid-template-columns: minmax(0, 4fr) minmax(0, 1fr);
|
1109 |
+
}
|
1110 |
+
|
1111 |
+
.grid-rows-\[55px_1fr\] {
|
1112 |
+
grid-template-rows: 55px 1fr;
|
1113 |
+
}
|
1114 |
+
|
1115 |
+
.grid-rows-\[auto_1fr_auto\] {
|
1116 |
+
grid-template-rows: auto 1fr auto;
|
1117 |
+
}
|
1118 |
+
|
1119 |
.flex-col {
|
1120 |
flex-direction: column;
|
1121 |
}
|
|
|
1132 |
align-content: flex-start;
|
1133 |
}
|
1134 |
|
1135 |
+
.items-end {
|
1136 |
+
align-items: flex-end;
|
1137 |
+
}
|
1138 |
+
|
1139 |
.items-center {
|
1140 |
align-items: center;
|
1141 |
}
|
1142 |
|
1143 |
+
.justify-end {
|
1144 |
+
justify-content: flex-end;
|
1145 |
+
}
|
1146 |
+
|
1147 |
.justify-center {
|
1148 |
justify-content: center;
|
1149 |
}
|
|
|
1164 |
gap: 0.5rem;
|
1165 |
}
|
1166 |
|
1167 |
+
.gap-3 {
|
1168 |
+
gap: 0.75rem;
|
1169 |
+
}
|
1170 |
+
|
1171 |
.gap-4 {
|
1172 |
gap: 1rem;
|
1173 |
}
|
|
|
1232 |
margin-bottom: calc(0.5rem * var(--tw-space-y-reverse));
|
1233 |
}
|
1234 |
|
1235 |
+
.self-end {
|
1236 |
+
align-self: flex-end;
|
1237 |
+
}
|
1238 |
+
|
1239 |
.self-stretch {
|
1240 |
align-self: stretch;
|
1241 |
}
|
|
|
1288 |
border-width: 2px;
|
1289 |
}
|
1290 |
|
1291 |
+
.border-x {
|
1292 |
+
border-left-width: 1px;
|
1293 |
+
border-right-width: 1px;
|
1294 |
+
}
|
1295 |
+
|
1296 |
.border-b {
|
1297 |
border-bottom-width: 1px;
|
1298 |
}
|
|
|
1534 |
padding-top: 1rem;
|
1535 |
}
|
1536 |
|
1537 |
+
.pr-3 {
|
1538 |
+
padding-right: 0.75rem;
|
1539 |
+
}
|
1540 |
+
|
1541 |
.text-left {
|
1542 |
text-align: left;
|
1543 |
}
|
|
|
1622 |
letter-spacing: 0.025em;
|
1623 |
}
|
1624 |
|
1625 |
+
.text-black {
|
1626 |
+
--tw-text-opacity: 1;
|
1627 |
+
color: rgb(0 0 0 / var(--tw-text-opacity));
|
1628 |
+
}
|
1629 |
+
|
1630 |
.text-card-foreground {
|
1631 |
color: hsl(var(--card-foreground));
|
1632 |
}
|
|
|
2043 |
z-index: 10;
|
2044 |
}
|
2045 |
|
2046 |
+
header {
|
2047 |
+
grid-column: 1/-1;
|
2048 |
+
}
|
2049 |
+
|
2050 |
+
main {
|
2051 |
+
overflow: auto;
|
2052 |
+
}
|
2053 |
+
|
2054 |
+
aside {
|
2055 |
+
overflow: auto;
|
2056 |
+
}
|
2057 |
+
|
2058 |
+
.scroll-container {
|
2059 |
+
padding-right: 10px;
|
2060 |
+
}
|
2061 |
+
|
2062 |
+
.question-message {
|
2063 |
+
background-color: #61D790;
|
2064 |
+
color: #2E2F27;
|
2065 |
+
}
|
2066 |
+
|
2067 |
:root:has(.data-\[state\=open\]\:no-bg-scroll[data-state="open"]) {
|
2068 |
overflow: hidden;
|
2069 |
}
|
|
|
2608 |
--tw-gradient-to: #d1d5db var(--tw-gradient-to-position);
|
2609 |
}
|
2610 |
|
2611 |
+
.dark\:text-white:where(.dark, .dark *) {
|
2612 |
+
--tw-text-opacity: 1;
|
2613 |
+
color: rgb(255 255 255 / var(--tw-text-opacity));
|
2614 |
+
}
|
2615 |
+
|
2616 |
.dark\:hover\:border-white:hover:where(.dark, .dark *) {
|
2617 |
--tw-border-opacity: 1;
|
2618 |
border-color: rgb(255 255 255 / var(--tw-border-opacity));
|
|
|
2686 |
|
2687 |
.\[\&_tr\]\:border-b tr {
|
2688 |
border-bottom-width: 1px;
|
2689 |
+
}
|
prepare_feed_deploy.py
ADDED
@@ -0,0 +1,977 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %% [markdown]
|
2 |
+
# # Visual PDF Retrieval - demo application
|
3 |
+
#
|
4 |
+
# In this notebook, we will prepare the Vespa backend application for our visual retrieval demo.
|
5 |
+
# We will use ColPali as the model to extract patch vectors from images of pdf pages.
|
6 |
+
# At query time, we use MaxSim to retrieve and/or (based on the configuration) rank the page results.
|
7 |
+
#
|
8 |
+
# To see the application in action, visit TODO:
|
9 |
+
#
|
10 |
+
# The web application is written in FastHTML, meaning the complete application is written in python.
|
11 |
+
#
|
12 |
+
# The steps we will take in this notebook are:
|
13 |
+
#
|
14 |
+
# 0. Setup and configuration
|
15 |
+
# 1. Download the data
|
16 |
+
# 2. Prepare the data
|
17 |
+
# 3. Generate queries for evaluation and typeahead search suggestions
|
18 |
+
# 4. Deploy the Vespa application
|
19 |
+
# 5. Create the Vespa application
|
20 |
+
# 6. Feed the data to the Vespa application
|
21 |
+
#
|
22 |
+
# All the steps that are needed to provision the Vespa application, including feeding the data, can be done from this notebook.
|
23 |
+
# We have tried to make it easy for others to run this notebook, to create your own PDF Enterprise Search application using Vespa.
|
24 |
+
#
|
25 |
+
|
26 |
+
# %% [markdown]
|
27 |
+
# ## 0. Setup and Configuration
|
28 |
+
#
|
29 |
+
|
30 |
+
# %%
|
31 |
+
import os
|
32 |
+
import asyncio
|
33 |
+
import json
|
34 |
+
from typing import Tuple
|
35 |
+
import hashlib
|
36 |
+
import numpy as np
|
37 |
+
|
38 |
+
# Vespa
|
39 |
+
from vespa.package import (
|
40 |
+
ApplicationPackage,
|
41 |
+
Field,
|
42 |
+
Schema,
|
43 |
+
Document,
|
44 |
+
HNSW,
|
45 |
+
RankProfile,
|
46 |
+
Function,
|
47 |
+
FieldSet,
|
48 |
+
SecondPhaseRanking,
|
49 |
+
Summary,
|
50 |
+
DocumentSummary,
|
51 |
+
)
|
52 |
+
from vespa.deployment import VespaCloud
|
53 |
+
from vespa.application import Vespa
|
54 |
+
from vespa.io import VespaResponse
|
55 |
+
|
56 |
+
# Google Generative AI
|
57 |
+
import google.generativeai as genai
|
58 |
+
|
59 |
+
# Torch and other ML libraries
|
60 |
+
import torch
|
61 |
+
from torch.utils.data import DataLoader
|
62 |
+
from tqdm import tqdm
|
63 |
+
from pdf2image import convert_from_path
|
64 |
+
from pypdf import PdfReader
|
65 |
+
|
66 |
+
# ColPali model and processor
|
67 |
+
from colpali_engine.models import ColPali, ColPaliProcessor
|
68 |
+
from colpali_engine.utils.torch_utils import get_torch_device
|
69 |
+
from vidore_benchmark.utils.image_utils import scale_image, get_base64_image
|
70 |
+
|
71 |
+
# Other utilities
|
72 |
+
from bs4 import BeautifulSoup
|
73 |
+
import httpx
|
74 |
+
from urllib.parse import urljoin, urlparse
|
75 |
+
|
76 |
+
# Load environment variables
|
77 |
+
from dotenv import load_dotenv
|
78 |
+
|
79 |
+
load_dotenv()
|
80 |
+
|
81 |
+
# Avoid warning from huggingface tokenizers
|
82 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
83 |
+
|
84 |
+
# %% [markdown]
|
85 |
+
# ### Create a free trial in Vespa Cloud
|
86 |
+
#
|
87 |
+
# Create a tenant from [here](https://vespa.ai/free-trial/).
|
88 |
+
# The trial includes $300 credit.
|
89 |
+
# Take note of your tenant name.
|
90 |
+
#
|
91 |
+
|
92 |
+
# %%
|
93 |
+
VESPA_TENANT_NAME = "vespa-team"
|
94 |
+
|
95 |
+
# %% [markdown]
|
96 |
+
# Here, set your desired application name. (Will be created in later steps)
|
97 |
+
# Note that you can not have hyphen `-` or underscore `_` in the application name.
|
98 |
+
#
|
99 |
+
|
100 |
+
# %%
|
101 |
+
VESPA_APPLICATION_NAME = "colpalidemo2"
|
102 |
+
VESPA_SCHEMA_NAME = "pdf_page"
|
103 |
+
|
104 |
+
# %% [markdown]
|
105 |
+
# Next, you need to create some tokens for feeding data, and querying the application.
|
106 |
+
# We recommend separate tokens for feeding and querying, (the former with write permission, and the latter with read permission).
|
107 |
+
# The tokens can be created from the [Vespa Cloud console](https://console.vespa-cloud.com/) in the 'Account' -> 'Tokens' section.
|
108 |
+
#
|
109 |
+
|
110 |
+
# %%
|
111 |
+
VESPA_TOKEN_ID_WRITE = "colpalidemo_write"
|
112 |
+
VESPA_TOKEN_ID_READ = "colpalidemo_read"
|
113 |
+
|
114 |
+
# %% [markdown]
|
115 |
+
# We also need to set the value of the write token to be able to feed data to the Vespa application.
|
116 |
+
#
|
117 |
+
|
118 |
+
# %%
|
119 |
+
VESPA_CLOUD_SECRET_TOKEN = os.getenv("VESPA_CLOUD_SECRET_TOKEN") or input(
|
120 |
+
"Enter Vespa cloud secret token: "
|
121 |
+
)
|
122 |
+
|
123 |
+
# %% [markdown]
|
124 |
+
# We will also use the Gemini API to create sample queries for our images.
|
125 |
+
# You can also use other VLM's to create these queries.
|
126 |
+
# Create a Gemini API key from [here](https://aistudio.google.com/app/apikey).
|
127 |
+
#
|
128 |
+
|
129 |
+
# %%
|
130 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") or input(
|
131 |
+
"Enter Google Generative AI API key: "
|
132 |
+
)
|
133 |
+
|
134 |
+
# %%
|
135 |
+
MODEL_NAME = "vidore/colpali-v1.2"
|
136 |
+
|
137 |
+
# Configure Google Generative AI
|
138 |
+
genai.configure(api_key=GEMINI_API_KEY)
|
139 |
+
|
140 |
+
# Set device for Torch
|
141 |
+
device = get_torch_device("auto")
|
142 |
+
print(f"Using device: {device}")
|
143 |
+
|
144 |
+
# Load the ColPali model and processor
|
145 |
+
model = ColPali.from_pretrained(
|
146 |
+
MODEL_NAME,
|
147 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
148 |
+
device_map=device,
|
149 |
+
).eval()
|
150 |
+
|
151 |
+
processor = ColPaliProcessor.from_pretrained(MODEL_NAME)
|
152 |
+
|
153 |
+
# %% [markdown]
|
154 |
+
# ## 1. Download PDFs
|
155 |
+
#
|
156 |
+
# We are going to use public reports from the Norwegian Government Pension Fund Global (also known as the Oil Fund).
|
157 |
+
# The fund puts transparency at the forefront and publishes reports on its investments, holdings, and returns, as well as its strategy and governance.
|
158 |
+
#
|
159 |
+
# These reports are the ones we are going to use for this showcase.
|
160 |
+
# Here are some sample images:
|
161 |
+
#
|
162 |
+
# ![Sample1](./static/img/gfpg-sample-1.png)
|
163 |
+
# ![Sample2](./static/img/gfpg-sample-2.png)
|
164 |
+
#
|
165 |
+
|
166 |
+
# %% [markdown]
|
167 |
+
# As we can see, a lot of the information is in the form of tables, charts and numbers.
|
168 |
+
# These are not easily extractable using pdf-readers or OCR tools.
|
169 |
+
#
|
170 |
+
|
171 |
+
# %%
|
172 |
+
import requests
|
173 |
+
|
174 |
+
url = "https://www.nbim.no/en/publications/reports/"
|
175 |
+
response = requests.get(url)
|
176 |
+
response.raise_for_status()
|
177 |
+
html_content = response.text
|
178 |
+
|
179 |
+
# Parse with BeautifulSoup
|
180 |
+
soup = BeautifulSoup(html_content, "html.parser")
|
181 |
+
|
182 |
+
links = []
|
183 |
+
|
184 |
+
# Find all <a> elements with the specific classes
|
185 |
+
for a_tag in soup.find_all("a", href=True):
|
186 |
+
classes = a_tag.get("class", [])
|
187 |
+
if "button" in classes and "button--download-secondary" in classes:
|
188 |
+
href = a_tag["href"]
|
189 |
+
full_url = urljoin(url, href)
|
190 |
+
links.append(full_url)
|
191 |
+
|
192 |
+
links
|
193 |
+
|
194 |
+
# %%
|
195 |
+
# Limit the number of PDFs to download
|
196 |
+
NUM_PDFS = 2 # Set to None to download all PDFs
|
197 |
+
links = links[:NUM_PDFS] if NUM_PDFS else links
|
198 |
+
links
|
199 |
+
|
200 |
+
# %%
|
201 |
+
from nest_asyncio import apply
|
202 |
+
from typing import List
|
203 |
+
|
204 |
+
apply()
|
205 |
+
|
206 |
+
max_attempts = 3
|
207 |
+
|
208 |
+
|
209 |
+
async def download_pdf(session, url, filename):
|
210 |
+
attempt = 0
|
211 |
+
while attempt < max_attempts:
|
212 |
+
try:
|
213 |
+
response = await session.get(url)
|
214 |
+
response.raise_for_status()
|
215 |
+
|
216 |
+
# Use Content-Disposition header to get the filename if available
|
217 |
+
content_disposition = response.headers.get("Content-Disposition")
|
218 |
+
if content_disposition:
|
219 |
+
import re
|
220 |
+
|
221 |
+
fname = re.findall('filename="(.+)"', content_disposition)
|
222 |
+
if fname:
|
223 |
+
filename = fname[0]
|
224 |
+
|
225 |
+
# Ensure the filename is safe to use on the filesystem
|
226 |
+
safe_filename = filename.replace("/", "_").replace("\\", "_")
|
227 |
+
if not safe_filename or safe_filename == "_":
|
228 |
+
print(f"Invalid filename: {filename}")
|
229 |
+
continue
|
230 |
+
|
231 |
+
filepath = os.path.join("pdfs", safe_filename)
|
232 |
+
with open(filepath, "wb") as f:
|
233 |
+
f.write(response.content)
|
234 |
+
print(f"Downloaded {safe_filename}")
|
235 |
+
return filepath
|
236 |
+
except Exception as e:
|
237 |
+
print(f"Error downloading {filename}: {e}")
|
238 |
+
print(f"Retrying ({attempt})...")
|
239 |
+
await asyncio.sleep(1) # Wait a bit before retrying
|
240 |
+
attempt += 1
|
241 |
+
return None
|
242 |
+
|
243 |
+
|
244 |
+
async def download_pdfs(links: List[str]) -> List[dict]:
|
245 |
+
"""Download PDFs from a list of URLs. Add the filename to the dictionary."""
|
246 |
+
async with httpx.AsyncClient() as client:
|
247 |
+
tasks = []
|
248 |
+
|
249 |
+
for idx, link in enumerate(links):
|
250 |
+
# Try to get the filename from the URL
|
251 |
+
path = urlparse(link).path
|
252 |
+
filename = os.path.basename(path)
|
253 |
+
|
254 |
+
# If filename is empty,skip
|
255 |
+
if not filename:
|
256 |
+
continue
|
257 |
+
tasks.append(download_pdf(client, link, filename))
|
258 |
+
|
259 |
+
# Run the tasks concurrently
|
260 |
+
paths = await asyncio.gather(*tasks)
|
261 |
+
pdf_files = [
|
262 |
+
{"url": link, "path": path} for link, path in zip(links, paths) if path
|
263 |
+
]
|
264 |
+
return pdf_files
|
265 |
+
|
266 |
+
|
267 |
+
# Create the pdfs directory if it doesn't exist
|
268 |
+
os.makedirs("pdfs", exist_ok=True)
|
269 |
+
# Now run the download_pdfs function with the URL
|
270 |
+
pdfs = asyncio.run(download_pdfs(links))
|
271 |
+
|
272 |
+
# %%
|
273 |
+
pdfs
|
274 |
+
|
275 |
+
# %% [markdown]
|
276 |
+
# ## 2. Convert PDFs to Images
|
277 |
+
#
|
278 |
+
|
279 |
+
|
280 |
+
# %%
|
281 |
+
def get_pdf_images(pdf_path):
|
282 |
+
reader = PdfReader(pdf_path)
|
283 |
+
page_texts = []
|
284 |
+
for page_number in range(len(reader.pages)):
|
285 |
+
page = reader.pages[page_number]
|
286 |
+
text = page.extract_text()
|
287 |
+
page_texts.append(text)
|
288 |
+
images = convert_from_path(pdf_path)
|
289 |
+
# Convert to PIL images
|
290 |
+
assert len(images) == len(page_texts)
|
291 |
+
return images, page_texts
|
292 |
+
|
293 |
+
|
294 |
+
pdf_folder = "pdfs"
|
295 |
+
pdf_pages = []
|
296 |
+
for pdf in tqdm(pdfs):
|
297 |
+
pdf_file = pdf["path"]
|
298 |
+
title = os.path.splitext(os.path.basename(pdf_file))[0]
|
299 |
+
images, texts = get_pdf_images(pdf_file)
|
300 |
+
for page_no, (image, text) in enumerate(zip(images, texts)):
|
301 |
+
pdf_pages.append(
|
302 |
+
{
|
303 |
+
"title": title,
|
304 |
+
"url": pdf["url"],
|
305 |
+
"path": pdf_file,
|
306 |
+
"image": image,
|
307 |
+
"text": text,
|
308 |
+
"page_no": page_no,
|
309 |
+
}
|
310 |
+
)
|
311 |
+
|
312 |
+
# %%
|
313 |
+
len(pdf_pages)
|
314 |
+
|
315 |
+
# %%
|
316 |
+
from collections import Counter
|
317 |
+
|
318 |
+
# Print the length of the text fields - mean, max and min
|
319 |
+
text_lengths = [len(page["text"]) for page in pdf_pages]
|
320 |
+
print(f"Mean text length: {np.mean(text_lengths)}")
|
321 |
+
print(f"Max text length: {np.max(text_lengths)}")
|
322 |
+
print(f"Min text length: {np.min(text_lengths)}")
|
323 |
+
print(f"Median text length: {np.median(text_lengths)}")
|
324 |
+
print(f"Number of text with length == 0: {Counter(text_lengths)[0]}")
|
325 |
+
|
326 |
+
# %% [markdown]
|
327 |
+
# ## 3. Generate Queries
|
328 |
+
#
|
329 |
+
# In this step, we want to generate queries for each page image.
|
330 |
+
# These will be useful for 2 reasons:
|
331 |
+
#
|
332 |
+
# 1. We can use these queries as typeahead suggestions in the search bar.
|
333 |
+
# 2. We can use the queries to generate an evaluation dataset. See [Improving Retrieval with LLM-as-a-judge](https://blog.vespa.ai/improving-retrieval-with-llm-as-a-judge/) for a deeper dive into this topic.
|
334 |
+
#
|
335 |
+
# The prompt for generating queries is taken from [this](https://danielvanstrien.xyz/posts/post-with-code/colpali/2024-09-23-generate_colpali_dataset.html#an-update-retrieval-focused-prompt) wonderful blog post by Daniel van Strien.
|
336 |
+
#
|
337 |
+
# We will use the Gemini API to generate these queries, with `gemini-1.5-flash-8b` as the model.
|
338 |
+
#
|
339 |
+
|
340 |
+
# %%
|
341 |
+
from pydantic import BaseModel
|
342 |
+
|
343 |
+
|
344 |
+
class GeneratedQueries(BaseModel):
|
345 |
+
broad_topical_question: str
|
346 |
+
broad_topical_query: str
|
347 |
+
specific_detail_question: str
|
348 |
+
specific_detail_query: str
|
349 |
+
visual_element_question: str
|
350 |
+
visual_element_query: str
|
351 |
+
|
352 |
+
|
353 |
+
def get_retrieval_prompt() -> Tuple[str, GeneratedQueries]:
|
354 |
+
prompt = (
|
355 |
+
prompt
|
356 |
+
) = """You are an investor, stock analyst and financial expert. You will be presented an image of a document page from a report published by the Norwegian Government Pension Fund Global (GPFG). The report may be annual or quarterly reports, or policy reports, on topics such as responsible investment, risk etc.
|
357 |
+
Your task is to generate retrieval queries and questions that you would use to retrieve this document (or ask based on this document) in a large corpus.
|
358 |
+
Please generate 3 different types of retrieval queries and questions.
|
359 |
+
A retrieval query is a keyword based query, made up of 2-5 words, that you would type into a search engine to find this document.
|
360 |
+
A question is a natural language question that you would ask, for which the document contains the answer.
|
361 |
+
The queries should be of the following types:
|
362 |
+
1. A broad topical query: This should cover the main subject of the document.
|
363 |
+
2. A specific detail query: This should cover a specific detail or aspect of the document.
|
364 |
+
3. A visual element query: This should cover a visual element of the document, such as a chart, graph, or image.
|
365 |
+
|
366 |
+
Important guidelines:
|
367 |
+
- Ensure the queries are relevant for retrieval tasks, not just describing the page content.
|
368 |
+
- Use a fact-based natural language style for the questions.
|
369 |
+
- Frame the queries as if someone is searching for this document in a large corpus.
|
370 |
+
- Make the queries diverse and representative of different search strategies.
|
371 |
+
|
372 |
+
Format your response as a JSON object with the structure of the following example:
|
373 |
+
{
|
374 |
+
"broad_topical_question": "What was the Responsible Investment Policy in 2019?",
|
375 |
+
"broad_topical_query": "responsible investment policy 2019",
|
376 |
+
"specific_detail_question": "What is the percentage of investments in renewable energy?",
|
377 |
+
"specific_detail_query": "renewable energy investments percentage",
|
378 |
+
"visual_element_question": "What is the trend of total holding value over time?",
|
379 |
+
"visual_element_query": "total holding value trend"
|
380 |
+
}
|
381 |
+
|
382 |
+
If there are no relevant visual elements, provide an empty string for the visual element question and query.
|
383 |
+
Here is the document image to analyze:
|
384 |
+
Generate the queries based on this image and provide the response in the specified JSON format.
|
385 |
+
Only return JSON. Don't return any extra explanation text. """
|
386 |
+
|
387 |
+
return prompt, GeneratedQueries
|
388 |
+
|
389 |
+
|
390 |
+
prompt_text, pydantic_model = get_retrieval_prompt()
|
391 |
+
|
392 |
+
# %%
|
393 |
+
gemini_model = genai.GenerativeModel("gemini-1.5-flash-8b")
|
394 |
+
|
395 |
+
|
396 |
+
def generate_queries(image, prompt_text, pydantic_model):
|
397 |
+
try:
|
398 |
+
response = gemini_model.generate_content(
|
399 |
+
[image, "\n\n", prompt_text],
|
400 |
+
generation_config=genai.GenerationConfig(
|
401 |
+
response_mime_type="application/json",
|
402 |
+
response_schema=pydantic_model,
|
403 |
+
),
|
404 |
+
)
|
405 |
+
queries = json.loads(response.text)
|
406 |
+
except Exception as _e:
|
407 |
+
queries = {
|
408 |
+
"broad_topical_question": "",
|
409 |
+
"broad_topical_query": "",
|
410 |
+
"specific_detail_question": "",
|
411 |
+
"specific_detail_query": "",
|
412 |
+
"visual_element_question": "",
|
413 |
+
"visual_element_query": "",
|
414 |
+
}
|
415 |
+
return queries
|
416 |
+
|
417 |
+
|
418 |
+
# %%
|
419 |
+
for pdf in tqdm(pdf_pages):
|
420 |
+
image = pdf.get("image")
|
421 |
+
pdf["queries"] = generate_queries(image, prompt_text, pydantic_model)
|
422 |
+
|
423 |
+
# %%
|
424 |
+
pdf_pages[46]["image"]
|
425 |
+
|
426 |
+
# %%
|
427 |
+
pdf_pages[46]["queries"]
|
428 |
+
|
429 |
+
# %%
|
430 |
+
# Generate queries async - keeping for now as we probably need when applying to the full dataset
|
431 |
+
# import asyncio
|
432 |
+
# from tenacity import retry, stop_after_attempt, wait_exponential
|
433 |
+
# import google.generativeai as genai
|
434 |
+
# from tqdm.asyncio import tqdm_asyncio
|
435 |
+
|
436 |
+
# max_in_flight = 200 # Maximum number of concurrent requests
|
437 |
+
|
438 |
+
|
439 |
+
# async def generate_queries_for_image_async(model, image, semaphore):
|
440 |
+
# @retry(stop=stop_after_attempt(3), wait=wait_exponential(), reraise=True)
|
441 |
+
# async def _generate():
|
442 |
+
# async with semaphore:
|
443 |
+
# result = await model.generate_content_async(
|
444 |
+
# [image, "\n\n", prompt_text],
|
445 |
+
# generation_config=genai.GenerationConfig(
|
446 |
+
# response_mime_type="application/json",
|
447 |
+
# response_schema=pydantic_model,
|
448 |
+
# ),
|
449 |
+
# )
|
450 |
+
# return json.loads(result.text)
|
451 |
+
|
452 |
+
# try:
|
453 |
+
# return await _generate()
|
454 |
+
# except Exception as e:
|
455 |
+
# print(f"Error generating queries for image: {e}")
|
456 |
+
# return None # Return None or handle as needed
|
457 |
+
|
458 |
+
|
459 |
+
# async def enrich_pdfs():
|
460 |
+
# gemini_model = genai.GenerativeModel("gemini-1.5-flash-8b")
|
461 |
+
# semaphore = asyncio.Semaphore(max_in_flight)
|
462 |
+
# tasks = []
|
463 |
+
# for pdf in pdf_pages:
|
464 |
+
# pdf["queries"] = []
|
465 |
+
# image = pdf.get("image")
|
466 |
+
# if image:
|
467 |
+
# task = generate_queries_for_image_async(gemini_model, image, semaphore)
|
468 |
+
# tasks.append((pdf, task))
|
469 |
+
|
470 |
+
# # Run the tasks concurrently using asyncio.gather()
|
471 |
+
# for pdf, task in tqdm_asyncio(tasks):
|
472 |
+
# result = await task
|
473 |
+
# if result:
|
474 |
+
# pdf["queries"] = result
|
475 |
+
# return pdf_pages
|
476 |
+
|
477 |
+
|
478 |
+
# pdf_pages = asyncio.run(enrich_pdfs())
|
479 |
+
|
480 |
+
# %%
|
481 |
+
# write title, url, page_no, text, queries, not image to JSON
|
482 |
+
with open("output/pdf_pages.json", "w") as f:
|
483 |
+
to_write = [{k: v for k, v in pdf.items() if k != "image"} for pdf in pdf_pages]
|
484 |
+
json.dump(to_write, f, indent=2)
|
485 |
+
|
486 |
+
# with open("pdfs/pdf_pages.json", "r") as f:
|
487 |
+
# saved_pdf_pages = json.load(f)
|
488 |
+
# for pdf, saved_pdf in zip(pdf_pages, saved_pdf_pages):
|
489 |
+
# pdf.update(saved_pdf)
|
490 |
+
|
491 |
+
# %% [markdown]
|
492 |
+
# ## 4. Generate embeddings
|
493 |
+
#
|
494 |
+
# Now that we have the queries, we can use the ColPali model to generate embeddings for each page image.
|
495 |
+
#
|
496 |
+
|
497 |
+
|
498 |
+
# %%
|
499 |
+
def generate_embeddings(images, model, processor, batch_size=2) -> np.ndarray:
|
500 |
+
"""
|
501 |
+
Generate embeddings for a list of images.
|
502 |
+
Move to CPU only once per batch.
|
503 |
+
|
504 |
+
Args:
|
505 |
+
images (List[PIL.Image]): List of PIL images.
|
506 |
+
model (nn.Module): The model to generate embeddings.
|
507 |
+
processor: The processor to preprocess images.
|
508 |
+
batch_size (int, optional): Batch size for processing. Defaults to 64.
|
509 |
+
|
510 |
+
Returns:
|
511 |
+
np.ndarray: Embeddings for the images, shape
|
512 |
+
(len(images), processor.max_patch_length (1030 for ColPali), model.config.hidden_size (Patch embedding dimension - 128 for ColPali)).
|
513 |
+
"""
|
514 |
+
embeddings_list = []
|
515 |
+
|
516 |
+
def collate_fn(batch):
|
517 |
+
# Batch is a list of images
|
518 |
+
return processor.process_images(batch) # Should return a dict of tensors
|
519 |
+
|
520 |
+
dataloader = DataLoader(
|
521 |
+
images,
|
522 |
+
shuffle=False,
|
523 |
+
collate_fn=collate_fn,
|
524 |
+
)
|
525 |
+
|
526 |
+
for batch_doc in tqdm(dataloader, desc="Generating embeddings"):
|
527 |
+
with torch.no_grad():
|
528 |
+
# Move batch to the device
|
529 |
+
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
530 |
+
embeddings_batch = model(**batch_doc)
|
531 |
+
embeddings_list.append(torch.unbind(embeddings_batch.to("cpu"), dim=0))
|
532 |
+
# Concatenate all embeddings and create a numpy array
|
533 |
+
all_embeddings = np.concatenate(embeddings_list, axis=0)
|
534 |
+
return all_embeddings
|
535 |
+
|
536 |
+
|
537 |
+
# %%
|
538 |
+
# Generate embeddings for all images
|
539 |
+
images = [pdf["image"] for pdf in pdf_pages]
|
540 |
+
embeddings = generate_embeddings(images, model, processor)
|
541 |
+
|
542 |
+
# %%
|
543 |
+
embeddings.shape
|
544 |
+
|
545 |
+
# %% [markdown]
|
546 |
+
# ## 5. Prepare Data on Vespa Format
|
547 |
+
#
|
548 |
+
# Now, that we have all the data we need, all that remains is to make sure it is in the right format for Vespa.
|
549 |
+
#
|
550 |
+
|
551 |
+
|
552 |
+
# %%
|
553 |
+
def float_to_binary_embedding(float_query_embedding: dict) -> dict:
|
554 |
+
"""Utility function to convert float query embeddings to binary query embeddings."""
|
555 |
+
binary_query_embeddings = {}
|
556 |
+
for k, v in float_query_embedding.items():
|
557 |
+
binary_vector = (
|
558 |
+
np.packbits(np.where(np.array(v) > 0, 1, 0)).astype(np.int8).tolist()
|
559 |
+
)
|
560 |
+
binary_query_embeddings[k] = binary_vector
|
561 |
+
return binary_query_embeddings
|
562 |
+
|
563 |
+
|
564 |
+
# %%
|
565 |
+
vespa_feed = []
|
566 |
+
for pdf, embedding in zip(pdf_pages, embeddings):
|
567 |
+
url = pdf["url"]
|
568 |
+
title = pdf["title"]
|
569 |
+
image = pdf["image"]
|
570 |
+
text = pdf.get("text", "")
|
571 |
+
page_no = pdf["page_no"]
|
572 |
+
query_dict = pdf["queries"]
|
573 |
+
questions = [v for k, v in query_dict.items() if "question" in k and v]
|
574 |
+
queries = [v for k, v in query_dict.items() if "query" in k and v]
|
575 |
+
base_64_image = get_base64_image(
|
576 |
+
scale_image(image, 32), add_url_prefix=False
|
577 |
+
) # Scaled down image to return fast on search (~1kb)
|
578 |
+
base_64_full_image = get_base64_image(image, add_url_prefix=False)
|
579 |
+
embedding_dict = {k: v for k, v in enumerate(embedding)}
|
580 |
+
binary_embedding = float_to_binary_embedding(embedding_dict)
|
581 |
+
# id_hash should be md5 hash of url and page_number
|
582 |
+
id_hash = hashlib.md5(f"{url}_{page_no}".encode()).hexdigest()
|
583 |
+
page = {
|
584 |
+
"id": id_hash,
|
585 |
+
"fields": {
|
586 |
+
"id": id_hash,
|
587 |
+
"url": url,
|
588 |
+
"title": title,
|
589 |
+
"page_number": page_no,
|
590 |
+
"blur_image": base_64_image,
|
591 |
+
"full_image": base_64_full_image,
|
592 |
+
"text": text,
|
593 |
+
"embedding": binary_embedding,
|
594 |
+
"queries": queries,
|
595 |
+
"questions": questions,
|
596 |
+
},
|
597 |
+
}
|
598 |
+
vespa_feed.append(page)
|
599 |
+
|
600 |
+
# %%
|
601 |
+
# We will prepare the Vespa feed data, including the embeddings and the generated queries
|
602 |
+
|
603 |
+
|
604 |
+
# Save vespa_feed to vespa_feed.json
|
605 |
+
os.makedirs("output", exist_ok=True)
|
606 |
+
with open("output/vespa_feed.json", "w") as f:
|
607 |
+
vespa_feed_to_save = []
|
608 |
+
for page in vespa_feed:
|
609 |
+
document_id = page["id"]
|
610 |
+
put_id = f"id:{VESPA_APPLICATION_NAME}:{VESPA_SCHEMA_NAME}::{document_id}"
|
611 |
+
vespa_feed_to_save.append({"put": put_id, "fields": page["fields"]})
|
612 |
+
json.dump(vespa_feed_to_save, f)
|
613 |
+
|
614 |
+
# %%
|
615 |
+
# import json
|
616 |
+
|
617 |
+
# with open("output/vespa_feed.json", "r") as f:
|
618 |
+
# vespa_feed = json.load(f)
|
619 |
+
|
620 |
+
# %%
|
621 |
+
len(vespa_feed)
|
622 |
+
|
623 |
+
# %% [markdown]
|
624 |
+
# ## 5. Prepare Vespa Application
|
625 |
+
#
|
626 |
+
|
627 |
+
# %%
|
628 |
+
# Define the Vespa schema
|
629 |
+
colpali_schema = Schema(
|
630 |
+
name=VESPA_SCHEMA_NAME,
|
631 |
+
document=Document(
|
632 |
+
fields=[
|
633 |
+
Field(
|
634 |
+
name="id",
|
635 |
+
type="string",
|
636 |
+
indexing=["summary", "index"],
|
637 |
+
match=["word"],
|
638 |
+
),
|
639 |
+
Field(name="url", type="string", indexing=["summary", "index"]),
|
640 |
+
Field(
|
641 |
+
name="title",
|
642 |
+
type="string",
|
643 |
+
indexing=["summary", "index"],
|
644 |
+
match=["text"],
|
645 |
+
index="enable-bm25",
|
646 |
+
),
|
647 |
+
Field(name="page_number", type="int", indexing=["summary", "attribute"]),
|
648 |
+
Field(name="blur_image", type="raw", indexing=["summary"]),
|
649 |
+
Field(name="full_image", type="raw", indexing=["summary"]),
|
650 |
+
Field(
|
651 |
+
name="text",
|
652 |
+
type="string",
|
653 |
+
indexing=["summary", "index"],
|
654 |
+
match=["text"],
|
655 |
+
index="enable-bm25",
|
656 |
+
),
|
657 |
+
Field(
|
658 |
+
name="embedding",
|
659 |
+
type="tensor<int8>(patch{}, v[16])",
|
660 |
+
indexing=[
|
661 |
+
"attribute",
|
662 |
+
"index",
|
663 |
+
],
|
664 |
+
ann=HNSW(
|
665 |
+
distance_metric="hamming",
|
666 |
+
max_links_per_node=32,
|
667 |
+
neighbors_to_explore_at_insert=400,
|
668 |
+
),
|
669 |
+
),
|
670 |
+
Field(
|
671 |
+
name="questions",
|
672 |
+
type="array<string>",
|
673 |
+
indexing=["summary", "index", "attribute"],
|
674 |
+
index="enable-bm25",
|
675 |
+
stemming="best",
|
676 |
+
),
|
677 |
+
Field(
|
678 |
+
name="queries",
|
679 |
+
type="array<string>",
|
680 |
+
indexing=["summary", "index", "attribute"],
|
681 |
+
index="enable-bm25",
|
682 |
+
stemming="best",
|
683 |
+
),
|
684 |
+
# Add synthetic fields for the questions and queries
|
685 |
+
# Field(
|
686 |
+
# name="questions_exact",
|
687 |
+
# type="array<string>",
|
688 |
+
# indexing=["input questions", "index", "attribute"],
|
689 |
+
# match=["word"],
|
690 |
+
# is_document_field=False,
|
691 |
+
# ),
|
692 |
+
# Field(
|
693 |
+
# name="queries_exact",
|
694 |
+
# type="array<string>",
|
695 |
+
# indexing=["input queries", "index"],
|
696 |
+
# match=["word"],
|
697 |
+
# is_document_field=False,
|
698 |
+
# ),
|
699 |
+
]
|
700 |
+
),
|
701 |
+
fieldsets=[
|
702 |
+
FieldSet(
|
703 |
+
name="default",
|
704 |
+
fields=["title", "url", "blur_image", "page_number", "text"],
|
705 |
+
),
|
706 |
+
FieldSet(
|
707 |
+
name="image",
|
708 |
+
fields=["full_image"],
|
709 |
+
),
|
710 |
+
],
|
711 |
+
document_summaries=[
|
712 |
+
DocumentSummary(
|
713 |
+
name="default",
|
714 |
+
summary_fields=[
|
715 |
+
Summary(
|
716 |
+
name="text",
|
717 |
+
fields=[("bolding", "on")],
|
718 |
+
),
|
719 |
+
Summary(
|
720 |
+
name="snippet",
|
721 |
+
fields=[("source", "text"), "dynamic"],
|
722 |
+
),
|
723 |
+
],
|
724 |
+
from_disk=True,
|
725 |
+
),
|
726 |
+
],
|
727 |
+
)
|
728 |
+
|
729 |
+
# Define similarity functions used in all rank profiles
|
730 |
+
mapfunctions = [
|
731 |
+
Function(
|
732 |
+
name="similarities", # computes similarity scores between each query token and image patch
|
733 |
+
expression="""
|
734 |
+
sum(
|
735 |
+
query(qt) * unpack_bits(attribute(embedding)), v
|
736 |
+
)
|
737 |
+
""",
|
738 |
+
),
|
739 |
+
Function(
|
740 |
+
name="normalized", # normalizes the similarity scores to [-1, 1]
|
741 |
+
expression="""
|
742 |
+
(similarities - reduce(similarities, min)) / (reduce((similarities - reduce(similarities, min)), max)) * 2 - 1
|
743 |
+
""",
|
744 |
+
),
|
745 |
+
Function(
|
746 |
+
name="quantized", # quantizes the normalized similarity scores to signed 8-bit integers [-128, 127]
|
747 |
+
expression="""
|
748 |
+
cell_cast(normalized * 127.999, int8)
|
749 |
+
""",
|
750 |
+
),
|
751 |
+
]
|
752 |
+
|
753 |
+
# Define the 'bm25' rank profile
|
754 |
+
colpali_bm25_profile = RankProfile(
|
755 |
+
name="bm25",
|
756 |
+
inputs=[("query(qt)", "tensor<float>(querytoken{}, v[128])")],
|
757 |
+
first_phase="bm25(title) + bm25(text)",
|
758 |
+
functions=mapfunctions,
|
759 |
+
summary_features=["quantized"],
|
760 |
+
)
|
761 |
+
colpali_schema.add_rank_profile(colpali_bm25_profile)
|
762 |
+
|
763 |
+
# Update the 'default' rank profile
|
764 |
+
colpali_profile = RankProfile(
|
765 |
+
name="default",
|
766 |
+
inputs=[("query(qt)", "tensor<float>(querytoken{}, v[128])")],
|
767 |
+
first_phase="bm25_score",
|
768 |
+
second_phase=SecondPhaseRanking(expression="max_sim", rerank_count=10),
|
769 |
+
functions=mapfunctions
|
770 |
+
+ [
|
771 |
+
Function(
|
772 |
+
name="max_sim",
|
773 |
+
expression="""
|
774 |
+
sum(
|
775 |
+
reduce(
|
776 |
+
sum(
|
777 |
+
query(qt) * unpack_bits(attribute(embedding)), v
|
778 |
+
),
|
779 |
+
max, patch
|
780 |
+
),
|
781 |
+
querytoken
|
782 |
+
)
|
783 |
+
""",
|
784 |
+
),
|
785 |
+
Function(name="bm25_score", expression="bm25(title) + bm25(text)"),
|
786 |
+
],
|
787 |
+
summary_features=["quantized"],
|
788 |
+
)
|
789 |
+
colpali_schema.add_rank_profile(colpali_profile)
|
790 |
+
|
791 |
+
# Update the 'retrieval-and-rerank' rank profile
|
792 |
+
input_query_tensors = []
|
793 |
+
MAX_QUERY_TERMS = 64
|
794 |
+
for i in range(MAX_QUERY_TERMS):
|
795 |
+
input_query_tensors.append((f"query(rq{i})", "tensor<int8>(v[16])"))
|
796 |
+
|
797 |
+
input_query_tensors.extend(
|
798 |
+
[
|
799 |
+
("query(qt)", "tensor<float>(querytoken{}, v[128])"),
|
800 |
+
("query(qtb)", "tensor<int8>(querytoken{}, v[16])"),
|
801 |
+
]
|
802 |
+
)
|
803 |
+
|
804 |
+
colpali_retrieval_profile = RankProfile(
|
805 |
+
name="retrieval-and-rerank",
|
806 |
+
inputs=input_query_tensors,
|
807 |
+
first_phase="max_sim_binary",
|
808 |
+
second_phase=SecondPhaseRanking(expression="max_sim", rerank_count=10),
|
809 |
+
functions=mapfunctions
|
810 |
+
+ [
|
811 |
+
Function(
|
812 |
+
name="max_sim",
|
813 |
+
expression="""
|
814 |
+
sum(
|
815 |
+
reduce(
|
816 |
+
sum(
|
817 |
+
query(qt) * unpack_bits(attribute(embedding)), v
|
818 |
+
),
|
819 |
+
max, patch
|
820 |
+
),
|
821 |
+
querytoken
|
822 |
+
)
|
823 |
+
""",
|
824 |
+
),
|
825 |
+
Function(
|
826 |
+
name="max_sim_binary",
|
827 |
+
expression="""
|
828 |
+
sum(
|
829 |
+
reduce(
|
830 |
+
1 / (1 + sum(
|
831 |
+
hamming(query(qtb), attribute(embedding)), v)
|
832 |
+
),
|
833 |
+
max, patch
|
834 |
+
),
|
835 |
+
querytoken
|
836 |
+
)
|
837 |
+
""",
|
838 |
+
),
|
839 |
+
],
|
840 |
+
summary_features=["quantized"],
|
841 |
+
)
|
842 |
+
colpali_schema.add_rank_profile(colpali_retrieval_profile)
|
843 |
+
|
844 |
+
# %%
|
845 |
+
from vespa.configuration.services import (
|
846 |
+
services,
|
847 |
+
container,
|
848 |
+
search,
|
849 |
+
document_api,
|
850 |
+
document_processing,
|
851 |
+
clients,
|
852 |
+
client,
|
853 |
+
config,
|
854 |
+
content,
|
855 |
+
redundancy,
|
856 |
+
documents,
|
857 |
+
node,
|
858 |
+
certificate,
|
859 |
+
token,
|
860 |
+
document,
|
861 |
+
nodes,
|
862 |
+
)
|
863 |
+
from vespa.configuration.vt import vt
|
864 |
+
from vespa.package import ServicesConfiguration
|
865 |
+
|
866 |
+
service_config = ServicesConfiguration(
|
867 |
+
application_name=VESPA_APPLICATION_NAME,
|
868 |
+
services_config=services(
|
869 |
+
container(
|
870 |
+
search(),
|
871 |
+
document_api(),
|
872 |
+
document_processing(),
|
873 |
+
clients(
|
874 |
+
client(
|
875 |
+
certificate(file="security/clients.pem"),
|
876 |
+
id="mtls",
|
877 |
+
permissions="read,write",
|
878 |
+
),
|
879 |
+
client(
|
880 |
+
token(id=f"{VESPA_TOKEN_ID_WRITE}"),
|
881 |
+
id="token_write",
|
882 |
+
permissions="read,write",
|
883 |
+
),
|
884 |
+
client(
|
885 |
+
token(id=f"{VESPA_TOKEN_ID_READ}"),
|
886 |
+
id="token_read",
|
887 |
+
permissions="read",
|
888 |
+
),
|
889 |
+
),
|
890 |
+
config(
|
891 |
+
vt("tag")(
|
892 |
+
vt("bold")(
|
893 |
+
vt("open", "<strong>"),
|
894 |
+
vt("close", "</strong>"),
|
895 |
+
),
|
896 |
+
vt("separator", "..."),
|
897 |
+
),
|
898 |
+
name="container.qr-searchers",
|
899 |
+
),
|
900 |
+
id=f"{VESPA_APPLICATION_NAME}_container",
|
901 |
+
version="1.0",
|
902 |
+
),
|
903 |
+
content(
|
904 |
+
redundancy("1"),
|
905 |
+
documents(document(type="pdf_page", mode="index")),
|
906 |
+
nodes(node(distribution_key="0", hostalias="node1")),
|
907 |
+
config(
|
908 |
+
vt("max_matches", "2", replace_underscores=False),
|
909 |
+
vt("length", "1000"),
|
910 |
+
vt("surround_max", "500", replace_underscores=False),
|
911 |
+
vt("min_length", "300", replace_underscores=False),
|
912 |
+
name="vespa.config.search.summary.juniperrc",
|
913 |
+
),
|
914 |
+
id=f"{VESPA_APPLICATION_NAME}_content",
|
915 |
+
version="1.0",
|
916 |
+
),
|
917 |
+
version="1.0",
|
918 |
+
),
|
919 |
+
)
|
920 |
+
|
921 |
+
# %%
|
922 |
+
# Create the Vespa application package
|
923 |
+
vespa_application_package = ApplicationPackage(
|
924 |
+
name=VESPA_APPLICATION_NAME,
|
925 |
+
schema=[colpali_schema],
|
926 |
+
services_config=service_config,
|
927 |
+
)
|
928 |
+
|
929 |
+
# %% [markdown]
|
930 |
+
# ## 6. Deploy Vespa Application
|
931 |
+
#
|
932 |
+
|
933 |
+
# %%
|
934 |
+
VESPA_TEAM_API_KEY = os.getenv("VESPA_TEAM_API_KEY") or input(
|
935 |
+
"Enter Vespa team API key: "
|
936 |
+
)
|
937 |
+
|
938 |
+
# %%
|
939 |
+
vespa_cloud = VespaCloud(
|
940 |
+
tenant=VESPA_TENANT_NAME,
|
941 |
+
application=VESPA_APPLICATION_NAME,
|
942 |
+
key_content=VESPA_TEAM_API_KEY,
|
943 |
+
application_package=vespa_application_package,
|
944 |
+
)
|
945 |
+
|
946 |
+
# Deploy the application
|
947 |
+
vespa_cloud.deploy()
|
948 |
+
|
949 |
+
# Output the endpoint URL
|
950 |
+
endpoint_url = vespa_cloud.get_token_endpoint()
|
951 |
+
print(f"Application deployed. Token endpoint URL: {endpoint_url}")
|
952 |
+
|
953 |
+
# %% [markdown]
|
954 |
+
# Make sure to take note of the token endpoint_url.
|
955 |
+
# You need to put this in your `.env` file - `VESPA_APP_URL=https://abcd.vespa-app.cloud` - to access the Vespa application from your web application.
|
956 |
+
#
|
957 |
+
|
958 |
+
# %% [markdown]
|
959 |
+
# ## 8. Feed Data to Vespa
|
960 |
+
#
|
961 |
+
|
962 |
+
# %%
|
963 |
+
# Instantiate Vespa connection using token
|
964 |
+
app = Vespa(url=endpoint_url, vespa_cloud_secret_token=VESPA_CLOUD_SECRET_TOKEN)
|
965 |
+
app.get_application_status()
|
966 |
+
|
967 |
+
|
968 |
+
# %%
|
969 |
+
def callback(response: VespaResponse, id: str):
|
970 |
+
if not response.is_successful():
|
971 |
+
print(
|
972 |
+
f"Failed to feed document {id} with status code {response.status_code}: Reason {response.get_json()}"
|
973 |
+
)
|
974 |
+
|
975 |
+
|
976 |
+
# Feed data into Vespa asynchronously
|
977 |
+
app.feed_async_iterable(vespa_feed, schema=VESPA_SCHEMA_NAME, callback=callback)
|
pyproject.toml
CHANGED
@@ -8,7 +8,7 @@ license = { text = "Apache-2.0" }
|
|
8 |
dependencies = [
|
9 |
"python-fasthtml",
|
10 |
"huggingface-hub",
|
11 |
-
"pyvespa
|
12 |
"vespacli",
|
13 |
"torch",
|
14 |
"vidore-benchmark[interpretability]>=4.0.0,<5.0.0",
|
@@ -18,6 +18,7 @@ dependencies = [
|
|
18 |
"setuptools",
|
19 |
"python-dotenv",
|
20 |
"shad4fast>=1.2.1",
|
|
|
21 |
]
|
22 |
|
23 |
# dev-dependencies
|
@@ -27,3 +28,11 @@ dev = [
|
|
27 |
"python-dotenv",
|
28 |
"huggingface_hub[cli]"
|
29 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
dependencies = [
|
9 |
"python-fasthtml",
|
10 |
"huggingface-hub",
|
11 |
+
"pyvespa>=0.50.0",
|
12 |
"vespacli",
|
13 |
"torch",
|
14 |
"vidore-benchmark[interpretability]>=4.0.0,<5.0.0",
|
|
|
18 |
"setuptools",
|
19 |
"python-dotenv",
|
20 |
"shad4fast>=1.2.1",
|
21 |
+
"google-generativeai>=0.7.2"
|
22 |
]
|
23 |
|
24 |
# dev-dependencies
|
|
|
28 |
"python-dotenv",
|
29 |
"huggingface_hub[cli]"
|
30 |
]
|
31 |
+
feed = [
|
32 |
+
"ipykernel",
|
33 |
+
"jupytext",
|
34 |
+
"pydantic",
|
35 |
+
"beautifulsoup4",
|
36 |
+
"pdf2image",
|
37 |
+
"google-generativeai"
|
38 |
+
]
|
static/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
uv.lock
CHANGED
The diff for this file is too large to render.
See raw diff
|
|