import cv2 import gradio as gr from typing import Union, Tuple from PIL import Image, ImageOps import numpy as np import torch model = torch.jit.load('./model/model.pt').eval() def resize_with_padding(img: Image.Image, expected_size: Tuple[int, int]) -> Image.Image: img.thumbnail((expected_size[0], expected_size[1])) delta_width = expected_size[0] - img.size[0] delta_height = expected_size[1] - img.size[1] pad_width = delta_width // 2 pad_height = delta_height // 2 padding = (pad_width, pad_height, delta_width - pad_width, delta_height - pad_height) return ImageOps.expand(img, padding), padding def preprocess_image(img: Image.Image, size: int = 512) -> Tuple[Image.Image, torch.tensor, Tuple[int]]: pil_img, padding = resize_with_padding(img, (size, size)) img = (np.array(pil_img).astype(np.float32) / 255) - np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3) img = img / np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3) img = np.transpose(img, (2, 0, 1)) return pil_img, torch.tensor(img[None]), padding def soft_blur_with_mask(image: Image.Image, mask: torch.tensor, padding: Tuple[int]) -> Image.Image: image = np.array(image) # Create a blurred copy of the original image. blurred_image = cv2.GaussianBlur(image, (221, 221), sigmaX=20, sigmaY=20) image_height, image_width = image.shape[:2] mask = cv2.resize(mask.astype(np.uint8), (image_width, image_height), interpolation=cv2.INTER_NEAREST) # Blurring the mask itself to get a softer mask with no firm edges mask = cv2.GaussianBlur(mask.astype(np.float32), (11, 11), 10, 10)[:, :, None] # Take the blurred image where the mask it positive, and the original image where the image is original image = (mask * blurred_image + (1.0 - mask) * image) pad_w, pad_h, _, _ = padding img_w, img_h, _ = image.shape image = image[(pad_h):(img_h-pad_h), (pad_w):(img_w-pad_w), :] return Image.fromarray(image.astype(np.uint8)) def run(image, size): pil_image, torch_image, padding = preprocess_image(image, size=size) with torch.inference_mode(): mask = model(torch_image) mask = mask.argmax(dim=1).numpy().squeeze() return soft_blur_with_mask(pil_image, mask, padding) content_image_input = gr.inputs.Image(label="Entrada", type="pil") model_image_size = gr.inputs.Radio([256, 384, 512, 1024], type="value", default=512, label="Ajustar nivel de inferencia") app_interface = gr.Interface(fn=run, inputs=[content_image_input, model_image_size], outputs="image") app_interface.launch()