thomasht86 commited on
Commit
b7897bb
β€’
1 Parent(s): bb4f59a

Upload folder using huggingface_hub

Browse files
Files changed (14) hide show
  1. .DS_Store +0 -0
  2. .env.example +2 -1
  3. .gitignore +3 -1
  4. README.md +20 -1
  5. backend/colpali.py +25 -7
  6. frontend/app.py +46 -13
  7. frontend/layout.py +95 -14
  8. globals.css +22 -1
  9. main.py +106 -19
  10. output.css +77 -1
  11. prepare_feed_deploy.py +977 -0
  12. pyproject.toml +10 -1
  13. static/.DS_Store +0 -0
  14. uv.lock +0 -0
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
.env.example CHANGED
@@ -1,3 +1,4 @@
1
  VESPA_APP_URL=https://abcde.z.vespa-app.cloud
2
  HF_TOKEN=hf_xxxxxxxxxx
3
- VESPA_CLOUD_SECRET_TOKEN=vespa_cloud_xxxxxxxx
 
 
1
  VESPA_APP_URL=https://abcde.z.vespa-app.cloud
2
  HF_TOKEN=hf_xxxxxxxxxx
3
+ VESPA_CLOUD_SECRET_TOKEN=vespa_cloud_xxxxxxxx
4
+ GEMINI_API_KEY=
.gitignore CHANGED
@@ -1,8 +1,10 @@
1
  .sesskey
2
  .venv/
3
  __pycache__/
 
4
  .python-version
5
  .env
6
  template/
7
  *.json
8
- output/
 
 
1
  .sesskey
2
  .venv/
3
  __pycache__/
4
+ ipynb_checkpoints/
5
  .python-version
6
  .env
7
  template/
8
  *.json
9
+ output/
10
+ pdfs/
README.md CHANGED
@@ -27,7 +27,7 @@ preload_from_hub:
27
 
28
  # Visual Retrieval ColPali
29
 
30
- # Developing
31
 
32
  First, install `uv`:
33
 
@@ -35,6 +35,25 @@ First, install `uv`:
35
  curl -LsSf https://astral.sh/uv/install.sh | sh
36
  ```
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  Then, in this directory, run:
39
 
40
  ```bash
 
27
 
28
  # Visual Retrieval ColPali
29
 
30
+ # Prepare data and Vespa application
31
 
32
  First, install `uv`:
33
 
 
35
  curl -LsSf https://astral.sh/uv/install.sh | sh
36
  ```
37
 
38
+ Then, run:
39
+
40
+ ```bash
41
+ uv sync --extra dev --extra feed
42
+ ```
43
+
44
+ Convert the `prepare_feed_deploy.py` to notebook to:
45
+
46
+ ```bash
47
+ jupytext --to notebook prepare_feed_deploy.py
48
+ ```
49
+
50
+ And launch a Jupyter instance, see https://docs.astral.sh/uv/guides/integration/jupyter/ for recommended approach.
51
+
52
+ Open and follow the `prepare_feed_deploy.ipynb` notebook to prepare the data and deploy the Vespa application.
53
+
54
+ # Developing on the web app
55
+
56
+
57
  Then, in this directory, run:
58
 
59
  ```bash
backend/colpali.py CHANGED
@@ -170,13 +170,13 @@ def gen_similarity_maps(
170
  if vespa_sim_maps:
171
  print("Using provided similarity maps")
172
  # A sim map looks like this:
173
- # "similarities": [
174
  # {
175
  # "address": {
176
  # "patch": "0",
177
  # "querytoken": "0"
178
  # },
179
- # "value": 1.2599412202835083
180
  # },
181
  # ... and so on.
182
  # Now turn these into a tensor of same shape as previous similarity map
@@ -189,7 +189,7 @@ def gen_similarity_maps(
189
  )
190
  )
191
  for idx, vespa_sim_map in enumerate(vespa_sim_maps):
192
- for cell in vespa_sim_map["similarities"]["cells"]:
193
  patch = int(cell["address"]["patch"])
194
  # if dummy model then just use 1024 as the image_seq_length
195
 
@@ -359,7 +359,7 @@ async def query_vespa_default(
359
  start = time.perf_counter()
360
  response: VespaQueryResponse = await session.query(
361
  body={
362
- "yql": "select id,title,url,full_image,page_number,snippet,text,summaryfeatures from pdf_page where userQuery();",
363
  "ranking": "default",
364
  "query": query,
365
  "timeout": timeout,
@@ -392,7 +392,7 @@ async def query_vespa_bm25(
392
  start = time.perf_counter()
393
  response: VespaQueryResponse = await session.query(
394
  body={
395
- "yql": "select id,title,url,full_image,page_number,snippet,text,summaryfeatures from pdf_page where userQuery();",
396
  "ranking": "bm25",
397
  "query": query,
398
  "timeout": timeout,
@@ -472,7 +472,7 @@ async def query_vespa_nearest_neighbor(
472
  **query_tensors,
473
  "presentation.timing": True,
474
  # if we use rank({nn_string}, userQuery()), dynamic summary doesn't work, see https://github.com/vespa-engine/vespa/issues/28704
475
- "yql": f"select id,title,snippet,text,url,full_image,page_number,summaryfeatures from pdf_page where {nn_string} or userQuery()",
476
  "ranking.profile": "retrieval-and-rerank",
477
  "timeout": timeout,
478
  "hits": hits,
@@ -492,6 +492,24 @@ def is_special_token(token: str) -> bool:
492
  return True
493
  return False
494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
 
496
  async def get_result_from_query(
497
  app: Vespa,
@@ -538,7 +556,7 @@ def add_sim_maps_to_result(
538
  imgs: List[str] = []
539
  vespa_sim_maps: List[str] = []
540
  for single_result in result["root"]["children"]:
541
- img = single_result["fields"]["full_image"]
542
  if img:
543
  imgs.append(img)
544
  vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
 
170
  if vespa_sim_maps:
171
  print("Using provided similarity maps")
172
  # A sim map looks like this:
173
+ # "quantized": [
174
  # {
175
  # "address": {
176
  # "patch": "0",
177
  # "querytoken": "0"
178
  # },
179
+ # "value": 12, # score in range [-128, 127]
180
  # },
181
  # ... and so on.
182
  # Now turn these into a tensor of same shape as previous similarity map
 
189
  )
190
  )
191
  for idx, vespa_sim_map in enumerate(vespa_sim_maps):
192
+ for cell in vespa_sim_map["quantized"]["cells"]:
193
  patch = int(cell["address"]["patch"])
194
  # if dummy model then just use 1024 as the image_seq_length
195
 
 
359
  start = time.perf_counter()
360
  response: VespaQueryResponse = await session.query(
361
  body={
362
+ "yql": "select id,title,url,blur_image,page_number,snippet,text,summaryfeatures from pdf_page where userQuery();",
363
  "ranking": "default",
364
  "query": query,
365
  "timeout": timeout,
 
392
  start = time.perf_counter()
393
  response: VespaQueryResponse = await session.query(
394
  body={
395
+ "yql": "select id,title,url,blur_image,page_number,snippet,text,summaryfeatures from pdf_page where userQuery();",
396
  "ranking": "bm25",
397
  "query": query,
398
  "timeout": timeout,
 
472
  **query_tensors,
473
  "presentation.timing": True,
474
  # if we use rank({nn_string}, userQuery()), dynamic summary doesn't work, see https://github.com/vespa-engine/vespa/issues/28704
475
+ "yql": f"select id,title,snippet,text,url,blur_image,page_number,summaryfeatures from pdf_page where {nn_string} or userQuery()",
476
  "ranking.profile": "retrieval-and-rerank",
477
  "timeout": timeout,
478
  "hits": hits,
 
492
  return True
493
  return False
494
 
495
+ async def get_full_image_from_vespa(
496
+ app: Vespa,
497
+ id: str) -> str:
498
+ async with app.asyncio(connections=1, total_timeout=120) as session:
499
+ start = time.perf_counter()
500
+ response: VespaQueryResponse = await session.query(
501
+ body={
502
+ "yql": f"select full_image from pdf_page where id contains \"{id}\"",
503
+ "ranking": "unranked",
504
+ "presentation.timing": True,
505
+ },
506
+ )
507
+ assert response.is_successful(), response.json
508
+ stop = time.perf_counter()
509
+ print(
510
+ f"Getting image from Vespa took: {stop - start} s, vespa said searchtime was {response.json.get('timing', {}).get('searchtime', -1)} s"
511
+ )
512
+ return response.json["root"]["children"][0]["fields"]["full_image"]
513
 
514
  async def get_result_from_query(
515
  app: Vespa,
 
556
  imgs: List[str] = []
557
  vespa_sim_maps: List[str] = []
558
  for single_result in result["root"]["children"]:
559
+ img = single_result["fields"]["blur_image"]
560
  if img:
561
  imgs.append(img)
562
  vespa_sim_map = single_result["fields"].get("summaryfeatures", None)
frontend/app.py CHANGED
@@ -131,9 +131,13 @@ def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"):
131
 
132
  def SampleQueries():
133
  sample_queries = [
134
- "Percentage of non-fresh water as source?",
135
- "Policies related to nature risk?",
136
- "How much of produced water is recycled?",
 
 
 
 
137
  ]
138
 
139
  query_badges = []
@@ -193,21 +197,23 @@ def Search(request, search_results=[]):
193
  )
194
  return Div(
195
  Div(
196
- SearchBox(query_value=query_value, ranking_value=ranking_value),
197
  Div(
198
- LoadingMessage(),
199
- id="search-results", # This will be replaced by the search results
 
 
 
 
200
  ),
201
  cls="grid",
202
  ),
203
- cls="grid",
204
  )
205
 
206
 
207
- def LoadingMessage():
208
  return Div(
209
  Lucide(icon="loader-circle", cls="size-5 mr-1.5 animate-spin"),
210
- Span("Retrieving search results", cls="text-base text-center"),
211
  cls="p-10 text-muted-foreground flex items-center justify-center",
212
  id="loading-indicator",
213
  )
@@ -250,7 +256,7 @@ def SearchResult(results: list, query_id: Optional[str] = None):
250
  result_items = []
251
  for idx, result in enumerate(results):
252
  fields = result["fields"] # Extract the 'fields' part of each result
253
- full_image_base64 = f"data:image/jpeg;base64,{fields['full_image']}"
254
 
255
  # Filter sim_map fields that are words with 4 or more characters
256
  sim_map_fields = {
@@ -286,7 +292,7 @@ def SearchResult(results: list, query_id: Optional[str] = None):
286
  "Reset",
287
  variant="outline",
288
  size="sm",
289
- data_image_src=full_image_base64,
290
  cls="reset-button pointer-events-auto font-mono text-xs h-5 rounded-none px-2",
291
  )
292
 
@@ -312,7 +318,11 @@ def SearchResult(results: list, query_id: Optional[str] = None):
312
  Div(
313
  Div(
314
  Img(
315
- src=full_image_base64,
 
 
 
 
316
  alt=fields["title"],
317
  cls="result-image w-full h-full object-contain",
318
  ),
@@ -350,12 +360,35 @@ def SearchResult(results: list, query_id: Optional[str] = None):
350
  ),
351
  cls="bg-background px-3 py-5 hidden md:block",
352
  ),
353
- cls="grid grid-cols-1 md:grid-cols-2 col-span-2",
354
  )
355
  )
 
356
  return Div(
357
  *result_items,
358
  image_swapping,
359
  id="search-results",
360
  cls="grid grid-cols-2 gap-px bg-border",
361
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  def SampleQueries():
133
  sample_queries = [
134
+ "Proportion of female new hires 2021-2023?",
135
+ "Total amount of performance-based pay awarded in 2023?",
136
+ "What is the percentage distribution of employees with performance-based pay relative to the limit in 2023?",
137
+ "What is the breakdown of management costs by investment strategy in 2023?",
138
+ "2023 profit loss portfolio",
139
+ "net cash flow operating activities",
140
+ "fund currency basket returns",
141
  ]
142
 
143
  query_badges = []
 
197
  )
198
  return Div(
199
  Div(
 
200
  Div(
201
+ SearchBox(query_value=query_value, ranking_value=ranking_value),
202
+ Div(
203
+ LoadingMessage(),
204
+ id="search-results", # This will be replaced by the search results
205
+ ),
206
+ cls="grid",
207
  ),
208
  cls="grid",
209
  ),
 
210
  )
211
 
212
 
213
+ def LoadingMessage(display_text="Retrieving search results"):
214
  return Div(
215
  Lucide(icon="loader-circle", cls="size-5 mr-1.5 animate-spin"),
216
+ Span(display_text, cls="text-base text-center"),
217
  cls="p-10 text-muted-foreground flex items-center justify-center",
218
  id="loading-indicator",
219
  )
 
256
  result_items = []
257
  for idx, result in enumerate(results):
258
  fields = result["fields"] # Extract the 'fields' part of each result
259
+ blur_image_base64 = f"data:image/jpeg;base64,{fields['blur_image']}"
260
 
261
  # Filter sim_map fields that are words with 4 or more characters
262
  sim_map_fields = {
 
292
  "Reset",
293
  variant="outline",
294
  size="sm",
295
+ data_image_src=blur_image_base64,
296
  cls="reset-button pointer-events-auto font-mono text-xs h-5 rounded-none px-2",
297
  )
298
 
 
318
  Div(
319
  Div(
320
  Img(
321
+ src=blur_image_base64,
322
+ hx_get=f"/full_image?id={fields['id']}",
323
+ style="filter: blur(5px);",
324
+ hx_trigger="load",
325
+ hx_swap="outerHTML",
326
  alt=fields["title"],
327
  cls="result-image w-full h-full object-contain",
328
  ),
 
360
  ),
361
  cls="bg-background px-3 py-5 hidden md:block",
362
  ),
363
+ cls="grid grid-cols-1 md:grid-cols-2 col-span-2 border-t",
364
  )
365
  )
366
+
367
  return Div(
368
  *result_items,
369
  image_swapping,
370
  id="search-results",
371
  cls="grid grid-cols-2 gap-px bg-border",
372
  )
373
+
374
+
375
+ def ChatResult(query_id: str, query: str):
376
+ return Div(
377
+ Div("Chat", cls="text-xl font-semibold p-3"),
378
+ Div(
379
+ Div(
380
+ Div(
381
+ LoadingMessage(display_text="Waiting for response..."),
382
+ cls="bg-muted/80 dark:bg-muted/40 text-black dark:text-white p-2 rounded-md",
383
+ hx_ext="sse",
384
+ sse_connect=f"/get-message?query_id={query_id}&query={quote_plus(query)}",
385
+ sse_swap="message",
386
+ sse_close="close",
387
+ hx_swap="innerHTML",
388
+ ),
389
+ ),
390
+ id="chat-messages",
391
+ cls="overflow-auto min-h-0 grid items-end px-3",
392
+ ),
393
+ cls="h-full grid grid-rows-[auto_1fr_auto] min-h-0 gap-3",
394
+ )
frontend/layout.py CHANGED
@@ -1,15 +1,96 @@
1
- from fasthtml.components import Div, Img, Nav, Title, Body, Header, Main
2
- from fasthtml.xtend import A
3
  from lucide_fasthtml import Lucide
4
  from shad4fast import Button, Separator
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def Logo():
8
  return Div(
9
- Img(src='https://assets.vespa.ai/logos/vespa-logo-black.svg', alt='Vespa Logo', cls='h-full dark:hidden'),
10
- Img(src='https://assets.vespa.ai/logos/vespa-logo-white.svg', alt='Vespa Logo Dark Mode',
11
- cls='h-full hidden dark:block'),
12
- cls='h-[27px]'
 
 
 
 
 
 
 
13
  )
14
 
15
 
@@ -38,23 +119,23 @@ def Links():
38
  ),
39
  Separator(orientation="vertical"),
40
  ThemeToggle(),
41
- cls='flex items-center space-x-3'
42
  )
43
 
44
 
45
  def Layout(*c, **kwargs):
46
  return (
47
- Title('Visual Retrieval ColPali'),
48
  Body(
49
  Header(
50
  A(Logo(), href="/"),
51
  Links(),
52
- cls='min-h-[55px] h-[55px] w-full flex items-center justify-between px-4'
53
- ),
54
- Main(
55
- *c, **kwargs,
56
- cls='flex-1 h-full'
57
  ),
58
- cls='h-full flex flex-col'
 
 
59
  ),
 
 
60
  )
 
1
+ from fasthtml.components import Body, Div, Header, Img, Nav, Title
2
+ from fasthtml.xtend import A, Script
3
  from lucide_fasthtml import Lucide
4
  from shad4fast import Button, Separator
5
 
6
+ script = Script(
7
+ """
8
+ document.addEventListener("DOMContentLoaded", function () {
9
+ const main = document.querySelector('main');
10
+ const aside = document.querySelector('aside');
11
+ const body = document.body;
12
+
13
+ if (main && aside && main.nextElementSibling === aside) {
14
+ // Main + Aside layout
15
+ body.classList.add('grid-cols-[minmax(0,_4fr)_minmax(0,_1fr)]');
16
+ aside.classList.remove('hidden');
17
+ } else if (main) {
18
+ // Only Main layout (full width)
19
+ body.classList.add('grid-cols-[1fr]');
20
+ }
21
+ });
22
+ """
23
+ )
24
+
25
+ overlay_scrollbars = Script(
26
+ """
27
+ (function () {
28
+ const { OverlayScrollbars } = OverlayScrollbarsGlobal;
29
+
30
+ function getPreferredTheme() {
31
+ return localStorage.theme === 'dark' || (!('theme' in localStorage) && window.matchMedia('(prefers-color-scheme: dark)').matches)
32
+ ? 'dark'
33
+ : 'light';
34
+ }
35
+
36
+ function applyOverlayScrollbars(element, scrollbarTheme) {
37
+ // Destroy existing OverlayScrollbars instance if it exists
38
+ const instance = OverlayScrollbars(element);
39
+ if (instance) {
40
+ instance.destroy();
41
+ }
42
+
43
+ // Reinitialize OverlayScrollbars with the new theme
44
+ OverlayScrollbars(element, {
45
+ scrollbars: {
46
+ theme: scrollbarTheme,
47
+ visibility: 'auto',
48
+ autoHide: 'leave',
49
+ autoHideDelay: 800
50
+ }
51
+ });
52
+ }
53
+
54
+ function updateScrollbarTheme() {
55
+ const isDarkMode = getPreferredTheme() === 'dark';
56
+ const scrollbarTheme = isDarkMode ? 'os-theme-light' : 'os-theme-dark'; // Light theme in dark mode, dark theme in light mode
57
+
58
+ const mainElement = document.querySelector('main');
59
+ const chatMessagesElement = document.querySelector('#chat-messages'); // Select the chat message container by ID
60
+
61
+ if (mainElement) {
62
+ applyOverlayScrollbars(mainElement, scrollbarTheme);
63
+ }
64
+
65
+ if (chatMessagesElement) {
66
+ applyOverlayScrollbars(chatMessagesElement, scrollbarTheme);
67
+ }
68
+ }
69
+
70
+ // Apply the correct theme immediately when the page loads
71
+ updateScrollbarTheme();
72
+
73
+ // Observe changes in the 'dark' class on the <html> element
74
+ const observer = new MutationObserver(updateScrollbarTheme);
75
+ observer.observe(document.documentElement, { attributes: true, attributeFilter: ['class'] });
76
+ })();
77
+ """
78
+ )
79
+
80
 
81
  def Logo():
82
  return Div(
83
+ Img(
84
+ src="https://assets.vespa.ai/logos/vespa-logo-black.svg",
85
+ alt="Vespa Logo",
86
+ cls="h-full dark:hidden",
87
+ ),
88
+ Img(
89
+ src="https://assets.vespa.ai/logos/vespa-logo-white.svg",
90
+ alt="Vespa Logo Dark Mode",
91
+ cls="h-full hidden dark:block",
92
+ ),
93
+ cls="h-[27px]",
94
  )
95
 
96
 
 
119
  ),
120
  Separator(orientation="vertical"),
121
  ThemeToggle(),
122
+ cls="flex items-center space-x-3",
123
  )
124
 
125
 
126
  def Layout(*c, **kwargs):
127
  return (
128
+ Title("Visual Retrieval ColPali"),
129
  Body(
130
  Header(
131
  A(Logo(), href="/"),
132
  Links(),
133
+ cls="min-h-[55px] h-[55px] w-full flex items-center justify-between px-4",
 
 
 
 
134
  ),
135
+ *c,
136
+ **kwargs,
137
+ cls="grid grid-rows-[55px_1fr] min-h-0",
138
  ),
139
+ script,
140
+ overlay_scrollbars,
141
  )
globals.css CHANGED
@@ -183,4 +183,25 @@
183
  width: 100%;
184
  height: 100%;
185
  z-index: 10;
186
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  width: 100%;
184
  height: 100%;
185
  z-index: 10;
186
+ }
187
+
188
+ header {
189
+ grid-column: 1/-1;
190
+ }
191
+
192
+ main {
193
+ overflow: auto;
194
+ }
195
+
196
+ aside {
197
+ overflow: auto;
198
+ }
199
+
200
+ .scroll-container {
201
+ padding-right: 10px;
202
+ }
203
+
204
+ .question-message {
205
+ background-color: #61D790;
206
+ color: #2E2F27;
207
+ }
main.py CHANGED
@@ -1,22 +1,25 @@
1
  import asyncio
 
 
2
  from concurrent.futures import ThreadPoolExecutor
3
  from functools import partial
4
 
5
  from fasthtml.common import *
6
  from shad4fast import *
7
  from vespa.application import Vespa
8
- import time
9
 
 
10
  from backend.colpali import (
11
- get_result_from_query,
12
- get_query_embeddings_and_token_map,
13
  add_sim_maps_to_result,
 
 
14
  is_special_token,
 
15
  )
16
- from backend.vespa_app import get_vespa_app
17
- from backend.cache import LRUCache
18
  from backend.modelmanager import ModelManager
 
19
  from frontend.app import (
 
20
  Home,
21
  Search,
22
  SearchBox,
@@ -25,7 +28,10 @@ from frontend.app import (
25
  SimMapButtonReady,
26
  )
27
  from frontend.layout import Layout
28
- import hashlib
 
 
 
29
 
30
  highlight_js_theme_link = Link(id="highlight-theme", rel="stylesheet", href="")
31
  highlight_js_theme = Script(src="/static/js/highlightjs-theme.js")
@@ -35,15 +41,27 @@ highlight_js = HighlightJS(
35
  light="github",
36
  )
37
 
 
 
 
 
 
 
 
 
 
38
 
39
  app, rt = fast_app(
40
- htmlkw={"cls": "h-full"},
41
  pico=False,
42
  hdrs=(
43
  ShadHead(tw_cdn=False, theme_handle=True),
44
  highlight_js,
45
  highlight_js_theme_link,
46
  highlight_js_theme,
 
 
 
47
  ),
48
  )
49
  vespa_app: Vespa = get_vespa_app()
@@ -53,6 +71,16 @@ task_cache = LRUCache(
53
  max_size=1000
54
  ) # Map from query_id to boolean value - False if not all results are ready.
55
  thread_pool = ThreadPoolExecutor()
 
 
 
 
 
 
 
 
 
 
56
 
57
 
58
  @app.on_event("startup")
@@ -72,7 +100,7 @@ def serve_static(filepath: str):
72
 
73
  @rt("/")
74
  def get():
75
- return Layout(Home())
76
 
77
 
78
  @rt("/search")
@@ -86,16 +114,18 @@ def get(request):
86
  if not query_value:
87
  # Show SearchBox and a message for missing query
88
  return Layout(
89
- Div(
90
- SearchBox(query_value=query_value, ranking_value=ranking_value),
91
  Div(
92
- P(
93
- "No query provided. Please enter a query.",
94
- cls="text-center text-muted-foreground",
 
 
 
 
95
  ),
96
- cls="p-10",
97
- ),
98
- cls="grid",
99
  )
100
  )
101
  # Generate a unique query_id based on the query and ranking value
@@ -107,7 +137,12 @@ def get(request):
107
  # search_results = get_results_children(result)
108
  # return Layout(Search(request, search_results))
109
  # Show the loading message if a query is provided
110
- return Layout(Search(request)) # Show SearchBox and Loading message initially
 
 
 
 
 
111
 
112
 
113
  @rt("/fetch_results")
@@ -215,15 +250,67 @@ async def get_sim_map(query_id: str, idx: int, token: str):
215
  sim_map_b64 = search_results[idx]["fields"].get(sim_map_key, None)
216
  if sim_map_b64 is None:
217
  return SimMapButtonPoll(query_id=query_id, idx=idx, token=token)
218
- sim_map_img_src = f"data:image/jpeg;base64,{sim_map_b64}"
219
  return SimMapButtonReady(
220
  query_id=query_id, idx=idx, token=token, img_src=sim_map_img_src
221
  )
222
 
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  @rt("/app")
225
  def get():
226
- return Layout(Div(P(f"Connected to Vespa at {vespa_app.url}"), cls="p-4"))
227
 
228
 
229
  if __name__ == "__main__":
 
1
  import asyncio
2
+ import hashlib
3
+ import time
4
  from concurrent.futures import ThreadPoolExecutor
5
  from functools import partial
6
 
7
  from fasthtml.common import *
8
  from shad4fast import *
9
  from vespa.application import Vespa
 
10
 
11
+ from backend.cache import LRUCache
12
  from backend.colpali import (
 
 
13
  add_sim_maps_to_result,
14
+ get_query_embeddings_and_token_map,
15
+ get_result_from_query,
16
  is_special_token,
17
+ get_full_image_from_vespa,
18
  )
 
 
19
  from backend.modelmanager import ModelManager
20
+ from backend.vespa_app import get_vespa_app
21
  from frontend.app import (
22
+ ChatResult,
23
  Home,
24
  Search,
25
  SearchBox,
 
28
  SimMapButtonReady,
29
  )
30
  from frontend.layout import Layout
31
+ import google.generativeai as genai
32
+ from PIL import Image
33
+ import io
34
+ import base64
35
 
36
  highlight_js_theme_link = Link(id="highlight-theme", rel="stylesheet", href="")
37
  highlight_js_theme = Script(src="/static/js/highlightjs-theme.js")
 
41
  light="github",
42
  )
43
 
44
+ overlayscrollbars_link = Link(
45
+ rel="stylesheet",
46
+ href="https://cdnjs.cloudflare.com/ajax/libs/overlayscrollbars/2.10.0/styles/overlayscrollbars.min.css",
47
+ type="text/css",
48
+ )
49
+ overlayscrollbars_js = Script(
50
+ src="https://cdnjs.cloudflare.com/ajax/libs/overlayscrollbars/2.10.0/browser/overlayscrollbars.browser.es5.min.js"
51
+ )
52
+ sselink = Script(src="https://unpkg.com/htmx-ext-sse@2.2.1/sse.js")
53
 
54
  app, rt = fast_app(
55
+ htmlkw={"cls": "grid h-full"},
56
  pico=False,
57
  hdrs=(
58
  ShadHead(tw_cdn=False, theme_handle=True),
59
  highlight_js,
60
  highlight_js_theme_link,
61
  highlight_js_theme,
62
+ overlayscrollbars_link,
63
+ overlayscrollbars_js,
64
+ sselink,
65
  ),
66
  )
67
  vespa_app: Vespa = get_vespa_app()
 
71
  max_size=1000
72
  ) # Map from query_id to boolean value - False if not all results are ready.
73
  thread_pool = ThreadPoolExecutor()
74
+ # Gemini config
75
+
76
+ genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
77
+ GEMINI_SYSTEM_PROMPT = """If the user query is a question, try your best to answer it based on the provided images.
78
+ If the user query is not an obvious question, reply with 'No question detected.'. Your response should be HTML formatted.
79
+ This means that newlines will be replaced with <br> tags, bold text will be enclosed in <b> tags, and so on.
80
+ """
81
+ gemini_model = genai.GenerativeModel(
82
+ "gemini-1.5-flash-8b", system_instruction=GEMINI_SYSTEM_PROMPT
83
+ )
84
 
85
 
86
  @app.on_event("startup")
 
100
 
101
  @rt("/")
102
  def get():
103
+ return Layout(Main(Home()))
104
 
105
 
106
  @rt("/search")
 
114
  if not query_value:
115
  # Show SearchBox and a message for missing query
116
  return Layout(
117
+ Main(
 
118
  Div(
119
+ SearchBox(query_value=query_value, ranking_value=ranking_value),
120
+ Div(
121
+ P(
122
+ "No query provided. Please enter a query.",
123
+ cls="text-center text-muted-foreground",
124
+ ),
125
+ cls="p-10",
126
  ),
127
+ cls="grid",
128
+ )
 
129
  )
130
  )
131
  # Generate a unique query_id based on the query and ranking value
 
137
  # search_results = get_results_children(result)
138
  # return Layout(Search(request, search_results))
139
  # Show the loading message if a query is provided
140
+ return Layout(
141
+ Main(Search(request), data_overlayscrollbars_initialize=True, cls="border-t"),
142
+ Aside(
143
+ ChatResult(query_id=query_id, query=query_value), cls="border-t border-l"
144
+ ),
145
+ ) # Show SearchBox and Loading message initially
146
 
147
 
148
  @rt("/fetch_results")
 
250
  sim_map_b64 = search_results[idx]["fields"].get(sim_map_key, None)
251
  if sim_map_b64 is None:
252
  return SimMapButtonPoll(query_id=query_id, idx=idx, token=token)
253
+ sim_map_img_src = f"data:image/png;base64,{sim_map_b64}"
254
  return SimMapButtonReady(
255
  query_id=query_id, idx=idx, token=token, img_src=sim_map_img_src
256
  )
257
 
258
 
259
+ @app.get("/full_image")
260
+ async def full_image(id: str):
261
+ """
262
+ Endpoint to get the full quality image for a given result id.
263
+ """
264
+ image_data = await get_full_image_from_vespa(vespa_app, id)
265
+
266
+ # Decode the base64 image data
267
+ # image_data = base64.b64decode(image_data)
268
+ image_data = "data:image/jpeg;base64," + image_data
269
+
270
+ return Img(
271
+ src=image_data,
272
+ alt="something",
273
+ cls="result-image w-full h-full object-contain",
274
+ )
275
+
276
+
277
+ async def message_generator(query_id: str, query: str):
278
+ result = None
279
+ while result is None:
280
+ result = result_cache.get(query_id)
281
+ await asyncio.sleep(0.5)
282
+ search_results = get_results_children(result)
283
+ images = [result["fields"]["blur_image"] for result in search_results]
284
+ # from b64 to PIL image
285
+ images = [Image.open(io.BytesIO(base64.b64decode(img))) for img in images]
286
+
287
+ # If newlines are present in the response, the connection will be closed.
288
+ def replace_newline_with_br(text):
289
+ return text.replace("\n", "<br>")
290
+
291
+ response_text = ""
292
+ async for chunk in await gemini_model.generate_content_async(
293
+ images + ["\n\n Query: ", query], stream=True
294
+ ):
295
+ if chunk.text:
296
+ response_text += chunk.text
297
+ response_text = replace_newline_with_br(response_text)
298
+ yield f"event: message\ndata: {response_text}\n\n"
299
+ await asyncio.sleep(0.5)
300
+ yield "event: close\ndata: \n\n"
301
+
302
+
303
+ @app.get("/get-message")
304
+ async def get_message(query_id: str, query: str):
305
+ return StreamingResponse(
306
+ message_generator(query_id=query_id, query=query),
307
+ media_type="text/event-stream",
308
+ )
309
+
310
+
311
  @rt("/app")
312
  def get():
313
+ return Layout(Main(Div(P(f"Connected to Vespa at {vespa_app.url}"), cls="p-4")))
314
 
315
 
316
  if __name__ == "__main__":
output.css CHANGED
@@ -927,6 +927,10 @@ body {
927
  max-height: 100vh;
928
  }
929
 
 
 
 
 
930
  .min-h-\[55px\] {
931
  min-height: 55px;
932
  }
@@ -1096,6 +1100,22 @@ body {
1096
  grid-template-columns: repeat(2, minmax(0, 1fr));
1097
  }
1098
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1099
  .flex-col {
1100
  flex-direction: column;
1101
  }
@@ -1112,10 +1132,18 @@ body {
1112
  align-content: flex-start;
1113
  }
1114
 
 
 
 
 
1115
  .items-center {
1116
  align-items: center;
1117
  }
1118
 
 
 
 
 
1119
  .justify-center {
1120
  justify-content: center;
1121
  }
@@ -1136,6 +1164,10 @@ body {
1136
  gap: 0.5rem;
1137
  }
1138
 
 
 
 
 
1139
  .gap-4 {
1140
  gap: 1rem;
1141
  }
@@ -1200,6 +1232,10 @@ body {
1200
  margin-bottom: calc(0.5rem * var(--tw-space-y-reverse));
1201
  }
1202
 
 
 
 
 
1203
  .self-stretch {
1204
  align-self: stretch;
1205
  }
@@ -1252,6 +1288,11 @@ body {
1252
  border-width: 2px;
1253
  }
1254
 
 
 
 
 
 
1255
  .border-b {
1256
  border-bottom-width: 1px;
1257
  }
@@ -1493,6 +1534,10 @@ body {
1493
  padding-top: 1rem;
1494
  }
1495
 
 
 
 
 
1496
  .text-left {
1497
  text-align: left;
1498
  }
@@ -1577,6 +1622,11 @@ body {
1577
  letter-spacing: 0.025em;
1578
  }
1579
 
 
 
 
 
 
1580
  .text-card-foreground {
1581
  color: hsl(var(--card-foreground));
1582
  }
@@ -1993,6 +2043,27 @@ body {
1993
  z-index: 10;
1994
  }
1995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1996
  :root:has(.data-\[state\=open\]\:no-bg-scroll[data-state="open"]) {
1997
  overflow: hidden;
1998
  }
@@ -2537,6 +2608,11 @@ body {
2537
  --tw-gradient-to: #d1d5db var(--tw-gradient-to-position);
2538
  }
2539
 
 
 
 
 
 
2540
  .dark\:hover\:border-white:hover:where(.dark, .dark *) {
2541
  --tw-border-opacity: 1;
2542
  border-color: rgb(255 255 255 / var(--tw-border-opacity));
@@ -2610,4 +2686,4 @@ body {
2610
 
2611
  .\[\&_tr\]\:border-b tr {
2612
  border-bottom-width: 1px;
2613
- }
 
927
  max-height: 100vh;
928
  }
929
 
930
+ .min-h-0 {
931
+ min-height: 0px;
932
+ }
933
+
934
  .min-h-\[55px\] {
935
  min-height: 55px;
936
  }
 
1100
  grid-template-columns: repeat(2, minmax(0, 1fr));
1101
  }
1102
 
1103
+ .grid-cols-\[1fr\] {
1104
+ grid-template-columns: 1fr;
1105
+ }
1106
+
1107
+ .grid-cols-\[minmax\(0\2c _4fr\)_minmax\(0\2c _1fr\)\] {
1108
+ grid-template-columns: minmax(0, 4fr) minmax(0, 1fr);
1109
+ }
1110
+
1111
+ .grid-rows-\[55px_1fr\] {
1112
+ grid-template-rows: 55px 1fr;
1113
+ }
1114
+
1115
+ .grid-rows-\[auto_1fr_auto\] {
1116
+ grid-template-rows: auto 1fr auto;
1117
+ }
1118
+
1119
  .flex-col {
1120
  flex-direction: column;
1121
  }
 
1132
  align-content: flex-start;
1133
  }
1134
 
1135
+ .items-end {
1136
+ align-items: flex-end;
1137
+ }
1138
+
1139
  .items-center {
1140
  align-items: center;
1141
  }
1142
 
1143
+ .justify-end {
1144
+ justify-content: flex-end;
1145
+ }
1146
+
1147
  .justify-center {
1148
  justify-content: center;
1149
  }
 
1164
  gap: 0.5rem;
1165
  }
1166
 
1167
+ .gap-3 {
1168
+ gap: 0.75rem;
1169
+ }
1170
+
1171
  .gap-4 {
1172
  gap: 1rem;
1173
  }
 
1232
  margin-bottom: calc(0.5rem * var(--tw-space-y-reverse));
1233
  }
1234
 
1235
+ .self-end {
1236
+ align-self: flex-end;
1237
+ }
1238
+
1239
  .self-stretch {
1240
  align-self: stretch;
1241
  }
 
1288
  border-width: 2px;
1289
  }
1290
 
1291
+ .border-x {
1292
+ border-left-width: 1px;
1293
+ border-right-width: 1px;
1294
+ }
1295
+
1296
  .border-b {
1297
  border-bottom-width: 1px;
1298
  }
 
1534
  padding-top: 1rem;
1535
  }
1536
 
1537
+ .pr-3 {
1538
+ padding-right: 0.75rem;
1539
+ }
1540
+
1541
  .text-left {
1542
  text-align: left;
1543
  }
 
1622
  letter-spacing: 0.025em;
1623
  }
1624
 
1625
+ .text-black {
1626
+ --tw-text-opacity: 1;
1627
+ color: rgb(0 0 0 / var(--tw-text-opacity));
1628
+ }
1629
+
1630
  .text-card-foreground {
1631
  color: hsl(var(--card-foreground));
1632
  }
 
2043
  z-index: 10;
2044
  }
2045
 
2046
+ header {
2047
+ grid-column: 1/-1;
2048
+ }
2049
+
2050
+ main {
2051
+ overflow: auto;
2052
+ }
2053
+
2054
+ aside {
2055
+ overflow: auto;
2056
+ }
2057
+
2058
+ .scroll-container {
2059
+ padding-right: 10px;
2060
+ }
2061
+
2062
+ .question-message {
2063
+ background-color: #61D790;
2064
+ color: #2E2F27;
2065
+ }
2066
+
2067
  :root:has(.data-\[state\=open\]\:no-bg-scroll[data-state="open"]) {
2068
  overflow: hidden;
2069
  }
 
2608
  --tw-gradient-to: #d1d5db var(--tw-gradient-to-position);
2609
  }
2610
 
2611
+ .dark\:text-white:where(.dark, .dark *) {
2612
+ --tw-text-opacity: 1;
2613
+ color: rgb(255 255 255 / var(--tw-text-opacity));
2614
+ }
2615
+
2616
  .dark\:hover\:border-white:hover:where(.dark, .dark *) {
2617
  --tw-border-opacity: 1;
2618
  border-color: rgb(255 255 255 / var(--tw-border-opacity));
 
2686
 
2687
  .\[\&_tr\]\:border-b tr {
2688
  border-bottom-width: 1px;
2689
+ }
prepare_feed_deploy.py ADDED
@@ -0,0 +1,977 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # Visual PDF Retrieval - demo application
3
+ #
4
+ # In this notebook, we will prepare the Vespa backend application for our visual retrieval demo.
5
+ # We will use ColPali as the model to extract patch vectors from images of pdf pages.
6
+ # At query time, we use MaxSim to retrieve and/or (based on the configuration) rank the page results.
7
+ #
8
+ # To see the application in action, visit TODO:
9
+ #
10
+ # The web application is written in FastHTML, meaning the complete application is written in python.
11
+ #
12
+ # The steps we will take in this notebook are:
13
+ #
14
+ # 0. Setup and configuration
15
+ # 1. Download the data
16
+ # 2. Prepare the data
17
+ # 3. Generate queries for evaluation and typeahead search suggestions
18
+ # 4. Deploy the Vespa application
19
+ # 5. Create the Vespa application
20
+ # 6. Feed the data to the Vespa application
21
+ #
22
+ # All the steps that are needed to provision the Vespa application, including feeding the data, can be done from this notebook.
23
+ # We have tried to make it easy for others to run this notebook, to create your own PDF Enterprise Search application using Vespa.
24
+ #
25
+
26
+ # %% [markdown]
27
+ # ## 0. Setup and Configuration
28
+ #
29
+
30
+ # %%
31
+ import os
32
+ import asyncio
33
+ import json
34
+ from typing import Tuple
35
+ import hashlib
36
+ import numpy as np
37
+
38
+ # Vespa
39
+ from vespa.package import (
40
+ ApplicationPackage,
41
+ Field,
42
+ Schema,
43
+ Document,
44
+ HNSW,
45
+ RankProfile,
46
+ Function,
47
+ FieldSet,
48
+ SecondPhaseRanking,
49
+ Summary,
50
+ DocumentSummary,
51
+ )
52
+ from vespa.deployment import VespaCloud
53
+ from vespa.application import Vespa
54
+ from vespa.io import VespaResponse
55
+
56
+ # Google Generative AI
57
+ import google.generativeai as genai
58
+
59
+ # Torch and other ML libraries
60
+ import torch
61
+ from torch.utils.data import DataLoader
62
+ from tqdm import tqdm
63
+ from pdf2image import convert_from_path
64
+ from pypdf import PdfReader
65
+
66
+ # ColPali model and processor
67
+ from colpali_engine.models import ColPali, ColPaliProcessor
68
+ from colpali_engine.utils.torch_utils import get_torch_device
69
+ from vidore_benchmark.utils.image_utils import scale_image, get_base64_image
70
+
71
+ # Other utilities
72
+ from bs4 import BeautifulSoup
73
+ import httpx
74
+ from urllib.parse import urljoin, urlparse
75
+
76
+ # Load environment variables
77
+ from dotenv import load_dotenv
78
+
79
+ load_dotenv()
80
+
81
+ # Avoid warning from huggingface tokenizers
82
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
83
+
84
+ # %% [markdown]
85
+ # ### Create a free trial in Vespa Cloud
86
+ #
87
+ # Create a tenant from [here](https://vespa.ai/free-trial/).
88
+ # The trial includes $300 credit.
89
+ # Take note of your tenant name.
90
+ #
91
+
92
+ # %%
93
+ VESPA_TENANT_NAME = "vespa-team"
94
+
95
+ # %% [markdown]
96
+ # Here, set your desired application name. (Will be created in later steps)
97
+ # Note that you can not have hyphen `-` or underscore `_` in the application name.
98
+ #
99
+
100
+ # %%
101
+ VESPA_APPLICATION_NAME = "colpalidemo2"
102
+ VESPA_SCHEMA_NAME = "pdf_page"
103
+
104
+ # %% [markdown]
105
+ # Next, you need to create some tokens for feeding data, and querying the application.
106
+ # We recommend separate tokens for feeding and querying, (the former with write permission, and the latter with read permission).
107
+ # The tokens can be created from the [Vespa Cloud console](https://console.vespa-cloud.com/) in the 'Account' -> 'Tokens' section.
108
+ #
109
+
110
+ # %%
111
+ VESPA_TOKEN_ID_WRITE = "colpalidemo_write"
112
+ VESPA_TOKEN_ID_READ = "colpalidemo_read"
113
+
114
+ # %% [markdown]
115
+ # We also need to set the value of the write token to be able to feed data to the Vespa application.
116
+ #
117
+
118
+ # %%
119
+ VESPA_CLOUD_SECRET_TOKEN = os.getenv("VESPA_CLOUD_SECRET_TOKEN") or input(
120
+ "Enter Vespa cloud secret token: "
121
+ )
122
+
123
+ # %% [markdown]
124
+ # We will also use the Gemini API to create sample queries for our images.
125
+ # You can also use other VLM's to create these queries.
126
+ # Create a Gemini API key from [here](https://aistudio.google.com/app/apikey).
127
+ #
128
+
129
+ # %%
130
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") or input(
131
+ "Enter Google Generative AI API key: "
132
+ )
133
+
134
+ # %%
135
+ MODEL_NAME = "vidore/colpali-v1.2"
136
+
137
+ # Configure Google Generative AI
138
+ genai.configure(api_key=GEMINI_API_KEY)
139
+
140
+ # Set device for Torch
141
+ device = get_torch_device("auto")
142
+ print(f"Using device: {device}")
143
+
144
+ # Load the ColPali model and processor
145
+ model = ColPali.from_pretrained(
146
+ MODEL_NAME,
147
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
148
+ device_map=device,
149
+ ).eval()
150
+
151
+ processor = ColPaliProcessor.from_pretrained(MODEL_NAME)
152
+
153
+ # %% [markdown]
154
+ # ## 1. Download PDFs
155
+ #
156
+ # We are going to use public reports from the Norwegian Government Pension Fund Global (also known as the Oil Fund).
157
+ # The fund puts transparency at the forefront and publishes reports on its investments, holdings, and returns, as well as its strategy and governance.
158
+ #
159
+ # These reports are the ones we are going to use for this showcase.
160
+ # Here are some sample images:
161
+ #
162
+ # ![Sample1](./static/img/gfpg-sample-1.png)
163
+ # ![Sample2](./static/img/gfpg-sample-2.png)
164
+ #
165
+
166
+ # %% [markdown]
167
+ # As we can see, a lot of the information is in the form of tables, charts and numbers.
168
+ # These are not easily extractable using pdf-readers or OCR tools.
169
+ #
170
+
171
+ # %%
172
+ import requests
173
+
174
+ url = "https://www.nbim.no/en/publications/reports/"
175
+ response = requests.get(url)
176
+ response.raise_for_status()
177
+ html_content = response.text
178
+
179
+ # Parse with BeautifulSoup
180
+ soup = BeautifulSoup(html_content, "html.parser")
181
+
182
+ links = []
183
+
184
+ # Find all <a> elements with the specific classes
185
+ for a_tag in soup.find_all("a", href=True):
186
+ classes = a_tag.get("class", [])
187
+ if "button" in classes and "button--download-secondary" in classes:
188
+ href = a_tag["href"]
189
+ full_url = urljoin(url, href)
190
+ links.append(full_url)
191
+
192
+ links
193
+
194
+ # %%
195
+ # Limit the number of PDFs to download
196
+ NUM_PDFS = 2 # Set to None to download all PDFs
197
+ links = links[:NUM_PDFS] if NUM_PDFS else links
198
+ links
199
+
200
+ # %%
201
+ from nest_asyncio import apply
202
+ from typing import List
203
+
204
+ apply()
205
+
206
+ max_attempts = 3
207
+
208
+
209
+ async def download_pdf(session, url, filename):
210
+ attempt = 0
211
+ while attempt < max_attempts:
212
+ try:
213
+ response = await session.get(url)
214
+ response.raise_for_status()
215
+
216
+ # Use Content-Disposition header to get the filename if available
217
+ content_disposition = response.headers.get("Content-Disposition")
218
+ if content_disposition:
219
+ import re
220
+
221
+ fname = re.findall('filename="(.+)"', content_disposition)
222
+ if fname:
223
+ filename = fname[0]
224
+
225
+ # Ensure the filename is safe to use on the filesystem
226
+ safe_filename = filename.replace("/", "_").replace("\\", "_")
227
+ if not safe_filename or safe_filename == "_":
228
+ print(f"Invalid filename: {filename}")
229
+ continue
230
+
231
+ filepath = os.path.join("pdfs", safe_filename)
232
+ with open(filepath, "wb") as f:
233
+ f.write(response.content)
234
+ print(f"Downloaded {safe_filename}")
235
+ return filepath
236
+ except Exception as e:
237
+ print(f"Error downloading {filename}: {e}")
238
+ print(f"Retrying ({attempt})...")
239
+ await asyncio.sleep(1) # Wait a bit before retrying
240
+ attempt += 1
241
+ return None
242
+
243
+
244
+ async def download_pdfs(links: List[str]) -> List[dict]:
245
+ """Download PDFs from a list of URLs. Add the filename to the dictionary."""
246
+ async with httpx.AsyncClient() as client:
247
+ tasks = []
248
+
249
+ for idx, link in enumerate(links):
250
+ # Try to get the filename from the URL
251
+ path = urlparse(link).path
252
+ filename = os.path.basename(path)
253
+
254
+ # If filename is empty,skip
255
+ if not filename:
256
+ continue
257
+ tasks.append(download_pdf(client, link, filename))
258
+
259
+ # Run the tasks concurrently
260
+ paths = await asyncio.gather(*tasks)
261
+ pdf_files = [
262
+ {"url": link, "path": path} for link, path in zip(links, paths) if path
263
+ ]
264
+ return pdf_files
265
+
266
+
267
+ # Create the pdfs directory if it doesn't exist
268
+ os.makedirs("pdfs", exist_ok=True)
269
+ # Now run the download_pdfs function with the URL
270
+ pdfs = asyncio.run(download_pdfs(links))
271
+
272
+ # %%
273
+ pdfs
274
+
275
+ # %% [markdown]
276
+ # ## 2. Convert PDFs to Images
277
+ #
278
+
279
+
280
+ # %%
281
+ def get_pdf_images(pdf_path):
282
+ reader = PdfReader(pdf_path)
283
+ page_texts = []
284
+ for page_number in range(len(reader.pages)):
285
+ page = reader.pages[page_number]
286
+ text = page.extract_text()
287
+ page_texts.append(text)
288
+ images = convert_from_path(pdf_path)
289
+ # Convert to PIL images
290
+ assert len(images) == len(page_texts)
291
+ return images, page_texts
292
+
293
+
294
+ pdf_folder = "pdfs"
295
+ pdf_pages = []
296
+ for pdf in tqdm(pdfs):
297
+ pdf_file = pdf["path"]
298
+ title = os.path.splitext(os.path.basename(pdf_file))[0]
299
+ images, texts = get_pdf_images(pdf_file)
300
+ for page_no, (image, text) in enumerate(zip(images, texts)):
301
+ pdf_pages.append(
302
+ {
303
+ "title": title,
304
+ "url": pdf["url"],
305
+ "path": pdf_file,
306
+ "image": image,
307
+ "text": text,
308
+ "page_no": page_no,
309
+ }
310
+ )
311
+
312
+ # %%
313
+ len(pdf_pages)
314
+
315
+ # %%
316
+ from collections import Counter
317
+
318
+ # Print the length of the text fields - mean, max and min
319
+ text_lengths = [len(page["text"]) for page in pdf_pages]
320
+ print(f"Mean text length: {np.mean(text_lengths)}")
321
+ print(f"Max text length: {np.max(text_lengths)}")
322
+ print(f"Min text length: {np.min(text_lengths)}")
323
+ print(f"Median text length: {np.median(text_lengths)}")
324
+ print(f"Number of text with length == 0: {Counter(text_lengths)[0]}")
325
+
326
+ # %% [markdown]
327
+ # ## 3. Generate Queries
328
+ #
329
+ # In this step, we want to generate queries for each page image.
330
+ # These will be useful for 2 reasons:
331
+ #
332
+ # 1. We can use these queries as typeahead suggestions in the search bar.
333
+ # 2. We can use the queries to generate an evaluation dataset. See [Improving Retrieval with LLM-as-a-judge](https://blog.vespa.ai/improving-retrieval-with-llm-as-a-judge/) for a deeper dive into this topic.
334
+ #
335
+ # The prompt for generating queries is taken from [this](https://danielvanstrien.xyz/posts/post-with-code/colpali/2024-09-23-generate_colpali_dataset.html#an-update-retrieval-focused-prompt) wonderful blog post by Daniel van Strien.
336
+ #
337
+ # We will use the Gemini API to generate these queries, with `gemini-1.5-flash-8b` as the model.
338
+ #
339
+
340
+ # %%
341
+ from pydantic import BaseModel
342
+
343
+
344
+ class GeneratedQueries(BaseModel):
345
+ broad_topical_question: str
346
+ broad_topical_query: str
347
+ specific_detail_question: str
348
+ specific_detail_query: str
349
+ visual_element_question: str
350
+ visual_element_query: str
351
+
352
+
353
+ def get_retrieval_prompt() -> Tuple[str, GeneratedQueries]:
354
+ prompt = (
355
+ prompt
356
+ ) = """You are an investor, stock analyst and financial expert. You will be presented an image of a document page from a report published by the Norwegian Government Pension Fund Global (GPFG). The report may be annual or quarterly reports, or policy reports, on topics such as responsible investment, risk etc.
357
+ Your task is to generate retrieval queries and questions that you would use to retrieve this document (or ask based on this document) in a large corpus.
358
+ Please generate 3 different types of retrieval queries and questions.
359
+ A retrieval query is a keyword based query, made up of 2-5 words, that you would type into a search engine to find this document.
360
+ A question is a natural language question that you would ask, for which the document contains the answer.
361
+ The queries should be of the following types:
362
+ 1. A broad topical query: This should cover the main subject of the document.
363
+ 2. A specific detail query: This should cover a specific detail or aspect of the document.
364
+ 3. A visual element query: This should cover a visual element of the document, such as a chart, graph, or image.
365
+
366
+ Important guidelines:
367
+ - Ensure the queries are relevant for retrieval tasks, not just describing the page content.
368
+ - Use a fact-based natural language style for the questions.
369
+ - Frame the queries as if someone is searching for this document in a large corpus.
370
+ - Make the queries diverse and representative of different search strategies.
371
+
372
+ Format your response as a JSON object with the structure of the following example:
373
+ {
374
+ "broad_topical_question": "What was the Responsible Investment Policy in 2019?",
375
+ "broad_topical_query": "responsible investment policy 2019",
376
+ "specific_detail_question": "What is the percentage of investments in renewable energy?",
377
+ "specific_detail_query": "renewable energy investments percentage",
378
+ "visual_element_question": "What is the trend of total holding value over time?",
379
+ "visual_element_query": "total holding value trend"
380
+ }
381
+
382
+ If there are no relevant visual elements, provide an empty string for the visual element question and query.
383
+ Here is the document image to analyze:
384
+ Generate the queries based on this image and provide the response in the specified JSON format.
385
+ Only return JSON. Don't return any extra explanation text. """
386
+
387
+ return prompt, GeneratedQueries
388
+
389
+
390
+ prompt_text, pydantic_model = get_retrieval_prompt()
391
+
392
+ # %%
393
+ gemini_model = genai.GenerativeModel("gemini-1.5-flash-8b")
394
+
395
+
396
+ def generate_queries(image, prompt_text, pydantic_model):
397
+ try:
398
+ response = gemini_model.generate_content(
399
+ [image, "\n\n", prompt_text],
400
+ generation_config=genai.GenerationConfig(
401
+ response_mime_type="application/json",
402
+ response_schema=pydantic_model,
403
+ ),
404
+ )
405
+ queries = json.loads(response.text)
406
+ except Exception as _e:
407
+ queries = {
408
+ "broad_topical_question": "",
409
+ "broad_topical_query": "",
410
+ "specific_detail_question": "",
411
+ "specific_detail_query": "",
412
+ "visual_element_question": "",
413
+ "visual_element_query": "",
414
+ }
415
+ return queries
416
+
417
+
418
+ # %%
419
+ for pdf in tqdm(pdf_pages):
420
+ image = pdf.get("image")
421
+ pdf["queries"] = generate_queries(image, prompt_text, pydantic_model)
422
+
423
+ # %%
424
+ pdf_pages[46]["image"]
425
+
426
+ # %%
427
+ pdf_pages[46]["queries"]
428
+
429
+ # %%
430
+ # Generate queries async - keeping for now as we probably need when applying to the full dataset
431
+ # import asyncio
432
+ # from tenacity import retry, stop_after_attempt, wait_exponential
433
+ # import google.generativeai as genai
434
+ # from tqdm.asyncio import tqdm_asyncio
435
+
436
+ # max_in_flight = 200 # Maximum number of concurrent requests
437
+
438
+
439
+ # async def generate_queries_for_image_async(model, image, semaphore):
440
+ # @retry(stop=stop_after_attempt(3), wait=wait_exponential(), reraise=True)
441
+ # async def _generate():
442
+ # async with semaphore:
443
+ # result = await model.generate_content_async(
444
+ # [image, "\n\n", prompt_text],
445
+ # generation_config=genai.GenerationConfig(
446
+ # response_mime_type="application/json",
447
+ # response_schema=pydantic_model,
448
+ # ),
449
+ # )
450
+ # return json.loads(result.text)
451
+
452
+ # try:
453
+ # return await _generate()
454
+ # except Exception as e:
455
+ # print(f"Error generating queries for image: {e}")
456
+ # return None # Return None or handle as needed
457
+
458
+
459
+ # async def enrich_pdfs():
460
+ # gemini_model = genai.GenerativeModel("gemini-1.5-flash-8b")
461
+ # semaphore = asyncio.Semaphore(max_in_flight)
462
+ # tasks = []
463
+ # for pdf in pdf_pages:
464
+ # pdf["queries"] = []
465
+ # image = pdf.get("image")
466
+ # if image:
467
+ # task = generate_queries_for_image_async(gemini_model, image, semaphore)
468
+ # tasks.append((pdf, task))
469
+
470
+ # # Run the tasks concurrently using asyncio.gather()
471
+ # for pdf, task in tqdm_asyncio(tasks):
472
+ # result = await task
473
+ # if result:
474
+ # pdf["queries"] = result
475
+ # return pdf_pages
476
+
477
+
478
+ # pdf_pages = asyncio.run(enrich_pdfs())
479
+
480
+ # %%
481
+ # write title, url, page_no, text, queries, not image to JSON
482
+ with open("output/pdf_pages.json", "w") as f:
483
+ to_write = [{k: v for k, v in pdf.items() if k != "image"} for pdf in pdf_pages]
484
+ json.dump(to_write, f, indent=2)
485
+
486
+ # with open("pdfs/pdf_pages.json", "r") as f:
487
+ # saved_pdf_pages = json.load(f)
488
+ # for pdf, saved_pdf in zip(pdf_pages, saved_pdf_pages):
489
+ # pdf.update(saved_pdf)
490
+
491
+ # %% [markdown]
492
+ # ## 4. Generate embeddings
493
+ #
494
+ # Now that we have the queries, we can use the ColPali model to generate embeddings for each page image.
495
+ #
496
+
497
+
498
+ # %%
499
+ def generate_embeddings(images, model, processor, batch_size=2) -> np.ndarray:
500
+ """
501
+ Generate embeddings for a list of images.
502
+ Move to CPU only once per batch.
503
+
504
+ Args:
505
+ images (List[PIL.Image]): List of PIL images.
506
+ model (nn.Module): The model to generate embeddings.
507
+ processor: The processor to preprocess images.
508
+ batch_size (int, optional): Batch size for processing. Defaults to 64.
509
+
510
+ Returns:
511
+ np.ndarray: Embeddings for the images, shape
512
+ (len(images), processor.max_patch_length (1030 for ColPali), model.config.hidden_size (Patch embedding dimension - 128 for ColPali)).
513
+ """
514
+ embeddings_list = []
515
+
516
+ def collate_fn(batch):
517
+ # Batch is a list of images
518
+ return processor.process_images(batch) # Should return a dict of tensors
519
+
520
+ dataloader = DataLoader(
521
+ images,
522
+ shuffle=False,
523
+ collate_fn=collate_fn,
524
+ )
525
+
526
+ for batch_doc in tqdm(dataloader, desc="Generating embeddings"):
527
+ with torch.no_grad():
528
+ # Move batch to the device
529
+ batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
530
+ embeddings_batch = model(**batch_doc)
531
+ embeddings_list.append(torch.unbind(embeddings_batch.to("cpu"), dim=0))
532
+ # Concatenate all embeddings and create a numpy array
533
+ all_embeddings = np.concatenate(embeddings_list, axis=0)
534
+ return all_embeddings
535
+
536
+
537
+ # %%
538
+ # Generate embeddings for all images
539
+ images = [pdf["image"] for pdf in pdf_pages]
540
+ embeddings = generate_embeddings(images, model, processor)
541
+
542
+ # %%
543
+ embeddings.shape
544
+
545
+ # %% [markdown]
546
+ # ## 5. Prepare Data on Vespa Format
547
+ #
548
+ # Now, that we have all the data we need, all that remains is to make sure it is in the right format for Vespa.
549
+ #
550
+
551
+
552
+ # %%
553
+ def float_to_binary_embedding(float_query_embedding: dict) -> dict:
554
+ """Utility function to convert float query embeddings to binary query embeddings."""
555
+ binary_query_embeddings = {}
556
+ for k, v in float_query_embedding.items():
557
+ binary_vector = (
558
+ np.packbits(np.where(np.array(v) > 0, 1, 0)).astype(np.int8).tolist()
559
+ )
560
+ binary_query_embeddings[k] = binary_vector
561
+ return binary_query_embeddings
562
+
563
+
564
+ # %%
565
+ vespa_feed = []
566
+ for pdf, embedding in zip(pdf_pages, embeddings):
567
+ url = pdf["url"]
568
+ title = pdf["title"]
569
+ image = pdf["image"]
570
+ text = pdf.get("text", "")
571
+ page_no = pdf["page_no"]
572
+ query_dict = pdf["queries"]
573
+ questions = [v for k, v in query_dict.items() if "question" in k and v]
574
+ queries = [v for k, v in query_dict.items() if "query" in k and v]
575
+ base_64_image = get_base64_image(
576
+ scale_image(image, 32), add_url_prefix=False
577
+ ) # Scaled down image to return fast on search (~1kb)
578
+ base_64_full_image = get_base64_image(image, add_url_prefix=False)
579
+ embedding_dict = {k: v for k, v in enumerate(embedding)}
580
+ binary_embedding = float_to_binary_embedding(embedding_dict)
581
+ # id_hash should be md5 hash of url and page_number
582
+ id_hash = hashlib.md5(f"{url}_{page_no}".encode()).hexdigest()
583
+ page = {
584
+ "id": id_hash,
585
+ "fields": {
586
+ "id": id_hash,
587
+ "url": url,
588
+ "title": title,
589
+ "page_number": page_no,
590
+ "blur_image": base_64_image,
591
+ "full_image": base_64_full_image,
592
+ "text": text,
593
+ "embedding": binary_embedding,
594
+ "queries": queries,
595
+ "questions": questions,
596
+ },
597
+ }
598
+ vespa_feed.append(page)
599
+
600
+ # %%
601
+ # We will prepare the Vespa feed data, including the embeddings and the generated queries
602
+
603
+
604
+ # Save vespa_feed to vespa_feed.json
605
+ os.makedirs("output", exist_ok=True)
606
+ with open("output/vespa_feed.json", "w") as f:
607
+ vespa_feed_to_save = []
608
+ for page in vespa_feed:
609
+ document_id = page["id"]
610
+ put_id = f"id:{VESPA_APPLICATION_NAME}:{VESPA_SCHEMA_NAME}::{document_id}"
611
+ vespa_feed_to_save.append({"put": put_id, "fields": page["fields"]})
612
+ json.dump(vespa_feed_to_save, f)
613
+
614
+ # %%
615
+ # import json
616
+
617
+ # with open("output/vespa_feed.json", "r") as f:
618
+ # vespa_feed = json.load(f)
619
+
620
+ # %%
621
+ len(vespa_feed)
622
+
623
+ # %% [markdown]
624
+ # ## 5. Prepare Vespa Application
625
+ #
626
+
627
+ # %%
628
+ # Define the Vespa schema
629
+ colpali_schema = Schema(
630
+ name=VESPA_SCHEMA_NAME,
631
+ document=Document(
632
+ fields=[
633
+ Field(
634
+ name="id",
635
+ type="string",
636
+ indexing=["summary", "index"],
637
+ match=["word"],
638
+ ),
639
+ Field(name="url", type="string", indexing=["summary", "index"]),
640
+ Field(
641
+ name="title",
642
+ type="string",
643
+ indexing=["summary", "index"],
644
+ match=["text"],
645
+ index="enable-bm25",
646
+ ),
647
+ Field(name="page_number", type="int", indexing=["summary", "attribute"]),
648
+ Field(name="blur_image", type="raw", indexing=["summary"]),
649
+ Field(name="full_image", type="raw", indexing=["summary"]),
650
+ Field(
651
+ name="text",
652
+ type="string",
653
+ indexing=["summary", "index"],
654
+ match=["text"],
655
+ index="enable-bm25",
656
+ ),
657
+ Field(
658
+ name="embedding",
659
+ type="tensor<int8>(patch{}, v[16])",
660
+ indexing=[
661
+ "attribute",
662
+ "index",
663
+ ],
664
+ ann=HNSW(
665
+ distance_metric="hamming",
666
+ max_links_per_node=32,
667
+ neighbors_to_explore_at_insert=400,
668
+ ),
669
+ ),
670
+ Field(
671
+ name="questions",
672
+ type="array<string>",
673
+ indexing=["summary", "index", "attribute"],
674
+ index="enable-bm25",
675
+ stemming="best",
676
+ ),
677
+ Field(
678
+ name="queries",
679
+ type="array<string>",
680
+ indexing=["summary", "index", "attribute"],
681
+ index="enable-bm25",
682
+ stemming="best",
683
+ ),
684
+ # Add synthetic fields for the questions and queries
685
+ # Field(
686
+ # name="questions_exact",
687
+ # type="array<string>",
688
+ # indexing=["input questions", "index", "attribute"],
689
+ # match=["word"],
690
+ # is_document_field=False,
691
+ # ),
692
+ # Field(
693
+ # name="queries_exact",
694
+ # type="array<string>",
695
+ # indexing=["input queries", "index"],
696
+ # match=["word"],
697
+ # is_document_field=False,
698
+ # ),
699
+ ]
700
+ ),
701
+ fieldsets=[
702
+ FieldSet(
703
+ name="default",
704
+ fields=["title", "url", "blur_image", "page_number", "text"],
705
+ ),
706
+ FieldSet(
707
+ name="image",
708
+ fields=["full_image"],
709
+ ),
710
+ ],
711
+ document_summaries=[
712
+ DocumentSummary(
713
+ name="default",
714
+ summary_fields=[
715
+ Summary(
716
+ name="text",
717
+ fields=[("bolding", "on")],
718
+ ),
719
+ Summary(
720
+ name="snippet",
721
+ fields=[("source", "text"), "dynamic"],
722
+ ),
723
+ ],
724
+ from_disk=True,
725
+ ),
726
+ ],
727
+ )
728
+
729
+ # Define similarity functions used in all rank profiles
730
+ mapfunctions = [
731
+ Function(
732
+ name="similarities", # computes similarity scores between each query token and image patch
733
+ expression="""
734
+ sum(
735
+ query(qt) * unpack_bits(attribute(embedding)), v
736
+ )
737
+ """,
738
+ ),
739
+ Function(
740
+ name="normalized", # normalizes the similarity scores to [-1, 1]
741
+ expression="""
742
+ (similarities - reduce(similarities, min)) / (reduce((similarities - reduce(similarities, min)), max)) * 2 - 1
743
+ """,
744
+ ),
745
+ Function(
746
+ name="quantized", # quantizes the normalized similarity scores to signed 8-bit integers [-128, 127]
747
+ expression="""
748
+ cell_cast(normalized * 127.999, int8)
749
+ """,
750
+ ),
751
+ ]
752
+
753
+ # Define the 'bm25' rank profile
754
+ colpali_bm25_profile = RankProfile(
755
+ name="bm25",
756
+ inputs=[("query(qt)", "tensor<float>(querytoken{}, v[128])")],
757
+ first_phase="bm25(title) + bm25(text)",
758
+ functions=mapfunctions,
759
+ summary_features=["quantized"],
760
+ )
761
+ colpali_schema.add_rank_profile(colpali_bm25_profile)
762
+
763
+ # Update the 'default' rank profile
764
+ colpali_profile = RankProfile(
765
+ name="default",
766
+ inputs=[("query(qt)", "tensor<float>(querytoken{}, v[128])")],
767
+ first_phase="bm25_score",
768
+ second_phase=SecondPhaseRanking(expression="max_sim", rerank_count=10),
769
+ functions=mapfunctions
770
+ + [
771
+ Function(
772
+ name="max_sim",
773
+ expression="""
774
+ sum(
775
+ reduce(
776
+ sum(
777
+ query(qt) * unpack_bits(attribute(embedding)), v
778
+ ),
779
+ max, patch
780
+ ),
781
+ querytoken
782
+ )
783
+ """,
784
+ ),
785
+ Function(name="bm25_score", expression="bm25(title) + bm25(text)"),
786
+ ],
787
+ summary_features=["quantized"],
788
+ )
789
+ colpali_schema.add_rank_profile(colpali_profile)
790
+
791
+ # Update the 'retrieval-and-rerank' rank profile
792
+ input_query_tensors = []
793
+ MAX_QUERY_TERMS = 64
794
+ for i in range(MAX_QUERY_TERMS):
795
+ input_query_tensors.append((f"query(rq{i})", "tensor<int8>(v[16])"))
796
+
797
+ input_query_tensors.extend(
798
+ [
799
+ ("query(qt)", "tensor<float>(querytoken{}, v[128])"),
800
+ ("query(qtb)", "tensor<int8>(querytoken{}, v[16])"),
801
+ ]
802
+ )
803
+
804
+ colpali_retrieval_profile = RankProfile(
805
+ name="retrieval-and-rerank",
806
+ inputs=input_query_tensors,
807
+ first_phase="max_sim_binary",
808
+ second_phase=SecondPhaseRanking(expression="max_sim", rerank_count=10),
809
+ functions=mapfunctions
810
+ + [
811
+ Function(
812
+ name="max_sim",
813
+ expression="""
814
+ sum(
815
+ reduce(
816
+ sum(
817
+ query(qt) * unpack_bits(attribute(embedding)), v
818
+ ),
819
+ max, patch
820
+ ),
821
+ querytoken
822
+ )
823
+ """,
824
+ ),
825
+ Function(
826
+ name="max_sim_binary",
827
+ expression="""
828
+ sum(
829
+ reduce(
830
+ 1 / (1 + sum(
831
+ hamming(query(qtb), attribute(embedding)), v)
832
+ ),
833
+ max, patch
834
+ ),
835
+ querytoken
836
+ )
837
+ """,
838
+ ),
839
+ ],
840
+ summary_features=["quantized"],
841
+ )
842
+ colpali_schema.add_rank_profile(colpali_retrieval_profile)
843
+
844
+ # %%
845
+ from vespa.configuration.services import (
846
+ services,
847
+ container,
848
+ search,
849
+ document_api,
850
+ document_processing,
851
+ clients,
852
+ client,
853
+ config,
854
+ content,
855
+ redundancy,
856
+ documents,
857
+ node,
858
+ certificate,
859
+ token,
860
+ document,
861
+ nodes,
862
+ )
863
+ from vespa.configuration.vt import vt
864
+ from vespa.package import ServicesConfiguration
865
+
866
+ service_config = ServicesConfiguration(
867
+ application_name=VESPA_APPLICATION_NAME,
868
+ services_config=services(
869
+ container(
870
+ search(),
871
+ document_api(),
872
+ document_processing(),
873
+ clients(
874
+ client(
875
+ certificate(file="security/clients.pem"),
876
+ id="mtls",
877
+ permissions="read,write",
878
+ ),
879
+ client(
880
+ token(id=f"{VESPA_TOKEN_ID_WRITE}"),
881
+ id="token_write",
882
+ permissions="read,write",
883
+ ),
884
+ client(
885
+ token(id=f"{VESPA_TOKEN_ID_READ}"),
886
+ id="token_read",
887
+ permissions="read",
888
+ ),
889
+ ),
890
+ config(
891
+ vt("tag")(
892
+ vt("bold")(
893
+ vt("open", "<strong>"),
894
+ vt("close", "</strong>"),
895
+ ),
896
+ vt("separator", "..."),
897
+ ),
898
+ name="container.qr-searchers",
899
+ ),
900
+ id=f"{VESPA_APPLICATION_NAME}_container",
901
+ version="1.0",
902
+ ),
903
+ content(
904
+ redundancy("1"),
905
+ documents(document(type="pdf_page", mode="index")),
906
+ nodes(node(distribution_key="0", hostalias="node1")),
907
+ config(
908
+ vt("max_matches", "2", replace_underscores=False),
909
+ vt("length", "1000"),
910
+ vt("surround_max", "500", replace_underscores=False),
911
+ vt("min_length", "300", replace_underscores=False),
912
+ name="vespa.config.search.summary.juniperrc",
913
+ ),
914
+ id=f"{VESPA_APPLICATION_NAME}_content",
915
+ version="1.0",
916
+ ),
917
+ version="1.0",
918
+ ),
919
+ )
920
+
921
+ # %%
922
+ # Create the Vespa application package
923
+ vespa_application_package = ApplicationPackage(
924
+ name=VESPA_APPLICATION_NAME,
925
+ schema=[colpali_schema],
926
+ services_config=service_config,
927
+ )
928
+
929
+ # %% [markdown]
930
+ # ## 6. Deploy Vespa Application
931
+ #
932
+
933
+ # %%
934
+ VESPA_TEAM_API_KEY = os.getenv("VESPA_TEAM_API_KEY") or input(
935
+ "Enter Vespa team API key: "
936
+ )
937
+
938
+ # %%
939
+ vespa_cloud = VespaCloud(
940
+ tenant=VESPA_TENANT_NAME,
941
+ application=VESPA_APPLICATION_NAME,
942
+ key_content=VESPA_TEAM_API_KEY,
943
+ application_package=vespa_application_package,
944
+ )
945
+
946
+ # Deploy the application
947
+ vespa_cloud.deploy()
948
+
949
+ # Output the endpoint URL
950
+ endpoint_url = vespa_cloud.get_token_endpoint()
951
+ print(f"Application deployed. Token endpoint URL: {endpoint_url}")
952
+
953
+ # %% [markdown]
954
+ # Make sure to take note of the token endpoint_url.
955
+ # You need to put this in your `.env` file - `VESPA_APP_URL=https://abcd.vespa-app.cloud` - to access the Vespa application from your web application.
956
+ #
957
+
958
+ # %% [markdown]
959
+ # ## 8. Feed Data to Vespa
960
+ #
961
+
962
+ # %%
963
+ # Instantiate Vespa connection using token
964
+ app = Vespa(url=endpoint_url, vespa_cloud_secret_token=VESPA_CLOUD_SECRET_TOKEN)
965
+ app.get_application_status()
966
+
967
+
968
+ # %%
969
+ def callback(response: VespaResponse, id: str):
970
+ if not response.is_successful():
971
+ print(
972
+ f"Failed to feed document {id} with status code {response.status_code}: Reason {response.get_json()}"
973
+ )
974
+
975
+
976
+ # Feed data into Vespa asynchronously
977
+ app.feed_async_iterable(vespa_feed, schema=VESPA_SCHEMA_NAME, callback=callback)
pyproject.toml CHANGED
@@ -8,7 +8,7 @@ license = { text = "Apache-2.0" }
8
  dependencies = [
9
  "python-fasthtml",
10
  "huggingface-hub",
11
- "pyvespa@git+https://github.com/vespa-engine/pyvespa",
12
  "vespacli",
13
  "torch",
14
  "vidore-benchmark[interpretability]>=4.0.0,<5.0.0",
@@ -18,6 +18,7 @@ dependencies = [
18
  "setuptools",
19
  "python-dotenv",
20
  "shad4fast>=1.2.1",
 
21
  ]
22
 
23
  # dev-dependencies
@@ -27,3 +28,11 @@ dev = [
27
  "python-dotenv",
28
  "huggingface_hub[cli]"
29
  ]
 
 
 
 
 
 
 
 
 
8
  dependencies = [
9
  "python-fasthtml",
10
  "huggingface-hub",
11
+ "pyvespa>=0.50.0",
12
  "vespacli",
13
  "torch",
14
  "vidore-benchmark[interpretability]>=4.0.0,<5.0.0",
 
18
  "setuptools",
19
  "python-dotenv",
20
  "shad4fast>=1.2.1",
21
+ "google-generativeai>=0.7.2"
22
  ]
23
 
24
  # dev-dependencies
 
28
  "python-dotenv",
29
  "huggingface_hub[cli]"
30
  ]
31
+ feed = [
32
+ "ipykernel",
33
+ "jupytext",
34
+ "pydantic",
35
+ "beautifulsoup4",
36
+ "pdf2image",
37
+ "google-generativeai"
38
+ ]
static/.DS_Store ADDED
Binary file (6.15 kB). View file
 
uv.lock CHANGED
The diff for this file is too large to render. See raw diff