davanstrien HF staff commited on
Commit
bf43437
1 Parent(s): ab1b9b5

chore: Update requirements.txt with new dependencies

Browse files
Files changed (1) hide show
  1. app.py +120 -39
app.py CHANGED
@@ -1,15 +1,30 @@
 
 
1
  import os
2
- import torch
 
 
 
3
  import gradio as gr
4
  import requests
 
 
5
  from PIL import Image
6
- from io import BytesIO
7
  from qdrant_client import QdrantClient
8
- from colpali_engine.models import ColQwen2, ColQwen2Processor
9
 
 
 
 
 
 
 
 
 
 
 
 
10
  # Initialize ColPali model and processor
11
  model_name = "vidore/colqwen2-v0.1"
12
- device = "cuda:0" if torch.cuda.is_available() else "cpu" # You can change this to "mps" for Apple Silicon if needed
13
  colpali_model = ColQwen2.from_pretrained(
14
  model_name,
15
  torch_dtype=torch.bfloat16,
@@ -21,15 +36,22 @@ colpali_processor = ColQwen2Processor.from_pretrained(
21
 
22
  # Initialize Qdrant client
23
  QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
24
- qdrant_client = QdrantClient(url="https://davanstrien-qdrant-test.hf.space",
25
- port=None, api_key=QDRANT_API_KEY, timeout=10)
 
 
 
 
26
 
27
  collection_name = "song_sheets" # Replace with your actual collection name
28
 
 
29
  def search_images_by_text(query_text, top_k=5):
30
  # Process and encode the text query
31
  with torch.no_grad():
32
- batch_query = colpali_processor.process_queries([query_text]).to(colpali_model.device)
 
 
33
  query_embedding = colpali_model(**batch_query)
34
 
35
  # Convert the query embedding to a list of vectors
@@ -45,42 +67,101 @@ def search_images_by_text(query_text, top_k=5):
45
 
46
  return search_result
47
 
 
48
  def modify_iiif_url(url, size_percent):
49
  # Modify the IIIF URL to use percentage scaling
50
- parts = url.split('/')
51
  size_index = -3
52
  parts[size_index] = f"pct:{size_percent}"
53
- return '/'.join(parts)
 
 
 
 
 
 
 
54
 
55
- def search_and_display(query, top_k, size_percent):
 
 
 
 
 
 
56
  results = search_images_by_text(query, top_k)
57
- images = []
58
- captions = []
59
-
60
- for result in results.points:
61
- modified_url = modify_iiif_url(result.payload['image_url'], size_percent)
62
- response = requests.get(modified_url)
63
- img = Image.open(BytesIO(response.content)).convert("RGB")
64
- images.append(img)
65
- captions.append(f"Score: {result.score:.2f}")
66
-
67
- return images, captions
68
-
69
- # Define Gradio interface
70
- iface = gr.Interface(
71
- fn=search_and_display,
72
- inputs=[
73
- gr.Textbox(label="Search Query"),
74
- gr.Slider(minimum=1, maximum=20, step=1, label="Number of Results", value=5),
75
- gr.Slider(minimum=1, maximum=100, step=1, label="Image Size (%)", value=100)
76
- ],
77
- outputs=[
78
- gr.Gallery(label="Search Results", show_label=False, columns=5, height="auto"),
79
- gr.JSON(label="Captions")
80
- ],
81
- title="Image Search with IIIF Percentage Resizing",
82
- description="Enter a text query to search for images. You can adjust the number of results and the size of the returned images as a percentage of the original size."
83
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- # Launch the Gradio interface
86
- iface.launch()
 
1
+ import asyncio
2
+ import html
3
  import os
4
+ from io import BytesIO
5
+
6
+ import aiohttp
7
+ import dotenv
8
  import gradio as gr
9
  import requests
10
+ import torch
11
+ from colpali_engine.models import ColQwen2, ColQwen2Processor
12
  from PIL import Image
 
13
  from qdrant_client import QdrantClient
 
14
 
15
+ dotenv.load_dotenv()
16
+
17
+ if torch.cuda.is_available():
18
+ device = "cuda:0"
19
+ elif torch.backends.mps.is_available():
20
+ device = "mps"
21
+ else:
22
+ device = "cpu"
23
+
24
+
25
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
26
  # Initialize ColPali model and processor
27
  model_name = "vidore/colqwen2-v0.1"
 
28
  colpali_model = ColQwen2.from_pretrained(
29
  model_name,
30
  torch_dtype=torch.bfloat16,
 
36
 
37
  # Initialize Qdrant client
38
  QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
39
+ qdrant_client = QdrantClient(
40
+ url="https://davanstrien-qdrant-test.hf.space",
41
+ port=None,
42
+ api_key=QDRANT_API_KEY,
43
+ timeout=10,
44
+ )
45
 
46
  collection_name = "song_sheets" # Replace with your actual collection name
47
 
48
+
49
  def search_images_by_text(query_text, top_k=5):
50
  # Process and encode the text query
51
  with torch.no_grad():
52
+ batch_query = colpali_processor.process_queries([query_text]).to(
53
+ colpali_model.device
54
+ )
55
  query_embedding = colpali_model(**batch_query)
56
 
57
  # Convert the query embedding to a list of vectors
 
67
 
68
  return search_result
69
 
70
+
71
  def modify_iiif_url(url, size_percent):
72
  # Modify the IIIF URL to use percentage scaling
73
+ parts = url.split("/")
74
  size_index = -3
75
  parts[size_index] = f"pct:{size_percent}"
76
+ return "/".join(parts)
77
+
78
+
79
+ async def fetch_image(session, url):
80
+ async with session.get(url) as response:
81
+ content = await response.read()
82
+ return Image.open(BytesIO(content)).convert("RGB")
83
+
84
 
85
+ async def fetch_all_images(urls):
86
+ async with aiohttp.ClientSession() as session:
87
+ tasks = [fetch_image(session, url) for url in urls]
88
+ return await asyncio.gather(*tasks)
89
+
90
+
91
+ async def search_and_display(query, top_k, size_percent):
92
  results = search_images_by_text(query, top_k)
93
+ modified_urls = [
94
+ modify_iiif_url(result.payload["image_url"], size_percent)
95
+ for result in results.points
96
+ ]
97
+
98
+ images = await fetch_all_images(modified_urls)
99
+ html_output = (
100
+ "<div style='display: flex; flex-wrap: wrap; justify-content: space-around;'>"
101
+ )
102
+ for i, (image, result) in enumerate(zip(images, results.points)):
103
+ image_url = modified_urls[i]
104
+ item_url = result.payload["item_url"]
105
+ score = result.score
106
+ html_output += f"""
107
+ <div style='margin: 10px; text-align: center; width: 300px;'>
108
+ <img src='{image_url}' style='max-width: 100%; height: auto;'>
109
+ <p>Score: {score:.2f}</p>
110
+ <a href='{item_url}' target='_blank'>View Item</a>
111
+ </div>
112
+ """
113
+ html_output += "</div>"
114
+ return html_output
115
+
116
+
117
+ # Wrapper function for synchronous Gradio interface
118
+ def search_and_display_wrapper(query, top_k, size_percent):
119
+ return asyncio.run(search_and_display(query, top_k, size_percent))
120
+
121
+
122
+ with gr.Blocks() as demo:
123
+ gr.HTML(
124
+ """
125
+ <h1 style='text-align: center; color: #2a4b7c;'>America Singing: Nineteenth-Century Song Sheets ColPali Search</h1>
126
+ <div style="display: flex; align-items: flex-start; margin-bottom: 20px;">
127
+ <div style="flex: 2; padding-right: 20px;">
128
+ <p>This app allows you to search through the Library of Congress's <a href="https://www.loc.gov/collections/nineteenth-century-song-sheets/about-this-collection/" target="_blank">"America Singing: Nineteenth-Century Song Sheets"</a> collection using natural language queries. The collection contains 4,291 song sheets from the 19th century, offering a unique window into American history, culture, and music.</p>
129
+
130
+ <p>This search functionality is powered by <a href="https://huggingface.co/blog/manu/colpali" target="_blank">ColPali</a>, an efficient document retrieval system that uses Vision Language Models. ColPali allows for searching through documents (including images and complex layouts) without the need for traditional text extraction or OCR. It works by directly embedding page images and using a <a href="https://jina.ai/news/what-is-colbert-and-late-interaction-and-why-they-matter-in-search/" target="_blank">late interaction mechanism</a> to match queries with relevant document patches.</p>
131
+
132
+ <p>ColPali's approach:
133
+ <ul>
134
+ <li>Uses a Vision Language Model to encode document page images directly</li>
135
+ <li>Splits images into patches and creates contextualized patch embeddings</li>
136
+ <li>Employs a late interaction mechanism to efficiently match query tokens to document patches</li>
137
+ <li>Eliminates the need for complex OCR and document parsing pipelines</li>
138
+ <li>Captures both textual and visual information from documents</li>
139
+ </ul>
140
+ </p>
141
+ </div>
142
+ <div style="flex: 1;">
143
+ <img src="https://tile.loc.gov/image-services/iiif/service:rbc:amss:hc:00:00:3b:hc00003b:001a/full/pct:50/0/default.jpg" alt="Example Song Sheet" style="width: 100%; height: auto; border-radius: 8px; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
144
+ <p style="text-align: center;"><em>Example of a song sheet from the collection</em></p>
145
+ </div>
146
+ </div>
147
+ """
148
+ )
149
+ with gr.Row():
150
+ with gr.Column(scale=4):
151
+ search_box = gr.Textbox(
152
+ label="Search Query", placeholder="i.e. Irish migrant experience"
153
+ )
154
+ with gr.Column(scale=1):
155
+ submit_button = gr.Button("Search", variant="primary")
156
+ num_results = gr.Slider(
157
+ minimum=1, maximum=20, step=1, label="Number of Results", value=5
158
+ )
159
+ results_html = gr.HTML(label="Search Results")
160
+
161
+ submit_button.click(
162
+ fn=lambda query, top_k: search_and_display_wrapper(query, top_k, 100),
163
+ inputs=[search_box, num_results],
164
+ outputs=results_html,
165
+ )
166
 
167
+ demo.launch()