FlagCLIP1 / app.py
PuristanLabs1's picture
Update app.py
f39f590 verified
import gradio as gr
import os
import numpy as np
import torch
import clip
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity
# Load CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
# Configuration
FLAG_IMAGE_DIR = "./named_flags"
# Function to search flags with specific queries using CLIP
def search_by_query(query, top_n=10):
"""Search flags based on a text query using CLIP."""
# Encode the text query
with torch.no_grad():
text_embedding = model.encode_text(clip.tokenize([query]).to(device))
# Compare the query embedding with all flag embeddings
similarities = {}
for flag, embedding in flag_embeddings.items():
similarity = cosine_similarity(text_embedding.cpu().numpy(), embedding)[0][0]
similarities[flag] = similarity
# Sort and return the top_n results
sorted_flags = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
results = []
for flag_file, similarity in sorted_flags[:top_n]:
flag_path = os.path.join(FLAG_IMAGE_DIR, flag_file)
if os.path.exists(flag_path):
results.append((flag_path, f"{get_country_name(flag_file)} (Similarity: {similarity:.3f})"))
else:
print(f"File not found: {flag_file}")
return results
# Get all image paths
image_paths = [
os.path.join(FLAG_IMAGE_DIR, img)
for img in os.listdir(FLAG_IMAGE_DIR)
if img.endswith((".png", ".jpg", ".jpeg"))
]
# Load precomputed embeddings
FLAG_EMBEDDINGS_PATH = "./flag_embeddings_1.npy"
flag_embeddings = np.load(FLAG_EMBEDDINGS_PATH, allow_pickle=True).item()
def get_country_name(image_filename):
"""Extract country name from image filename."""
return os.path.splitext(os.path.basename(image_filename))[0].upper()
def get_image_embedding(image_path):
"""Get embedding for an input image."""
image = Image.open(image_path).convert("RGB")
image_input = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
embedding = model.encode_image(image_input)
return embedding.cpu().numpy()
def find_similar_flags(image_path, top_n=10):
"""Find similar flags based on cosine similarity."""
query_embedding = get_image_embedding(image_path)
similarities = {}
for flag, embedding in flag_embeddings.items():
similarity = cosine_similarity(query_embedding, embedding)[0][0]
similarities[flag] = similarity
sorted_flags = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
return sorted_flags[1:top_n + 1] # Skip the first one as it's the same flag
def search_flags(query):
"""Search flags based on country name."""
if not query:
return image_paths
return [img for img in image_paths if query.lower() in get_country_name(img).lower()]
def analyze_and_display(selected_flag):
"""Main function to analyze flag similarity and prepare display."""
try:
if selected_flag is None:
return None
similar_flags = find_similar_flags(selected_flag)
output_images = []
for flag_file, similarity in similar_flags:
flag_path = os.path.join(FLAG_IMAGE_DIR, flag_file)
country_name = get_country_name(flag_file)
output_images.append((flag_path, f"{country_name} (Similarity: {similarity:.3f})"))
return output_images
except Exception as e:
return gr.Error(f"Error processing image: {str(e)}")
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Flag Similarity Analysis")
gr.Markdown("Select a flag from the gallery to find similar flags based on visual features or search using text queries.")
with gr.Tabs():
with gr.Tab("Similarity Search"):
with gr.Row():
with gr.Column(scale=1):
# Search and input gallery
search_box = gr.Textbox(label="Search Flags", placeholder="Enter country name...")
#query_box = gr.Textbox(label="Search by Query", placeholder="e.g., 'crescent in the center'")
input_gallery = gr.Gallery(
label="Available Flags",
show_label=True,
elem_id="gallery",
columns=4,
height="auto"
)
with gr.Column(scale=1):
# Output gallery
output_gallery = gr.Gallery(
label="Similar Flags",
show_label=True,
elem_id="output",
columns=2,
height="auto"
)
# Event handlers
def update_gallery(query):
matching_flags = search_flags(query)
return [(path, get_country_name(path)) for path in matching_flags]
def on_select(evt: gr.SelectData, gallery):
"""Handle flag selection from gallery"""
selected_flag_path = gallery[evt.index][0]
return analyze_and_display(selected_flag_path)
# Connect event handlers
search_box.change(
update_gallery,
inputs=[search_box],
outputs=[input_gallery]
)
input_gallery.select(
on_select,
inputs=[input_gallery],
outputs=[output_gallery]
)
with gr.Tab("Advanced Search"):
gr.Markdown("### Search Flags with Nuanced Queries")
nuanced_query_box = gr.Textbox(label="Enter Advanced Query", placeholder="e.g., 'Find flags with crescent' or 'flags with animals'")
advanced_output_gallery = gr.Gallery(
label="Matching Flags",
show_label=True,
elem_id="advanced_output",
columns=3,
height="auto"
)
def advanced_search(query):
return search_by_query(query)
nuanced_query_box.change(
advanced_search,
inputs=[nuanced_query_box],
outputs=[advanced_output_gallery]
)
# Initialize gallery with all flags
def init_gallery():
return [(path, get_country_name(path)) for path in image_paths]
demo.load(init_gallery, outputs=[input_gallery])
# Launch the app
demo.launch()