Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import numpy as np | |
import torch | |
import clip | |
from PIL import Image | |
from sklearn.metrics.pairwise import cosine_similarity | |
# Load CLIP model | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model, preprocess = clip.load("ViT-B/32", device=device) | |
# Configuration | |
FLAG_IMAGE_DIR = "./named_flags" | |
# Function to search flags with specific queries using CLIP | |
def search_by_query(query, top_n=10): | |
"""Search flags based on a text query using CLIP.""" | |
# Encode the text query | |
with torch.no_grad(): | |
text_embedding = model.encode_text(clip.tokenize([query]).to(device)) | |
# Compare the query embedding with all flag embeddings | |
similarities = {} | |
for flag, embedding in flag_embeddings.items(): | |
similarity = cosine_similarity(text_embedding.cpu().numpy(), embedding)[0][0] | |
similarities[flag] = similarity | |
# Sort and return the top_n results | |
sorted_flags = sorted(similarities.items(), key=lambda x: x[1], reverse=True) | |
results = [] | |
for flag_file, similarity in sorted_flags[:top_n]: | |
flag_path = os.path.join(FLAG_IMAGE_DIR, flag_file) | |
if os.path.exists(flag_path): | |
results.append((flag_path, f"{get_country_name(flag_file)} (Similarity: {similarity:.3f})")) | |
else: | |
print(f"File not found: {flag_file}") | |
return results | |
# Get all image paths | |
image_paths = [ | |
os.path.join(FLAG_IMAGE_DIR, img) | |
for img in os.listdir(FLAG_IMAGE_DIR) | |
if img.endswith((".png", ".jpg", ".jpeg")) | |
] | |
# Load precomputed embeddings | |
FLAG_EMBEDDINGS_PATH = "./flag_embeddings_1.npy" | |
flag_embeddings = np.load(FLAG_EMBEDDINGS_PATH, allow_pickle=True).item() | |
def get_country_name(image_filename): | |
"""Extract country name from image filename.""" | |
return os.path.splitext(os.path.basename(image_filename))[0].upper() | |
def get_image_embedding(image_path): | |
"""Get embedding for an input image.""" | |
image = Image.open(image_path).convert("RGB") | |
image_input = preprocess(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
embedding = model.encode_image(image_input) | |
return embedding.cpu().numpy() | |
def find_similar_flags(image_path, top_n=10): | |
"""Find similar flags based on cosine similarity.""" | |
query_embedding = get_image_embedding(image_path) | |
similarities = {} | |
for flag, embedding in flag_embeddings.items(): | |
similarity = cosine_similarity(query_embedding, embedding)[0][0] | |
similarities[flag] = similarity | |
sorted_flags = sorted(similarities.items(), key=lambda x: x[1], reverse=True) | |
return sorted_flags[1:top_n + 1] # Skip the first one as it's the same flag | |
def search_flags(query): | |
"""Search flags based on country name.""" | |
if not query: | |
return image_paths | |
return [img for img in image_paths if query.lower() in get_country_name(img).lower()] | |
def analyze_and_display(selected_flag): | |
"""Main function to analyze flag similarity and prepare display.""" | |
try: | |
if selected_flag is None: | |
return None | |
similar_flags = find_similar_flags(selected_flag) | |
output_images = [] | |
for flag_file, similarity in similar_flags: | |
flag_path = os.path.join(FLAG_IMAGE_DIR, flag_file) | |
country_name = get_country_name(flag_file) | |
output_images.append((flag_path, f"{country_name} (Similarity: {similarity:.3f})")) | |
return output_images | |
except Exception as e: | |
return gr.Error(f"Error processing image: {str(e)}") | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Flag Similarity Analysis") | |
gr.Markdown("Select a flag from the gallery to find similar flags based on visual features or search using text queries.") | |
with gr.Tabs(): | |
with gr.Tab("Similarity Search"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Search and input gallery | |
search_box = gr.Textbox(label="Search Flags", placeholder="Enter country name...") | |
#query_box = gr.Textbox(label="Search by Query", placeholder="e.g., 'crescent in the center'") | |
input_gallery = gr.Gallery( | |
label="Available Flags", | |
show_label=True, | |
elem_id="gallery", | |
columns=4, | |
height="auto" | |
) | |
with gr.Column(scale=1): | |
# Output gallery | |
output_gallery = gr.Gallery( | |
label="Similar Flags", | |
show_label=True, | |
elem_id="output", | |
columns=2, | |
height="auto" | |
) | |
# Event handlers | |
def update_gallery(query): | |
matching_flags = search_flags(query) | |
return [(path, get_country_name(path)) for path in matching_flags] | |
def on_select(evt: gr.SelectData, gallery): | |
"""Handle flag selection from gallery""" | |
selected_flag_path = gallery[evt.index][0] | |
return analyze_and_display(selected_flag_path) | |
# Connect event handlers | |
search_box.change( | |
update_gallery, | |
inputs=[search_box], | |
outputs=[input_gallery] | |
) | |
input_gallery.select( | |
on_select, | |
inputs=[input_gallery], | |
outputs=[output_gallery] | |
) | |
with gr.Tab("Advanced Search"): | |
gr.Markdown("### Search Flags with Nuanced Queries") | |
nuanced_query_box = gr.Textbox(label="Enter Advanced Query", placeholder="e.g., 'Find flags with crescent' or 'flags with animals'") | |
advanced_output_gallery = gr.Gallery( | |
label="Matching Flags", | |
show_label=True, | |
elem_id="advanced_output", | |
columns=3, | |
height="auto" | |
) | |
def advanced_search(query): | |
return search_by_query(query) | |
nuanced_query_box.change( | |
advanced_search, | |
inputs=[nuanced_query_box], | |
outputs=[advanced_output_gallery] | |
) | |
# Initialize gallery with all flags | |
def init_gallery(): | |
return [(path, get_country_name(path)) for path in image_paths] | |
demo.load(init_gallery, outputs=[input_gallery]) | |
# Launch the app | |
demo.launch() |