davanstrien's picture
davanstrien HF staff
Update app.py
e95b21b verified
raw
history blame
2.92 kB
import os
import torch
import gradio as gr
import requests
from PIL import Image
from io import BytesIO
from qdrant_client import QdrantClient
from colpali_engine.models import ColQwen2, ColQwen2Processor
# Initialize ColPali model and processor
model_name = "vidore/colqwen2-v0.1"
device = "cuda:0" if torch.cuda.is_available() else "cpu" # You can change this to "mps" for Apple Silicon if needed
colpali_model = ColQwen2.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
)
colpali_processor = ColQwen2Processor.from_pretrained(
model_name,
)
# Initialize Qdrant client
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
qdrant_client = QdrantClient(url="https://davanstrien-qdrant-test.hf.space",
port=None, api_key=QDRANT_API_KEY, timeout=10)
collection_name = "song_sheets" # Replace with your actual collection name
def search_images_by_text(query_text, top_k=5):
# Process and encode the text query
with torch.no_grad():
batch_query = colpali_processor.process_queries([query_text]).to(colpali_model.device)
query_embedding = colpali_model(**batch_query)
# Convert the query embedding to a list of vectors
multivector_query = query_embedding[0].cpu().float().numpy().tolist()
# Search in Qdrant
search_result = qdrant_client.query_points(
collection_name=collection_name,
query=multivector_query,
limit=top_k,
timeout=800,
)
return search_result
def modify_iiif_url(url, size_percent):
# Modify the IIIF URL to use percentage scaling
parts = url.split('/')
size_index = -3
parts[size_index] = f"pct:{size_percent}"
return '/'.join(parts)
def search_and_display(query, top_k, size_percent):
results = search_images_by_text(query, top_k)
images = []
captions = []
for result in results.points:
modified_url = modify_iiif_url(result.payload['image_url'], size_percent)
response = requests.get(modified_url)
img = Image.open(BytesIO(response.content)).convert("RGB")
images.append(img)
captions.append(f"Score: {result.score:.2f}")
return images, captions
# Define Gradio interface
iface = gr.Interface(
fn=search_and_display,
inputs=[
gr.Textbox(label="Search Query"),
gr.Slider(minimum=1, maximum=20, step=1, label="Number of Results", value=5),
gr.Slider(minimum=1, maximum=100, step=1, label="Image Size (%)", value=100)
],
outputs=[
gr.Gallery(label="Search Results", show_label=False, columns=5, height="auto"),
gr.JSON(label="Captions")
],
title="Image Search with IIIF Percentage Resizing",
description="Enter a text query to search for images. You can adjust the number of results and the size of the returned images as a percentage of the original size."
)
# Launch the Gradio interface
iface.launch()