File size: 3,273 Bytes
4e3dd77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from argparse import ArgumentParser
import time
import numpy as np
import os
import json
import sys
from PIL import Image
import multiprocessing as mp
import math
import torch
import torchvision.transforms as trans

sys.path.append(".")
sys.path.append("..")

from models.mtcnn.mtcnn import MTCNN
from models.encoders.model_irse import IR_101
from configs.paths_config import model_paths
CIRCULAR_FACE_PATH = model_paths['circular_face']


def chunks(lst, n):
	"""Yield successive n-sized chunks from lst."""
	for i in range(0, len(lst), n):
		yield lst[i:i + n]


def extract_on_paths(file_paths):
	facenet = IR_101(input_size=112)
	facenet.load_state_dict(torch.load(CIRCULAR_FACE_PATH))
	facenet.cuda()
	facenet.eval()
	mtcnn = MTCNN()
	id_transform = trans.Compose([
		trans.ToTensor(),
		trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
	])

	pid = mp.current_process().name
	print('\t{} is starting to extract on {} images'.format(pid, len(file_paths)))
	tot_count = len(file_paths)
	count = 0

	scores_dict = {}
	for res_path, gt_path in file_paths:
		count += 1
		if count % 100 == 0:
			print('{} done with {}/{}'.format(pid, count, tot_count))
		if True:
			input_im = Image.open(res_path)
			input_im, _ = mtcnn.align(input_im)
			if input_im is None:
				print('{} skipping {}'.format(pid, res_path))
				continue

			input_id = facenet(id_transform(input_im).unsqueeze(0).cuda())[0]

			result_im = Image.open(gt_path)
			result_im, _ = mtcnn.align(result_im)
			if result_im is None:
				print('{} skipping {}'.format(pid, gt_path))
				continue

			result_id = facenet(id_transform(result_im).unsqueeze(0).cuda())[0]
			score = float(input_id.dot(result_id))
			scores_dict[os.path.basename(gt_path)] = score

	return scores_dict


def parse_args():
	parser = ArgumentParser(add_help=False)
	parser.add_argument('--num_threads', type=int, default=4)
	parser.add_argument('--data_path', type=str,  default='results')
	parser.add_argument('--gt_path', type=str, default='gt_images')
	args = parser.parse_args()
	return args


def run(args):
	file_paths = []
	for f in os.listdir(args.data_path):
		image_path = os.path.join(args.data_path, f)
		gt_path = os.path.join(args.gt_path, f)
		if f.endswith(".jpg") or f.endswith('.png'):
			file_paths.append([image_path, gt_path.replace('.png','.jpg')])

	file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads))))
	pool = mp.Pool(args.num_threads)
	print('Running on {} paths\nHere we goooo'.format(len(file_paths)))

	tic = time.time()
	results = pool.map(extract_on_paths, file_chunks)
	scores_dict = {}
	for d in results:
		scores_dict.update(d)

	all_scores = list(scores_dict.values())
	mean = np.mean(all_scores)
	std = np.std(all_scores)
	result_str = 'New Average score is {:.2f}+-{:.2f}'.format(mean, std)
	print(result_str)

	out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics')
	if not os.path.exists(out_path):
		os.makedirs(out_path)

	with open(os.path.join(out_path, 'stat_id.txt'), 'w') as f:
		f.write(result_str)
	with open(os.path.join(out_path, 'scores_id.json'), 'w') as f:
		json.dump(scores_dict, f)

	toc = time.time()
	print('Mischief managed in {}s'.format(toc - tic))


if __name__ == '__main__':
	args = parse_args()
	run(args)