Hunter-X-Hunter-Anime-Classification
/
pages
/03-π HxH Character Anime Detection with Prototypical Networks.py
import cv2 | |
import numpy as np | |
import streamlit as st | |
from PIL import Image | |
from models.anime_face_detection_model import SingleShotDetectorModel | |
from models.prototypical_networks import ( | |
PrototypicalNetworksGradCAM, | |
PrototypicalNetworksModel, | |
) | |
from utils import configs | |
from utils.functional import ( | |
generate_empty_space, | |
get_default_images, | |
get_most_salient_object, | |
set_page_config, | |
set_seed, | |
) | |
# Set seed | |
set_seed() | |
# Set page config | |
set_page_config("HxH Character Anime Detection with Prototypical Networks", "π") | |
# Sidebar | |
name_model = st.sidebar.selectbox("Select Model", tuple(configs.NAME_MODELS.keys())) | |
support_set_method = st.sidebar.selectbox( | |
"Select Support Set Method", configs.SUPPORT_SET_METHODS | |
) | |
freeze_model = st.sidebar.checkbox("Freeze Model", value=True) | |
pretrained_model = st.sidebar.checkbox("Pretrained Model", value=True) | |
# Load Model | |
def load_model( | |
name_model: str, support_set_method: str, freeze_model: bool, pretrained_model: bool | |
): | |
prototypical_networks = PrototypicalNetworksModel( | |
name_model, freeze_model, pretrained_model, support_set_method | |
) | |
custom_grad_cam = PrototypicalNetworksGradCAM( | |
name_model, freeze_model, pretrained_model, support_set_method | |
) | |
ssd_model = SingleShotDetectorModel() | |
return prototypical_networks, custom_grad_cam, ssd_model | |
prototypical_networks, custom_grad_cam, ssd_model = load_model( | |
name_model, support_set_method, freeze_model, pretrained_model | |
) | |
# Application Description | |
st.markdown("# β Application Description") | |
st.write( | |
f""" | |
Welcome to our HxH Character Anime Detection with Prototypical Networks application! π΅οΈββοΈπ¦ΈββοΈπ | |
This powerful and efficient tool allows you to quickly and accurately identify your favorite anime characters from Hunter x Hunter using state-of-the-art Prototypical Networks. Simply upload an image or select one of our default options, and let our model do the rest! With our user-friendly interface, anyone can easily classify HxH anime characters with just a few clicks. | |
But that's not all! Our application also features a powerful Grad-CAM visualization tool that lets you see which parts of the image the model is using to make its predictions. Plus, with lightning-fast inference times, you won't have to wait long to get your results. | |
Whether you're a hardcore anime fan or just looking for a fun way to pass the time, our HxH Character Anime Detection app is sure to entertain and delight. So what are you waiting for? Give it a try and see how many characters you can identify! | |
DISCLAIMER: The output of this app only {", ".join(configs.CLASS_CHARACTERS)} | |
""" | |
) | |
uploaded_file = st.file_uploader( | |
"Upload image file", type=["jpg", "jpeg", "png", "bmp", "tiff"] | |
) | |
select_default_images = st.selectbox("Select default images", get_default_images()) | |
st.caption("Default Images will be used if no image is uploaded.") | |
select_image_button = st.button("Select Image") | |
if select_image_button: | |
st.success("Image selected") | |
if select_image_button and uploaded_file is not None: | |
image = np.array(Image.open(uploaded_file).convert("RGB")) | |
st.session_state["image"] = image | |
elif select_image_button and uploaded_file is None: | |
image = np.array(Image.open(select_default_images).convert("RGB")) | |
st.session_state["image"] = image | |
if st.session_state.get("image") is not None: | |
image = st.session_state.get("image") | |
col1, col2, col3 = st.columns(3) | |
col2.write("## πΈ Preview Image") | |
col2.image(image, use_column_width=True) | |
predict_image_button = col2.button("Detect Image Character") | |
generate_empty_space(2) | |
if predict_image_button: | |
with st.spinner("Detecting Image Character..."): | |
results_face_anime_detection = ssd_model.detect_anime_face(image) | |
result_grad_cam = custom_grad_cam.get_grad_cam(image) | |
bounding_box_image = image.copy() | |
inference_time = results_face_anime_detection["inference_time"] | |
results_anime_face = [] | |
if results_face_anime_detection["anime_face"]: | |
for result in results_face_anime_detection["anime_face"]: | |
crop_image = image[ | |
int(result[1]) : int(result[3]), int(result[0]) : int(result[2]) | |
] | |
character = prototypical_networks.predict(crop_image) | |
character_grad_cam = custom_grad_cam.get_grad_cam( | |
crop_image, | |
) | |
results_anime_face.append( | |
{ | |
"face": crop_image, | |
"face_grad_cam": character_grad_cam, | |
"most_salient_object": get_most_salient_object(crop_image), | |
"character": character["character"], | |
"confidence_detection": result[4], | |
"confidence_classification": character["confidence"], | |
} | |
) | |
inference_time += character["inference_time"] | |
cv2.rectangle( | |
bounding_box_image, | |
(int(result[0]), int(result[1])), | |
(int(result[2]), int(result[3])), | |
(255, 255, 0), | |
4, | |
) | |
cv2.putText( | |
bounding_box_image, | |
character["character"], | |
(int(result[0]), int(result[1]) - 10), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
1, | |
(255, 255, 0), | |
2, | |
) | |
col1, col2, col3, col4 = st.columns(4) | |
col1.write("### π Source Image") | |
col1.image(image, use_column_width=True) | |
col2.write("### π Detected Image") | |
col2.image(bounding_box_image, use_column_width=True) | |
col3.write("### π Grad CAM Image") | |
col3.image(result_grad_cam, use_column_width=True) | |
col4.write("### π€ Most Salient Object") | |
col4.image(get_most_salient_object(image), use_column_width=True) | |
st.write("### π Result") | |
st.write(f"Inference Time: {inference_time:.2f} s") | |
for result in results_anime_face: | |
col1, col2, col3 = st.columns(3) | |
col1.write("#### π Cropped Face Image") | |
col1.image(result["face"], use_column_width=True) | |
col2.write("#### π Cropped Face Grad CAM Image") | |
col2.image(result["face_grad_cam"], use_column_width=True) | |
col3.write("### π€ Most Salient Object") | |
col3.image( | |
get_most_salient_object(result["most_salient_object"]), | |
use_column_width=True, | |
) | |
st.write(f"Character: {result['character'].title()}") | |
st.write( | |
f"Confidence Score Detection: {result['confidence_detection']*100:.2f}%" | |
) | |
st.write( | |
f"Confidence Score Classification: {result['confidence_classification']*100:.2f}%" | |
) | |
generate_empty_space(2) | |
st.session_state["image"] = None | |