Spaces:
Running
on
T4
Running
on
T4
File size: 3,273 Bytes
4e3dd77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from argparse import ArgumentParser
import time
import numpy as np
import os
import json
import sys
from PIL import Image
import multiprocessing as mp
import math
import torch
import torchvision.transforms as trans
sys.path.append(".")
sys.path.append("..")
from models.mtcnn.mtcnn import MTCNN
from models.encoders.model_irse import IR_101
from configs.paths_config import model_paths
CIRCULAR_FACE_PATH = model_paths['circular_face']
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]
def extract_on_paths(file_paths):
facenet = IR_101(input_size=112)
facenet.load_state_dict(torch.load(CIRCULAR_FACE_PATH))
facenet.cuda()
facenet.eval()
mtcnn = MTCNN()
id_transform = trans.Compose([
trans.ToTensor(),
trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
pid = mp.current_process().name
print('\t{} is starting to extract on {} images'.format(pid, len(file_paths)))
tot_count = len(file_paths)
count = 0
scores_dict = {}
for res_path, gt_path in file_paths:
count += 1
if count % 100 == 0:
print('{} done with {}/{}'.format(pid, count, tot_count))
if True:
input_im = Image.open(res_path)
input_im, _ = mtcnn.align(input_im)
if input_im is None:
print('{} skipping {}'.format(pid, res_path))
continue
input_id = facenet(id_transform(input_im).unsqueeze(0).cuda())[0]
result_im = Image.open(gt_path)
result_im, _ = mtcnn.align(result_im)
if result_im is None:
print('{} skipping {}'.format(pid, gt_path))
continue
result_id = facenet(id_transform(result_im).unsqueeze(0).cuda())[0]
score = float(input_id.dot(result_id))
scores_dict[os.path.basename(gt_path)] = score
return scores_dict
def parse_args():
parser = ArgumentParser(add_help=False)
parser.add_argument('--num_threads', type=int, default=4)
parser.add_argument('--data_path', type=str, default='results')
parser.add_argument('--gt_path', type=str, default='gt_images')
args = parser.parse_args()
return args
def run(args):
file_paths = []
for f in os.listdir(args.data_path):
image_path = os.path.join(args.data_path, f)
gt_path = os.path.join(args.gt_path, f)
if f.endswith(".jpg") or f.endswith('.png'):
file_paths.append([image_path, gt_path.replace('.png','.jpg')])
file_chunks = list(chunks(file_paths, int(math.ceil(len(file_paths) / args.num_threads))))
pool = mp.Pool(args.num_threads)
print('Running on {} paths\nHere we goooo'.format(len(file_paths)))
tic = time.time()
results = pool.map(extract_on_paths, file_chunks)
scores_dict = {}
for d in results:
scores_dict.update(d)
all_scores = list(scores_dict.values())
mean = np.mean(all_scores)
std = np.std(all_scores)
result_str = 'New Average score is {:.2f}+-{:.2f}'.format(mean, std)
print(result_str)
out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics')
if not os.path.exists(out_path):
os.makedirs(out_path)
with open(os.path.join(out_path, 'stat_id.txt'), 'w') as f:
f.write(result_str)
with open(os.path.join(out_path, 'scores_id.json'), 'w') as f:
json.dump(scores_dict, f)
toc = time.time()
print('Mischief managed in {}s'.format(toc - tic))
if __name__ == '__main__':
args = parse_args()
run(args)
|