Spaces:
Running
Running
File size: 4,153 Bytes
1f7d4dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
from skimage.exposure import match_histograms
from skimage import io
import os
from PIL import Image
import torch
import torchvision
import torchvision.transforms as transforms
def normalize():
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
return transforms.Normalize(mean = MEAN, std = STD)
def denormalize():
# out = (x - mean) / std
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
MEAN = [-mean/std for mean, std in zip(MEAN, STD)]
STD = [1/std for std in STD]
return transforms.Normalize(mean=MEAN, std=STD)
def transformer(imsize = None, cropsize = None):
transformer = []
if imsize:
transformer.append(transforms.Resize(imsize))
if cropsize:
transformer.append(transforms.RandomCrop(cropsize))
transformer.append(transforms.ToTensor())
transformer.append(normalize())
return transforms.Compose(transformer)
def load_img(path, imsize = None, cropsize = None):
transform = transformer(imsize = imsize, cropsize = cropsize)
# torchvision.transforms supports PIL Images
return transform(Image.open(path).convert("RGB")).unsqueeze(0)
def tensor_to_img(tensor):
denormalizer = denormalize()
if tensor.device == "cuda":
tensor = tensor.cpu()
#
tensor = torchvision.utils.make_grid(denormalizer(tensor.squeeze()))
image = transforms.functional.to_pil_image(tensor.clamp_(0., 1.))
return image
def save_img(tensor, path):
pass
def histogram_matching(image, reference):
"""
img: style image
reference: original img
output: style image that resembles original img's color histogram
"""
device = image.device
reference = reference.cpu().permute(1, 2, 0).numpy()
image = image.cpu().permute(1, 2, 0).numpy()
output = match_histograms(image, reference, multichannel = True)
return torch.Tensor(output).permute(2, 0, 1).to(device)
def batch_histogram_matching(images, reference):
"""
images of shape BxCxHxW
reference of shape 1xCxHxW
"""
reference = reference.squeeze()
output = torch.zeros_like(images, dtype = images.dtype)
B = images.shape[0]
for i in range(B):
output[i] = histogram_matching(images[i], reference)
return output
def statistics(f, inverse = False, eps = 1e-10):
c, h, w = f.shape
f_mean = torch.mean(f.view(c, h*w), dim=1, keepdim=True)
f_zeromean = f.view(c, h*w) - f_mean
f_cov = torch.mm(f_zeromean, f_zeromean.t())
u, s, v = torch.svd(f_cov)
k = c
for i in range(c):
if s[i] < eps:
k = i
break
if inverse:
p = -0.5
else:
p = 0.5
f_covsqrt = torch.mm(torch.mm(u[:, 0:k], torch.diag(s[0:k].pow(p))), v[:, 0:k].t())
return f_mean, f_covsqrt
def whitening(f):
c, h, w = f.shape
f_mean, f_inv_covsqrt = statistics(f, inverse = True)
whitened_f = torch.mm(f_inv_covsqrt, f.view(c, h*w) - f_mean)
return whitened_f.view(c, h, w)
def batch_whitening(f):
b, c, h, w = f.shape
whitened_f = torch.zeros(size = (b, c, h, w), dtype = f.dtype, device = f.device)
for i in range(b):
whitened_f[i] = whitening(f[i])
return whitened_f
def coloring(style, content):
s_c, s_h, s_w = style.shape
c_mean, c_covsqrt = statistics(content, inverse = False)
colored_s = torch.mm(c_covsqrt, whitening(style).view(s_c, s_h * s_w)) + c_mean
return colored_s.view(s_c, s_h, s_w)
def batch_coloring(styles, content):
colored_styles = torch.zeros_like(styles, dtype = styles.dtype, device = styles.device)
for i, style in enumerate(styles):
colored_styles[i] = coloring(style, content[i])
return colored_styles
def batch_wct(styles, content):
whitened_styles = batch_whitening(styles)
return batch_coloring(whitened_styles, content)
class Image_Set(torch.utils.data.Dataset):
def __init__(self, root_path, imsize, cropsize):
super(Image_Set, self).__init__()
self.root_path = root_path
self.files = sorted(os.listdir(self.root_path))
self.transformer = transformer(imsize, cropsize)
def __len__(self):
return len(self.file_names)
def __getitem__(self, index):
image = Image.open(os.path.join(self.root_path + self.file_names[index])).convert("RGB")
return self.transformer(image)
|