Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torchvision
|
3 |
+
|
4 |
+
def normalize():
|
5 |
+
MEAN = [0.485, 0.456, 0.406]
|
6 |
+
STD = [0.229, 0.224, 0.225]
|
7 |
+
return transforms.Normalize(mean = MEAN, std = STD)
|
8 |
+
|
9 |
+
def denormalize():
|
10 |
+
# out = (x - mean) / std
|
11 |
+
MEAN = [0.485, 0.456, 0.406]
|
12 |
+
STD = [0.229, 0.224, 0.225]
|
13 |
+
MEAN = [-mean/std for mean, std in zip(MEAN, STD)]
|
14 |
+
STD = [1/std for std in STD]
|
15 |
+
return transforms.Normalize(mean=MEAN, std=STD)
|
16 |
+
def transformer(imsize = None, cropsize = None):
|
17 |
+
transformer = []
|
18 |
+
if imsize:
|
19 |
+
transformer.append(transforms.Resize(imsize))
|
20 |
+
if cropsize:
|
21 |
+
transformer.append(transforms.RandomCrop(cropsize))
|
22 |
+
|
23 |
+
transformer.append(transforms.ToTensor())
|
24 |
+
transformer.append(normalize())
|
25 |
+
return transforms.Compose(transformer)
|
26 |
+
def tensor_to_img(tensor):
|
27 |
+
denormalizer = denormalize()
|
28 |
+
if tensor.device == "cuda":
|
29 |
+
tensor = tensor.cpu()
|
30 |
+
#
|
31 |
+
tensor = torchvision.utils.make_grid(denormalizer(tensor.squeeze()))
|
32 |
+
image = transforms.functional.to_pil_image(tensor.clamp_(0., 1.))
|
33 |
+
return image
|
34 |
+
|
35 |
+
def style_transfer(content_img, style_strength, style_img_1 = None, iw_1 = 0., style_img_2 = None, iw_2 = 0., style_img_3 = None, iw_3 = 0., preserve_color = None):
|
36 |
+
transform = transformer(imsize = 512)
|
37 |
+
|
38 |
+
content = transform(content_img).unsqueeze(0).cuda()
|
39 |
+
|
40 |
+
iw = [iw_1, iw_2, iw_3]
|
41 |
+
interpolation_weights = [i/ sum(iw) for i in iw]
|
42 |
+
|
43 |
+
style_imgs = [style_img_1, style_img_2, style_img_3]
|
44 |
+
styles = []
|
45 |
+
for style_img in style_imgs:
|
46 |
+
if style_img is not None:
|
47 |
+
styles.append(transform(style_img).unsqueeze(0).cuda())
|
48 |
+
if preserve_color == "None": preserve_color = None
|
49 |
+
elif preserve_color == "Whitening": preserve_color = "batch_wct"
|
50 |
+
#elif preserve_color == "Histogram matching": preserve_color = "histogram_matching"
|
51 |
+
with torch.no_grad():
|
52 |
+
stylized_img = model(content, styles, interpolation_weights, preserve_color, style_strength)
|
53 |
+
return tensor_to_img(stylized_img)
|
54 |
+
|
55 |
+
title = "Artistic Style Transfer"
|
56 |
+
|
57 |
+
content_img = gr.components.Image(label="Content image", type = "pil")
|
58 |
+
|
59 |
+
style_img_1 = gr.components.Image(label="Style images", type = "pil")
|
60 |
+
iw_1 = gr.components.Slider(0., 1., label = "Style 1 interpolation")
|
61 |
+
style_img_2 = gr.components.Image(label="Style images", type = "pil")
|
62 |
+
iw_2 = gr.components.Slider(0., 1., label = "Style 2 interpolation")
|
63 |
+
style_img_3 = gr.components.Image(label="Style images", type = "pil")
|
64 |
+
iw_3 = gr.components.Slider(0., 1., label = "Style 3 interpolation")
|
65 |
+
style_strength = gr.components.Slider(0., 1., label = "Adjust style strength")
|
66 |
+
preserve_color = gr.components.Dropdown(["None", "Whitening"], label = "Choose color preserving mode")
|
67 |
+
|
68 |
+
interface = gr.Interface(fn = style_transfer,
|
69 |
+
inputs = [content_img,
|
70 |
+
style_strength,
|
71 |
+
style_img_1,
|
72 |
+
iw_1,
|
73 |
+
style_img_2,
|
74 |
+
iw_2,
|
75 |
+
style_img_3,
|
76 |
+
iw_3,
|
77 |
+
preserve_color],
|
78 |
+
outputs = gr.components.Image(),
|
79 |
+
title = title,
|
80 |
+
description = None
|
81 |
+
)
|
82 |
+
interface.queue()
|
83 |
+
interface.launch(share = True, debug = True)
|