Hunter-X-Hunter-Anime-Classification
/
pages
/04-π Image Similarity with Prototypical Networks.py
import numpy as np | |
import streamlit as st | |
from PIL import Image | |
from models.prototypical_networks import ImageSimilarity, PrototypicalNetworksGradCAM | |
from utils import configs | |
from utils.functional import ( | |
generate_empty_space, | |
get_default_images, | |
get_most_salient_object, | |
set_page_config, | |
set_seed, | |
) | |
# Set seed | |
set_seed() | |
# Set page config | |
set_page_config("Image Similarity with Prototypical Networks", "π") | |
# Sidebar | |
name_model = st.sidebar.selectbox("Select Model", tuple(configs.NAME_MODELS.keys())) | |
support_set_method = st.sidebar.selectbox( | |
"Select Support Set Method", configs.SUPPORT_SET_METHODS | |
) | |
freeze_model = st.sidebar.checkbox("Freeze Model", value=True) | |
pretrained_model = st.sidebar.checkbox("Pretrained Model", value=True) | |
# Load Model | |
def load_model( | |
name_model: str, support_set_method: str, freeze_model: bool, pretrained_model: bool | |
): | |
image_similarity = ImageSimilarity( | |
name_model, freeze_model, pretrained_model, support_set_method | |
) | |
custom_grad_cam = PrototypicalNetworksGradCAM( | |
name_model, freeze_model, pretrained_model, support_set_method | |
) | |
return image_similarity, custom_grad_cam | |
image_similarity, custom_grad_cam = load_model( | |
name_model, support_set_method, freeze_model, pretrained_model | |
) | |
# Application Description | |
st.markdown("# β Application Description") | |
st.write( | |
""" | |
Looking for a fun way to find similar images using cutting-edge technology? Look no further than Image Similarity with Prototypical Networks! π | |
Our powerful and efficient algorithm allows you to quickly and accurately identify similar images based on their visual features. Whether you're an artist looking for inspiration or just want to see how two images compare, our user-friendly interface makes it easy to get started. | |
With just a few clicks, you can upload your images and see how they stack up against each other. Our sophisticated neural network will do the rest, generating a detailed report on the similarities and differences between your images. | |
So why wait? Try Image Similarity with Prototypical Networks today and discover a whole new world of image analysis and exploration! π | |
""" | |
) | |
col1, col2 = st.columns(2) | |
uploaded_file1 = col1.file_uploader( | |
"Upload image file 1", type=["jpg", "jpeg", "png", "bmp", "tiff"] | |
) | |
select_default_images1 = col1.selectbox("Select default images 1", get_default_images()) | |
col1.caption("Default Images 1 will be used if no image is uploaded.") | |
select_image_button1 = col1.button("Select Image 1") | |
if select_image_button1: | |
st.success("Image 1 selected") | |
uploaded_file2 = col2.file_uploader( | |
"Upload image file 2", type=["jpg", "jpeg", "png", "bmp", "tiff"] | |
) | |
select_default_images2 = col2.selectbox("Select default images 2", get_default_images()) | |
col2.caption("Default Images 2 will be used if no image is uploaded.") | |
select_image_button2 = col2.button("Select Image 2") | |
if select_image_button2: | |
st.success("Image 2 selected") | |
if select_image_button1 and uploaded_file1 is not None: | |
image1 = np.array(Image.open(uploaded_file1).convert("RGB")) | |
st.session_state["image1"] = image1 | |
elif select_image_button1 and uploaded_file1 is None: | |
image1 = np.array(Image.open(select_default_images1).convert("RGB")) | |
st.session_state["image1"] = image1 | |
if select_image_button2 and uploaded_file2 is not None: | |
image2 = np.array(Image.open(uploaded_file2).convert("RGB")) | |
st.session_state["image2"] = image2 | |
elif select_image_button2 and uploaded_file2 is None: | |
image2 = np.array(Image.open(select_default_images2).convert("RGB")) | |
st.session_state["image2"] = image2 | |
if ( | |
st.session_state.get("image1") is not None | |
and st.session_state.get("image2") is not None | |
): | |
image1 = st.session_state.get("image1") | |
image2 = st.session_state.get("image2") | |
col1, col2 = st.columns(2) | |
col1.write("## πΈ Preview Image 1") | |
col1.image(image1, use_column_width=True) | |
col2.write("## πΈ Preview Image 2") | |
col2.image(image2, use_column_width=True) | |
predict_image_button = st.button("Get Image Similarity") | |
generate_empty_space(2) | |
if predict_image_button: | |
with st.spinner("Getting Image Similarity..."): | |
result_similarity = image_similarity.get_similarity(image1, image2) | |
result_grad_cam1 = custom_grad_cam.get_grad_cam(image1) | |
result_grad_cam2 = custom_grad_cam.get_grad_cam(image2) | |
inference_time = result_similarity["inference_time"] | |
col1, col2 = st.columns(2) | |
col1.write("### π Grad CAM Image 1") | |
col1.image(result_grad_cam1, use_column_width=True) | |
col2.write("### π Grad CAM Image 2") | |
col2.image(result_grad_cam2, use_column_width=True) | |
col1, col2 = st.columns(2) | |
col1.write("### π€ Most Salient Object Image 1") | |
col1.image(get_most_salient_object(image1), use_column_width=True) | |
col2.write("### π€ Most Salient Object Image 2") | |
col2.image(get_most_salient_object(image2), use_column_width=True) | |
st.write("### π Result") | |
st.write(f"Similarity Score: {result_similarity['similarity'] * 100:.2f}%") | |
st.write( | |
f"Similarity Label: {result_similarity['result_similarity'].title()}" | |
) | |
st.write(f"Inference Time: {inference_time:.2f} s") | |
st.session_state["image1"] = None | |
st.session_state["image2"] = None | |