import os import torch import gradio as gr import requests from PIL import Image from io import BytesIO from qdrant_client import QdrantClient from colpali_engine.models import ColQwen2, ColQwen2Processor # Initialize ColPali model and processor model_name = "vidore/colqwen2-v0.1" device = "cuda:0" if torch.cuda.is_available() else "cpu" # You can change this to "mps" for Apple Silicon if needed colpali_model = ColQwen2.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map=device, ) colpali_processor = ColQwen2Processor.from_pretrained( model_name, ) # Initialize Qdrant client QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") qdrant_client = QdrantClient(url="https://davanstrien-qdrant-test.hf.space", port=None, api_key=QDRANT_API_KEY, timeout=10) collection_name = "song_sheets" # Replace with your actual collection name def search_images_by_text(query_text, top_k=5): # Process and encode the text query with torch.no_grad(): batch_query = colpali_processor.process_queries([query_text]).to(colpali_model.device) query_embedding = colpali_model(**batch_query) # Convert the query embedding to a list of vectors multivector_query = query_embedding[0].cpu().float().numpy().tolist() # Search in Qdrant search_result = qdrant_client.query_points( collection_name=collection_name, query=multivector_query, limit=top_k, timeout=800, ) return search_result def modify_iiif_url(url, size_percent): # Modify the IIIF URL to use percentage scaling parts = url.split('/') size_index = -3 parts[size_index] = f"pct:{size_percent}" return '/'.join(parts) def search_and_display(query, top_k, size_percent): results = search_images_by_text(query, top_k) images = [] captions = [] for result in results.points: modified_url = modify_iiif_url(result.payload['image_url'], size_percent) response = requests.get(modified_url) img = Image.open(BytesIO(response.content)).convert("RGB") images.append(img) captions.append(f"Score: {result.score:.2f}") return images, captions # Define Gradio interface iface = gr.Interface( fn=search_and_display, inputs=[ gr.Textbox(label="Search Query"), gr.Slider(minimum=1, maximum=20, step=1, label="Number of Results", value=5), gr.Slider(minimum=1, maximum=100, step=1, label="Image Size (%)", value=100) ], outputs=[ gr.Gallery(label="Search Results", show_label=False, columns=5, height="auto"), gr.JSON(label="Captions") ], title="Image Search with IIIF Percentage Resizing", description="Enter a text query to search for images. You can adjust the number of results and the size of the returned images as a percentage of the original size." ) # Launch the Gradio interface iface.launch()