Spaces:
Runtime error
Runtime error
| import os | |
| from pathlib import Path | |
| from PIL import Image | |
| import torch | |
| import torch.backends.cudnn as cudnn | |
| from numpy import random | |
| from models.experimental import attempt_load | |
| from utils.datasets import LoadImages | |
| from utils.general import (non_max_suppression, scale_coords, xyxy2xywh) | |
| from utils.torch_utils import select_device | |
| import gradio as gr | |
| import huggingface_hub | |
| from crop import crop | |
| class FaceCrop: | |
| def __init__(self): | |
| self.device = select_device() | |
| self.half = self.device.type != 'cpu' | |
| self.results = [] | |
| def load_dataset(self, source): | |
| self.source = source | |
| self.dataset = LoadImages(source) | |
| print(f'Successfully load {source}') | |
| def load_model(self, model): | |
| self.model = attempt_load(model, map_location=self.device) | |
| if self.half: | |
| self.model.half() | |
| print(f'Successfully load model weights from {model}') | |
| def set_crop_config(self, target_size, mode=0, face_ratio=3, threshold=1.5): | |
| self.target_size = target_size | |
| self.mode = mode | |
| self.face_ratio = face_ratio | |
| self.threshold = threshold | |
| def info(self): | |
| attributes = dir(self) | |
| for attribute in attributes: | |
| if not attribute.startswith('__') and not callable(getattr(self, attribute)): | |
| value = getattr(self, attribute) | |
| print(attribute, " = ", value) | |
| def process(self): | |
| for path, img, im0s, vid_cap in self.dataset: | |
| img = torch.from_numpy(img).to(self.device) | |
| img = img.half() if self.half else img.float() # uint8 to fp16/32 | |
| img /= 255.0 # 0 - 255 to 0.0 - 1.0 | |
| if img.ndimension() == 3: | |
| img = img.unsqueeze(0) | |
| # Inference | |
| pred = self.model(img, augment=False)[0] | |
| # Apply NMS | |
| pred = non_max_suppression(pred) | |
| # Process detections | |
| for i, det in enumerate(pred): # detections per image | |
| p, s, im0 = path, '', im0s | |
| #txt_path = str(Path(out) / Path(p).stem) | |
| s += '%gx%g ' % img.shape[2:] # print string | |
| gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh | |
| if det is not None and len(det): | |
| # Rescale boxes from img_size to im0 size | |
| det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() | |
| # Write results | |
| for *xyxy, conf, cls in det: | |
| if conf > 0.6: # Write to file | |
| x, y, w, h = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() | |
| self.results.append(crop(self.source, (x, y), mode=self.mode, size=self.target_size, box=(w, h), face_ratio=self.face_ratio, shreshold=self.threshold)) | |
| def run(img, mode, width, height, face_ratio, threshold): | |
| face_crop_pipeline.load_dataset(img) | |
| face_crop_pipeline.set_crop_config(mode=mode, target_size=(width,height), face_ratio=face_ratio, threshold=threshold) | |
| face_crop_pipeline.process() | |
| return face_crop_pipeline.results | |
| if __name__ == '__main__': | |
| model_path = huggingface_hub.hf_hub_download("Carzit/yolo5x_anime", "yolov5x_anime.pt") | |
| face_crop_pipeline = FaceCrop() | |
| face_crop_pipeline.load_model(model_path) | |
| app = gr.Blocks() | |
| with app: | |
| gr.Markdown("# Face Crop Anime") | |
| with gr.Row(): | |
| input_img = gr.Image(label="Input Image", image_mode="RGB", type='filepath') | |
| output_img = gr.Gallery(label="Cropped Image") | |
| with gr.Row(): | |
| crop_mode = gr.Dropdown(['Auto', 'No Scale', 'Full Screen', 'Fixed Face Propotion'], label="Crop Mode", value='Auto', type='index') | |
| tgt_width = gr.Slider(32, 2048, value=512, label="Width") | |
| tgt_height = gr.Slider(32, 2048, value=512, label="Height") | |
| with gr.Row(): | |
| face_ratio = gr.Slider(1, 5, step=0.1, value=2, label="Face Ratio", info="Necessary if choosing \'Auto\' or 'Fixed Face Propotion' Mode") | |
| threshold = gr.Slider(1, 5, step=0.1, value=1.5, label="Threshold", info="Necessary if choosing \'Auto\' Mode") | |
| run_btn = gr.Button(variant="primary") | |
| with gr.Row(): | |
| examples_data = [["examples/Eda.png"],["examples/Chtholly.png"],["examples/Fairies.png"]] | |
| examples = gr.Examples(examples=examples_data, | |
| inputs=input_img) | |
| run_btn.click(run, [input_img, crop_mode, tgt_width, tgt_height, face_ratio, threshold], [output_img]) | |
| app.launch() | |