Meloo commited on
Commit
b53be1a
1 Parent(s): 2620c02

Create utils/color_fix.py

Browse files
Files changed (1) hide show
  1. utils/color_fix.py +115 -0
utils/color_fix.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from torch import Tensor
4
+ from torch.nn import functional as F
5
+
6
+ from torchvision.transforms import ToTensor, ToPILImage
7
+
8
+ def adain_color_fix(target: Image, source: Image):
9
+ # Convert images to tensors
10
+ to_tensor = ToTensor()
11
+ target_tensor = to_tensor(target).unsqueeze(0)
12
+ source_tensor = to_tensor(source).unsqueeze(0)
13
+
14
+ # Apply adaptive instance normalization
15
+ result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
16
+
17
+ # Convert tensor back to image
18
+ to_image = ToPILImage()
19
+ result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
20
+
21
+ return result_image
22
+
23
+ def wavelet_color_fix(target: Image, source: Image):
24
+ if target.size() != source.size():
25
+ source = source.resize((target.size()[-2], target.size()[-1]), Image.LANCZOS)
26
+ # Convert images to tensors
27
+ to_tensor = ToTensor()
28
+ target_tensor = to_tensor(target).unsqueeze(0)
29
+ source_tensor = to_tensor(source).unsqueeze(0)
30
+
31
+ # Apply wavelet reconstruction
32
+ result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
33
+
34
+ # Convert tensor back to image
35
+ to_image = ToPILImage()
36
+ result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
37
+
38
+ return result_image
39
+
40
+ def calc_mean_std(feat: Tensor, eps=1e-5):
41
+ """Calculate mean and std for adaptive_instance_normalization.
42
+ Args:
43
+ feat (Tensor): 4D tensor.
44
+ eps (float): A small value added to the variance to avoid
45
+ divide-by-zero. Default: 1e-5.
46
+ """
47
+ size = feat.size()
48
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
49
+ b, c = size[:2]
50
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
51
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
52
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
53
+ return feat_mean, feat_std
54
+
55
+ def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
56
+ """Adaptive instance normalization.
57
+ Adjust the reference features to have the similar color and illuminations
58
+ as those in the degradate features.
59
+ Args:
60
+ content_feat (Tensor): The reference feature.
61
+ style_feat (Tensor): The degradate features.
62
+ """
63
+ size = content_feat.size()
64
+ style_mean, style_std = calc_mean_std(style_feat)
65
+ content_mean, content_std = calc_mean_std(content_feat)
66
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
67
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
68
+
69
+ def wavelet_blur(image: Tensor, radius: int):
70
+ """
71
+ Apply wavelet blur to the input tensor.
72
+ """
73
+ # input shape: (1, 3, H, W)
74
+ # convolution kernel
75
+ kernel_vals = [
76
+ [0.0625, 0.125, 0.0625],
77
+ [0.125, 0.25, 0.125],
78
+ [0.0625, 0.125, 0.0625],
79
+ ]
80
+ kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
81
+ # add channel dimensions to the kernel to make it a 4D tensor
82
+ kernel = kernel[None, None]
83
+ # repeat the kernel across all input channels
84
+ kernel = kernel.repeat(3, 1, 1, 1)
85
+ image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
86
+ # apply convolution
87
+ output = F.conv2d(image, kernel, groups=3, dilation=radius)
88
+ return output
89
+
90
+ def wavelet_decomposition(image: Tensor, levels=5):
91
+ """
92
+ Apply wavelet decomposition to the input tensor.
93
+ This function only returns the low frequency & the high frequency.
94
+ """
95
+ high_freq = torch.zeros_like(image)
96
+ for i in range(levels):
97
+ radius = 2 ** i
98
+ low_freq = wavelet_blur(image, radius)
99
+ high_freq += (image - low_freq)
100
+ image = low_freq
101
+
102
+ return high_freq, low_freq
103
+
104
+ def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
105
+ """
106
+ Apply wavelet decomposition, so that the content will have the same color as the style.
107
+ """
108
+ # calculate the wavelet decomposition of the content feature
109
+ content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
110
+ del content_low_freq
111
+ # calculate the wavelet decomposition of the style feature
112
+ style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
113
+ del style_high_freq
114
+ # reconstruct the content feature with the style's high frequency
115
+ return content_high_freq + style_low_freq