import argparse import cv2 import numpy as np import torch import kornia as K from kornia.contrib import FaceDetector, FaceDetectorResult import gradio as gr import face_detection def detect_faces(img: np.ndarray, method:str): frame = np.array(img) kornia_detections = kornia_detect(frame) retina_detections = retina_detect(frame) retina_mobile_detections = retina_mobilenet_detect(frame) dsfd_detections = dsfd_detect(frame) # if method == "Kornia YuNet": # re_im = kornia_detect(frame) # elif method == "RetinaFace": # re_im = retina_detect(frame) return kornia_detections, retina_detections, retina_mobile_detections, dsfd_detections def scale_image(img: np.ndarray, size: int) -> np.ndarray: h, w = img.shape[:2] scale = 1.0 * size / w return cv2.resize(img, (int(w * scale), int(h * scale))) def base_detect(detector, img): img = scale_image(img, 400) detections = detector.detect(img) img_vis = img.copy() for box in detections: img_vis = cv2.rectangle(img_vis, box[:2].astype(int).tolist(), box[2:4].astype(int).tolist(), (0, 255, 0), 1) return img_vis def retina_detect(img): detector = face_detection.build_detector( "RetinaNetResNet50", confidence_threshold=.5, nms_iou_threshold=.3) img_vis = base_detect(detector, img) return img_vis def retina_mobilenet_detect(img): detector = face_detection.build_detector( "RetinaNetMobileNetV1", confidence_threshold=.5, nms_iou_threshold=.3) img_vis = base_detect(detector, img) return img_vis def dsfd_detect(img): detector = face_detection.build_detector( "DSFDDetector", confidence_threshold=.5, nms_iou_threshold=.3) img_vis = base_detect(detector, img) return img_vis def kornia_detect(img): # select the device device = torch.device('cpu') vis_threshold = 0.6 # load the image and scale img_raw = scale_image(img, 400) # preprocess img = K.image_to_tensor(img_raw, keepdim=False).to(device) img = K.color.bgr_to_rgb(img.float()) # create the detector and find the faces ! face_detection = FaceDetector().to(device) with torch.no_grad(): dets = face_detection(img) dets = [FaceDetectorResult(o) for o in dets] # show image img_vis = img_raw.copy() for b in dets: if b.score < vis_threshold: continue # draw face bounding box img_vis = cv2.rectangle(img_vis, b.top_left.int().tolist(), b.bottom_right.int().tolist(), (0, 255, 0), 1) return img_vis input_image = gr.components.Image() image_kornia = gr.components.Image(label="Kornia YuNet") image_retina = gr.components.Image(label="RetinaFace") image_retina_mobile = gr.components.Image(label="Retina Mobilenet") image_dsfd = gr.components.Image(label="DSFD") confidence_slider = gr.components.Slider(minimum=0.1, maximum=0.9, value=0.5, label="Confidence Threshold") nms_slider = gr.components.Slider(minimum=0.1, maximum=0.9, value=0.5, label="Min Number of Neighbours") # scale_slider = gr.components.Slider(minimum=1.1, maximum=2.0, value=1.3, step=0.1, label="Scale Factor") # classifier_radio = gr.components.Radio(s) methods_dropdown = gr.components.Dropdown(["Kornia YuNet", "RetinaFace", "RetinaMobile", "DSFD"], value="Kornia YuNet", label="Choose a method") description = """Face Detection""" Iface = gr.Interface( fn=detect_faces, inputs=[input_image, methods_dropdown],#, size_slider, neighbour_slider, scale_slider], outputs=[image_kornia, image_retina, image_retina_mobile, image_dsfd], examples=[["data/9_Press_Conference_Press_Conference_9_86.jpg"], ["data/12_Group_Group_12_Group_Group_12_39.jpg"], ["data/31_Waiter_Waitress_Waiter_Waitress_31_55.jpg"]], title="Face Detection", ).launch()