PKUWilliamYang commited on
Commit
4e3dd77
·
1 Parent(s): 3b98894

Upload 7 files

Browse files
scripts/align_all_parallel.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
3
+ author: lzhbrian (https://lzhbrian.me)
4
+ date: 2020.1.5
5
+ note: code is heavily borrowed from
6
+ https://github.com/NVlabs/ffhq-dataset
7
+ http://dlib.net/face_landmark_detection.py.html
8
+
9
+ requirements:
10
+ apt install cmake
11
+ conda install Pillow numpy scipy
12
+ pip install dlib
13
+ # download face landmark model from:
14
+ # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
15
+ """
16
+ from argparse import ArgumentParser
17
+ import time
18
+ import numpy as np
19
+ import PIL
20
+ import PIL.Image
21
+ import os
22
+ import scipy
23
+ import scipy.ndimage
24
+ import dlib
25
+ import multiprocessing as mp
26
+ import math
27
+
28
+ from configs.paths_config import model_paths
29
+ SHAPE_PREDICTOR_PATH = model_paths["shape_predictor"]
30
+
31
+
32
+ def get_landmark(filepath, predictor):
33
+ """get landmark with dlib
34
+ :return: np.array shape=(68, 2)
35
+ """
36
+ detector = dlib.get_frontal_face_detector()
37
+ if type(filepath) == str:
38
+ img = dlib.load_rgb_image(filepath)
39
+ else:
40
+ img = filepath
41
+ dets = detector(img, 1)
42
+
43
+ if len(dets) == 0:
44
+ print('Error: no face detected! If you are sure there are faces in your input, you may rerun the code or change the image several times until the face is detected. Sometimes the detector is unstable.')
45
+ return None
46
+
47
+ shape = None
48
+ for k, d in enumerate(dets):
49
+ shape = predictor(img, d)
50
+
51
+ t = list(shape.parts())
52
+ a = []
53
+ for tt in t:
54
+ a.append([tt.x, tt.y])
55
+ lm = np.array(a)
56
+ return lm
57
+
58
+
59
+ def align_face(filepath, predictor):
60
+ """
61
+ :param filepath: str
62
+ :return: PIL Image
63
+ """
64
+
65
+ lm = get_landmark(filepath, predictor)
66
+ if lm is None:
67
+ return None
68
+
69
+ lm_chin = lm[0: 17] # left-right
70
+ lm_eyebrow_left = lm[17: 22] # left-right
71
+ lm_eyebrow_right = lm[22: 27] # left-right
72
+ lm_nose = lm[27: 31] # top-down
73
+ lm_nostrils = lm[31: 36] # top-down
74
+ lm_eye_left = lm[36: 42] # left-clockwise
75
+ lm_eye_right = lm[42: 48] # left-clockwise
76
+ lm_mouth_outer = lm[48: 60] # left-clockwise
77
+ lm_mouth_inner = lm[60: 68] # left-clockwise
78
+
79
+ # Calculate auxiliary vectors.
80
+ eye_left = np.mean(lm_eye_left, axis=0)
81
+ eye_right = np.mean(lm_eye_right, axis=0)
82
+ eye_avg = (eye_left + eye_right) * 0.5
83
+ eye_to_eye = eye_right - eye_left
84
+ mouth_left = lm_mouth_outer[0]
85
+ mouth_right = lm_mouth_outer[6]
86
+ mouth_avg = (mouth_left + mouth_right) * 0.5
87
+ eye_to_mouth = mouth_avg - eye_avg
88
+
89
+ # Choose oriented crop rectangle.
90
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
91
+ x /= np.hypot(*x)
92
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
93
+ y = np.flipud(x) * [-1, 1]
94
+ c = eye_avg + eye_to_mouth * 0.1
95
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
96
+ qsize = np.hypot(*x) * 2
97
+
98
+ # read image
99
+ if type(filepath) == str:
100
+ img = PIL.Image.open(filepath)
101
+ else:
102
+ img = PIL.Image.fromarray(filepath)
103
+
104
+ output_size = 256
105
+ transform_size = 256
106
+ enable_padding = True
107
+
108
+ # Shrink.
109
+ shrink = int(np.floor(qsize / output_size * 0.5))
110
+ if shrink > 1:
111
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
112
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
113
+ quad /= shrink
114
+ qsize /= shrink
115
+
116
+ # Crop.
117
+ border = max(int(np.rint(qsize * 0.1)), 3)
118
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
119
+ int(np.ceil(max(quad[:, 1]))))
120
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
121
+ min(crop[3] + border, img.size[1]))
122
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
123
+ img = img.crop(crop)
124
+ quad -= crop[0:2]
125
+
126
+ # Pad.
127
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
128
+ int(np.ceil(max(quad[:, 1]))))
129
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
130
+ max(pad[3] - img.size[1] + border, 0))
131
+ if enable_padding and max(pad) > border - 4:
132
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
133
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
134
+ h, w, _ = img.shape
135
+ y, x, _ = np.ogrid[:h, :w, :1]
136
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
137
+ 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
138
+ blur = qsize * 0.02
139
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
140
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
141
+ img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
142
+ quad += pad[:2]
143
+
144
+ # Transform.
145
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
146
+ if output_size < transform_size:
147
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
148
+
149
+ # Save aligned image.
150
+ return img
151
+
152
+
153
+ def chunks(lst, n):
154
+ """Yield successive n-sized chunks from lst."""
155
+ for i in range(0, len(lst), n):
156
+ yield lst[i:i + n]
157
+
158
+
159
+ def extract_on_paths(file_paths):
160
+ predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH)
161
+ pid = mp.current_process().name
162
+ print('\t{} is starting to extract on #{} images'.format(pid, len(file_paths)))
163
+ tot_count = len(file_paths)
164
+ count = 0
165
+ for file_path, res_path in file_paths:
166
+ count += 1
167
+ if count % 100 == 0:
168
+ print('{} done with {}/{}'.format(pid, count, tot_count))
169
+ try:
170
+ res = align_face(file_path, predictor)
171
+ res = res.convert('RGB')
172
+ os.makedirs(os.path.dirname(res_path), exist_ok=True)
173
+ res.save(res_path)
174
+ except Exception:
175
+ continue
176
+ print('\tDone!')
177
+
178
+
179
+ def parse_args():
180
+ parser = ArgumentParser(add_help=False)
181
+ parser.add_argument('--num_threads', type=int, default=1)
182
+ parser.add_argument('--root_path', type=str, default='')
183
+ args = parser.parse_args()
184
+ return args
185
+
186
+
187
+ def run(args):
188
+ root_path = args.root_path
189
+ out_crops_path = root_path + '_crops'
190
+ if not os.path.exists(out_crops_path):
191
+ os.makedirs(out_crops_path, exist_ok=True)
192
+
193
+ file_paths = []
194
+ for root, dirs, files in os.walk(root_path):
195
+ for file in files:
196
+ file_path = os.path.join(root, file)
197
+ fname = os.path.join(out_crops_path, os.path.relpath(file_path, root_path))
198
+ res_path = '{}.jpg'.format(os.path.splitext(fname)[0])
199
+ if os.path.splitext(file_path)[1] == '.txt' or os.path.exists(res_path):
200
+ continue
201
+ file_paths.append((file_path, res_path))
202
+
203
+ file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads))))
204
+ print(len(file_chunks))
205
+ pool = mp.Pool(args.num_threads)
206
+ print('Running on {} paths\nHere we goooo'.format(len(file_paths)))
207
+ tic = time.time()
208
+ pool.map(extract_on_paths, file_chunks)
209
+ toc = time.time()
210
+ print('Mischief managed in {}s'.format(toc - tic))
211
+
212
+
213
+ if __name__ == '__main__':
214
+ args = parse_args()
215
+ run(args)
scripts/calc_id_loss_parallel.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ import time
3
+ import numpy as np
4
+ import os
5
+ import json
6
+ import sys
7
+ from PIL import Image
8
+ import multiprocessing as mp
9
+ import math
10
+ import torch
11
+ import torchvision.transforms as trans
12
+
13
+ sys.path.append(".")
14
+ sys.path.append("..")
15
+
16
+ from models.mtcnn.mtcnn import MTCNN
17
+ from models.encoders.model_irse import IR_101
18
+ from configs.paths_config import model_paths
19
+ CIRCULAR_FACE_PATH = model_paths['circular_face']
20
+
21
+
22
+ def chunks(lst, n):
23
+ """Yield successive n-sized chunks from lst."""
24
+ for i in range(0, len(lst), n):
25
+ yield lst[i:i + n]
26
+
27
+
28
+ def extract_on_paths(file_paths):
29
+ facenet = IR_101(input_size=112)
30
+ facenet.load_state_dict(torch.load(CIRCULAR_FACE_PATH))
31
+ facenet.cuda()
32
+ facenet.eval()
33
+ mtcnn = MTCNN()
34
+ id_transform = trans.Compose([
35
+ trans.ToTensor(),
36
+ trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
37
+ ])
38
+
39
+ pid = mp.current_process().name
40
+ print('\t{} is starting to extract on {} images'.format(pid, len(file_paths)))
41
+ tot_count = len(file_paths)
42
+ count = 0
43
+
44
+ scores_dict = {}
45
+ for res_path, gt_path in file_paths:
46
+ count += 1
47
+ if count % 100 == 0:
48
+ print('{} done with {}/{}'.format(pid, count, tot_count))
49
+ if True:
50
+ input_im = Image.open(res_path)
51
+ input_im, _ = mtcnn.align(input_im)
52
+ if input_im is None:
53
+ print('{} skipping {}'.format(pid, res_path))
54
+ continue
55
+
56
+ input_id = facenet(id_transform(input_im).unsqueeze(0).cuda())[0]
57
+
58
+ result_im = Image.open(gt_path)
59
+ result_im, _ = mtcnn.align(result_im)
60
+ if result_im is None:
61
+ print('{} skipping {}'.format(pid, gt_path))
62
+ continue
63
+
64
+ result_id = facenet(id_transform(result_im).unsqueeze(0).cuda())[0]
65
+ score = float(input_id.dot(result_id))
66
+ scores_dict[os.path.basename(gt_path)] = score
67
+
68
+ return scores_dict
69
+
70
+
71
+ def parse_args():
72
+ parser = ArgumentParser(add_help=False)
73
+ parser.add_argument('--num_threads', type=int, default=4)
74
+ parser.add_argument('--data_path', type=str, default='results')
75
+ parser.add_argument('--gt_path', type=str, default='gt_images')
76
+ args = parser.parse_args()
77
+ return args
78
+
79
+
80
+ def run(args):
81
+ file_paths = []
82
+ for f in os.listdir(args.data_path):
83
+ image_path = os.path.join(args.data_path, f)
84
+ gt_path = os.path.join(args.gt_path, f)
85
+ if f.endswith(".jpg") or f.endswith('.png'):
86
+ file_paths.append([image_path, gt_path.replace('.png','.jpg')])
87
+
88
+ file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads))))
89
+ pool = mp.Pool(args.num_threads)
90
+ print('Running on {} paths\nHere we goooo'.format(len(file_paths)))
91
+
92
+ tic = time.time()
93
+ results = pool.map(extract_on_paths, file_chunks)
94
+ scores_dict = {}
95
+ for d in results:
96
+ scores_dict.update(d)
97
+
98
+ all_scores = list(scores_dict.values())
99
+ mean = np.mean(all_scores)
100
+ std = np.std(all_scores)
101
+ result_str = 'New Average score is {:.2f}+-{:.2f}'.format(mean, std)
102
+ print(result_str)
103
+
104
+ out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics')
105
+ if not os.path.exists(out_path):
106
+ os.makedirs(out_path)
107
+
108
+ with open(os.path.join(out_path, 'stat_id.txt'), 'w') as f:
109
+ f.write(result_str)
110
+ with open(os.path.join(out_path, 'scores_id.json'), 'w') as f:
111
+ json.dump(scores_dict, f)
112
+
113
+ toc = time.time()
114
+ print('Mischief managed in {}s'.format(toc - tic))
115
+
116
+
117
+ if __name__ == '__main__':
118
+ args = parse_args()
119
+ run(args)
scripts/calc_losses_on_images.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ import os
3
+ import json
4
+ import sys
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+ import torchvision.transforms as transforms
10
+
11
+ sys.path.append(".")
12
+ sys.path.append("..")
13
+
14
+ from criteria.lpips.lpips import LPIPS
15
+ from datasets.gt_res_dataset import GTResDataset
16
+
17
+
18
+ def parse_args():
19
+ parser = ArgumentParser(add_help=False)
20
+ parser.add_argument('--mode', type=str, default='lpips', choices=['lpips', 'l2'])
21
+ parser.add_argument('--data_path', type=str, default='results')
22
+ parser.add_argument('--gt_path', type=str, default='gt_images')
23
+ parser.add_argument('--workers', type=int, default=4)
24
+ parser.add_argument('--batch_size', type=int, default=4)
25
+ args = parser.parse_args()
26
+ return args
27
+
28
+
29
+ def run(args):
30
+
31
+ transform = transforms.Compose([transforms.Resize((256, 256)),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
34
+
35
+ print('Loading dataset')
36
+ dataset = GTResDataset(root_path=args.data_path,
37
+ gt_dir=args.gt_path,
38
+ transform=transform)
39
+
40
+ dataloader = DataLoader(dataset,
41
+ batch_size=args.batch_size,
42
+ shuffle=False,
43
+ num_workers=int(args.workers),
44
+ drop_last=True)
45
+
46
+ if args.mode == 'lpips':
47
+ loss_func = LPIPS(net_type='alex')
48
+ elif args.mode == 'l2':
49
+ loss_func = torch.nn.MSELoss()
50
+ else:
51
+ raise Exception('Not a valid mode!')
52
+ loss_func.cuda()
53
+
54
+ global_i = 0
55
+ scores_dict = {}
56
+ all_scores = []
57
+ for result_batch, gt_batch in tqdm(dataloader):
58
+ for i in range(args.batch_size):
59
+ loss = float(loss_func(result_batch[i:i+1].cuda(), gt_batch[i:i+1].cuda()))
60
+ all_scores.append(loss)
61
+ im_path = dataset.pairs[global_i][0]
62
+ scores_dict[os.path.basename(im_path)] = loss
63
+ global_i += 1
64
+
65
+ all_scores = list(scores_dict.values())
66
+ mean = np.mean(all_scores)
67
+ std = np.std(all_scores)
68
+ result_str = 'Average loss is {:.2f}+-{:.2f}'.format(mean, std)
69
+ print('Finished with ', args.data_path)
70
+ print(result_str)
71
+
72
+ out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics')
73
+ if not os.path.exists(out_path):
74
+ os.makedirs(out_path)
75
+
76
+ with open(os.path.join(out_path, 'stat_{}.txt'.format(args.mode)), 'w') as f:
77
+ f.write(result_str)
78
+ with open(os.path.join(out_path, 'scores_{}.json'.format(args.mode)), 'w') as f:
79
+ json.dump(scores_dict, f)
80
+
81
+
82
+ if __name__ == '__main__':
83
+ args = parse_args()
84
+ run(args)
scripts/generate_sketch_data.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ from torchvision.utils import save_image
3
+ from torch.utils.serialization import load_lua
4
+ import os
5
+ import cv2
6
+ import numpy as np
7
+
8
+ """
9
+ NOTE!: Must have torch==0.4.1 and torchvision==0.2.1
10
+ The sketch simplification model (sketch_gan.t7) from Simo Serra et al. can be downloaded from their official implementation:
11
+ https://github.com/bobbens/sketch_simplification
12
+ """
13
+
14
+
15
+ def sobel(img):
16
+ opImgx = cv2.Sobel(img, cv2.CV_8U, 0, 1, ksize=3)
17
+ opImgy = cv2.Sobel(img, cv2.CV_8U, 1, 0, ksize=3)
18
+ return cv2.bitwise_or(opImgx, opImgy)
19
+
20
+
21
+ def sketch(frame):
22
+ frame = cv2.GaussianBlur(frame, (3, 3), 0)
23
+ invImg = 255 - frame
24
+ edgImg0 = sobel(frame)
25
+ edgImg1 = sobel(invImg)
26
+ edgImg = cv2.addWeighted(edgImg0, 0.75, edgImg1, 0.75, 0)
27
+ opImg = 255 - edgImg
28
+ return opImg
29
+
30
+
31
+ def get_sketch_image(image_path):
32
+ original = cv2.imread(image_path)
33
+ original = cv2.cvtColor(original, cv2.COLOR_BGR2GRAY)
34
+ sketch_image = sketch(original)
35
+ return sketch_image[:, :, np.newaxis]
36
+
37
+
38
+ use_cuda = True
39
+
40
+ cache = load_lua("/path/to/sketch_gan.t7")
41
+ model = cache.model
42
+ immean = cache.mean
43
+ imstd = cache.std
44
+ model.evaluate()
45
+
46
+ data_path = "/path/to/data/imgs"
47
+ images = [os.path.join(data_path, f) for f in os.listdir(data_path)]
48
+
49
+ output_dir = "/path/to/data/edges"
50
+ if not os.path.exists(output_dir):
51
+ os.makedirs(output_dir)
52
+
53
+ for idx, image_path in enumerate(images):
54
+ if idx % 50 == 0:
55
+ print("{} out of {}".format(idx, len(images)))
56
+ data = get_sketch_image(image_path)
57
+ data = ((transforms.ToTensor()(data) - immean) / imstd).unsqueeze(0)
58
+ if use_cuda:
59
+ pred = model.cuda().forward(data.cuda()).float()
60
+ else:
61
+ pred = model.forward(data)
62
+ save_image(pred[0], os.path.join(output_dir, "{}_edges.jpg".format(image_path.split("/")[-1].split('.')[0])))
scripts/inference.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from argparse import Namespace
3
+
4
+ from tqdm import tqdm
5
+ import time
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+ from torch.utils.data import DataLoader
10
+ import sys
11
+
12
+ sys.path.append(".")
13
+ sys.path.append("..")
14
+
15
+ from configs import data_configs
16
+ from datasets.inference_dataset import InferenceDataset
17
+ from utils.common import tensor2im, log_input_image
18
+ from options.test_options import TestOptions
19
+ from models.psp import pSp
20
+
21
+
22
+ def run():
23
+ test_opts = TestOptions().parse()
24
+
25
+ if test_opts.resize_factors is not None:
26
+ assert len(
27
+ test_opts.resize_factors.split(',')) == 1, "When running inference, provide a single downsampling factor!"
28
+ out_path_results = os.path.join(test_opts.exp_dir, 'inference_results',
29
+ 'downsampling_{}'.format(test_opts.resize_factors))
30
+ out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled',
31
+ 'downsampling_{}'.format(test_opts.resize_factors))
32
+ else:
33
+ out_path_results = os.path.join(test_opts.exp_dir, 'inference_results')
34
+ out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled')
35
+
36
+ os.makedirs(out_path_results, exist_ok=True)
37
+ os.makedirs(out_path_coupled, exist_ok=True)
38
+
39
+ # update test options with options used during training
40
+ ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
41
+ opts = ckpt['opts']
42
+ opts.update(vars(test_opts))
43
+ if 'learn_in_w' not in opts:
44
+ opts['learn_in_w'] = False
45
+ if 'output_size' not in opts:
46
+ opts['output_size'] = 1024
47
+ opts = Namespace(**opts)
48
+
49
+ net = pSp(opts)
50
+ net.eval()
51
+ net.cuda()
52
+
53
+ print('Loading dataset for {}'.format(opts.dataset_type))
54
+ dataset_args = data_configs.DATASETS[opts.dataset_type]
55
+ transforms_dict = dataset_args['transforms'](opts).get_transforms()
56
+ dataset = InferenceDataset(root=opts.data_path,
57
+ transform=transforms_dict['transform_inference'],
58
+ opts=opts)
59
+ dataloader = DataLoader(dataset,
60
+ batch_size=opts.test_batch_size,
61
+ shuffle=False,
62
+ num_workers=int(opts.test_workers),
63
+ drop_last=True)
64
+
65
+ if opts.n_images is None:
66
+ opts.n_images = len(dataset)
67
+
68
+ global_i = 0
69
+ global_time = []
70
+ for input_batch in tqdm(dataloader):
71
+ if global_i >= opts.n_images:
72
+ break
73
+ with torch.no_grad():
74
+ input_cuda = input_batch.cuda().float()
75
+ tic = time.time()
76
+ result_batch = run_on_batch(input_cuda, net, opts)
77
+ toc = time.time()
78
+ global_time.append(toc - tic)
79
+
80
+ for i in range(opts.test_batch_size):
81
+ result = tensor2im(result_batch[i])
82
+ im_path = dataset.paths[global_i]
83
+
84
+ if opts.couple_outputs or global_i % 100 == 0:
85
+ input_im = log_input_image(input_batch[i], opts)
86
+ resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size)
87
+ if opts.resize_factors is not None:
88
+ # for super resolution, save the original, down-sampled, and output
89
+ source = Image.open(im_path)
90
+ res = np.concatenate([np.array(source.resize(resize_amount)),
91
+ np.array(input_im.resize(resize_amount, resample=Image.NEAREST)),
92
+ np.array(result.resize(resize_amount))], axis=1)
93
+ else:
94
+ # otherwise, save the original and output
95
+ res = np.concatenate([np.array(input_im.resize(resize_amount)),
96
+ np.array(result.resize(resize_amount))], axis=1)
97
+ Image.fromarray(res).save(os.path.join(out_path_coupled, os.path.basename(im_path)))
98
+
99
+ im_save_path = os.path.join(out_path_results, os.path.basename(im_path))
100
+ Image.fromarray(np.array(result)).save(im_save_path)
101
+
102
+ global_i += 1
103
+
104
+ stats_path = os.path.join(opts.exp_dir, 'stats.txt')
105
+ result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time))
106
+ print(result_str)
107
+
108
+ with open(stats_path, 'w') as f:
109
+ f.write(result_str)
110
+
111
+
112
+ def run_on_batch(inputs, net, opts):
113
+ if opts.latent_mask is None:
114
+ result_batch = net(inputs, randomize_noise=False, resize=opts.resize_outputs)
115
+ else:
116
+ latent_mask = [int(l) for l in opts.latent_mask.split(",")]
117
+ result_batch = []
118
+ for image_idx, input_image in enumerate(inputs):
119
+ # get latent vector to inject into our input image
120
+ vec_to_inject = np.random.randn(1, 512).astype('float32')
121
+ _, latent_to_inject = net(torch.from_numpy(vec_to_inject).to("cuda"),
122
+ input_code=True,
123
+ return_latents=True)
124
+ # get output image with injected style vector
125
+ res = net(input_image.unsqueeze(0).to("cuda").float(),
126
+ latent_mask=latent_mask,
127
+ inject_latent=latent_to_inject,
128
+ alpha=opts.mix_alpha,
129
+ resize=opts.resize_outputs)
130
+ result_batch.append(res)
131
+ result_batch = torch.cat(result_batch, dim=0)
132
+ return result_batch
133
+
134
+
135
+ if __name__ == '__main__':
136
+ run()
scripts/style_mixing.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from argparse import Namespace
3
+
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+ import sys
10
+
11
+ sys.path.append(".")
12
+ sys.path.append("..")
13
+
14
+ from configs import data_configs
15
+ from datasets.inference_dataset import InferenceDataset
16
+ from utils.common import tensor2im, log_input_image
17
+ from options.test_options import TestOptions
18
+ from models.psp import pSp
19
+
20
+
21
+ def run():
22
+ test_opts = TestOptions().parse()
23
+
24
+ if test_opts.resize_factors is not None:
25
+ factors = test_opts.resize_factors.split(',')
26
+ assert len(factors) == 1, "When running inference, please provide a single downsampling factor!"
27
+ mixed_path_results = os.path.join(test_opts.exp_dir, 'style_mixing',
28
+ 'downsampling_{}'.format(test_opts.resize_factors))
29
+ else:
30
+ mixed_path_results = os.path.join(test_opts.exp_dir, 'style_mixing')
31
+ os.makedirs(mixed_path_results, exist_ok=True)
32
+
33
+ # update test options with options used during training
34
+ ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
35
+ opts = ckpt['opts']
36
+ opts.update(vars(test_opts))
37
+ if 'learn_in_w' not in opts:
38
+ opts['learn_in_w'] = False
39
+ if 'output_size' not in opts:
40
+ opts['output_size'] = 1024
41
+ opts = Namespace(**opts)
42
+
43
+ net = pSp(opts)
44
+ net.eval()
45
+ net.cuda()
46
+
47
+ print('Loading dataset for {}'.format(opts.dataset_type))
48
+ dataset_args = data_configs.DATASETS[opts.dataset_type]
49
+ transforms_dict = dataset_args['transforms'](opts).get_transforms()
50
+ dataset = InferenceDataset(root=opts.data_path,
51
+ transform=transforms_dict['transform_inference'],
52
+ opts=opts)
53
+ dataloader = DataLoader(dataset,
54
+ batch_size=opts.test_batch_size,
55
+ shuffle=False,
56
+ num_workers=int(opts.test_workers),
57
+ drop_last=True)
58
+
59
+ latent_mask = [int(l) for l in opts.latent_mask.split(",")]
60
+ if opts.n_images is None:
61
+ opts.n_images = len(dataset)
62
+
63
+ global_i = 0
64
+ for input_batch in tqdm(dataloader):
65
+ if global_i >= opts.n_images:
66
+ break
67
+ with torch.no_grad():
68
+ input_batch = input_batch.cuda()
69
+ for image_idx, input_image in enumerate(input_batch):
70
+ # generate random vectors to inject into input image
71
+ vecs_to_inject = np.random.randn(opts.n_outputs_to_generate, 512).astype('float32')
72
+ multi_modal_outputs = []
73
+ for vec_to_inject in vecs_to_inject:
74
+ cur_vec = torch.from_numpy(vec_to_inject).unsqueeze(0).to("cuda")
75
+ # get latent vector to inject into our input image
76
+ _, latent_to_inject = net(cur_vec,
77
+ input_code=True,
78
+ return_latents=True)
79
+ # get output image with injected style vector
80
+ res = net(input_image.unsqueeze(0).to("cuda").float(),
81
+ latent_mask=latent_mask,
82
+ inject_latent=latent_to_inject,
83
+ alpha=opts.mix_alpha,
84
+ resize=opts.resize_outputs)
85
+ multi_modal_outputs.append(res[0])
86
+
87
+ # visualize multi modal outputs
88
+ input_im_path = dataset.paths[global_i]
89
+ image = input_batch[image_idx]
90
+ input_image = log_input_image(image, opts)
91
+ resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size)
92
+ res = np.array(input_image.resize(resize_amount))
93
+ for output in multi_modal_outputs:
94
+ output = tensor2im(output)
95
+ res = np.concatenate([res, np.array(output.resize(resize_amount))], axis=1)
96
+ Image.fromarray(res).save(os.path.join(mixed_path_results, os.path.basename(input_im_path)))
97
+ global_i += 1
98
+
99
+
100
+ if __name__ == '__main__':
101
+ run()
scripts/train.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file runs the main training/val loop
3
+ """
4
+ import os
5
+ import json
6
+ import sys
7
+ import pprint
8
+
9
+ sys.path.append(".")
10
+ sys.path.append("..")
11
+
12
+ from options.train_options import TrainOptions
13
+ from training.coach import Coach
14
+
15
+
16
+ def main():
17
+ opts = TrainOptions().parse()
18
+ if os.path.exists(opts.exp_dir):
19
+ raise Exception('Oops... {} already exists'.format(opts.exp_dir))
20
+ os.makedirs(opts.exp_dir)
21
+
22
+ opts_dict = vars(opts)
23
+ pprint.pprint(opts_dict)
24
+ with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f:
25
+ json.dump(opts_dict, f, indent=4, sort_keys=True)
26
+
27
+ coach = Coach(opts)
28
+ coach.train()
29
+
30
+
31
+ if __name__ == '__main__':
32
+ main()