Spaces:
Sleeping
Sleeping
File size: 3,500 Bytes
a5f14de 6c74699 67754d0 a5f14de 6c74699 a5f14de 6c74699 a5f14de 6c74699 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import gradio as gr
import torchvision
import torchvision.transforms as transforms
import torch
model = torch.jit.load('https://huggingface.co/spaces/1-13-am/neural-style-transfer/blob/main/neural_style_transfer.pt')
def normalize():
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
return transforms.Normalize(mean = MEAN, std = STD)
def denormalize():
# out = (x - mean) / std
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
MEAN = [-mean/std for mean, std in zip(MEAN, STD)]
STD = [1/std for std in STD]
return transforms.Normalize(mean=MEAN, std=STD)
def transformer(imsize = None, cropsize = None):
transformer = []
if imsize:
transformer.append(transforms.Resize(imsize))
if cropsize:
transformer.append(transforms.RandomCrop(cropsize))
transformer.append(transforms.ToTensor())
transformer.append(normalize())
return transforms.Compose(transformer)
def tensor_to_img(tensor):
denormalizer = denormalize()
if tensor.device == "cuda":
tensor = tensor.cpu()
#
tensor = torchvision.utils.make_grid(denormalizer(tensor.squeeze()))
image = transforms.functional.to_pil_image(tensor.clamp_(0., 1.))
return image
def style_transfer(content_img, style_strength, style_img_1 = None, iw_1 = 0., style_img_2 = None, iw_2 = 0., style_img_3 = None, iw_3 = 0., preserve_color = None):
transform = transformer(imsize = 512)
content = transform(content_img).unsqueeze(0)
iw = [iw_1, iw_2, iw_3]
interpolation_weights = [i/ sum(iw) for i in iw]
style_imgs = [style_img_1, style_img_2, style_img_3]
styles = []
for style_img in style_imgs:
if style_img is not None:
styles.append(transform(style_img).unsqueeze(0))
if preserve_color == "None": preserve_color = None
elif preserve_color == "Whitening": preserve_color = "batch_wct"
#elif preserve_color == "Histogram matching": preserve_color = "histogram_matching"
with torch.no_grad():
stylized_img = model(content, styles, interpolation_weights, preserve_color, style_strength)
return tensor_to_img(stylized_img)
title = "Artistic Style Transfer"
content_img = gr.components.Image(label="Content image", type = "pil")
style_img_1 = gr.components.Image(label="Style images", type = "pil")
iw_1 = gr.components.Slider(0., 1., label = "Style 1 interpolation")
style_img_2 = gr.components.Image(label="Style images", type = "pil")
iw_2 = gr.components.Slider(0., 1., label = "Style 2 interpolation")
style_img_3 = gr.components.Image(label="Style images", type = "pil")
iw_3 = gr.components.Slider(0., 1., label = "Style 3 interpolation")
style_strength = gr.components.Slider(0., 1., label = "Adjust style strength")
preserve_color = gr.components.Dropdown(["None", "Whitening"], label = "Choose color preserving mode")
interface = gr.Interface(fn = style_transfer,
inputs = [content_img,
style_strength,
style_img_1,
iw_1,
style_img_2,
iw_2,
style_img_3,
iw_3,
preserve_color],
outputs = gr.components.Image(),
title = title,
description = None
)
interface.queue()
interface.launch(share = True) |