# takn from: https://huggingface.co/spaces/frgfm/torch-cam/blob/main/app.py # streamlit run app.py from io import BytesIO import os import sys import cv2 import matplotlib.pyplot as plt import numpy as np import streamlit as st import torch import tempfile from PIL import Image from torchvision import models from torchvision.transforms.functional import normalize, resize, to_pil_image, to_tensor from torchvision import transforms from torchcam.methods import CAM from torchcam import methods as torchcam_methods from torchcam.utils import overlay_mask import os.path as osp root_path = osp.abspath(osp.join(__file__, osp.pardir)) sys.path.append(root_path) from preprocessing.dataset_creation import EyeDentityDatasetCreation from utils import get_model from registry_utils import import_registered_modules import_registered_modules() # from torchcam.methods._utils import locate_candidate_layer CAM_METHODS = [ "CAM", # "GradCAM", # "GradCAMpp", # "SmoothGradCAMpp", # "ScoreCAM", # "SSCAM", # "ISCAM", # "XGradCAM", # "LayerCAM", ] TV_MODELS = [ "ResNet18", "ResNet50", ] SR_METHODS = ["GFPGAN", "CodeFormer", "RealESRGAN", "SRResNet", "HAT"] UPSCALE = [2, 4] UPSCALE_METHODS = ["BILINEAR", "BICUBIC"] LABEL_MAP = ["left_pupil", "right_pupil"] @torch.no_grad() def _load_model(model_configs, device="cpu"): model_path = os.path.join(root_path, model_configs["model_path"]) model_configs.pop("model_path") model_dict = torch.load(model_path, map_location=device) model = get_model(model_configs=model_configs) model.load_state_dict(model_dict) model = model.to(device) model = model.eval() return model def extract_frames(video_path): vidcap = cv2.VideoCapture(video_path) frames = [] success, image = vidcap.read() count = 0 while success: # Convert the frame to RGB (cv2 uses BGR by default) image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) frames.append(image_rgb) success, image = vidcap.read() count += 1 vidcap.release() return frames # Function to check if a file is an image def is_image(file_extension): return file_extension.lower() in ["png", "jpeg", "jpg"] # Function to check if a file is a video def is_video(file_extension): return file_extension.lower() in ["mp4", "avi", "mov", "mkv", "webm"] def resize_frame(frame, max_width, max_height): image = Image.fromarray(frame) original_size = image.size # Resize the frame similarly to the image resizing logic if original_size[0] == original_size[1] and original_size[0] >= 256: max_size = (256, 256) else: max_size = list(original_size) if original_size[0] >= 640: max_size[0] = 640 elif original_size[0] < 64: max_size[0] = 64 if original_size[1] >= 480: max_size[1] = 480 elif original_size[1] < 32: max_size[1] = 32 image.thumbnail(max_size) return image def main(): # Wide mode st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide") # Designing the interface st.title("EyeDentify Playground") # For newline st.write("\n") # Set the columns cols = st.columns((1, 1)) # cols = st.columns((1, 1, 1)) cols[0].header("Input image") # cols[1].header("Raw CAM") cols[-1].header("Prediction") # Sidebar # File selection st.sidebar.title("Upload Face or Eye") # Disabling warning st.set_option("deprecation.showfileUploaderEncoding", False) # Choose your own image uploaded_file = st.sidebar.file_uploader( "Upload Image or Video", type=["png", "jpeg", "jpg", "mp4", "avi", "mov", "mkv", "webm"] ) if uploaded_file is not None: # Get file extension file_extension = uploaded_file.name.split(".")[-1] input_imgs = [] if is_image(file_extension): input_img = Image.open(BytesIO(uploaded_file.read()), mode="r").convert("RGB") # print("input_img before = ", input_img.size) max_size = [input_img.size[0], input_img.size[1]] cols[0].text(f"Input Image: {max_size[0]} x {max_size[1]}") if input_img.size[0] == input_img.size[1] and input_img.size[0] >= 256: max_size[0] = 256 max_size[1] = 256 else: if input_img.size[0] >= 640: max_size[0] = 640 elif input_img.size[0] < 64: max_size[0] = 64 if input_img.size[1] >= 480: max_size[1] = 480 elif input_img.size[1] < 32: max_size[1] = 32 input_img.thumbnail((max_size[0], max_size[1])) # Bicubic resampling input_imgs.append(input_img) # print("input_img after = ", input_img.size) # cols[0].image(input_img) fig0, axs0 = plt.subplots(1, 1, figsize=(10, 10)) # Display the input image axs0.imshow(input_imgs[0]) axs0.axis("off") axs0.set_title("Input Image") # Display the plot cols[0].pyplot(fig0) cols[0].text(f"Input Image Resized: {max_size[0]} x {max_size[1]}") # TODO: show the face features extracted from the image under 'input image' column elif is_video(file_extension): tfile = tempfile.NamedTemporaryFile(delete=False) tfile.write(uploaded_file.read()) video_path = tfile.name # Extract frames from the video frames = extract_frames(video_path) print(f"Extracted {len(frames)} frames from the video") # Process the frames for i, frame in enumerate(frames): input_imgs.append(resize_frame(frame, 640, 480)) os.remove(video_path) fig0, axs0 = plt.subplots(1, 1, figsize=(10, 10)) # Display the input image axs0.imshow(input_imgs[0]) axs0.axis("off") axs0.set_title("Input Image") # Display the plot cols[0].pyplot(fig0) # cols[0].text(f"Input Image Resized: {max_size[0]} x {max_size[1]}") st.sidebar.title("Setup") # Upscale selection upscale = "-" # upscale = st.sidebar.selectbox( # "Upscale", # ["-"] + UPSCALE, # help="Upscale the uploaded image 2 or 4 times. Keep blank for no upscaling", # ) # Upscale method selection if upscale != "-": upscale_method_or_model = st.sidebar.selectbox( "Upscale Method / Model", UPSCALE_METHODS + SR_METHODS, help="Select a method or model to upscale the uploaded image", ) else: upscale_method_or_model = None # Pupil selection pupil_selection = st.sidebar.selectbox( "Pupil Selection", ["-"] + LABEL_MAP, help="Select left or right pupil OR keep blank for both pupil diameter estimation", ) # Model selection tv_model = st.sidebar.selectbox( "Classification model", TV_MODELS, help="Supported Models for Pupil Diameter Estimation", ) cam_method = "CAM" # cam_method = st.sidebar.selectbox( # "CAM method", # CAM_METHODS, # help="The way your class activation map will be computed", # ) # target_layer = st.sidebar.text_input( # "Target layer", # default_layer, # help='If you want to target several layers, add a "+" separator (e.g. "layer3+layer4")', # ) st.sidebar.write("\n") if st.sidebar.button("Predict Diameter & Compute CAM"): if uploaded_file is None: st.sidebar.error("Please upload an image first") else: with st.spinner("Analyzing..."): model = None for input_img in input_imgs: if upscale == "-": sr_configs = None else: sr_configs = { "method": upscale_method_or_model, "params": {"upscale": upscale}, } config_file = { "sr_configs": sr_configs, "feature_extraction_configs": { "blink_detection": False, "upscale": upscale, "extraction_library": "mediapipe", }, } img = np.array(input_img) # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # if img.shape[0] > max_size or img.shape[1] > max_size: # img = cv2.resize(img, (max_size, max_size)) ds_results = EyeDentityDatasetCreation( feature_extraction_configs=config_file["feature_extraction_configs"], sr_configs=config_file["sr_configs"], )(img) # if ds_results is not None: # print("ds_results = ", ds_results.keys()) # NOTE: # ds_results.keys() contains ===> 'full_imgs', 'faces', 'eyes', 'blinks', 'iris' preprocess_steps = [ transforms.ToTensor(), transforms.Resize( [32, 64], interpolation=transforms.InterpolationMode.BICUBIC, antialias=True, ), ] preprocess_function = transforms.Compose(preprocess_steps) left_eye = None right_eye = None if ds_results is None: # print("type of input_img = ", type(input_img)) input_img = preprocess_function(input_img) input_img = input_img.unsqueeze(0) if pupil_selection == "left_pupil": left_eye = input_img elif pupil_selection == "right_pupil": right_eye = input_img else: left_eye = input_img right_eye = input_img # print("type of left_eye = ", type(left_eye)) # print("type of right_eye = ", type(right_eye)) elif "eyes" in ds_results.keys(): if "left_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["left_eye"] is not None: left_eye = ds_results["eyes"]["left_eye"] # print("type of left_eye = ", type(left_eye)) left_eye = to_pil_image(left_eye).convert("RGB") # print("type of left_eye = ", type(left_eye)) left_eye = preprocess_function(left_eye) # print("type of left_eye = ", type(left_eye)) left_eye = left_eye.unsqueeze(0) if "right_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["right_eye"] is not None: right_eye = ds_results["eyes"]["right_eye"] # print("type of right_eye = ", type(right_eye)) right_eye = to_pil_image(right_eye).convert("RGB") # print("type of right_eye = ", type(right_eye)) right_eye = preprocess_function(right_eye) # print("type of right_eye = ", type(right_eye)) right_eye = right_eye.unsqueeze(0) else: # print("type of input_img = ", type(input_img)) input_img = preprocess_function(input_img) input_img = input_img.unsqueeze(0) if pupil_selection == "left_pupil": left_eye = input_img elif pupil_selection == "right_pupil": right_eye = input_img else: left_eye = input_img right_eye = input_img # print("type of left_eye = ", type(left_eye)) # print("type of right_eye = ", type(right_eye)) # print("left_eye = ", left_eye.shape) # print("right_eye = ", right_eye.shape) if pupil_selection == "-": selected_eyes = ["left_eye", "right_eye"] elif pupil_selection == "left_pupil": selected_eyes = ["left_eye"] elif pupil_selection == "right_pupil": selected_eyes = ["right_eye"] for eye_type in selected_eyes: if model is None: model_configs = { "model_path": root_path + f"/pre_trained_models/{tv_model}/{eye_type}.pt", "registered_model_name": tv_model, "num_classes": 1, } registered_model_name = model_configs["registered_model_name"] model = _load_model(model_configs) if registered_model_name == "ResNet18": target_layer = model.resnet.layer4[-1].conv2 elif registered_model_name == "ResNet50": target_layer = model.resnet.layer4[-1].conv3 else: raise Exception(f"No target layer available for selected model: {registered_model_name}") if left_eye is not None and eye_type == "left_eye": input_img = left_eye elif right_eye is not None and eye_type == "right_eye": input_img = right_eye else: raise Exception("Wrong Data") if cam_method is not None: cam_extractor = torchcam_methods.__dict__[cam_method]( model, target_layer=target_layer, fc_layer=model.resnet.fc, input_shape=input_img.shape, ) # with torch.no_grad(): out = model(input_img) cols[-1].markdown( f"