File size: 3,365 Bytes
a5f14de
 
6c74699
a5f14de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c74699
a5f14de
 
 
 
 
 
 
 
6c74699
a5f14de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c74699
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import torchvision
import torchvision.transforms as transforms

def normalize():
  MEAN = [0.485, 0.456, 0.406]
  STD = [0.229, 0.224, 0.225]
  return transforms.Normalize(mean = MEAN, std = STD)

def denormalize():
  # out = (x - mean) / std
  MEAN = [0.485, 0.456, 0.406]
  STD = [0.229, 0.224, 0.225]
  MEAN = [-mean/std for mean, std in zip(MEAN, STD)]
  STD = [1/std for std in STD]
  return transforms.Normalize(mean=MEAN, std=STD)
def transformer(imsize = None, cropsize = None):
  transformer = []
  if imsize:
        transformer.append(transforms.Resize(imsize))
  if cropsize:
        transformer.append(transforms.RandomCrop(cropsize))

  transformer.append(transforms.ToTensor())
  transformer.append(normalize())
  return transforms.Compose(transformer)
def tensor_to_img(tensor):
  denormalizer = denormalize()
  if tensor.device == "cuda":
    tensor = tensor.cpu()
  #
  tensor = torchvision.utils.make_grid(denormalizer(tensor.squeeze()))
  image = transforms.functional.to_pil_image(tensor.clamp_(0., 1.))
  return image

def style_transfer(content_img, style_strength, style_img_1 = None, iw_1 = 0., style_img_2 = None, iw_2 = 0., style_img_3 = None, iw_3 = 0., preserve_color = None):
    transform = transformer(imsize = 512)

    content = transform(content_img).unsqueeze(0)

    iw = [iw_1, iw_2, iw_3]
    interpolation_weights = [i/ sum(iw) for i in iw]

    style_imgs = [style_img_1, style_img_2, style_img_3]
    styles = []
    for style_img in style_imgs:
      if style_img is not None:
        styles.append(transform(style_img).unsqueeze(0))
    if preserve_color == "None": preserve_color = None
    elif preserve_color == "Whitening": preserve_color = "batch_wct"
    #elif preserve_color == "Histogram matching": preserve_color = "histogram_matching"
    with torch.no_grad():
        stylized_img = model(content, styles, interpolation_weights, preserve_color, style_strength)
    return tensor_to_img(stylized_img)

title = "Artistic Style Transfer"

content_img = gr.components.Image(label="Content image", type = "pil")

style_img_1 = gr.components.Image(label="Style images", type = "pil")
iw_1 = gr.components.Slider(0., 1., label = "Style 1 interpolation")
style_img_2 = gr.components.Image(label="Style images", type = "pil")
iw_2 = gr.components.Slider(0., 1., label = "Style 2 interpolation")
style_img_3 = gr.components.Image(label="Style images", type = "pil")
iw_3 = gr.components.Slider(0., 1., label = "Style 3 interpolation")
style_strength =  gr.components.Slider(0., 1., label = "Adjust style strength")
preserve_color = gr.components.Dropdown(["None", "Whitening"], label = "Choose color preserving mode")

interface = gr.Interface(fn = style_transfer,
                         inputs = [content_img,
                                   style_strength,
                                   style_img_1,
                                   iw_1,
                                   style_img_2,
                                   iw_2,
                                   style_img_3,
                                   iw_3,
                                   preserve_color],
                         outputs = gr.components.Image(),
                         title = title,
                         description = None
                         )
interface.queue()
interface.launch(share = True)