1-13-am commited on
Commit
a5f14de
1 Parent(s): 4509b17

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torchvision
3
+
4
+ def normalize():
5
+ MEAN = [0.485, 0.456, 0.406]
6
+ STD = [0.229, 0.224, 0.225]
7
+ return transforms.Normalize(mean = MEAN, std = STD)
8
+
9
+ def denormalize():
10
+ # out = (x - mean) / std
11
+ MEAN = [0.485, 0.456, 0.406]
12
+ STD = [0.229, 0.224, 0.225]
13
+ MEAN = [-mean/std for mean, std in zip(MEAN, STD)]
14
+ STD = [1/std for std in STD]
15
+ return transforms.Normalize(mean=MEAN, std=STD)
16
+ def transformer(imsize = None, cropsize = None):
17
+ transformer = []
18
+ if imsize:
19
+ transformer.append(transforms.Resize(imsize))
20
+ if cropsize:
21
+ transformer.append(transforms.RandomCrop(cropsize))
22
+
23
+ transformer.append(transforms.ToTensor())
24
+ transformer.append(normalize())
25
+ return transforms.Compose(transformer)
26
+ def tensor_to_img(tensor):
27
+ denormalizer = denormalize()
28
+ if tensor.device == "cuda":
29
+ tensor = tensor.cpu()
30
+ #
31
+ tensor = torchvision.utils.make_grid(denormalizer(tensor.squeeze()))
32
+ image = transforms.functional.to_pil_image(tensor.clamp_(0., 1.))
33
+ return image
34
+
35
+ def style_transfer(content_img, style_strength, style_img_1 = None, iw_1 = 0., style_img_2 = None, iw_2 = 0., style_img_3 = None, iw_3 = 0., preserve_color = None):
36
+ transform = transformer(imsize = 512)
37
+
38
+ content = transform(content_img).unsqueeze(0).cuda()
39
+
40
+ iw = [iw_1, iw_2, iw_3]
41
+ interpolation_weights = [i/ sum(iw) for i in iw]
42
+
43
+ style_imgs = [style_img_1, style_img_2, style_img_3]
44
+ styles = []
45
+ for style_img in style_imgs:
46
+ if style_img is not None:
47
+ styles.append(transform(style_img).unsqueeze(0).cuda())
48
+ if preserve_color == "None": preserve_color = None
49
+ elif preserve_color == "Whitening": preserve_color = "batch_wct"
50
+ #elif preserve_color == "Histogram matching": preserve_color = "histogram_matching"
51
+ with torch.no_grad():
52
+ stylized_img = model(content, styles, interpolation_weights, preserve_color, style_strength)
53
+ return tensor_to_img(stylized_img)
54
+
55
+ title = "Artistic Style Transfer"
56
+
57
+ content_img = gr.components.Image(label="Content image", type = "pil")
58
+
59
+ style_img_1 = gr.components.Image(label="Style images", type = "pil")
60
+ iw_1 = gr.components.Slider(0., 1., label = "Style 1 interpolation")
61
+ style_img_2 = gr.components.Image(label="Style images", type = "pil")
62
+ iw_2 = gr.components.Slider(0., 1., label = "Style 2 interpolation")
63
+ style_img_3 = gr.components.Image(label="Style images", type = "pil")
64
+ iw_3 = gr.components.Slider(0., 1., label = "Style 3 interpolation")
65
+ style_strength = gr.components.Slider(0., 1., label = "Adjust style strength")
66
+ preserve_color = gr.components.Dropdown(["None", "Whitening"], label = "Choose color preserving mode")
67
+
68
+ interface = gr.Interface(fn = style_transfer,
69
+ inputs = [content_img,
70
+ style_strength,
71
+ style_img_1,
72
+ iw_1,
73
+ style_img_2,
74
+ iw_2,
75
+ style_img_3,
76
+ iw_3,
77
+ preserve_color],
78
+ outputs = gr.components.Image(),
79
+ title = title,
80
+ description = None
81
+ )
82
+ interface.queue()
83
+ interface.launch(share = True, debug = True)