Johannes
init
b713355
raw
history blame
4.11 kB
import argparse
import cv2
import numpy as np
import torch
import kornia as K
from kornia.contrib import FaceDetector, FaceDetectorResult
import gradio as gr
import face_detection
def detect_faces(img: np.ndarray, method:str):
frame = np.array(img)
kornia_detections = kornia_detect(frame)
retina_detections = retina_detect(frame)
retina_mobile_detections = retina_mobilenet_detect(frame)
dsfd_detections = dsfd_detect(frame)
# if method == "Kornia YuNet":
# re_im = kornia_detect(frame)
# elif method == "RetinaFace":
# re_im = retina_detect(frame)
return kornia_detections, retina_detections, retina_mobile_detections, dsfd_detections
def scale_image(img: np.ndarray, size: int) -> np.ndarray:
h, w = img.shape[:2]
scale = 1.0 * size / w
return cv2.resize(img, (int(w * scale), int(h * scale)))
def base_detect(detector, img):
img = scale_image(img, 400)
detections = detector.detect(img)
img_vis = img.copy()
for box in detections:
img_vis = cv2.rectangle(img_vis,
box[:2].astype(int).tolist(),
box[2:4].astype(int).tolist(),
(0, 255, 0), 1)
return img_vis
def retina_detect(img):
detector = face_detection.build_detector(
"RetinaNetResNet50", confidence_threshold=.5, nms_iou_threshold=.3)
img_vis = base_detect(detector, img)
return img_vis
def retina_mobilenet_detect(img):
detector = face_detection.build_detector(
"RetinaNetMobileNetV1", confidence_threshold=.5, nms_iou_threshold=.3)
img_vis = base_detect(detector, img)
return img_vis
def dsfd_detect(img):
detector = face_detection.build_detector(
"DSFDDetector", confidence_threshold=.5, nms_iou_threshold=.3)
img_vis = base_detect(detector, img)
return img_vis
def kornia_detect(img):
# select the device
device = torch.device('cpu')
vis_threshold = 0.6
# load the image and scale
img_raw = scale_image(img, 400)
# preprocess
img = K.image_to_tensor(img_raw, keepdim=False).to(device)
img = K.color.bgr_to_rgb(img.float())
# create the detector and find the faces !
face_detection = FaceDetector().to(device)
with torch.no_grad():
dets = face_detection(img)
dets = [FaceDetectorResult(o) for o in dets]
# show image
img_vis = img_raw.copy()
for b in dets:
if b.score < vis_threshold:
continue
# draw face bounding box
img_vis = cv2.rectangle(img_vis,
b.top_left.int().tolist(),
b.bottom_right.int().tolist(),
(0, 255, 0),
1)
return img_vis
input_image = gr.components.Image()
image_kornia = gr.components.Image(label="Kornia YuNet")
image_retina = gr.components.Image(label="RetinaFace")
image_retina_mobile = gr.components.Image(label="Retina Mobilenet")
image_dsfd = gr.components.Image(label="DSFD")
confidence_slider = gr.components.Slider(minimum=0.1, maximum=0.9, value=0.5, label="Confidence Threshold")
nms_slider = gr.components.Slider(minimum=0.1, maximum=0.9, value=0.5, label="Min Number of Neighbours")
# scale_slider = gr.components.Slider(minimum=1.1, maximum=2.0, value=1.3, step=0.1, label="Scale Factor")
# classifier_radio = gr.components.Radio(s)
methods_dropdown = gr.components.Dropdown(["Kornia YuNet", "RetinaFace", "RetinaMobile", "DSFD"], value="Kornia YuNet", label="Choose a method")
description = """Face Detection"""
Iface = gr.Interface(
fn=detect_faces,
inputs=[input_image, methods_dropdown],#, size_slider, neighbour_slider, scale_slider],
outputs=[image_kornia, image_retina, image_retina_mobile, image_dsfd],
examples=[["data/9_Press_Conference_Press_Conference_9_86.jpg"], ["data/12_Group_Group_12_Group_Group_12_39.jpg"], ["data/31_Waiter_Waitress_Waiter_Waitress_31_55.jpg"]],
title="Face Detection",
).launch()