import os
import numpy as np
from matplotlib import rcParams
import matplotlib.pyplot as plt
from requests import get
import streamlit as st
import cv2
from ultralytics import YOLO
import shutil
PREDICTION_PATH = os.path.join('.', 'predictions')
@st.cache_resource
def load_od_model():
finetuned_model = YOLO('face_detection_best.pt')
return finetuned_model
def inference(input_image_path: str):
finetuned_model = load_od_model()
results = finetuned_model.predict(input_image_path,
show=False,
save=True,
save_crop=False,
imgsz=640,
conf=0.6,
save_txt=True,
project= PREDICTION_PATH,
show_labels=False,
show_conf=False,
line_width=2,
exist_ok=True)
names = finetuned_model.names
nfaces = 0
for r in results:
for c in r.boxes.cls:
nfaces += 1
with placeholder.container():
st.markdown(f"
{nfaces} faces detected.
", unsafe_allow_html=True)
st.image(os.path.join(PREDICTION_PATH, 'predict', 'input.jpg'))
def files_cleanup(path_: str):
if os.path.exists(path_):
os.remove(path_)
shutil.rmtree(PREDICTION_PATH)
# @st.cache_resource
def get_upload_path():
upload_file_path = os.path.join('.', 'uploads')
if not os.path.exists(upload_file_path):
os.makedirs(upload_file_path)
upload_filename = "input.jpg"
upload_file_path = os.path.join(upload_file_path, upload_filename)
return upload_file_path
def process_input_image(img_url):
upload_file_path = get_upload_path()
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/102.0.0.0 Safari/537.36'}
r = get(img_url, headers=headers)
arr = np.frombuffer(r.content, np.uint8)
input_image = cv2.imdecode(arr, cv2.IMREAD_UNCHANGED)
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
input_image = cv2.resize(input_image, (640, 640))
cv2.imwrite(upload_file_path, cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR))
return upload_file_path
try:
st.markdown("Face Detection
", unsafe_allow_html=True)
desc = '''Dataset used to fine-tune YOLOv8
can be found
here.
'''
st.markdown(desc, unsafe_allow_html=True)
img_url = st.text_input("Paste the image URL having faces:", "")
placeholder = st.empty()
if img_url:
placeholder.empty()
img_path = process_input_image(img_url)
inference(img_path)
files_cleanup(img_path)
except Exception as e:
st.error(f'An unexpected error occured: \n{e}')