File size: 4,153 Bytes
1f7d4dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from skimage.exposure import match_histograms
from skimage import io
import os
from PIL import Image
import torch
import torchvision
import torchvision.transforms as transforms

def normalize():
  MEAN = [0.485, 0.456, 0.406]
  STD = [0.229, 0.224, 0.225]
  return transforms.Normalize(mean = MEAN, std = STD)

def denormalize():
  # out = (x - mean) / std
  MEAN = [0.485, 0.456, 0.406]
  STD = [0.229, 0.224, 0.225]
  MEAN = [-mean/std for mean, std in zip(MEAN, STD)]
  STD = [1/std for std in STD]
  return transforms.Normalize(mean=MEAN, std=STD)

def transformer(imsize = None, cropsize = None):
  transformer = []
  if imsize:
        transformer.append(transforms.Resize(imsize))
  if cropsize:
        transformer.append(transforms.RandomCrop(cropsize))

  transformer.append(transforms.ToTensor())
  transformer.append(normalize())
  return transforms.Compose(transformer)

def load_img(path, imsize = None, cropsize = None):
  transform = transformer(imsize = imsize, cropsize = cropsize)
  # torchvision.transforms supports PIL Images
  return transform(Image.open(path).convert("RGB")).unsqueeze(0)

def tensor_to_img(tensor):
  denormalizer = denormalize()
  if tensor.device == "cuda":
    tensor = tensor.cpu()
  #
  tensor = torchvision.utils.make_grid(denormalizer(tensor.squeeze()))
  image = transforms.functional.to_pil_image(tensor.clamp_(0., 1.))
  return image

def save_img(tensor, path):
  pass

def histogram_matching(image, reference):
  """
  img: style image
  reference: original img
  output: style image that resembles original img's color histogram
  """
  device = image.device
  reference = reference.cpu().permute(1, 2, 0).numpy()
  image = image.cpu().permute(1, 2, 0).numpy()
  output = match_histograms(image, reference, multichannel = True)
  return torch.Tensor(output).permute(2, 0, 1).to(device)

def batch_histogram_matching(images, reference):
  """
  images of shape BxCxHxW
  reference of shape 1xCxHxW
  """
  reference = reference.squeeze()
  output = torch.zeros_like(images, dtype = images.dtype)
  B = images.shape[0]
  for i in range(B):
    output[i] = histogram_matching(images[i], reference)
  return output

def statistics(f, inverse = False, eps = 1e-10):
  c, h, w = f.shape
  f_mean = torch.mean(f.view(c, h*w), dim=1, keepdim=True)
  f_zeromean = f.view(c, h*w) - f_mean
  f_cov = torch.mm(f_zeromean, f_zeromean.t())

  u, s, v = torch.svd(f_cov)

  k = c
  for i in range(c):
    if s[i] < eps:
        k = i
        break
  if inverse:
        p = -0.5
  else:
        p = 0.5

  f_covsqrt = torch.mm(torch.mm(u[:, 0:k], torch.diag(s[0:k].pow(p))), v[:, 0:k].t())
  return f_mean, f_covsqrt

def whitening(f):
  c, h, w = f.shape
  f_mean, f_inv_covsqrt = statistics(f, inverse = True)
  whitened_f = torch.mm(f_inv_covsqrt, f.view(c, h*w) - f_mean)
  return whitened_f.view(c, h, w)

def batch_whitening(f):
  b, c, h, w = f.shape
  whitened_f = torch.zeros(size = (b, c, h, w), dtype = f.dtype, device = f.device)
  for i in range(b):
    whitened_f[i] = whitening(f[i])
  return whitened_f

def coloring(style, content):
  s_c, s_h, s_w = style.shape
  c_mean, c_covsqrt = statistics(content, inverse = False)
  colored_s = torch.mm(c_covsqrt, whitening(style).view(s_c, s_h * s_w)) + c_mean
  return colored_s.view(s_c, s_h, s_w)

def batch_coloring(styles, content):
  colored_styles = torch.zeros_like(styles, dtype = styles.dtype, device = styles.device)
  for i, style in enumerate(styles):
    colored_styles[i] = coloring(style, content[i])

  return colored_styles

def batch_wct(styles, content):
  whitened_styles = batch_whitening(styles)
  return batch_coloring(whitened_styles, content)

class Image_Set(torch.utils.data.Dataset):
  def __init__(self, root_path, imsize, cropsize):
    super(Image_Set, self).__init__()
    self.root_path = root_path
    self.files = sorted(os.listdir(self.root_path))
    self.transformer = transformer(imsize, cropsize)

  def __len__(self):
    return len(self.file_names)

  def __getitem__(self, index):
    image = Image.open(os.path.join(self.root_path + self.file_names[index])).convert("RGB")
    return self.transformer(image)