thomasht86 commited on
Commit
8dc2c8a
·
verified ·
1 Parent(s): f434932

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. backend/colpali.py +12 -5
  2. backend/vespa_app.py +13 -12
  3. frontend/app.py +34 -19
  4. main.py +42 -20
backend/colpali.py CHANGED
@@ -14,6 +14,8 @@ from colpali_engine.utils.torch_utils import get_torch_device
14
  from vidore_benchmark.interpretability.torch_utils import (
15
  normalize_similarity_map_per_query_token,
16
  )
 
 
17
 
18
 
19
  class SimMapGenerator:
@@ -21,10 +23,14 @@ class SimMapGenerator:
21
  Generates similarity maps based on query embeddings and image patches using the ColPali model.
22
  """
23
 
24
- COLPALI_GEMMA_MODEL_NAME = "vidore/colpaligemma-3b-pt-448-base"
25
  colormap = cm.get_cmap("viridis") # Preload colormap for efficiency
26
 
27
- def __init__(self, model_name: str = "vidore/colpali-v1.2", n_patch: int = 32):
 
 
 
 
 
28
  """
29
  Initializes the SimMapGenerator class with a specified model and patch dimension.
30
 
@@ -35,7 +41,8 @@ class SimMapGenerator:
35
  self.model_name = model_name
36
  self.n_patch = n_patch
37
  self.device = get_torch_device("auto")
38
- print(f"Using device: {self.device}")
 
39
  self.model, self.processor = self.load_model()
40
 
41
  def load_model(self) -> Tuple[ColPali, ColPaliProcessor]:
@@ -47,7 +54,7 @@ class SimMapGenerator:
47
  """
48
  model = ColPali.from_pretrained(
49
  self.model_name,
50
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
51
  device_map=self.device,
52
  ).eval()
53
 
@@ -250,7 +257,7 @@ class SimMapGenerator:
250
  )
251
  return bool(pattern.match(token))
252
 
253
- # TODO: Would be nice to @lru_cache this method.
254
  def get_query_embeddings_and_token_map(
255
  self, query: str
256
  ) -> Tuple[torch.Tensor, dict]:
 
14
  from vidore_benchmark.interpretability.torch_utils import (
15
  normalize_similarity_map_per_query_token,
16
  )
17
+ from functools import lru_cache
18
+ import logging
19
 
20
 
21
  class SimMapGenerator:
 
23
  Generates similarity maps based on query embeddings and image patches using the ColPali model.
24
  """
25
 
 
26
  colormap = cm.get_cmap("viridis") # Preload colormap for efficiency
27
 
28
+ def __init__(
29
+ self,
30
+ logger: logging.Logger,
31
+ model_name: str = "vidore/colpali-v1.2",
32
+ n_patch: int = 32,
33
+ ):
34
  """
35
  Initializes the SimMapGenerator class with a specified model and patch dimension.
36
 
 
41
  self.model_name = model_name
42
  self.n_patch = n_patch
43
  self.device = get_torch_device("auto")
44
+ self.logger = logger
45
+ self.logger.info(f"Using device: {self.device}")
46
  self.model, self.processor = self.load_model()
47
 
48
  def load_model(self) -> Tuple[ColPali, ColPaliProcessor]:
 
54
  """
55
  model = ColPali.from_pretrained(
56
  self.model_name,
57
+ torch_dtype=torch.bfloat16, # Note that the embeddings created during feed were float32 -> binarized, yet setting this seem to produce the most similar results both locally (mps) and HF (Cuda)
58
  device_map=self.device,
59
  ).eval()
60
 
 
257
  )
258
  return bool(pattern.match(token))
259
 
260
+ @lru_cache(maxsize=128)
261
  def get_query_embeddings_and_token_map(
262
  self, query: str
263
  ) -> Tuple[torch.Tensor, dict]:
backend/vespa_app.py CHANGED
@@ -9,6 +9,7 @@ from vespa.application import Vespa
9
  from vespa.io import VespaQueryResponse
10
  from .colpali import SimMapGenerator
11
  import backend.stopwords
 
12
 
13
 
14
  class VespaQueryClient:
@@ -16,14 +17,15 @@ class VespaQueryClient:
16
  VESPA_SCHEMA_NAME = "pdf_page"
17
  SELECT_FIELDS = "id,title,url,blur_image,page_number,snippet,text"
18
 
19
- def __init__(self):
20
  """
21
  Initialize the VespaQueryClient by loading environment variables and establishing a connection to the Vespa application.
22
  """
23
  load_dotenv()
 
24
 
25
  if os.environ.get("USE_MTLS") == "true":
26
- print("Connected using mTLS")
27
  mtls_key = os.environ.get("VESPA_CLOUD_MTLS_KEY")
28
  mtls_cert = os.environ.get("VESPA_CLOUD_MTLS_CERT")
29
 
@@ -52,7 +54,7 @@ class VespaQueryClient:
52
  url=self.vespa_app_url, key=mtls_key_path, cert=mtls_cert_path
53
  )
54
  else:
55
- print("Connected using token")
56
  self.vespa_app_url = os.environ.get("VESPA_APP_TOKEN_URL")
57
  if not self.vespa_app_url:
58
  raise ValueError(
@@ -73,7 +75,7 @@ class VespaQueryClient:
73
  )
74
 
75
  self.app.wait_for_application_up()
76
- print(f"Connected to Vespa at {self.vespa_app_url}")
77
 
78
  def get_fields(self, sim_map: bool = False):
79
  if not sim_map:
@@ -99,7 +101,7 @@ class VespaQueryClient:
99
  query_time = round(query_time, 2)
100
  count = response.json.get("root", {}).get("fields", {}).get("totalCount", 0)
101
  result_text = f"Query text: '{query}', query time {query_time}s, count={count}, top results:\n"
102
- print(result_text)
103
  return response.json
104
 
105
  async def query_vespa_default(
@@ -143,7 +145,7 @@ class VespaQueryClient:
143
  )
144
  assert response.is_successful(), response.json
145
  stop = time.perf_counter()
146
- print(
147
  f"Query time + data transfer took: {stop - start} s, Vespa reported searchtime was "
148
  f"{response.json.get('timing', {}).get('searchtime', -1)} s"
149
  )
@@ -190,7 +192,7 @@ class VespaQueryClient:
190
  )
191
  assert response.is_successful(), response.json
192
  stop = time.perf_counter()
193
- print(
194
  f"Query time + data transfer took: {stop - start} s, Vespa reported searchtime was "
195
  f"{response.json.get('timing', {}).get('searchtime', -1)} s"
196
  )
@@ -215,7 +217,7 @@ class VespaQueryClient:
215
  )
216
  binary_query_embeddings[key] = binary_vector
217
  if len(binary_query_embeddings) >= self.MAX_QUERY_TERMS:
218
- print(
219
  f"Warning: Query has more than {self.MAX_QUERY_TERMS} terms. Truncating."
220
  )
221
  break
@@ -292,12 +294,11 @@ class VespaQueryClient:
292
  result = await self.query_vespa_bm25(query, q_embs, sim_map=sim_map)
293
  else:
294
  raise ValueError(f"Unsupported ranking: {rank_method}")
295
- # Print score, title id, and text of the results
296
  if "root" not in result or "children" not in result["root"]:
297
  result["root"] = {"children": []}
298
  return result
299
  for single_result in result["root"]["children"]:
300
- print(single_result["fields"].keys())
301
  return result
302
 
303
  def get_sim_maps_from_query(
@@ -349,7 +350,7 @@ class VespaQueryClient:
349
  )
350
  assert response.is_successful(), response.json
351
  stop = time.perf_counter()
352
- print(
353
  f"Getting image from Vespa took: {stop - start} s, Vespa reported searchtime was "
354
  f"{response.json.get('timing', {}).get('searchtime', -1)} s"
355
  )
@@ -386,7 +387,7 @@ class VespaQueryClient:
386
  )
387
  assert response.is_successful(), response.json
388
  stop = time.perf_counter()
389
- print(
390
  f"Getting suggestions from Vespa took: {stop - start} s, Vespa reported searchtime was "
391
  f"{response.json.get('timing', {}).get('searchtime', -1)} s"
392
  )
 
9
  from vespa.io import VespaQueryResponse
10
  from .colpali import SimMapGenerator
11
  import backend.stopwords
12
+ import logging
13
 
14
 
15
  class VespaQueryClient:
 
17
  VESPA_SCHEMA_NAME = "pdf_page"
18
  SELECT_FIELDS = "id,title,url,blur_image,page_number,snippet,text"
19
 
20
+ def __init__(self, logger: logging.Logger):
21
  """
22
  Initialize the VespaQueryClient by loading environment variables and establishing a connection to the Vespa application.
23
  """
24
  load_dotenv()
25
+ self.logger = logger
26
 
27
  if os.environ.get("USE_MTLS") == "true":
28
+ self.logger.info("Connected using mTLS")
29
  mtls_key = os.environ.get("VESPA_CLOUD_MTLS_KEY")
30
  mtls_cert = os.environ.get("VESPA_CLOUD_MTLS_CERT")
31
 
 
54
  url=self.vespa_app_url, key=mtls_key_path, cert=mtls_cert_path
55
  )
56
  else:
57
+ self.logger.info("Connected using token")
58
  self.vespa_app_url = os.environ.get("VESPA_APP_TOKEN_URL")
59
  if not self.vespa_app_url:
60
  raise ValueError(
 
75
  )
76
 
77
  self.app.wait_for_application_up()
78
+ self.logger.info(f"Connected to Vespa at {self.vespa_app_url}")
79
 
80
  def get_fields(self, sim_map: bool = False):
81
  if not sim_map:
 
101
  query_time = round(query_time, 2)
102
  count = response.json.get("root", {}).get("fields", {}).get("totalCount", 0)
103
  result_text = f"Query text: '{query}', query time {query_time}s, count={count}, top results:\n"
104
+ self.logger.debug(result_text)
105
  return response.json
106
 
107
  async def query_vespa_default(
 
145
  )
146
  assert response.is_successful(), response.json
147
  stop = time.perf_counter()
148
+ self.logger.debug(
149
  f"Query time + data transfer took: {stop - start} s, Vespa reported searchtime was "
150
  f"{response.json.get('timing', {}).get('searchtime', -1)} s"
151
  )
 
192
  )
193
  assert response.is_successful(), response.json
194
  stop = time.perf_counter()
195
+ self.logger.debug(
196
  f"Query time + data transfer took: {stop - start} s, Vespa reported searchtime was "
197
  f"{response.json.get('timing', {}).get('searchtime', -1)} s"
198
  )
 
217
  )
218
  binary_query_embeddings[key] = binary_vector
219
  if len(binary_query_embeddings) >= self.MAX_QUERY_TERMS:
220
+ self.logger.warning(
221
  f"Warning: Query has more than {self.MAX_QUERY_TERMS} terms. Truncating."
222
  )
223
  break
 
294
  result = await self.query_vespa_bm25(query, q_embs, sim_map=sim_map)
295
  else:
296
  raise ValueError(f"Unsupported ranking: {rank_method}")
 
297
  if "root" not in result or "children" not in result["root"]:
298
  result["root"] = {"children": []}
299
  return result
300
  for single_result in result["root"]["children"]:
301
+ self.logger.debug(single_result["fields"].keys())
302
  return result
303
 
304
  def get_sim_maps_from_query(
 
350
  )
351
  assert response.is_successful(), response.json
352
  stop = time.perf_counter()
353
+ self.logger.debug(
354
  f"Getting image from Vespa took: {stop - start} s, Vespa reported searchtime was "
355
  f"{response.json.get('timing', {}).get('searchtime', -1)} s"
356
  )
 
387
  )
388
  assert response.is_successful(), response.json
389
  stop = time.perf_counter()
390
+ self.logger.debug(
391
  f"Getting suggestions from Vespa took: {stop - start} s, Vespa reported searchtime was "
392
  f"{response.json.get('timing', {}).get('searchtime', -1)} s"
393
  )
frontend/app.py CHANGED
@@ -1,7 +1,7 @@
1
  from typing import Optional
2
  from urllib.parse import quote_plus
3
 
4
- from fasthtml.components import H1, H2, H3, Br, Div, Form, Img, NotStr, P, Span
5
  from fasthtml.xtend import A, Script
6
  from lucide_fasthtml import Lucide
7
  from shad4fast import Badge, Button, Input, Label, RadioGroup, RadioGroupItem, Separator
@@ -137,6 +137,19 @@ dynamic_elements_scrollbars = Script(
137
  """
138
  )
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"):
142
  grid_cls = "grid gap-2 items-center p-3 bg-muted w-full"
@@ -183,6 +196,7 @@ def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"):
183
  name="ranking",
184
  default_value=ranking_value,
185
  cls="grid-flow-col gap-x-5 text-muted-foreground",
 
186
  ),
187
  cls="grid grid-flow-col items-center gap-x-3 border border-input px-3 rounded-sm",
188
  ),
@@ -197,9 +211,10 @@ def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"):
197
  ),
198
  check_input_script,
199
  autocomplete_script,
 
200
  action=f"/search?query={quote_plus(query_value)}&ranking={quote_plus(ranking_value)}",
201
  method="GET",
202
- hx_get=f"/fetch_results?query={quote_plus(query_value)}&ranking={quote_plus(ranking_value)}",
203
  hx_trigger="load",
204
  hx_target="#search-results",
205
  hx_swap="outerHTML",
@@ -310,9 +325,6 @@ def AboutThisDemo():
310
  def Search(request, search_results=[]):
311
  query_value = request.query_params.get("query", "").strip()
312
  ranking_value = request.query_params.get("ranking", "nn+colpali")
313
- print(
314
- f"Search: Fetching results for query: {query_value}, ranking: {ranking_value}"
315
- )
316
  return Div(
317
  Div(
318
  Div(
@@ -371,8 +383,13 @@ def SimMapButtonPoll(query_id, idx, token, token_idx):
371
  def SearchInfo(search_time, total_count):
372
  return (
373
  Div(
374
- NotStr(
375
- f"<span>Found <strong>{total_count}</strong> results in <strong>{search_time}</strong> seconds.</span>"
 
 
 
 
 
376
  ),
377
  cls="grid bg-background border-t text-sm text-center p-3",
378
  ),
@@ -381,7 +398,8 @@ def SearchInfo(search_time, total_count):
381
 
382
  def SearchResult(
383
  results: list,
384
- query: str, query_id: Optional[str] = None,
 
385
  search_time: float = 0,
386
  total_count: int = 0,
387
  ):
@@ -516,7 +534,7 @@ def SearchResult(
516
  Div(
517
  A(
518
  Lucide(icon="external-link", size="18"),
519
- f"PDF Source (Page {fields['page_number']})",
520
  href=f"{fields['url']}#page={fields['page_number'] + 1}",
521
  target="_blank",
522
  cls="flex items-center gap-1.5 font-mono bold text-sm",
@@ -584,16 +602,13 @@ def SearchResult(
584
  return [
585
  Div(
586
  SearchInfo(search_time, total_count),
587
- *result_items,
588
- image_swapping,
589
- toggle_text_content,
590
- dynamic_elements_scrollbars,
591
- id="search-results",
592
- cls="grid grid-cols-1 gap-px bg-border min-h-0",
593
- )
594
-
595
-
596
- ,
597
  Div(
598
  ChatResult(query_id=query_id, query=query, doc_ids=doc_ids),
599
  hx_swap_oob="true",
 
1
  from typing import Optional
2
  from urllib.parse import quote_plus
3
 
4
+ from fasthtml.components import H1, H2, H3, Br, Div, Form, Img, NotStr, P, Span, Strong
5
  from fasthtml.xtend import A, Script
6
  from lucide_fasthtml import Lucide
7
  from shad4fast import Badge, Button, Input, Label, RadioGroup, RadioGroupItem, Separator
 
137
  """
138
  )
139
 
140
+ submit_form_on_radio_change = Script(
141
+ """
142
+ document.addEventListener('click', function (e) {
143
+ // if target has data-ref="radio-item" and type is button
144
+ if (e.target.getAttribute('data-ref') === 'radio-item' && e.target.type === 'button') {
145
+ console.log('Radio button clicked');
146
+ const form = e.target.closest('form');
147
+ form.submit();
148
+ }
149
+ });
150
+ """
151
+ )
152
+
153
 
154
  def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"):
155
  grid_cls = "grid gap-2 items-center p-3 bg-muted w-full"
 
196
  name="ranking",
197
  default_value=ranking_value,
198
  cls="grid-flow-col gap-x-5 text-muted-foreground",
199
+ # Submit form when radio button is clicked
200
  ),
201
  cls="grid grid-flow-col items-center gap-x-3 border border-input px-3 rounded-sm",
202
  ),
 
211
  ),
212
  check_input_script,
213
  autocomplete_script,
214
+ submit_form_on_radio_change,
215
  action=f"/search?query={quote_plus(query_value)}&ranking={quote_plus(ranking_value)}",
216
  method="GET",
217
+ hx_get="/fetch_results", # As the component is a form, input components query and ranking are sent as query parameters automatically, see https://htmx.org/docs/#parameters
218
  hx_trigger="load",
219
  hx_target="#search-results",
220
  hx_swap="outerHTML",
 
325
  def Search(request, search_results=[]):
326
  query_value = request.query_params.get("query", "").strip()
327
  ranking_value = request.query_params.get("ranking", "nn+colpali")
 
 
 
328
  return Div(
329
  Div(
330
  Div(
 
383
  def SearchInfo(search_time, total_count):
384
  return (
385
  Div(
386
+ Span(
387
+ "Retrieved ",
388
+ Strong(total_count),
389
+ Span(" results"),
390
+ Span(" in "),
391
+ Strong(f"{search_time:.3f}"), # 3 significant digits
392
+ Span(" seconds."),
393
  ),
394
  cls="grid bg-background border-t text-sm text-center p-3",
395
  ),
 
398
 
399
  def SearchResult(
400
  results: list,
401
+ query: str,
402
+ query_id: Optional[str] = None,
403
  search_time: float = 0,
404
  total_count: int = 0,
405
  ):
 
534
  Div(
535
  A(
536
  Lucide(icon="external-link", size="18"),
537
+ f"PDF Source (Page {fields['page_number'] + 1})",
538
  href=f"{fields['url']}#page={fields['page_number'] + 1}",
539
  target="_blank",
540
  cls="flex items-center gap-1.5 font-mono bold text-sm",
 
602
  return [
603
  Div(
604
  SearchInfo(search_time, total_count),
605
+ *result_items,
606
+ image_swapping,
607
+ toggle_text_content,
608
+ dynamic_elements_scrollbars,
609
+ id="search-results",
610
+ cls="grid grid-cols-1 gap-px bg-border min-h-0",
611
+ ),
 
 
 
612
  Div(
613
  ChatResult(query_id=query_id, query=query, doc_ids=doc_ids),
614
  hx_swap_oob="true",
main.py CHANGED
@@ -3,6 +3,8 @@ import base64
3
  import os
4
  import time
5
  import uuid
 
 
6
  from concurrent.futures import ThreadPoolExecutor
7
  from pathlib import Path
8
 
@@ -68,6 +70,20 @@ awesomplete_js = Script(
68
  )
69
  sselink = Script(src="https://unpkg.com/htmx-ext-sse@2.2.1/sse.js")
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  app, rt = fast_app(
72
  htmlkw={"cls": "grid h-full"},
73
  pico=False,
@@ -83,7 +99,7 @@ app, rt = fast_app(
83
  ShadHead(tw_cdn=False, theme_handle=True),
84
  ),
85
  )
86
- vespa_app: Vespa = VespaQueryClient()
87
  thread_pool = ThreadPoolExecutor()
88
  # Gemini config
89
 
@@ -107,7 +123,7 @@ os.makedirs(SIM_MAP_DIR, exist_ok=True)
107
 
108
  @app.on_event("startup")
109
  def load_model_on_startup():
110
- app.sim_map_generator = SimMapGenerator()
111
  return
112
 
113
 
@@ -141,7 +157,7 @@ def get():
141
 
142
  @rt("/search")
143
  def get(request, query: str = "", ranking: str = "nn+colpali"):
144
- print("/search: Fetching results for ranking_value:", ranking)
145
 
146
  # Always render the SearchBox first
147
  if not query:
@@ -180,12 +196,16 @@ async def get(session, request, query: str, ranking: str):
180
 
181
  # Get the hash of the query and ranking value
182
  query_id = generate_query_id(query, ranking)
183
- print(f"Query id in /fetch_results: {query_id}")
184
  # Run the embedding and query against Vespa app
185
-
186
  q_embs, idx_to_token = app.sim_map_generator.get_query_embeddings_and_token_map(
187
  query
188
  )
 
 
 
 
189
 
190
  start = time.perf_counter()
191
  # Fetch real search results from Vespa
@@ -196,8 +216,8 @@ async def get(session, request, query: str, ranking: str):
196
  idx_to_token=idx_to_token,
197
  )
198
  end = time.perf_counter()
199
- print(
200
- f"Search results fetched in {end - start:.2f} seconds, Vespa says searchtime was {result['timing']['searchtime']} seconds"
201
  )
202
  search_time = result["timing"]["searchtime"]
203
  total_count = result["root"]["fields"]["totalCount"]
@@ -228,7 +248,7 @@ async def poll_vespa_keepalive():
228
  while True:
229
  await asyncio.sleep(5)
230
  await vespa_app.keepalive()
231
- print(f"Vespa keepalive: {time.time()}")
232
 
233
 
234
  @threaded
@@ -252,7 +272,7 @@ def get_and_store_sim_maps(
252
  ):
253
  time.sleep(0.2)
254
  if not all([os.path.exists(img_path) for img_path in img_paths]):
255
- print(f"Images not ready in 5 seconds for query_id: {query_id}")
256
  return False
257
  sim_map_generator = app.sim_map_generator.gen_similarity_maps(
258
  query=query,
@@ -264,7 +284,7 @@ def get_and_store_sim_maps(
264
  for idx, token, token_idx, blended_img_base64 in sim_map_generator:
265
  with open(SIM_MAP_DIR / f"{query_id}_{idx}_{token_idx}.png", "wb") as f:
266
  f.write(base64.b64decode(blended_img_base64))
267
- print(
268
  f"Sim map saved to disk for query_id: {query_id}, idx: {idx}, token: {token}"
269
  )
270
  return True
@@ -279,7 +299,9 @@ async def get_sim_map(query_id: str, idx: int, token: str, token_idx: int):
279
  """
280
  sim_map_path = SIM_MAP_DIR / f"{query_id}_{idx}_{token_idx}.png"
281
  if not os.path.exists(sim_map_path):
282
- print(f"Sim map not ready for query_id: {query_id}, idx: {idx}, token: {token}")
 
 
283
  return SimMapButtonPoll(
284
  query_id=query_id, idx=idx, token=token, token_idx=token_idx
285
  )
@@ -304,7 +326,7 @@ async def full_image(doc_id: str):
304
  # image data is base 64 encoded string. Save it to disk as jpg.
305
  with open(img_path, "wb") as f:
306
  f.write(base64.b64decode(image_data))
307
- print(f"Full image saved to disk for doc_id: {doc_id}")
308
  else:
309
  with open(img_path, "rb") as f:
310
  image_data = base64.b64encode(f.read()).decode("utf-8")
@@ -330,7 +352,7 @@ async def get_suggestions(query: str = ""):
330
 
331
  async def message_generator(query_id: str, query: str, doc_ids: list):
332
  """Generator function to yield SSE messages for chat response"""
333
- images = {}
334
  num_images = 3 # Number of images before firing chat request
335
  max_wait = 10 # seconds
336
  start_time = time.time()
@@ -339,21 +361,22 @@ async def message_generator(query_id: str, query: str, doc_ids: list):
339
  len(images) < min(num_images, len(doc_ids))
340
  and time.time() - start_time < max_wait
341
  ):
 
342
  for idx in range(num_images):
343
  image_filename = IMG_DIR / f"{doc_ids[idx]}.jpg"
344
  if not os.path.exists(image_filename):
345
- print(
346
  f"Message generator: Full image not ready for query_id: {query_id}, idx: {idx}"
347
  )
348
  continue
349
  else:
350
- print(
351
  f"Message generator: image ready for query_id: {query_id}, idx: {idx}"
352
  )
353
- images[image_filename] = Image.open(image_filename)
354
- await asyncio.sleep(0.2)
 
355
 
356
- images = list(images.values())
357
  # yield message with number of images ready
358
  yield f"event: message\ndata: Generating response based on {len(images)} images...\n\n"
359
  if not images:
@@ -391,7 +414,6 @@ def get():
391
 
392
 
393
  if __name__ == "__main__":
394
- # ModelManager.get_instance() # Initialize once at startup
395
  HOT_RELOAD = os.getenv("HOT_RELOAD", "False").lower() == "true"
396
- print(f"Starting app with hot reload: {HOT_RELOAD}")
397
  serve(port=7860, reload=HOT_RELOAD)
 
3
  import os
4
  import time
5
  import uuid
6
+ import logging
7
+ import sys
8
  from concurrent.futures import ThreadPoolExecutor
9
  from pathlib import Path
10
 
 
70
  )
71
  sselink = Script(src="https://unpkg.com/htmx-ext-sse@2.2.1/sse.js")
72
 
73
+ # Get log level from environment variable, default to INFO
74
+ LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
75
+ # Configure logger
76
+ logger = logging.getLogger("vespa_app")
77
+ handler = logging.StreamHandler(sys.stdout)
78
+ handler.setFormatter(
79
+ logging.Formatter(
80
+ "%(levelname)s: \t %(asctime)s \t %(message)s",
81
+ datefmt="%Y-%m-%d %H:%M:%S",
82
+ )
83
+ )
84
+ logger.addHandler(handler)
85
+ logger.setLevel(getattr(logging, LOG_LEVEL))
86
+
87
  app, rt = fast_app(
88
  htmlkw={"cls": "grid h-full"},
89
  pico=False,
 
99
  ShadHead(tw_cdn=False, theme_handle=True),
100
  ),
101
  )
102
+ vespa_app: Vespa = VespaQueryClient(logger=logger)
103
  thread_pool = ThreadPoolExecutor()
104
  # Gemini config
105
 
 
123
 
124
  @app.on_event("startup")
125
  def load_model_on_startup():
126
+ app.sim_map_generator = SimMapGenerator(logger=logger)
127
  return
128
 
129
 
 
157
 
158
  @rt("/search")
159
  def get(request, query: str = "", ranking: str = "nn+colpali"):
160
+ logger.info(f"/search: Fetching results for query: {query}, ranking: {ranking}")
161
 
162
  # Always render the SearchBox first
163
  if not query:
 
196
 
197
  # Get the hash of the query and ranking value
198
  query_id = generate_query_id(query, ranking)
199
+ logger.info(f"Query id in /fetch_results: {query_id}")
200
  # Run the embedding and query against Vespa app
201
+ start_inference = time.perf_counter()
202
  q_embs, idx_to_token = app.sim_map_generator.get_query_embeddings_and_token_map(
203
  query
204
  )
205
+ end_inference = time.perf_counter()
206
+ logger.info(
207
+ f"Inference time for query_id: {query_id} \t {end_inference - start_inference:.2f} seconds"
208
+ )
209
 
210
  start = time.perf_counter()
211
  # Fetch real search results from Vespa
 
216
  idx_to_token=idx_to_token,
217
  )
218
  end = time.perf_counter()
219
+ logger.info(
220
+ f"Search results fetched in {end - start:.2f} seconds. Vespa search time: {result['timing']['searchtime']}"
221
  )
222
  search_time = result["timing"]["searchtime"]
223
  total_count = result["root"]["fields"]["totalCount"]
 
248
  while True:
249
  await asyncio.sleep(5)
250
  await vespa_app.keepalive()
251
+ logger.debug(f"Vespa keepalive: {time.time()}")
252
 
253
 
254
  @threaded
 
272
  ):
273
  time.sleep(0.2)
274
  if not all([os.path.exists(img_path) for img_path in img_paths]):
275
+ logger.warning(f"Images not ready in 5 seconds for query_id: {query_id}")
276
  return False
277
  sim_map_generator = app.sim_map_generator.gen_similarity_maps(
278
  query=query,
 
284
  for idx, token, token_idx, blended_img_base64 in sim_map_generator:
285
  with open(SIM_MAP_DIR / f"{query_id}_{idx}_{token_idx}.png", "wb") as f:
286
  f.write(base64.b64decode(blended_img_base64))
287
+ logger.debug(
288
  f"Sim map saved to disk for query_id: {query_id}, idx: {idx}, token: {token}"
289
  )
290
  return True
 
299
  """
300
  sim_map_path = SIM_MAP_DIR / f"{query_id}_{idx}_{token_idx}.png"
301
  if not os.path.exists(sim_map_path):
302
+ logger.debug(
303
+ f"Sim map not ready for query_id: {query_id}, idx: {idx}, token: {token}"
304
+ )
305
  return SimMapButtonPoll(
306
  query_id=query_id, idx=idx, token=token, token_idx=token_idx
307
  )
 
326
  # image data is base 64 encoded string. Save it to disk as jpg.
327
  with open(img_path, "wb") as f:
328
  f.write(base64.b64decode(image_data))
329
+ logger.debug(f"Full image saved to disk for doc_id: {doc_id}")
330
  else:
331
  with open(img_path, "rb") as f:
332
  image_data = base64.b64encode(f.read()).decode("utf-8")
 
352
 
353
  async def message_generator(query_id: str, query: str, doc_ids: list):
354
  """Generator function to yield SSE messages for chat response"""
355
+ images = []
356
  num_images = 3 # Number of images before firing chat request
357
  max_wait = 10 # seconds
358
  start_time = time.time()
 
361
  len(images) < min(num_images, len(doc_ids))
362
  and time.time() - start_time < max_wait
363
  ):
364
+ images = []
365
  for idx in range(num_images):
366
  image_filename = IMG_DIR / f"{doc_ids[idx]}.jpg"
367
  if not os.path.exists(image_filename):
368
+ logger.debug(
369
  f"Message generator: Full image not ready for query_id: {query_id}, idx: {idx}"
370
  )
371
  continue
372
  else:
373
+ logger.debug(
374
  f"Message generator: image ready for query_id: {query_id}, idx: {idx}"
375
  )
376
+ images.append(Image.open(image_filename))
377
+ if len(images) < num_images:
378
+ await asyncio.sleep(0.2)
379
 
 
380
  # yield message with number of images ready
381
  yield f"event: message\ndata: Generating response based on {len(images)} images...\n\n"
382
  if not images:
 
414
 
415
 
416
  if __name__ == "__main__":
 
417
  HOT_RELOAD = os.getenv("HOT_RELOAD", "False").lower() == "true"
418
+ logger.info(f"Starting app with hot reload: {HOT_RELOAD}")
419
  serve(port=7860, reload=HOT_RELOAD)