FaceCropAnime / app.py
Carzit's picture
Upload 2 files
8364cb8 verified
raw
history blame
4.78 kB
import os
from pathlib import Path
from PIL import Image
import torch
import torch.backends.cudnn as cudnn
from numpy import random
from models.experimental import attempt_load
from utils.datasets import LoadImages
from utils.general import (non_max_suppression, scale_coords, xyxy2xywh)
from utils.torch_utils import select_device
import gradio as gr
import huggingface_hub
from crop import crop
class FaceCrop:
def __init__(self):
self.device = select_device()
self.half = self.device.type != 'cpu'
self.results = []
def load_dataset(self, source):
self.source = source
self.dataset = LoadImages(source)
print(f'Successfully load {source}')
def load_model(self, model):
self.model = attempt_load(model, map_location=self.device)
if self.half:
self.model.half()
print(f'Successfully load model weights from {model}')
def set_crop_config(self, target_size, mode=0, face_ratio=3, threshold=1.5):
self.target_size = target_size
self.mode = mode
self.face_ratio = face_ratio
self.threshold = threshold
def info(self):
attributes = dir(self)
for attribute in attributes:
if not attribute.startswith('__') and not callable(getattr(self, attribute)):
value = getattr(self, attribute)
print(attribute, " = ", value)
def process(self):
for path, img, im0s, vid_cap in self.dataset:
img = torch.from_numpy(img).to(self.device)
img = img.half() if self.half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# Inference
pred = self.model(img, augment=False)[0]
# Apply NMS
pred = non_max_suppression(pred)
# Process detections
for i, det in enumerate(pred): # detections per image
p, s, im0 = path, '', im0s
#txt_path = str(Path(out) / Path(p).stem)
s += '%gx%g ' % img.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
if det is not None and len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
# Write results
for *xyxy, conf, cls in det:
if conf > 0.6: # Write to file
x, y, w, h = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
self.results.append(crop(self.source, (x, y), mode=self.mode, size=self.target_size, box=(w, h), face_ratio=self.face_ratio, shreshold=self.threshold))
def run(img, mode, width, height, face_ratio, threshold):
face_crop_pipeline.load_dataset(img)
face_crop_pipeline.set_crop_config(mode=mode, target_size=(width,height), face_ratio=face_ratio, threshold=threshold)
face_crop_pipeline.process()
return face_crop_pipeline.results
if __name__ == '__main__':
model_path = huggingface_hub.hf_hub_download("Carzit/yolo5x_anime", "yolov5x_anime.pt")
face_crop_pipeline = FaceCrop()
face_crop_pipeline.load_model(model_path)
app = gr.Blocks()
with app:
gr.Markdown("# Face Crop Anime")
with gr.Row():
input_img = gr.Image(label="Input Image", image_mode="RGB", type='filepath')
output_img = gr.Gallery(label="Cropped Image")
with gr.Row():
crop_mode = gr.Dropdown(['Auto', 'No Scale', 'Full Screen', 'Fixed Face Propotion'], label="Crop Mode", value='Auto', type='index')
tgt_width = gr.Slider(32, 2048, value=512, label="Width")
tgt_height = gr.Slider(32, 2048, value=512, label="Height")
with gr.Row():
face_ratio = gr.Slider(1, 5, step=0.1, value=2, label="Face Ratio", info="Necessary if choosing \'Auto\' or 'Fixed Face Propotion' Mode")
threshold = gr.Slider(1, 5, step=0.1, value=1.5, label="Threshold", info="Necessary if choosing \'Auto\' Mode")
run_btn = gr.Button(variant="primary")
with gr.Row():
examples_data = [["examples/Eda.png"],["examples/Chtholly.png"],["examples/Fairies.png"]]
examples = gr.Examples(examples=examples_data,
inputs=input_img)
run_btn.click(run, [input_img, crop_mode, tgt_width, tgt_height, face_ratio, threshold], [output_img])
app.launch()