1-13-am commited on
Commit
1f7d4dd
·
1 Parent(s): 6e33f45

Upload 6 files

Browse files
Files changed (6) hide show
  1. UI.py +57 -0
  2. check_point1_0.pth +3 -0
  3. deploy.ipynb +195 -0
  4. network.py +127 -0
  5. train.ipynb +503 -0
  6. utils.py +138 -0
UI.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from utils import transformer, tensor_to_img
4
+ from network import Style_Transfer_Network
5
+
6
+ check_point = torch.load("/content/check_point.pth", map_location = torch.device('cpu'))
7
+ model = Style_Transfer_Network()
8
+ model.load_state_dict(check_point['state_dict'])
9
+
10
+ def style_transfer(content_img, style_strength, style_img_1 = None, iw_1 = 0, style_img_2 = None, iw_2 = 0, style_img_3 = None, iw_3 = 0, preserve_color = None):
11
+ transform = transformer(imsize = 512)
12
+
13
+ content = transform(content_img).unsqueeze(0)
14
+
15
+ iw = [iw_1, iw_2, iw_3]
16
+ interpolation_weights = [i/ sum(iw) for i in iw]
17
+
18
+ style_imgs = [style_img_1, style_img_2, style_img_3]
19
+ styles = []
20
+ for style_img in style_imgs:
21
+ if style_img is not None:
22
+ styles.append(transform(style_img).unsqueeze(0))
23
+ if preserve_color == "None": preserve_color = None
24
+ elif preserve_color == "Whitening": preserve_color = "batch_wct"
25
+ elif preserve_color == "Histogram matching": preserve_color = "histogram_matching"
26
+ with torch.no_grad():
27
+ stylized_img = model(content, styles, style_strength, interpolation_weights, preserve_color = preserve_color)
28
+ return tensor_to_img(stylized_img)
29
+
30
+ title = "Artistic Style Transfer"
31
+
32
+ content_img = gr.components.Image(label="Content image", type = "pil")
33
+
34
+ style_img_1 = gr.components.Image(label="Style images", type = "pil")
35
+ iw_1 = gr.components.Slider(0., 1., label = "Style 1 interpolation")
36
+ style_img_2 = gr.components.Image(label="Style images", type = "pil")
37
+ iw_2 = gr.components.Slider(0., 1., label = "Style 2 interpolation")
38
+ style_img_3 = gr.components.Image(label="Style images", type = "pil")
39
+ iw_3 = gr.components.Slider(0., 1., label = "Style 3 interpolation")
40
+ style_strength = gr.components.Slider(0., 1., label = "Adjust style strength")
41
+ preserve_color = gr.components.Dropdown(["None", "Whitening", "Histogram matching"], label = "Choose color preserving mode")
42
+
43
+ interface = gr.Interface(fn = style_transfer,
44
+ inputs = [content_img,
45
+ style_strength,
46
+ style_img_1,
47
+ iw_1,
48
+ style_img_2,
49
+ iw_2,
50
+ style_img_3,
51
+ iw_3,
52
+ preserve_color],
53
+ outputs = gr.components.Image(),
54
+ title = title
55
+ )
56
+ interface.queue()
57
+ interface.launch(share = True, debug = True)
check_point1_0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b500176427a41788b7314c77b6fdbbc6d474fd255f94b7787f7ee123cc092056
3
+ size 28057273
deploy.ipynb ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Note: you may need to restart the kernel to use updated packages.\n"
13
+ ]
14
+ }
15
+ ],
16
+ "source": [
17
+ "# Uncomment if you don't have the following modules\n",
18
+ "#pip install -qq gradio\n",
19
+ "#pip install -qq torch\n",
20
+ "#pip install -qq PIL\n",
21
+ "#pip install -qq torchvision"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 2,
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "import os\n",
31
+ "from PIL import Image\n",
32
+ "import torch\n",
33
+ "import torchvision\n",
34
+ "import torchvision.transforms as transforms\n",
35
+ "from utils import transformer, tensor_to_img\n",
36
+ "from network import Style_Transfer_Network\n",
37
+ "import gradio as gr"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": 3,
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "device = \"cpu\"\n",
47
+ "if torch.cuda.is_available(): device = \"cuda\""
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": 5,
53
+ "metadata": {},
54
+ "outputs": [
55
+ {
56
+ "name": "stderr",
57
+ "output_type": "stream",
58
+ "text": [
59
+ "C:\\Users\\VICTUS\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torchvision\\models\\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
60
+ " warnings.warn(\n",
61
+ "C:\\Users\\VICTUS\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torchvision\\models\\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG19_Weights.IMAGENET1K_V1`. You can also use `weights=VGG19_Weights.DEFAULT` to get the most up-to-date weights.\n",
62
+ " warnings.warn(msg)\n",
63
+ "C:\\Users\\VICTUS\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torchvision\\models\\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n",
64
+ " warnings.warn(msg)\n"
65
+ ]
66
+ },
67
+ {
68
+ "data": {
69
+ "text/plain": [
70
+ "<All keys matched successfully>"
71
+ ]
72
+ },
73
+ "execution_count": 5,
74
+ "metadata": {},
75
+ "output_type": "execute_result"
76
+ }
77
+ ],
78
+ "source": [
79
+ "#import gradio as gr\n",
80
+ "check_point = torch.load('check_point1_0.pth', map_location = device)\n",
81
+ "transfer_network = Style_Transfer_Network().to(device)\n",
82
+ "transfer_network.load_state_dict(check_point['state_dict'])"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": 6,
88
+ "metadata": {},
89
+ "outputs": [
90
+ {
91
+ "name": "stdout",
92
+ "output_type": "stream",
93
+ "text": [
94
+ "Running on local URL: http://127.0.0.1:7860\n",
95
+ "Running on public URL: https://b4e9024bf7c14725c6.gradio.live\n",
96
+ "\n",
97
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
98
+ ]
99
+ },
100
+ {
101
+ "data": {
102
+ "text/html": [
103
+ "<div><iframe src=\"https://b4e9024bf7c14725c6.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
104
+ ],
105
+ "text/plain": [
106
+ "<IPython.core.display.HTML object>"
107
+ ]
108
+ },
109
+ "metadata": {},
110
+ "output_type": "display_data"
111
+ },
112
+ {
113
+ "data": {
114
+ "text/plain": []
115
+ },
116
+ "execution_count": 6,
117
+ "metadata": {},
118
+ "output_type": "execute_result"
119
+ }
120
+ ],
121
+ "source": [
122
+ "def style_transfer(content_img, style_strength, style_img_1 = None, iw_1 = 0, style_img_2 = None, iw_2 = 0, style_img_3 = None, iw_3 = 0, preserve_color = None):\n",
123
+ " transform = transformer(imsize = 512)\n",
124
+ "\n",
125
+ " content = transform(content_img).unsqueeze(0).to(device)\n",
126
+ "\n",
127
+ " iw = [iw_1, iw_2, iw_3]\n",
128
+ " interpolation_weights = [i/ sum(iw) for i in iw]\n",
129
+ "\n",
130
+ " style_imgs = [style_img_1, style_img_2, style_img_3]\n",
131
+ " styles = []\n",
132
+ " for style_img in style_imgs:\n",
133
+ " if style_img is not None:\n",
134
+ " styles.append(transform(style_img).unsqueeze(0).to(device))\n",
135
+ " if preserve_color == \"None\": preserve_color = None\n",
136
+ " elif preserve_color == \"Whitening & Coloring\": preserve_color = \"whiten_and_color\"\n",
137
+ " elif preserve_color == \"Histogram matching\": preserve_color = \"histogram_matching\"\n",
138
+ " with torch.no_grad():\n",
139
+ " stylized_img = transfer_network(content, styles, style_strength, interpolation_weights, preserve_color = preserve_color)\n",
140
+ " return tensor_to_img(stylized_img)\n",
141
+ "\n",
142
+ "title = \"Artistic Style Transfer\"\n",
143
+ "\n",
144
+ "content_img = gr.components.Image(label=\"Content image\", type = \"pil\")\n",
145
+ "\n",
146
+ "style_img_1 = gr.components.Image(label=\"Style images\", type = \"pil\")\n",
147
+ "iw_1 = gr.components.Slider(0., 1., label = \"Style 1 interpolation\")\n",
148
+ "style_img_2 = gr.components.Image(label=\"Style images\", type = \"pil\")\n",
149
+ "iw_2 = gr.components.Slider(0., 1., label = \"Style 2 interpolation\")\n",
150
+ "style_img_3 = gr.components.Image(label=\"Style images\", type = \"pil\")\n",
151
+ "iw_3 = gr.components.Slider(0., 1., label = \"Style 3 interpolation\")\n",
152
+ "style_strength = gr.components.Slider(0., 1., label = \"Adjust style strength\")\n",
153
+ "preserve_color = gr.components.Dropdown([\"None\", \"Whitening & Coloring\", \"Histogram matching\"], label = \"Choose color preserving mode\")\n",
154
+ "\n",
155
+ "interface = gr.Interface(fn = style_transfer,\n",
156
+ " inputs = [content_img,\n",
157
+ " style_strength,\n",
158
+ " style_img_1,\n",
159
+ " iw_1,\n",
160
+ " style_img_2,\n",
161
+ " iw_2,\n",
162
+ " style_img_3,\n",
163
+ " iw_3,\n",
164
+ " preserve_color],\n",
165
+ " outputs = gr.components.Image(),\n",
166
+ " title = title,\n",
167
+ " \n",
168
+ " )\n",
169
+ "interface.queue()\n",
170
+ "interface.launch(share = True)"
171
+ ]
172
+ }
173
+ ],
174
+ "metadata": {
175
+ "kernelspec": {
176
+ "display_name": "Python 3",
177
+ "language": "python",
178
+ "name": "python3"
179
+ },
180
+ "language_info": {
181
+ "codemirror_mode": {
182
+ "name": "ipython",
183
+ "version": 3
184
+ },
185
+ "file_extension": ".py",
186
+ "mimetype": "text/x-python",
187
+ "name": "python",
188
+ "nbconvert_exporter": "python",
189
+ "pygments_lexer": "ipython3",
190
+ "version": "3.10.11"
191
+ }
192
+ },
193
+ "nbformat": 4,
194
+ "nbformat_minor": 2
195
+ }
network.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+ from torchvision.models import vgg19
5
+ import utils
6
+ from utils import batch_wct, batch_histogram_matching
7
+
8
+ class Encoder(nn.Module):
9
+ def __init__(self, layers = [1, 6, 11, 20]):
10
+ super(Encoder, self).__init__()
11
+ vgg = torchvision.models.vgg19(pretrained=True).features
12
+
13
+ self.encoder = nn.ModuleList()
14
+ temp_seq = nn.Sequential()
15
+ for i in range(max(layers)+1):
16
+ temp_seq.add_module(str(i), vgg[i])
17
+ if i in layers:
18
+ self.encoder.append(temp_seq)
19
+ temp_seq = nn.Sequential()
20
+
21
+ def forward(self, x):
22
+ features = []
23
+ for layer in self.encoder:
24
+ x = layer(x)
25
+ features.append(x)
26
+ return features
27
+
28
+ # need to copy the whole architecture bcuz we will need outputs from "layers" layers to compute the loss
29
+ class Decoder(nn.Module):
30
+ def __init__(self, layers=[1, 6, 11, 20]):
31
+ super(Decoder, self).__init__()
32
+ vgg = torchvision.models.vgg19(pretrained=False).features
33
+
34
+ self.decoder = nn.ModuleList()
35
+ temp_seq = nn.Sequential()
36
+ count = 0
37
+ for i in range(max(layers)-1, -1, -1):
38
+ if isinstance(vgg[i], nn.Conv2d):
39
+ # get number of in/out channels
40
+ out_channels = vgg[i].in_channels
41
+ in_channels = vgg[i].out_channels
42
+ kernel_size = vgg[i].kernel_size
43
+
44
+ # make a [reflection pad + convolution + relu] layer
45
+ temp_seq.add_module(str(count), nn.ReflectionPad2d(padding=(1,1,1,1)))
46
+ count += 1
47
+ temp_seq.add_module(str(count), nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size))
48
+ count += 1
49
+ temp_seq.add_module(str(count), nn.ReLU())
50
+ count += 1
51
+
52
+ # change down-sampling(MaxPooling) --> upsampling
53
+ elif isinstance(vgg[i], nn.MaxPool2d):
54
+ temp_seq.add_module(str(count), nn.Upsample(scale_factor=2))
55
+ count += 1
56
+
57
+ if i in layers:
58
+ self.decoder.append(temp_seq)
59
+ temp_seq = nn.Sequential()
60
+
61
+ # append last conv layers without ReLU activation
62
+ self.decoder.append(temp_seq[:-1])
63
+
64
+ def forward(self, x):
65
+ y = x
66
+ for layer in self.decoder:
67
+ y = layer(y)
68
+ return y
69
+
70
+ class AdaIN(nn.Module):
71
+ def __init__(self):
72
+ super(AdaIN, self).__init__()
73
+
74
+ def forward(self, content, style, style_strength=1.0, eps=1e-5):
75
+ """
76
+ content: tensor of shape B * C * H * W
77
+ style: tensor of shape B * C * H * W
78
+ note that AdaIN does computation on a pair of content - style img"""
79
+ b, c, h, w = content.size()
80
+
81
+ content_std, content_mean = torch.std_mean(content.view(b, c, -1), dim=2, keepdim=True)
82
+ style_std, style_mean = torch.std_mean(style.view(b, c, -1), dim=2, keepdim=True)
83
+
84
+ normalized_content = (content.view(b, c, -1) - content_mean) / (content_std+eps)
85
+
86
+ stylized_content = (normalized_content * style_std) + style_mean
87
+
88
+ output = (1-style_strength) * content + style_strength * stylized_content.view(b, c, h, w)
89
+ return output
90
+
91
+ class Style_Transfer_Network(nn.Module):
92
+ def __init__(self, layers = [1, 6, 11, 20]):
93
+ super(Style_Transfer_Network, self).__init__()
94
+ self.encoder = Encoder(layers)
95
+ self.decoder = Decoder(layers)
96
+ self.adain = AdaIN()
97
+
98
+ def forward(self, content, styles, style_strength = 1., interpolation_weights = None, preserve_color = None, train = False):
99
+ if interpolation_weights is None:
100
+ interpolation_weights = [1/len(styles)] * len(styles)
101
+ # encode the content image
102
+ content_feature = self.encoder(content)
103
+
104
+ # encode style images
105
+ style_features = []
106
+ for style in styles:
107
+ if preserve_color == 'whiten_and_color' or preserve_color == 'histogram_matching':
108
+ style = batch_wct(style, content)
109
+ style_features.append(self.encoder(style))
110
+
111
+ transformed_features = []
112
+ for style_feature, interpolation_weight in zip(style_features, interpolation_weights):
113
+ AdaIN_feature = self.adain(content_feature[-1], style_feature[-1], style_strength) * interpolation_weight
114
+ if preserve_color == 'histogram_matching':
115
+ AdaIN_feature *= 0.9
116
+ transformed_features.append(AdaIN_feature)
117
+ transformed_feature = sum(transformed_features)
118
+
119
+ stylized_image = self.decoder(transformed_feature)
120
+ if preserve_color == "whiten_and_color":
121
+ stylized_image = batch_wct(stylized_image, content)
122
+ if preserve_color == "histogram_matching":
123
+ stylized_image = batch_histogram_matching(stylized_image, content)
124
+ if train:
125
+ return stylized_image, transformed_feature
126
+ else:
127
+ return stylized_image
train.ipynb ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {
7
+ "id": "_qsogBHiKtzF",
8
+ "tags": []
9
+ },
10
+ "outputs": [
11
+ {
12
+ "name": "stdout",
13
+ "output_type": "stream",
14
+ "text": [
15
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
16
+ "datasets 2.4.0 requires dill<0.3.6, but you have dill 0.3.7 which is incompatible.\n",
17
+ "awscli 1.25.91 requires botocore==1.27.90, but you have botocore 1.31.17 which is incompatible.\u001b[0m\u001b[31m\n",
18
+ "\u001b[0m"
19
+ ]
20
+ }
21
+ ],
22
+ "source": [
23
+ "!pip install -qq hub\n",
24
+ "!pip install -qq flask"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 4,
30
+ "metadata": {
31
+ "id": "E8nHybN3KDIq",
32
+ "tags": []
33
+ },
34
+ "outputs": [],
35
+ "source": [
36
+ "import torch\n",
37
+ "import deeplake\n",
38
+ "from torch.utils.data import DataLoader\n",
39
+ "from torchvision import transforms\n",
40
+ "import torch.nn as nn\n",
41
+ "from network import Style_Transfer_Network, Encoder\n",
42
+ "from utils import save_img\n",
43
+ "import torchvision"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": 5,
49
+ "metadata": {
50
+ "colab": {
51
+ "base_uri": "https://localhost:8080/"
52
+ },
53
+ "id": "rnAFLCiIKqkM",
54
+ "outputId": "81b8f1c3-3974-4ee3-a284-99186c1502c7",
55
+ "tags": []
56
+ },
57
+ "outputs": [
58
+ {
59
+ "name": "stderr",
60
+ "output_type": "stream",
61
+ "text": [
62
+ "|"
63
+ ]
64
+ },
65
+ {
66
+ "name": "stdout",
67
+ "output_type": "stream",
68
+ "text": [
69
+ "Opening dataset in read-only mode as you don't have write permissions.\n"
70
+ ]
71
+ },
72
+ {
73
+ "name": "stderr",
74
+ "output_type": "stream",
75
+ "text": [
76
+ "-"
77
+ ]
78
+ },
79
+ {
80
+ "name": "stdout",
81
+ "output_type": "stream",
82
+ "text": [
83
+ "This dataset can be visualized in Jupyter Notebook by ds.visualize() or at https://app.activeloop.ai/activeloop/wiki-art\n",
84
+ "\n"
85
+ ]
86
+ },
87
+ {
88
+ "name": "stderr",
89
+ "output_type": "stream",
90
+ "text": [
91
+ "-"
92
+ ]
93
+ },
94
+ {
95
+ "name": "stdout",
96
+ "output_type": "stream",
97
+ "text": [
98
+ "hub://activeloop/wiki-art loaded successfully.\n",
99
+ "\n"
100
+ ]
101
+ },
102
+ {
103
+ "name": "stderr",
104
+ "output_type": "stream",
105
+ "text": [
106
+ " "
107
+ ]
108
+ },
109
+ {
110
+ "name": "stdout",
111
+ "output_type": "stream",
112
+ "text": [
113
+ "Opening dataset in read-only mode as you don't have write permissions.\n"
114
+ ]
115
+ },
116
+ {
117
+ "name": "stderr",
118
+ "output_type": "stream",
119
+ "text": [
120
+ "\\"
121
+ ]
122
+ },
123
+ {
124
+ "name": "stdout",
125
+ "output_type": "stream",
126
+ "text": [
127
+ "This dataset can be visualized in Jupyter Notebook by ds.visualize() or at https://app.activeloop.ai/activeloop/coco-test\n",
128
+ "\n"
129
+ ]
130
+ },
131
+ {
132
+ "name": "stderr",
133
+ "output_type": "stream",
134
+ "text": [
135
+ "\\"
136
+ ]
137
+ },
138
+ {
139
+ "name": "stdout",
140
+ "output_type": "stream",
141
+ "text": [
142
+ "hub://activeloop/coco-test loaded successfully.\n",
143
+ "\n"
144
+ ]
145
+ },
146
+ {
147
+ "name": "stderr",
148
+ "output_type": "stream",
149
+ "text": [
150
+ " "
151
+ ]
152
+ }
153
+ ],
154
+ "source": [
155
+ "reshape_size = 512\n",
156
+ "crop_size = 256\n",
157
+ "def any_to_rgb(img):\n",
158
+ " return img.convert('RGB')\n",
159
+ "preprocess = transforms.Compose([\n",
160
+ " transforms.Lambda(any_to_rgb),\n",
161
+ " transforms.ToTensor(),\n",
162
+ " transforms.Resize(reshape_size),\n",
163
+ " transforms.RandomCrop(crop_size),\n",
164
+ " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
165
+ " ])\n",
166
+ "wiki_art_dataset = deeplake.load('hub://activeloop/wiki-art')\n",
167
+ "coco_dataset = deeplake.load('hub://activeloop/coco-test')\n",
168
+ "\n",
169
+ "style_data_loader = wiki_art_dataset.pytorch(batch_size = 8, num_workers = 0,\n",
170
+ " transform = {'images': preprocess, 'labels': None}, shuffle = True, decode_method = {'images':'pil'})\n",
171
+ "\n",
172
+ "cnt_data_loader = coco_dataset.pytorch(batch_size = 8, num_workers = 0,\n",
173
+ " transform = {'images': preprocess}, shuffle = True, decode_method = {'images': 'pil'})\n"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": 7,
179
+ "metadata": {
180
+ "id": "XKqi9mMyoNUy",
181
+ "tags": []
182
+ },
183
+ "outputs": [],
184
+ "source": [
185
+ "mse_loss = nn.MSELoss(reduction = 'mean')\n",
186
+ "def content_loss(source, target):\n",
187
+ " cnt_loss = mse_loss(source, target)\n",
188
+ " return cnt_loss\n",
189
+ "\n",
190
+ "def style_loss(features, targets):\n",
191
+ " loss = 0\n",
192
+ " for feature, target in zip(features, targets):\n",
193
+ " B, C, H, W = feature.shape\n",
194
+ " feature_std, feature_mean = torch.std_mean(feature.view(B, C, -1), dim = 2)\n",
195
+ " target_std, target_mean = torch.std_mean(target.view(B, C, -1), dim = 2)\n",
196
+ " loss += mse_loss(feature_std, target_std) + mse_loss(feature_mean, target_mean)\n",
197
+ " return loss * 1. / len(features)\n",
198
+ "\"\"\"\n",
199
+ "def style_loss(features, targets, weights=None):\n",
200
+ " if weights is None:\n",
201
+ " weights = [1/len(features)] * len(features)\n",
202
+ " \n",
203
+ " loss = 0\n",
204
+ " for feature, target, weight in zip(features, targets, weights):\n",
205
+ " b, c, h, w = feature.size()\n",
206
+ " feature_std, feature_mean = torch.std_mean(feature.view(b, c, -1), dim=2)\n",
207
+ " target_std, target_mean = torch.std_mean(target.view(b, c, -1), dim=2)\n",
208
+ " loss += (mse_loss(feature_std, target_std) + mse_loss(feature_mean, target_mean))*weight\n",
209
+ " return loss\n",
210
+ "\"\"\"\n",
211
+ "def total_variational_loss(images):\n",
212
+ " loss = 0.0\n",
213
+ " B = images.shape[0]\n",
214
+ " vertical_up = images[:,:,:-1]\n",
215
+ " vertical_down = images[:,:,1:]\n",
216
+ "\n",
217
+ " horizontal_up = images[:,:,:,:-1]\n",
218
+ " horizontal_down = images[:,:,:,1:]\n",
219
+ "\n",
220
+ " loss = ((vertical_up - vertical_down) ** 2).sum() + \\\n",
221
+ " ((horizontal_up - horizontal_down) ** 2).sum()\n",
222
+ "\n",
223
+ " return loss * 1.0 / B"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": 8,
229
+ "metadata": {
230
+ "id": "JAeuZ2Sq6E-0",
231
+ "tags": []
232
+ },
233
+ "outputs": [],
234
+ "source": [
235
+ "if torch.cuda.is_available():\n",
236
+ " device = \"cuda\"\n",
237
+ "else: device = \"cpu\""
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "code",
242
+ "execution_count": 14,
243
+ "metadata": {},
244
+ "outputs": [
245
+ {
246
+ "data": {
247
+ "text/plain": [
248
+ "<All keys matched successfully>"
249
+ ]
250
+ },
251
+ "execution_count": 14,
252
+ "metadata": {},
253
+ "output_type": "execute_result"
254
+ }
255
+ ],
256
+ "source": [
257
+ "style_transfer_network = Style_Transfer_Network().to(device)\n",
258
+ "check_point = torch.load(\"/notebooks/Style_transfer_with_ADAin/check_point.pth\", map_location = 'cuda')\n",
259
+ "style_transfer_network.load_state_dict(check_point['state_dict'])"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "code",
264
+ "execution_count": 15,
265
+ "metadata": {},
266
+ "outputs": [],
267
+ "source": [
268
+ "def denormalize():\n",
269
+ " # out = (x - mean) / std\n",
270
+ " MEAN = [0.485, 0.456, 0.406]\n",
271
+ " STD = [0.229, 0.224, 0.225]\n",
272
+ " MEAN = [-mean/std for mean, std in zip(MEAN, STD)]\n",
273
+ " STD = [1/std for std in STD]\n",
274
+ " return transforms.Normalize(mean=MEAN, std=STD)\n",
275
+ "\n",
276
+ "def save_img(tensor, path):\n",
277
+ " denormalizer = denormalize() \n",
278
+ " if tensor.is_cuda:\n",
279
+ " tensor = tensor.cpu()\n",
280
+ " tensor = torchvision.utils.make_grid(tensor)\n",
281
+ " torchvision.utils.save_image(denormalizer(tensor).clamp_(0.0, 1.0), path) \n",
282
+ " return None"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "execution_count": 16,
288
+ "metadata": {
289
+ "colab": {
290
+ "base_uri": "https://localhost:8080/"
291
+ },
292
+ "id": "1Y-JrlNquBwn",
293
+ "outputId": "31d5fe14-5315-40cd-8946-99c34ff41726",
294
+ "tags": []
295
+ },
296
+ "outputs": [],
297
+ "source": [
298
+ "def train_network(iteration, loss_weight = [0.0, 0.0, 0.0001], check_iter = 1, test_iter = 10):\n",
299
+ " for param in style_transfer_network.encoder.parameters():\n",
300
+ " # freeze parameter in the encoder network\n",
301
+ " param.requires_grad = False\n",
302
+ " optimizer = torch.optim.Adam(style_transfer_network.decoder.parameters(), lr = 1e-6)\n",
303
+ "\n",
304
+ " encoder_net = Encoder().to(device)\n",
305
+ " for param in encoder_net.parameters():\n",
306
+ " param.requires_grad = False\n",
307
+ " for i in range(iteration):\n",
308
+ " content_imgs = next(iter(cnt_data_loader))['images'].to(device)\n",
309
+ " style_imgs = next(iter(style_data_loader))['images'].to(device)\n",
310
+ "\n",
311
+ " output_imgs, transformed_features = style_transfer_network(content_imgs, style_imgs, train = True)\n",
312
+ "\n",
313
+ " output_features = encoder_net(output_imgs)\n",
314
+ " style_features = encoder_net(style_imgs)\n",
315
+ "\n",
316
+ " cnt_loss = content_loss(transformed_features, output_features[-1])\n",
317
+ " st_loss = style_loss(output_features, style_features)\n",
318
+ " tv_loss = total_variational_loss(output_imgs)\n",
319
+ " cnt_w, style_w, tv_w = loss_weight\n",
320
+ " total_loss = cnt_w * tv_loss + style_w * st_loss + tv_w * tv_loss\n",
321
+ "\n",
322
+ " optimizer.zero_grad()\n",
323
+ " total_loss.backward()\n",
324
+ " optimizer.step()\n",
325
+ "\n",
326
+ " if i % check_iter == 0:\n",
327
+ " print('-' * 80)\n",
328
+ " print(\"Iteration {} loss: {}\".format(i, total_loss))\n",
329
+ "\n",
330
+ " if i % test_iter == 0:\n",
331
+ " #save_img(torch.cat([content_imgs[0], style_imgs[0], output_imgs[0]], dim = 0), \"training_image.png\")\n",
332
+ " torch.save({'iteration':iteration+1,\n",
333
+ " 'state_dict':style_transfer_network.state_dict()},\n",
334
+ " 'check_point1.pth')"
335
+ ]
336
+ },
337
+ {
338
+ "cell_type": "code",
339
+ "execution_count": 17,
340
+ "metadata": {},
341
+ "outputs": [
342
+ {
343
+ "name": "stdout",
344
+ "output_type": "stream",
345
+ "text": [
346
+ "--------------------------------------------------------------------------------\n",
347
+ "Iteration 0 loss: 0.8845198750495911\n",
348
+ "--------------------------------------------------------------------------------\n",
349
+ "Iteration 1 loss: 1.8098524808883667\n",
350
+ "--------------------------------------------------------------------------------\n",
351
+ "Iteration 2 loss: 1.868203043937683\n",
352
+ "--------------------------------------------------------------------------------\n",
353
+ "Iteration 3 loss: 1.1070071458816528\n",
354
+ "--------------------------------------------------------------------------------\n",
355
+ "Iteration 4 loss: 2.0751609802246094\n",
356
+ "--------------------------------------------------------------------------------\n",
357
+ "Iteration 5 loss: 2.7107627391815186\n",
358
+ "--------------------------------------------------------------------------------\n",
359
+ "Iteration 6 loss: 1.4618340730667114\n",
360
+ "--------------------------------------------------------------------------------\n",
361
+ "Iteration 7 loss: 1.2351319789886475\n",
362
+ "--------------------------------------------------------------------------------\n",
363
+ "Iteration 8 loss: 1.3090686798095703\n",
364
+ "--------------------------------------------------------------------------------\n",
365
+ "Iteration 9 loss: 1.7165802717208862\n",
366
+ "--------------------------------------------------------------------------------\n",
367
+ "Iteration 10 loss: 1.9655226469039917\n",
368
+ "--------------------------------------------------------------------------------\n",
369
+ "Iteration 11 loss: 1.8032971620559692\n",
370
+ "--------------------------------------------------------------------------------\n",
371
+ "Iteration 12 loss: 1.757157802581787\n",
372
+ "--------------------------------------------------------------------------------\n",
373
+ "Iteration 13 loss: 1.2641586065292358\n",
374
+ "--------------------------------------------------------------------------------\n",
375
+ "Iteration 14 loss: 1.230526328086853\n",
376
+ "--------------------------------------------------------------------------------\n",
377
+ "Iteration 15 loss: 1.8332327604293823\n",
378
+ "--------------------------------------------------------------------------------\n",
379
+ "Iteration 16 loss: 2.347355365753174\n",
380
+ "--------------------------------------------------------------------------------\n",
381
+ "Iteration 17 loss: 0.8620480298995972\n",
382
+ "--------------------------------------------------------------------------------\n",
383
+ "Iteration 18 loss: 1.572771668434143\n",
384
+ "--------------------------------------------------------------------------------\n",
385
+ "Iteration 19 loss: 2.281660795211792\n",
386
+ "--------------------------------------------------------------------------------\n",
387
+ "Iteration 20 loss: 1.417534589767456\n",
388
+ "--------------------------------------------------------------------------------\n",
389
+ "Iteration 21 loss: 1.848774790763855\n",
390
+ "--------------------------------------------------------------------------------\n",
391
+ "Iteration 22 loss: 1.1456807851791382\n",
392
+ "--------------------------------------------------------------------------------\n",
393
+ "Iteration 23 loss: 1.2357560396194458\n",
394
+ "--------------------------------------------------------------------------------\n",
395
+ "Iteration 24 loss: 0.6565238833427429\n",
396
+ "--------------------------------------------------------------------------------\n",
397
+ "Iteration 25 loss: 1.2375402450561523\n",
398
+ "--------------------------------------------------------------------------------\n",
399
+ "Iteration 26 loss: 2.1140313148498535\n",
400
+ "--------------------------------------------------------------------------------\n",
401
+ "Iteration 27 loss: 1.0238616466522217\n",
402
+ "--------------------------------------------------------------------------------\n",
403
+ "Iteration 28 loss: 2.618056058883667\n",
404
+ "--------------------------------------------------------------------------------\n",
405
+ "Iteration 29 loss: 1.1616159677505493\n",
406
+ "--------------------------------------------------------------------------------\n",
407
+ "Iteration 30 loss: 1.919601559638977\n",
408
+ "--------------------------------------------------------------------------------\n",
409
+ "Iteration 31 loss: 1.0250651836395264\n",
410
+ "--------------------------------------------------------------------------------\n",
411
+ "Iteration 32 loss: 1.1823596954345703\n",
412
+ "--------------------------------------------------------------------------------\n",
413
+ "Iteration 33 loss: 0.8185012936592102\n",
414
+ "--------------------------------------------------------------------------------\n",
415
+ "Iteration 34 loss: 1.1374247074127197\n",
416
+ "--------------------------------------------------------------------------------\n",
417
+ "Iteration 35 loss: 1.9250235557556152\n",
418
+ "--------------------------------------------------------------------------------\n",
419
+ "Iteration 36 loss: 1.466286540031433\n"
420
+ ]
421
+ },
422
+ {
423
+ "name": "stderr",
424
+ "output_type": "stream",
425
+ "text": [
426
+ "/usr/local/lib/python3.9/dist-packages/PIL/Image.py:3035: DecompressionBombWarning: Image size (99962094 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n",
427
+ " warnings.warn(\n"
428
+ ]
429
+ },
430
+ {
431
+ "name": "stdout",
432
+ "output_type": "stream",
433
+ "text": [
434
+ "--------------------------------------------------------------------------------\n",
435
+ "Iteration 37 loss: 0.7055997848510742\n",
436
+ "--------------------------------------------------------------------------------\n",
437
+ "Iteration 38 loss: 1.3557121753692627\n",
438
+ "--------------------------------------------------------------------------------\n",
439
+ "Iteration 39 loss: 1.0668007135391235\n",
440
+ "--------------------------------------------------------------------------------\n",
441
+ "Iteration 40 loss: 1.1934823989868164\n",
442
+ "--------------------------------------------------------------------------------\n",
443
+ "Iteration 41 loss: 0.7692145109176636\n",
444
+ "--------------------------------------------------------------------------------\n",
445
+ "Iteration 42 loss: 1.141457438468933\n",
446
+ "--------------------------------------------------------------------------------\n",
447
+ "Iteration 43 loss: 1.5705242156982422\n",
448
+ "--------------------------------------------------------------------------------\n",
449
+ "Iteration 44 loss: 1.7851486206054688\n",
450
+ "--------------------------------------------------------------------------------\n",
451
+ "Iteration 45 loss: 0.7252503633499146\n",
452
+ "--------------------------------------------------------------------------------\n",
453
+ "Iteration 46 loss: 1.1291860342025757\n",
454
+ "--------------------------------------------------------------------------------\n",
455
+ "Iteration 47 loss: 1.3588659763336182\n",
456
+ "--------------------------------------------------------------------------------\n",
457
+ "Iteration 48 loss: 0.9960977435112\n",
458
+ "--------------------------------------------------------------------------------\n",
459
+ "Iteration 49 loss: 0.9272828102111816\n",
460
+ "--------------------------------------------------------------------------------\n",
461
+ "Iteration 50 loss: 2.4692296981811523\n"
462
+ ]
463
+ }
464
+ ],
465
+ "source": [
466
+ "train_network(iteration = 300)"
467
+ ]
468
+ },
469
+ {
470
+ "cell_type": "code",
471
+ "execution_count": null,
472
+ "metadata": {},
473
+ "outputs": [],
474
+ "source": []
475
+ }
476
+ ],
477
+ "metadata": {
478
+ "accelerator": "GPU",
479
+ "colab": {
480
+ "gpuType": "T4",
481
+ "provenance": []
482
+ },
483
+ "kernelspec": {
484
+ "display_name": "Python 3 (ipykernel)",
485
+ "language": "python",
486
+ "name": "python3"
487
+ },
488
+ "language_info": {
489
+ "codemirror_mode": {
490
+ "name": "ipython",
491
+ "version": 3
492
+ },
493
+ "file_extension": ".py",
494
+ "mimetype": "text/x-python",
495
+ "name": "python",
496
+ "nbconvert_exporter": "python",
497
+ "pygments_lexer": "ipython3",
498
+ "version": "3.9.16"
499
+ }
500
+ },
501
+ "nbformat": 4,
502
+ "nbformat_minor": 4
503
+ }
utils.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from skimage.exposure import match_histograms
2
+ from skimage import io
3
+ import os
4
+ from PIL import Image
5
+ import torch
6
+ import torchvision
7
+ import torchvision.transforms as transforms
8
+
9
+ def normalize():
10
+ MEAN = [0.485, 0.456, 0.406]
11
+ STD = [0.229, 0.224, 0.225]
12
+ return transforms.Normalize(mean = MEAN, std = STD)
13
+
14
+ def denormalize():
15
+ # out = (x - mean) / std
16
+ MEAN = [0.485, 0.456, 0.406]
17
+ STD = [0.229, 0.224, 0.225]
18
+ MEAN = [-mean/std for mean, std in zip(MEAN, STD)]
19
+ STD = [1/std for std in STD]
20
+ return transforms.Normalize(mean=MEAN, std=STD)
21
+
22
+ def transformer(imsize = None, cropsize = None):
23
+ transformer = []
24
+ if imsize:
25
+ transformer.append(transforms.Resize(imsize))
26
+ if cropsize:
27
+ transformer.append(transforms.RandomCrop(cropsize))
28
+
29
+ transformer.append(transforms.ToTensor())
30
+ transformer.append(normalize())
31
+ return transforms.Compose(transformer)
32
+
33
+ def load_img(path, imsize = None, cropsize = None):
34
+ transform = transformer(imsize = imsize, cropsize = cropsize)
35
+ # torchvision.transforms supports PIL Images
36
+ return transform(Image.open(path).convert("RGB")).unsqueeze(0)
37
+
38
+ def tensor_to_img(tensor):
39
+ denormalizer = denormalize()
40
+ if tensor.device == "cuda":
41
+ tensor = tensor.cpu()
42
+ #
43
+ tensor = torchvision.utils.make_grid(denormalizer(tensor.squeeze()))
44
+ image = transforms.functional.to_pil_image(tensor.clamp_(0., 1.))
45
+ return image
46
+
47
+ def save_img(tensor, path):
48
+ pass
49
+
50
+ def histogram_matching(image, reference):
51
+ """
52
+ img: style image
53
+ reference: original img
54
+ output: style image that resembles original img's color histogram
55
+ """
56
+ device = image.device
57
+ reference = reference.cpu().permute(1, 2, 0).numpy()
58
+ image = image.cpu().permute(1, 2, 0).numpy()
59
+ output = match_histograms(image, reference, multichannel = True)
60
+ return torch.Tensor(output).permute(2, 0, 1).to(device)
61
+
62
+ def batch_histogram_matching(images, reference):
63
+ """
64
+ images of shape BxCxHxW
65
+ reference of shape 1xCxHxW
66
+ """
67
+ reference = reference.squeeze()
68
+ output = torch.zeros_like(images, dtype = images.dtype)
69
+ B = images.shape[0]
70
+ for i in range(B):
71
+ output[i] = histogram_matching(images[i], reference)
72
+ return output
73
+
74
+ def statistics(f, inverse = False, eps = 1e-10):
75
+ c, h, w = f.shape
76
+ f_mean = torch.mean(f.view(c, h*w), dim=1, keepdim=True)
77
+ f_zeromean = f.view(c, h*w) - f_mean
78
+ f_cov = torch.mm(f_zeromean, f_zeromean.t())
79
+
80
+ u, s, v = torch.svd(f_cov)
81
+
82
+ k = c
83
+ for i in range(c):
84
+ if s[i] < eps:
85
+ k = i
86
+ break
87
+ if inverse:
88
+ p = -0.5
89
+ else:
90
+ p = 0.5
91
+
92
+ f_covsqrt = torch.mm(torch.mm(u[:, 0:k], torch.diag(s[0:k].pow(p))), v[:, 0:k].t())
93
+ return f_mean, f_covsqrt
94
+
95
+ def whitening(f):
96
+ c, h, w = f.shape
97
+ f_mean, f_inv_covsqrt = statistics(f, inverse = True)
98
+ whitened_f = torch.mm(f_inv_covsqrt, f.view(c, h*w) - f_mean)
99
+ return whitened_f.view(c, h, w)
100
+
101
+ def batch_whitening(f):
102
+ b, c, h, w = f.shape
103
+ whitened_f = torch.zeros(size = (b, c, h, w), dtype = f.dtype, device = f.device)
104
+ for i in range(b):
105
+ whitened_f[i] = whitening(f[i])
106
+ return whitened_f
107
+
108
+ def coloring(style, content):
109
+ s_c, s_h, s_w = style.shape
110
+ c_mean, c_covsqrt = statistics(content, inverse = False)
111
+ colored_s = torch.mm(c_covsqrt, whitening(style).view(s_c, s_h * s_w)) + c_mean
112
+ return colored_s.view(s_c, s_h, s_w)
113
+
114
+ def batch_coloring(styles, content):
115
+ colored_styles = torch.zeros_like(styles, dtype = styles.dtype, device = styles.device)
116
+ for i, style in enumerate(styles):
117
+ colored_styles[i] = coloring(style, content[i])
118
+
119
+ return colored_styles
120
+
121
+ def batch_wct(styles, content):
122
+ whitened_styles = batch_whitening(styles)
123
+ return batch_coloring(whitened_styles, content)
124
+
125
+ class Image_Set(torch.utils.data.Dataset):
126
+ def __init__(self, root_path, imsize, cropsize):
127
+ super(Image_Set, self).__init__()
128
+ self.root_path = root_path
129
+ self.files = sorted(os.listdir(self.root_path))
130
+ self.transformer = transformer(imsize, cropsize)
131
+
132
+ def __len__(self):
133
+ return len(self.file_names)
134
+
135
+ def __getitem__(self, index):
136
+ image = Image.open(os.path.join(self.root_path + self.file_names[index])).convert("RGB")
137
+ return self.transformer(image)
138
+