Spaces:
Running
Running
from skimage.exposure import match_histograms | |
from skimage import io | |
import os | |
from PIL import Image | |
import torch | |
import torchvision | |
import torchvision.transforms as transforms | |
def normalize(): | |
MEAN = [0.485, 0.456, 0.406] | |
STD = [0.229, 0.224, 0.225] | |
return transforms.Normalize(mean = MEAN, std = STD) | |
def denormalize(): | |
# out = (x - mean) / std | |
MEAN = [0.485, 0.456, 0.406] | |
STD = [0.229, 0.224, 0.225] | |
MEAN = [-mean/std for mean, std in zip(MEAN, STD)] | |
STD = [1/std for std in STD] | |
return transforms.Normalize(mean=MEAN, std=STD) | |
def transformer(imsize = None, cropsize = None): | |
transformer = [] | |
if imsize: | |
transformer.append(transforms.Resize(imsize)) | |
if cropsize: | |
transformer.append(transforms.RandomCrop(cropsize)) | |
transformer.append(transforms.ToTensor()) | |
transformer.append(normalize()) | |
return transforms.Compose(transformer) | |
def load_img(path, imsize = None, cropsize = None): | |
transform = transformer(imsize = imsize, cropsize = cropsize) | |
# torchvision.transforms supports PIL Images | |
return transform(Image.open(path).convert("RGB")).unsqueeze(0) | |
def tensor_to_img(tensor): | |
denormalizer = denormalize() | |
if tensor.device == "cuda": | |
tensor = tensor.cpu() | |
# | |
tensor = torchvision.utils.make_grid(denormalizer(tensor.squeeze())) | |
image = transforms.functional.to_pil_image(tensor.clamp_(0., 1.)) | |
return image | |
def save_img(tensor, path): | |
pass | |
def histogram_matching(image, reference): | |
""" | |
img: style image | |
reference: original img | |
output: style image that resembles original img's color histogram | |
""" | |
device = image.device | |
reference = reference.cpu().permute(1, 2, 0).numpy() | |
image = image.cpu().permute(1, 2, 0).numpy() | |
output = match_histograms(image, reference, multichannel = True) | |
return torch.Tensor(output).permute(2, 0, 1).to(device) | |
def batch_histogram_matching(images, reference): | |
""" | |
images of shape BxCxHxW | |
reference of shape 1xCxHxW | |
""" | |
reference = reference.squeeze() | |
output = torch.zeros_like(images, dtype = images.dtype) | |
B = images.shape[0] | |
for i in range(B): | |
output[i] = histogram_matching(images[i], reference) | |
return output | |
def statistics(f, inverse = False, eps = 1e-10): | |
c, h, w = f.shape | |
f_mean = torch.mean(f.view(c, h*w), dim=1, keepdim=True) | |
f_zeromean = f.view(c, h*w) - f_mean | |
f_cov = torch.mm(f_zeromean, f_zeromean.t()) | |
u, s, v = torch.svd(f_cov) | |
k = c | |
for i in range(c): | |
if s[i] < eps: | |
k = i | |
break | |
if inverse: | |
p = -0.5 | |
else: | |
p = 0.5 | |
f_covsqrt = torch.mm(torch.mm(u[:, 0:k], torch.diag(s[0:k].pow(p))), v[:, 0:k].t()) | |
return f_mean, f_covsqrt | |
def whitening(f): | |
c, h, w = f.shape | |
f_mean, f_inv_covsqrt = statistics(f, inverse = True) | |
whitened_f = torch.mm(f_inv_covsqrt, f.view(c, h*w) - f_mean) | |
return whitened_f.view(c, h, w) | |
def batch_whitening(f): | |
b, c, h, w = f.shape | |
whitened_f = torch.zeros(size = (b, c, h, w), dtype = f.dtype, device = f.device) | |
for i in range(b): | |
whitened_f[i] = whitening(f[i]) | |
return whitened_f | |
def coloring(style, content): | |
s_c, s_h, s_w = style.shape | |
c_mean, c_covsqrt = statistics(content, inverse = False) | |
colored_s = torch.mm(c_covsqrt, whitening(style).view(s_c, s_h * s_w)) + c_mean | |
return colored_s.view(s_c, s_h, s_w) | |
def batch_coloring(styles, content): | |
colored_styles = torch.zeros_like(styles, dtype = styles.dtype, device = styles.device) | |
for i, style in enumerate(styles): | |
colored_styles[i] = coloring(style, content[i]) | |
return colored_styles | |
def batch_wct(styles, content): | |
whitened_styles = batch_whitening(styles) | |
return batch_coloring(whitened_styles, content) | |
class Image_Set(torch.utils.data.Dataset): | |
def __init__(self, root_path, imsize, cropsize): | |
super(Image_Set, self).__init__() | |
self.root_path = root_path | |
self.files = sorted(os.listdir(self.root_path)) | |
self.transformer = transformer(imsize, cropsize) | |
def __len__(self): | |
return len(self.file_names) | |
def __getitem__(self, index): | |
image = Image.open(os.path.join(self.root_path + self.file_names[index])).convert("RGB") | |
return self.transformer(image) | |