Spanicin commited on
Commit
b5d59e5
·
verified ·
1 Parent(s): 745f35f

Update videoretalking/third_part/GPEN/face_detect/retinaface_detection.py

Browse files
videoretalking/third_part/GPEN/face_detect/retinaface_detection.py CHANGED
@@ -1,193 +1,193 @@
1
- '''
2
- @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
3
- @author: yangxy (yangtao9009@gmail.com)
4
- '''
5
- import os
6
- import torch
7
- import torch.backends.cudnn as cudnn
8
- import numpy as np
9
- from face_detect.data import cfg_re50
10
- from face_detect.layers.functions.prior_box import PriorBox
11
- from face_detect.utils.nms.py_cpu_nms import py_cpu_nms
12
- import cv2
13
- from face_detect.facemodels.retinaface import RetinaFace
14
- from face_detect.utils.box_utils import decode, decode_landm
15
- import time
16
- import torch.nn.functional as F
17
-
18
-
19
- class RetinaFaceDetection(object):
20
- def __init__(self, base_dir, device='cuda', network='RetinaFace-R50'):
21
- torch.set_grad_enabled(False)
22
- cudnn.benchmark = True
23
- self.pretrained_path = os.path.join(base_dir, network+'.pth')
24
- self.device = device #torch.cuda.current_device()
25
- self.cfg = cfg_re50
26
- self.net = RetinaFace(cfg=self.cfg, phase='test')
27
- self.load_model()
28
- self.net = self.net.to(device)
29
-
30
- self.mean = torch.tensor([[[[104]], [[117]], [[123]]]]).to(device)
31
-
32
- def check_keys(self, pretrained_state_dict):
33
- ckpt_keys = set(pretrained_state_dict.keys())
34
- model_keys = set(self.net.state_dict().keys())
35
- used_pretrained_keys = model_keys & ckpt_keys
36
- unused_pretrained_keys = ckpt_keys - model_keys
37
- missing_keys = model_keys - ckpt_keys
38
- assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
39
- return True
40
-
41
- def remove_prefix(self, state_dict, prefix):
42
- ''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
43
- f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
44
- return {f(key): value for key, value in state_dict.items()}
45
-
46
- def load_model(self, load_to_cpu=False):
47
- #if load_to_cpu:
48
- # pretrained_dict = torch.load(self.pretrained_path, map_location=lambda storage, loc: storage)
49
- #else:
50
- # pretrained_dict = torch.load(self.pretrained_path, map_location=lambda storage, loc: storage.cuda())
51
- pretrained_dict = torch.load(self.pretrained_path, map_location=torch.device('cpu'))
52
- if "state_dict" in pretrained_dict.keys():
53
- pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'], 'module.')
54
- else:
55
- pretrained_dict = self.remove_prefix(pretrained_dict, 'module.')
56
- self.check_keys(pretrained_dict)
57
- self.net.load_state_dict(pretrained_dict, strict=False)
58
- self.net.eval()
59
-
60
- def detect(self, img_raw, resize=1, confidence_threshold=0.9, nms_threshold=0.4, top_k=5000, keep_top_k=750, save_image=False):
61
- img = np.float32(img_raw)
62
-
63
- im_height, im_width = img.shape[:2]
64
- ss = 1.0
65
- # tricky
66
- if max(im_height, im_width) > 1500:
67
- ss = 1000.0/max(im_height, im_width)
68
- img = cv2.resize(img, (0,0), fx=ss, fy=ss)
69
- im_height, im_width = img.shape[:2]
70
-
71
- scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
72
- img -= (104, 117, 123)
73
- img = img.transpose(2, 0, 1)
74
- img = torch.from_numpy(img).unsqueeze(0)
75
- img = img.to(self.device)
76
- scale = scale.to(self.device)
77
-
78
- with torch.no_grad():
79
- loc, conf, landms = self.net(img) # forward pass
80
-
81
- priorbox = PriorBox(self.cfg, image_size=(im_height, im_width))
82
- priors = priorbox.forward()
83
- priors = priors.to(self.device)
84
- prior_data = priors.data
85
- boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance'])
86
- boxes = boxes * scale / resize
87
- boxes = boxes.cpu().numpy()
88
- scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
89
- landms = decode_landm(landms.data.squeeze(0), prior_data, self.cfg['variance'])
90
- scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
91
- img.shape[3], img.shape[2], img.shape[3], img.shape[2],
92
- img.shape[3], img.shape[2]])
93
- scale1 = scale1.to(self.device)
94
- landms = landms * scale1 / resize
95
- landms = landms.cpu().numpy()
96
-
97
- # ignore low scores
98
- inds = np.where(scores > confidence_threshold)[0]
99
- boxes = boxes[inds]
100
- landms = landms[inds]
101
- scores = scores[inds]
102
-
103
- # keep top-K before NMS
104
- order = scores.argsort()[::-1][:top_k]
105
- boxes = boxes[order]
106
- landms = landms[order]
107
- scores = scores[order]
108
-
109
- # do NMS
110
- dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
111
- keep = py_cpu_nms(dets, nms_threshold)
112
- # keep = nms(dets, nms_threshold,force_cpu=args.cpu)
113
- dets = dets[keep, :]
114
- landms = landms[keep]
115
-
116
- # keep top-K faster NMS
117
- dets = dets[:keep_top_k, :]
118
- landms = landms[:keep_top_k, :]
119
-
120
- # sort faces(delete)
121
- '''
122
- fscores = [det[4] for det in dets]
123
- sorted_idx = sorted(range(len(fscores)), key=lambda k:fscores[k], reverse=False) # sort index
124
- tmp = [landms[idx] for idx in sorted_idx]
125
- landms = np.asarray(tmp)
126
- '''
127
-
128
- landms = landms.reshape((-1, 5, 2))
129
- landms = landms.transpose((0, 2, 1))
130
- landms = landms.reshape(-1, 10, )
131
- return dets/ss, landms/ss
132
-
133
- def detect_tensor(self, img, resize=1, confidence_threshold=0.9, nms_threshold=0.4, top_k=5000, keep_top_k=750, save_image=False):
134
- im_height, im_width = img.shape[-2:]
135
- ss = 1000/max(im_height, im_width)
136
- img = F.interpolate(img, scale_factor=ss)
137
- im_height, im_width = img.shape[-2:]
138
- scale = torch.Tensor([im_width, im_height, im_width, im_height]).to(self.device)
139
- img -= self.mean
140
-
141
- loc, conf, landms = self.net(img) # forward pass
142
-
143
- priorbox = PriorBox(self.cfg, image_size=(im_height, im_width))
144
- priors = priorbox.forward()
145
- priors = priors.to(self.device)
146
- prior_data = priors.data
147
- boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance'])
148
- boxes = boxes * scale / resize
149
- boxes = boxes.cpu().numpy()
150
- scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
151
- landms = decode_landm(landms.data.squeeze(0), prior_data, self.cfg['variance'])
152
- scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
153
- img.shape[3], img.shape[2], img.shape[3], img.shape[2],
154
- img.shape[3], img.shape[2]])
155
- scale1 = scale1.to(self.device)
156
- landms = landms * scale1 / resize
157
- landms = landms.cpu().numpy()
158
-
159
- # ignore low scores
160
- inds = np.where(scores > confidence_threshold)[0]
161
- boxes = boxes[inds]
162
- landms = landms[inds]
163
- scores = scores[inds]
164
-
165
- # keep top-K before NMS
166
- order = scores.argsort()[::-1][:top_k]
167
- boxes = boxes[order]
168
- landms = landms[order]
169
- scores = scores[order]
170
-
171
- # do NMS
172
- dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
173
- keep = py_cpu_nms(dets, nms_threshold)
174
- # keep = nms(dets, nms_threshold,force_cpu=args.cpu)
175
- dets = dets[keep, :]
176
- landms = landms[keep]
177
-
178
- # keep top-K faster NMS
179
- dets = dets[:keep_top_k, :]
180
- landms = landms[:keep_top_k, :]
181
-
182
- # sort faces(delete)
183
- '''
184
- fscores = [det[4] for det in dets]
185
- sorted_idx = sorted(range(len(fscores)), key=lambda k:fscores[k], reverse=False) # sort index
186
- tmp = [landms[idx] for idx in sorted_idx]
187
- landms = np.asarray(tmp)
188
- '''
189
-
190
- landms = landms.reshape((-1, 5, 2))
191
- landms = landms.transpose((0, 2, 1))
192
- landms = landms.reshape(-1, 10, )
193
- return dets/ss, landms/ss
 
1
+ '''
2
+ @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
3
+ @author: yangxy (yangtao9009@gmail.com)
4
+ '''
5
+ import os
6
+ import torch
7
+ import torch.backends.cudnn as cudnn
8
+ import numpy as np
9
+ from videoretalking.third_part.GPEN.face_detect.data import cfg_re50
10
+ from videoretalking.third_part.GPEN.face_detect.layers.functions.prior_box import PriorBox
11
+ from videoretalking.third_part.GPEN.face_detect.utils.nms.py_cpu_nms import py_cpu_nms
12
+ import cv2
13
+ from videoretalking.third_part.GPEN.face_detect.facemodels.retinaface import RetinaFace
14
+ from videoretalking.third_part.GPEN.face_detect.utils.box_utils import decode, decode_landm
15
+ import time
16
+ import torch.nn.functional as F
17
+
18
+
19
+ class RetinaFaceDetection(object):
20
+ def __init__(self, base_dir, device='cuda', network='RetinaFace-R50'):
21
+ torch.set_grad_enabled(False)
22
+ cudnn.benchmark = True
23
+ self.pretrained_path = os.path.join(base_dir, network+'.pth')
24
+ self.device = device #torch.cuda.current_device()
25
+ self.cfg = cfg_re50
26
+ self.net = RetinaFace(cfg=self.cfg, phase='test')
27
+ self.load_model()
28
+ self.net = self.net.to(device)
29
+
30
+ self.mean = torch.tensor([[[[104]], [[117]], [[123]]]]).to(device)
31
+
32
+ def check_keys(self, pretrained_state_dict):
33
+ ckpt_keys = set(pretrained_state_dict.keys())
34
+ model_keys = set(self.net.state_dict().keys())
35
+ used_pretrained_keys = model_keys & ckpt_keys
36
+ unused_pretrained_keys = ckpt_keys - model_keys
37
+ missing_keys = model_keys - ckpt_keys
38
+ assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
39
+ return True
40
+
41
+ def remove_prefix(self, state_dict, prefix):
42
+ ''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
43
+ f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
44
+ return {f(key): value for key, value in state_dict.items()}
45
+
46
+ def load_model(self, load_to_cpu=False):
47
+ #if load_to_cpu:
48
+ # pretrained_dict = torch.load(self.pretrained_path, map_location=lambda storage, loc: storage)
49
+ #else:
50
+ # pretrained_dict = torch.load(self.pretrained_path, map_location=lambda storage, loc: storage.cuda())
51
+ pretrained_dict = torch.load(self.pretrained_path, map_location=torch.device('cpu'))
52
+ if "state_dict" in pretrained_dict.keys():
53
+ pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'], 'module.')
54
+ else:
55
+ pretrained_dict = self.remove_prefix(pretrained_dict, 'module.')
56
+ self.check_keys(pretrained_dict)
57
+ self.net.load_state_dict(pretrained_dict, strict=False)
58
+ self.net.eval()
59
+
60
+ def detect(self, img_raw, resize=1, confidence_threshold=0.9, nms_threshold=0.4, top_k=5000, keep_top_k=750, save_image=False):
61
+ img = np.float32(img_raw)
62
+
63
+ im_height, im_width = img.shape[:2]
64
+ ss = 1.0
65
+ # tricky
66
+ if max(im_height, im_width) > 1500:
67
+ ss = 1000.0/max(im_height, im_width)
68
+ img = cv2.resize(img, (0,0), fx=ss, fy=ss)
69
+ im_height, im_width = img.shape[:2]
70
+
71
+ scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
72
+ img -= (104, 117, 123)
73
+ img = img.transpose(2, 0, 1)
74
+ img = torch.from_numpy(img).unsqueeze(0)
75
+ img = img.to(self.device)
76
+ scale = scale.to(self.device)
77
+
78
+ with torch.no_grad():
79
+ loc, conf, landms = self.net(img) # forward pass
80
+
81
+ priorbox = PriorBox(self.cfg, image_size=(im_height, im_width))
82
+ priors = priorbox.forward()
83
+ priors = priors.to(self.device)
84
+ prior_data = priors.data
85
+ boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance'])
86
+ boxes = boxes * scale / resize
87
+ boxes = boxes.cpu().numpy()
88
+ scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
89
+ landms = decode_landm(landms.data.squeeze(0), prior_data, self.cfg['variance'])
90
+ scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
91
+ img.shape[3], img.shape[2], img.shape[3], img.shape[2],
92
+ img.shape[3], img.shape[2]])
93
+ scale1 = scale1.to(self.device)
94
+ landms = landms * scale1 / resize
95
+ landms = landms.cpu().numpy()
96
+
97
+ # ignore low scores
98
+ inds = np.where(scores > confidence_threshold)[0]
99
+ boxes = boxes[inds]
100
+ landms = landms[inds]
101
+ scores = scores[inds]
102
+
103
+ # keep top-K before NMS
104
+ order = scores.argsort()[::-1][:top_k]
105
+ boxes = boxes[order]
106
+ landms = landms[order]
107
+ scores = scores[order]
108
+
109
+ # do NMS
110
+ dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
111
+ keep = py_cpu_nms(dets, nms_threshold)
112
+ # keep = nms(dets, nms_threshold,force_cpu=args.cpu)
113
+ dets = dets[keep, :]
114
+ landms = landms[keep]
115
+
116
+ # keep top-K faster NMS
117
+ dets = dets[:keep_top_k, :]
118
+ landms = landms[:keep_top_k, :]
119
+
120
+ # sort faces(delete)
121
+ '''
122
+ fscores = [det[4] for det in dets]
123
+ sorted_idx = sorted(range(len(fscores)), key=lambda k:fscores[k], reverse=False) # sort index
124
+ tmp = [landms[idx] for idx in sorted_idx]
125
+ landms = np.asarray(tmp)
126
+ '''
127
+
128
+ landms = landms.reshape((-1, 5, 2))
129
+ landms = landms.transpose((0, 2, 1))
130
+ landms = landms.reshape(-1, 10, )
131
+ return dets/ss, landms/ss
132
+
133
+ def detect_tensor(self, img, resize=1, confidence_threshold=0.9, nms_threshold=0.4, top_k=5000, keep_top_k=750, save_image=False):
134
+ im_height, im_width = img.shape[-2:]
135
+ ss = 1000/max(im_height, im_width)
136
+ img = F.interpolate(img, scale_factor=ss)
137
+ im_height, im_width = img.shape[-2:]
138
+ scale = torch.Tensor([im_width, im_height, im_width, im_height]).to(self.device)
139
+ img -= self.mean
140
+
141
+ loc, conf, landms = self.net(img) # forward pass
142
+
143
+ priorbox = PriorBox(self.cfg, image_size=(im_height, im_width))
144
+ priors = priorbox.forward()
145
+ priors = priors.to(self.device)
146
+ prior_data = priors.data
147
+ boxes = decode(loc.data.squeeze(0), prior_data, self.cfg['variance'])
148
+ boxes = boxes * scale / resize
149
+ boxes = boxes.cpu().numpy()
150
+ scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
151
+ landms = decode_landm(landms.data.squeeze(0), prior_data, self.cfg['variance'])
152
+ scale1 = torch.Tensor([img.shape[3], img.shape[2], img.shape[3], img.shape[2],
153
+ img.shape[3], img.shape[2], img.shape[3], img.shape[2],
154
+ img.shape[3], img.shape[2]])
155
+ scale1 = scale1.to(self.device)
156
+ landms = landms * scale1 / resize
157
+ landms = landms.cpu().numpy()
158
+
159
+ # ignore low scores
160
+ inds = np.where(scores > confidence_threshold)[0]
161
+ boxes = boxes[inds]
162
+ landms = landms[inds]
163
+ scores = scores[inds]
164
+
165
+ # keep top-K before NMS
166
+ order = scores.argsort()[::-1][:top_k]
167
+ boxes = boxes[order]
168
+ landms = landms[order]
169
+ scores = scores[order]
170
+
171
+ # do NMS
172
+ dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
173
+ keep = py_cpu_nms(dets, nms_threshold)
174
+ # keep = nms(dets, nms_threshold,force_cpu=args.cpu)
175
+ dets = dets[keep, :]
176
+ landms = landms[keep]
177
+
178
+ # keep top-K faster NMS
179
+ dets = dets[:keep_top_k, :]
180
+ landms = landms[:keep_top_k, :]
181
+
182
+ # sort faces(delete)
183
+ '''
184
+ fscores = [det[4] for det in dets]
185
+ sorted_idx = sorted(range(len(fscores)), key=lambda k:fscores[k], reverse=False) # sort index
186
+ tmp = [landms[idx] for idx in sorted_idx]
187
+ landms = np.asarray(tmp)
188
+ '''
189
+
190
+ landms = landms.reshape((-1, 5, 2))
191
+ landms = landms.transpose((0, 2, 1))
192
+ landms = landms.reshape(-1, 10, )
193
+ return dets/ss, landms/ss