import gradio as gr import os import numpy as np import torch import clip from PIL import Image from sklearn.metrics.pairwise import cosine_similarity # Load CLIP model device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load("ViT-B/32", device=device) # Configuration FLAG_IMAGE_DIR = "./named_flags" # Function to search flags with specific queries using CLIP def search_by_query(query, top_n=10): """Search flags based on a text query using CLIP.""" # Encode the text query with torch.no_grad(): text_embedding = model.encode_text(clip.tokenize([query]).to(device)) # Compare the query embedding with all flag embeddings similarities = {} for flag, embedding in flag_embeddings.items(): similarity = cosine_similarity(text_embedding.cpu().numpy(), embedding)[0][0] similarities[flag] = similarity # Sort and return the top_n results sorted_flags = sorted(similarities.items(), key=lambda x: x[1], reverse=True) results = [] for flag_file, similarity in sorted_flags[:top_n]: flag_path = os.path.join(FLAG_IMAGE_DIR, flag_file) if os.path.exists(flag_path): results.append((flag_path, f"{get_country_name(flag_file)} (Similarity: {similarity:.3f})")) else: print(f"File not found: {flag_file}") return results # Get all image paths image_paths = [ os.path.join(FLAG_IMAGE_DIR, img) for img in os.listdir(FLAG_IMAGE_DIR) if img.endswith((".png", ".jpg", ".jpeg")) ] # Load precomputed embeddings FLAG_EMBEDDINGS_PATH = "./flag_embeddings_1.npy" flag_embeddings = np.load(FLAG_EMBEDDINGS_PATH, allow_pickle=True).item() def get_country_name(image_filename): """Extract country name from image filename.""" return os.path.splitext(os.path.basename(image_filename))[0].upper() def get_image_embedding(image_path): """Get embedding for an input image.""" image = Image.open(image_path).convert("RGB") image_input = preprocess(image).unsqueeze(0).to(device) with torch.no_grad(): embedding = model.encode_image(image_input) return embedding.cpu().numpy() def find_similar_flags(image_path, top_n=10): """Find similar flags based on cosine similarity.""" query_embedding = get_image_embedding(image_path) similarities = {} for flag, embedding in flag_embeddings.items(): similarity = cosine_similarity(query_embedding, embedding)[0][0] similarities[flag] = similarity sorted_flags = sorted(similarities.items(), key=lambda x: x[1], reverse=True) return sorted_flags[1:top_n + 1] # Skip the first one as it's the same flag def search_flags(query): """Search flags based on country name.""" if not query: return image_paths return [img for img in image_paths if query.lower() in get_country_name(img).lower()] def analyze_and_display(selected_flag): """Main function to analyze flag similarity and prepare display.""" try: if selected_flag is None: return None similar_flags = find_similar_flags(selected_flag) output_images = [] for flag_file, similarity in similar_flags: flag_path = os.path.join(FLAG_IMAGE_DIR, flag_file) country_name = get_country_name(flag_file) output_images.append((flag_path, f"{country_name} (Similarity: {similarity:.3f})")) return output_images except Exception as e: return gr.Error(f"Error processing image: {str(e)}") # Create Gradio interface with gr.Blocks() as demo: gr.Markdown("# Flag Similarity Analysis") gr.Markdown("Select a flag from the gallery to find similar flags based on visual features or search using text queries.") with gr.Tabs(): with gr.Tab("Similarity Search"): with gr.Row(): with gr.Column(scale=1): # Search and input gallery search_box = gr.Textbox(label="Search Flags", placeholder="Enter country name...") #query_box = gr.Textbox(label="Search by Query", placeholder="e.g., 'crescent in the center'") input_gallery = gr.Gallery( label="Available Flags", show_label=True, elem_id="gallery", columns=4, height="auto" ) with gr.Column(scale=1): # Output gallery output_gallery = gr.Gallery( label="Similar Flags", show_label=True, elem_id="output", columns=2, height="auto" ) # Event handlers def update_gallery(query): matching_flags = search_flags(query) return [(path, get_country_name(path)) for path in matching_flags] def on_select(evt: gr.SelectData, gallery): """Handle flag selection from gallery""" selected_flag_path = gallery[evt.index][0] return analyze_and_display(selected_flag_path) # Connect event handlers search_box.change( update_gallery, inputs=[search_box], outputs=[input_gallery] ) input_gallery.select( on_select, inputs=[input_gallery], outputs=[output_gallery] ) with gr.Tab("Advanced Search"): gr.Markdown("### Search Flags with Nuanced Queries") nuanced_query_box = gr.Textbox(label="Enter Advanced Query", placeholder="e.g., 'Find flags with crescent' or 'flags with animals'") advanced_output_gallery = gr.Gallery( label="Matching Flags", show_label=True, elem_id="advanced_output", columns=3, height="auto" ) def advanced_search(query): return search_by_query(query) nuanced_query_box.change( advanced_search, inputs=[nuanced_query_box], outputs=[advanced_output_gallery] ) # Initialize gallery with all flags def init_gallery(): return [(path, get_country_name(path)) for path in image_paths] demo.load(init_gallery, outputs=[input_gallery]) # Launch the app demo.launch()