Spaces:
Paused
Paused
Update videoretalking/third_part/GPEN/face_detect/layers/modules/multibox_loss.py
Browse files
videoretalking/third_part/GPEN/face_detect/layers/modules/multibox_loss.py
CHANGED
@@ -1,125 +1,125 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
from torch.autograd import Variable
|
5 |
-
from face_detect.utils.box_utils import match, log_sum_exp
|
6 |
-
from face_detect.data import cfg_mnet
|
7 |
-
GPU = cfg_mnet['gpu_train']
|
8 |
-
|
9 |
-
class MultiBoxLoss(nn.Module):
|
10 |
-
"""SSD Weighted Loss Function
|
11 |
-
Compute Targets:
|
12 |
-
1) Produce Confidence Target Indices by matching ground truth boxes
|
13 |
-
with (default) 'priorboxes' that have jaccard index > threshold parameter
|
14 |
-
(default threshold: 0.5).
|
15 |
-
2) Produce localization target by 'encoding' variance into offsets of ground
|
16 |
-
truth boxes and their matched 'priorboxes'.
|
17 |
-
3) Hard negative mining to filter the excessive number of negative examples
|
18 |
-
that comes with using a large number of default bounding boxes.
|
19 |
-
(default negative:positive ratio 3:1)
|
20 |
-
Objective Loss:
|
21 |
-
L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
|
22 |
-
Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
|
23 |
-
weighted by α which is set to 1 by cross val.
|
24 |
-
Args:
|
25 |
-
c: class confidences,
|
26 |
-
l: predicted boxes,
|
27 |
-
g: ground truth boxes
|
28 |
-
N: number of matched default boxes
|
29 |
-
See: https://arxiv.org/pdf/1512.02325.pdf for more details.
|
30 |
-
"""
|
31 |
-
|
32 |
-
def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target):
|
33 |
-
super(MultiBoxLoss, self).__init__()
|
34 |
-
self.num_classes = num_classes
|
35 |
-
self.threshold = overlap_thresh
|
36 |
-
self.background_label = bkg_label
|
37 |
-
self.encode_target = encode_target
|
38 |
-
self.use_prior_for_matching = prior_for_matching
|
39 |
-
self.do_neg_mining = neg_mining
|
40 |
-
self.negpos_ratio = neg_pos
|
41 |
-
self.neg_overlap = neg_overlap
|
42 |
-
self.variance = [0.1, 0.2]
|
43 |
-
|
44 |
-
def forward(self, predictions, priors, targets):
|
45 |
-
"""Multibox Loss
|
46 |
-
Args:
|
47 |
-
predictions (tuple): A tuple containing loc preds, conf preds,
|
48 |
-
and prior boxes from SSD net.
|
49 |
-
conf shape: torch.size(batch_size,num_priors,num_classes)
|
50 |
-
loc shape: torch.size(batch_size,num_priors,4)
|
51 |
-
priors shape: torch.size(num_priors,4)
|
52 |
-
|
53 |
-
ground_truth (tensor): Ground truth boxes and labels for a batch,
|
54 |
-
shape: [batch_size,num_objs,5] (last idx is the label).
|
55 |
-
"""
|
56 |
-
|
57 |
-
loc_data, conf_data, landm_data = predictions
|
58 |
-
priors = priors
|
59 |
-
num = loc_data.size(0)
|
60 |
-
num_priors = (priors.size(0))
|
61 |
-
|
62 |
-
# match priors (default boxes) and ground truth boxes
|
63 |
-
loc_t = torch.Tensor(num, num_priors, 4)
|
64 |
-
landm_t = torch.Tensor(num, num_priors, 10)
|
65 |
-
conf_t = torch.LongTensor(num, num_priors)
|
66 |
-
for idx in range(num):
|
67 |
-
truths = targets[idx][:, :4].data
|
68 |
-
labels = targets[idx][:, -1].data
|
69 |
-
landms = targets[idx][:, 4:14].data
|
70 |
-
defaults = priors.data
|
71 |
-
match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx)
|
72 |
-
if GPU:
|
73 |
-
loc_t = loc_t.cuda()
|
74 |
-
conf_t = conf_t.cuda()
|
75 |
-
landm_t = landm_t.cuda()
|
76 |
-
|
77 |
-
zeros = torch.tensor(0).cuda()
|
78 |
-
# landm Loss (Smooth L1)
|
79 |
-
# Shape: [batch,num_priors,10]
|
80 |
-
pos1 = conf_t > zeros
|
81 |
-
num_pos_landm = pos1.long().sum(1, keepdim=True)
|
82 |
-
N1 = max(num_pos_landm.data.sum().float(), 1)
|
83 |
-
pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data)
|
84 |
-
landm_p = landm_data[pos_idx1].view(-1, 10)
|
85 |
-
landm_t = landm_t[pos_idx1].view(-1, 10)
|
86 |
-
loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum')
|
87 |
-
|
88 |
-
|
89 |
-
pos = conf_t != zeros
|
90 |
-
conf_t[pos] = 1
|
91 |
-
|
92 |
-
# Localization Loss (Smooth L1)
|
93 |
-
# Shape: [batch,num_priors,4]
|
94 |
-
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
|
95 |
-
loc_p = loc_data[pos_idx].view(-1, 4)
|
96 |
-
loc_t = loc_t[pos_idx].view(-1, 4)
|
97 |
-
loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
|
98 |
-
|
99 |
-
# Compute max conf across batch for hard negative mining
|
100 |
-
batch_conf = conf_data.view(-1, self.num_classes)
|
101 |
-
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
|
102 |
-
|
103 |
-
# Hard Negative Mining
|
104 |
-
loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now
|
105 |
-
loss_c = loss_c.view(num, -1)
|
106 |
-
_, loss_idx = loss_c.sort(1, descending=True)
|
107 |
-
_, idx_rank = loss_idx.sort(1)
|
108 |
-
num_pos = pos.long().sum(1, keepdim=True)
|
109 |
-
num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
|
110 |
-
neg = idx_rank < num_neg.expand_as(idx_rank)
|
111 |
-
|
112 |
-
# Confidence Loss Including Positive and Negative Examples
|
113 |
-
pos_idx = pos.unsqueeze(2).expand_as(conf_data)
|
114 |
-
neg_idx = neg.unsqueeze(2).expand_as(conf_data)
|
115 |
-
conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes)
|
116 |
-
targets_weighted = conf_t[(pos+neg).gt(0)]
|
117 |
-
loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
|
118 |
-
|
119 |
-
# Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
|
120 |
-
N = max(num_pos.data.sum().float(), 1)
|
121 |
-
loss_l /= N
|
122 |
-
loss_c /= N
|
123 |
-
loss_landm /= N1
|
124 |
-
|
125 |
-
return loss_l, loss_c, loss_landm
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.autograd import Variable
|
5 |
+
from videoretalking.third_part.GPEN.face_detect.utils.box_utils import match, log_sum_exp
|
6 |
+
from videoretalking.third_part.GPEN.face_detect.data import cfg_mnet
|
7 |
+
GPU = cfg_mnet['gpu_train']
|
8 |
+
|
9 |
+
class MultiBoxLoss(nn.Module):
|
10 |
+
"""SSD Weighted Loss Function
|
11 |
+
Compute Targets:
|
12 |
+
1) Produce Confidence Target Indices by matching ground truth boxes
|
13 |
+
with (default) 'priorboxes' that have jaccard index > threshold parameter
|
14 |
+
(default threshold: 0.5).
|
15 |
+
2) Produce localization target by 'encoding' variance into offsets of ground
|
16 |
+
truth boxes and their matched 'priorboxes'.
|
17 |
+
3) Hard negative mining to filter the excessive number of negative examples
|
18 |
+
that comes with using a large number of default bounding boxes.
|
19 |
+
(default negative:positive ratio 3:1)
|
20 |
+
Objective Loss:
|
21 |
+
L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
|
22 |
+
Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
|
23 |
+
weighted by α which is set to 1 by cross val.
|
24 |
+
Args:
|
25 |
+
c: class confidences,
|
26 |
+
l: predicted boxes,
|
27 |
+
g: ground truth boxes
|
28 |
+
N: number of matched default boxes
|
29 |
+
See: https://arxiv.org/pdf/1512.02325.pdf for more details.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target):
|
33 |
+
super(MultiBoxLoss, self).__init__()
|
34 |
+
self.num_classes = num_classes
|
35 |
+
self.threshold = overlap_thresh
|
36 |
+
self.background_label = bkg_label
|
37 |
+
self.encode_target = encode_target
|
38 |
+
self.use_prior_for_matching = prior_for_matching
|
39 |
+
self.do_neg_mining = neg_mining
|
40 |
+
self.negpos_ratio = neg_pos
|
41 |
+
self.neg_overlap = neg_overlap
|
42 |
+
self.variance = [0.1, 0.2]
|
43 |
+
|
44 |
+
def forward(self, predictions, priors, targets):
|
45 |
+
"""Multibox Loss
|
46 |
+
Args:
|
47 |
+
predictions (tuple): A tuple containing loc preds, conf preds,
|
48 |
+
and prior boxes from SSD net.
|
49 |
+
conf shape: torch.size(batch_size,num_priors,num_classes)
|
50 |
+
loc shape: torch.size(batch_size,num_priors,4)
|
51 |
+
priors shape: torch.size(num_priors,4)
|
52 |
+
|
53 |
+
ground_truth (tensor): Ground truth boxes and labels for a batch,
|
54 |
+
shape: [batch_size,num_objs,5] (last idx is the label).
|
55 |
+
"""
|
56 |
+
|
57 |
+
loc_data, conf_data, landm_data = predictions
|
58 |
+
priors = priors
|
59 |
+
num = loc_data.size(0)
|
60 |
+
num_priors = (priors.size(0))
|
61 |
+
|
62 |
+
# match priors (default boxes) and ground truth boxes
|
63 |
+
loc_t = torch.Tensor(num, num_priors, 4)
|
64 |
+
landm_t = torch.Tensor(num, num_priors, 10)
|
65 |
+
conf_t = torch.LongTensor(num, num_priors)
|
66 |
+
for idx in range(num):
|
67 |
+
truths = targets[idx][:, :4].data
|
68 |
+
labels = targets[idx][:, -1].data
|
69 |
+
landms = targets[idx][:, 4:14].data
|
70 |
+
defaults = priors.data
|
71 |
+
match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx)
|
72 |
+
if GPU:
|
73 |
+
loc_t = loc_t.cuda()
|
74 |
+
conf_t = conf_t.cuda()
|
75 |
+
landm_t = landm_t.cuda()
|
76 |
+
|
77 |
+
zeros = torch.tensor(0).cuda()
|
78 |
+
# landm Loss (Smooth L1)
|
79 |
+
# Shape: [batch,num_priors,10]
|
80 |
+
pos1 = conf_t > zeros
|
81 |
+
num_pos_landm = pos1.long().sum(1, keepdim=True)
|
82 |
+
N1 = max(num_pos_landm.data.sum().float(), 1)
|
83 |
+
pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data)
|
84 |
+
landm_p = landm_data[pos_idx1].view(-1, 10)
|
85 |
+
landm_t = landm_t[pos_idx1].view(-1, 10)
|
86 |
+
loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum')
|
87 |
+
|
88 |
+
|
89 |
+
pos = conf_t != zeros
|
90 |
+
conf_t[pos] = 1
|
91 |
+
|
92 |
+
# Localization Loss (Smooth L1)
|
93 |
+
# Shape: [batch,num_priors,4]
|
94 |
+
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
|
95 |
+
loc_p = loc_data[pos_idx].view(-1, 4)
|
96 |
+
loc_t = loc_t[pos_idx].view(-1, 4)
|
97 |
+
loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
|
98 |
+
|
99 |
+
# Compute max conf across batch for hard negative mining
|
100 |
+
batch_conf = conf_data.view(-1, self.num_classes)
|
101 |
+
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
|
102 |
+
|
103 |
+
# Hard Negative Mining
|
104 |
+
loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now
|
105 |
+
loss_c = loss_c.view(num, -1)
|
106 |
+
_, loss_idx = loss_c.sort(1, descending=True)
|
107 |
+
_, idx_rank = loss_idx.sort(1)
|
108 |
+
num_pos = pos.long().sum(1, keepdim=True)
|
109 |
+
num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
|
110 |
+
neg = idx_rank < num_neg.expand_as(idx_rank)
|
111 |
+
|
112 |
+
# Confidence Loss Including Positive and Negative Examples
|
113 |
+
pos_idx = pos.unsqueeze(2).expand_as(conf_data)
|
114 |
+
neg_idx = neg.unsqueeze(2).expand_as(conf_data)
|
115 |
+
conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes)
|
116 |
+
targets_weighted = conf_t[(pos+neg).gt(0)]
|
117 |
+
loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
|
118 |
+
|
119 |
+
# Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
|
120 |
+
N = max(num_pos.data.sum().float(), 1)
|
121 |
+
loss_l /= N
|
122 |
+
loss_c /= N
|
123 |
+
loss_landm /= N1
|
124 |
+
|
125 |
+
return loss_l, loss_c, loss_landm
|