import gradio as gr import torchvision def normalize(): MEAN = [0.485, 0.456, 0.406] STD = [0.229, 0.224, 0.225] return transforms.Normalize(mean = MEAN, std = STD) def denormalize(): # out = (x - mean) / std MEAN = [0.485, 0.456, 0.406] STD = [0.229, 0.224, 0.225] MEAN = [-mean/std for mean, std in zip(MEAN, STD)] STD = [1/std for std in STD] return transforms.Normalize(mean=MEAN, std=STD) def transformer(imsize = None, cropsize = None): transformer = [] if imsize: transformer.append(transforms.Resize(imsize)) if cropsize: transformer.append(transforms.RandomCrop(cropsize)) transformer.append(transforms.ToTensor()) transformer.append(normalize()) return transforms.Compose(transformer) def tensor_to_img(tensor): denormalizer = denormalize() if tensor.device == "cuda": tensor = tensor.cpu() # tensor = torchvision.utils.make_grid(denormalizer(tensor.squeeze())) image = transforms.functional.to_pil_image(tensor.clamp_(0., 1.)) return image def style_transfer(content_img, style_strength, style_img_1 = None, iw_1 = 0., style_img_2 = None, iw_2 = 0., style_img_3 = None, iw_3 = 0., preserve_color = None): transform = transformer(imsize = 512) content = transform(content_img).unsqueeze(0).cuda() iw = [iw_1, iw_2, iw_3] interpolation_weights = [i/ sum(iw) for i in iw] style_imgs = [style_img_1, style_img_2, style_img_3] styles = [] for style_img in style_imgs: if style_img is not None: styles.append(transform(style_img).unsqueeze(0).cuda()) if preserve_color == "None": preserve_color = None elif preserve_color == "Whitening": preserve_color = "batch_wct" #elif preserve_color == "Histogram matching": preserve_color = "histogram_matching" with torch.no_grad(): stylized_img = model(content, styles, interpolation_weights, preserve_color, style_strength) return tensor_to_img(stylized_img) title = "Artistic Style Transfer" content_img = gr.components.Image(label="Content image", type = "pil") style_img_1 = gr.components.Image(label="Style images", type = "pil") iw_1 = gr.components.Slider(0., 1., label = "Style 1 interpolation") style_img_2 = gr.components.Image(label="Style images", type = "pil") iw_2 = gr.components.Slider(0., 1., label = "Style 2 interpolation") style_img_3 = gr.components.Image(label="Style images", type = "pil") iw_3 = gr.components.Slider(0., 1., label = "Style 3 interpolation") style_strength = gr.components.Slider(0., 1., label = "Adjust style strength") preserve_color = gr.components.Dropdown(["None", "Whitening"], label = "Choose color preserving mode") interface = gr.Interface(fn = style_transfer, inputs = [content_img, style_strength, style_img_1, iw_1, style_img_2, iw_2, style_img_3, iw_3, preserve_color], outputs = gr.components.Image(), title = title, description = None ) interface.queue() interface.launch(share = True, debug = True)