1-13-am's picture
Upload 6 files
1f7d4dd
raw
history blame
4.15 kB
from skimage.exposure import match_histograms
from skimage import io
import os
from PIL import Image
import torch
import torchvision
import torchvision.transforms as transforms
def normalize():
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
return transforms.Normalize(mean = MEAN, std = STD)
def denormalize():
# out = (x - mean) / std
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
MEAN = [-mean/std for mean, std in zip(MEAN, STD)]
STD = [1/std for std in STD]
return transforms.Normalize(mean=MEAN, std=STD)
def transformer(imsize = None, cropsize = None):
transformer = []
if imsize:
transformer.append(transforms.Resize(imsize))
if cropsize:
transformer.append(transforms.RandomCrop(cropsize))
transformer.append(transforms.ToTensor())
transformer.append(normalize())
return transforms.Compose(transformer)
def load_img(path, imsize = None, cropsize = None):
transform = transformer(imsize = imsize, cropsize = cropsize)
# torchvision.transforms supports PIL Images
return transform(Image.open(path).convert("RGB")).unsqueeze(0)
def tensor_to_img(tensor):
denormalizer = denormalize()
if tensor.device == "cuda":
tensor = tensor.cpu()
#
tensor = torchvision.utils.make_grid(denormalizer(tensor.squeeze()))
image = transforms.functional.to_pil_image(tensor.clamp_(0., 1.))
return image
def save_img(tensor, path):
pass
def histogram_matching(image, reference):
"""
img: style image
reference: original img
output: style image that resembles original img's color histogram
"""
device = image.device
reference = reference.cpu().permute(1, 2, 0).numpy()
image = image.cpu().permute(1, 2, 0).numpy()
output = match_histograms(image, reference, multichannel = True)
return torch.Tensor(output).permute(2, 0, 1).to(device)
def batch_histogram_matching(images, reference):
"""
images of shape BxCxHxW
reference of shape 1xCxHxW
"""
reference = reference.squeeze()
output = torch.zeros_like(images, dtype = images.dtype)
B = images.shape[0]
for i in range(B):
output[i] = histogram_matching(images[i], reference)
return output
def statistics(f, inverse = False, eps = 1e-10):
c, h, w = f.shape
f_mean = torch.mean(f.view(c, h*w), dim=1, keepdim=True)
f_zeromean = f.view(c, h*w) - f_mean
f_cov = torch.mm(f_zeromean, f_zeromean.t())
u, s, v = torch.svd(f_cov)
k = c
for i in range(c):
if s[i] < eps:
k = i
break
if inverse:
p = -0.5
else:
p = 0.5
f_covsqrt = torch.mm(torch.mm(u[:, 0:k], torch.diag(s[0:k].pow(p))), v[:, 0:k].t())
return f_mean, f_covsqrt
def whitening(f):
c, h, w = f.shape
f_mean, f_inv_covsqrt = statistics(f, inverse = True)
whitened_f = torch.mm(f_inv_covsqrt, f.view(c, h*w) - f_mean)
return whitened_f.view(c, h, w)
def batch_whitening(f):
b, c, h, w = f.shape
whitened_f = torch.zeros(size = (b, c, h, w), dtype = f.dtype, device = f.device)
for i in range(b):
whitened_f[i] = whitening(f[i])
return whitened_f
def coloring(style, content):
s_c, s_h, s_w = style.shape
c_mean, c_covsqrt = statistics(content, inverse = False)
colored_s = torch.mm(c_covsqrt, whitening(style).view(s_c, s_h * s_w)) + c_mean
return colored_s.view(s_c, s_h, s_w)
def batch_coloring(styles, content):
colored_styles = torch.zeros_like(styles, dtype = styles.dtype, device = styles.device)
for i, style in enumerate(styles):
colored_styles[i] = coloring(style, content[i])
return colored_styles
def batch_wct(styles, content):
whitened_styles = batch_whitening(styles)
return batch_coloring(whitened_styles, content)
class Image_Set(torch.utils.data.Dataset):
def __init__(self, root_path, imsize, cropsize):
super(Image_Set, self).__init__()
self.root_path = root_path
self.files = sorted(os.listdir(self.root_path))
self.transformer = transformer(imsize, cropsize)
def __len__(self):
return len(self.file_names)
def __getitem__(self, index):
image = Image.open(os.path.join(self.root_path + self.file_names[index])).convert("RGB")
return self.transformer(image)