Hunter-X-Hunter-Anime-Classification
/
pages
/02-π¦Έ HxH Character Anime Classification with Prototypical Networks.py
import numpy as np | |
import streamlit as st | |
from PIL import Image | |
from models.prototypical_networks import ( | |
PrototypicalNetworksGradCAM, | |
PrototypicalNetworksModel, | |
) | |
from utils import configs | |
from utils.functional import ( | |
generate_empty_space, | |
get_default_images, | |
get_most_salient_object, | |
set_page_config, | |
set_seed, | |
) | |
# Set seed | |
set_seed() | |
# Set page config | |
set_page_config("HxH Character Anime Classification with Prototypical Networks", "π¦Έ") | |
# Sidebar | |
name_model = st.sidebar.selectbox("Select Model", tuple(configs.NAME_MODELS.keys())) | |
support_set_method = st.sidebar.selectbox( | |
"Select Support Set Method", configs.SUPPORT_SET_METHODS | |
) | |
freeze_model = st.sidebar.checkbox("Freeze Model", value=True) | |
pretrained_model = st.sidebar.checkbox("Pretrained Model", value=True) | |
# Load Model | |
def load_model( | |
name_model: str, support_set_method: str, freeze_model: bool, pretrained_model: bool | |
): | |
prototypical_networks = PrototypicalNetworksModel( | |
name_model, freeze_model, pretrained_model, support_set_method | |
) | |
custom_grad_cam = PrototypicalNetworksGradCAM( | |
name_model, freeze_model, pretrained_model, support_set_method | |
) | |
return prototypical_networks, custom_grad_cam | |
prototypical_networks, custom_grad_cam = load_model( | |
name_model, support_set_method, freeze_model, pretrained_model | |
) | |
# Application Description | |
st.markdown("# β Application Description") | |
st.write( | |
f""" | |
Welcome to our HxH Character Anime Classification with Prototypical Networks π¦Έ app! With just a few clicks, you can classify your favorite anime characters from Hunter x Hunter using our powerful and efficient Prototypical Networks. Our user-friendly interface makes it easy for anyone to get started, whether you're a hardcore anime fan or just looking for a fun way to pass the time. | |
Simply upload an image or select one of our default images, and let our app do the rest! Our app will accurately identify and classify the character, and even provide you with a Grad-CAM image to show you which parts of the image contributed most to the classification. | |
So what are you waiting for? Try our HxH Character Anime Classification app now and see if you can correctly identify all your favorite characters! | |
DISCLAIMER: The output of this app only {", ".join(configs.CLASS_CHARACTERS)} | |
""" | |
) | |
uploaded_file = st.file_uploader( | |
"Upload image file", type=["jpg", "jpeg", "png", "bmp", "tiff"] | |
) | |
select_default_images = st.selectbox("Select default images", get_default_images()) | |
st.caption("Default Images will be used if no image is uploaded.") | |
select_image_button = st.button("Select Image") | |
if select_image_button: | |
st.success("Image selected") | |
if select_image_button and uploaded_file is not None: | |
image = np.array(Image.open(uploaded_file).convert("RGB")) | |
st.session_state["image"] = image | |
elif select_image_button and uploaded_file is None: | |
image = np.array(Image.open(select_default_images).convert("RGB")) | |
st.session_state["image"] = image | |
if st.session_state.get("image") is not None: | |
image = st.session_state.get("image") | |
col1, col2, col3 = st.columns(3) | |
col2.write("## πΈ Preview Image") | |
col2.image(image, use_column_width=True) | |
predict_image_button = col2.button("Classify Image Character") | |
generate_empty_space(2) | |
if predict_image_button: | |
with st.spinner("Classifying Image Character..."): | |
result_class = prototypical_networks.predict(image) | |
result_grad_cam = custom_grad_cam.get_grad_cam(image) | |
inference_time = result_class["inference_time"] | |
col1, col2, col3 = st.columns(3) | |
col1.write("### π Source Image") | |
col1.image(image, use_column_width=True) | |
col2.write("### π Grad CAM Image") | |
col2.image(result_grad_cam, use_column_width=True) | |
col3.write("### π€ Most Salient Object") | |
col3.image(get_most_salient_object(image), use_column_width=True) | |
st.write("### π Result") | |
st.write(f"Predicted Character: {result_class['character'].title()}") | |
st.write(f"Confidence Score: {result_class['confidence'] * 100:.2f}%") | |
st.write(f"Inference Time: {inference_time:.2f} s") | |
st.session_state["image"] = None | |