Spanicin commited on
Commit
b9a690d
·
verified ·
1 Parent(s): ad9d237

Update videoretalking/third_part/GPEN/face_detect/layers/modules/multibox_loss.py

Browse files
videoretalking/third_part/GPEN/face_detect/layers/modules/multibox_loss.py CHANGED
@@ -1,125 +1,125 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torch.autograd import Variable
5
- from face_detect.utils.box_utils import match, log_sum_exp
6
- from face_detect.data import cfg_mnet
7
- GPU = cfg_mnet['gpu_train']
8
-
9
- class MultiBoxLoss(nn.Module):
10
- """SSD Weighted Loss Function
11
- Compute Targets:
12
- 1) Produce Confidence Target Indices by matching ground truth boxes
13
- with (default) 'priorboxes' that have jaccard index > threshold parameter
14
- (default threshold: 0.5).
15
- 2) Produce localization target by 'encoding' variance into offsets of ground
16
- truth boxes and their matched 'priorboxes'.
17
- 3) Hard negative mining to filter the excessive number of negative examples
18
- that comes with using a large number of default bounding boxes.
19
- (default negative:positive ratio 3:1)
20
- Objective Loss:
21
- L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
22
- Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
23
- weighted by α which is set to 1 by cross val.
24
- Args:
25
- c: class confidences,
26
- l: predicted boxes,
27
- g: ground truth boxes
28
- N: number of matched default boxes
29
- See: https://arxiv.org/pdf/1512.02325.pdf for more details.
30
- """
31
-
32
- def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target):
33
- super(MultiBoxLoss, self).__init__()
34
- self.num_classes = num_classes
35
- self.threshold = overlap_thresh
36
- self.background_label = bkg_label
37
- self.encode_target = encode_target
38
- self.use_prior_for_matching = prior_for_matching
39
- self.do_neg_mining = neg_mining
40
- self.negpos_ratio = neg_pos
41
- self.neg_overlap = neg_overlap
42
- self.variance = [0.1, 0.2]
43
-
44
- def forward(self, predictions, priors, targets):
45
- """Multibox Loss
46
- Args:
47
- predictions (tuple): A tuple containing loc preds, conf preds,
48
- and prior boxes from SSD net.
49
- conf shape: torch.size(batch_size,num_priors,num_classes)
50
- loc shape: torch.size(batch_size,num_priors,4)
51
- priors shape: torch.size(num_priors,4)
52
-
53
- ground_truth (tensor): Ground truth boxes and labels for a batch,
54
- shape: [batch_size,num_objs,5] (last idx is the label).
55
- """
56
-
57
- loc_data, conf_data, landm_data = predictions
58
- priors = priors
59
- num = loc_data.size(0)
60
- num_priors = (priors.size(0))
61
-
62
- # match priors (default boxes) and ground truth boxes
63
- loc_t = torch.Tensor(num, num_priors, 4)
64
- landm_t = torch.Tensor(num, num_priors, 10)
65
- conf_t = torch.LongTensor(num, num_priors)
66
- for idx in range(num):
67
- truths = targets[idx][:, :4].data
68
- labels = targets[idx][:, -1].data
69
- landms = targets[idx][:, 4:14].data
70
- defaults = priors.data
71
- match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx)
72
- if GPU:
73
- loc_t = loc_t.cuda()
74
- conf_t = conf_t.cuda()
75
- landm_t = landm_t.cuda()
76
-
77
- zeros = torch.tensor(0).cuda()
78
- # landm Loss (Smooth L1)
79
- # Shape: [batch,num_priors,10]
80
- pos1 = conf_t > zeros
81
- num_pos_landm = pos1.long().sum(1, keepdim=True)
82
- N1 = max(num_pos_landm.data.sum().float(), 1)
83
- pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data)
84
- landm_p = landm_data[pos_idx1].view(-1, 10)
85
- landm_t = landm_t[pos_idx1].view(-1, 10)
86
- loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum')
87
-
88
-
89
- pos = conf_t != zeros
90
- conf_t[pos] = 1
91
-
92
- # Localization Loss (Smooth L1)
93
- # Shape: [batch,num_priors,4]
94
- pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
95
- loc_p = loc_data[pos_idx].view(-1, 4)
96
- loc_t = loc_t[pos_idx].view(-1, 4)
97
- loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
98
-
99
- # Compute max conf across batch for hard negative mining
100
- batch_conf = conf_data.view(-1, self.num_classes)
101
- loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
102
-
103
- # Hard Negative Mining
104
- loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now
105
- loss_c = loss_c.view(num, -1)
106
- _, loss_idx = loss_c.sort(1, descending=True)
107
- _, idx_rank = loss_idx.sort(1)
108
- num_pos = pos.long().sum(1, keepdim=True)
109
- num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
110
- neg = idx_rank < num_neg.expand_as(idx_rank)
111
-
112
- # Confidence Loss Including Positive and Negative Examples
113
- pos_idx = pos.unsqueeze(2).expand_as(conf_data)
114
- neg_idx = neg.unsqueeze(2).expand_as(conf_data)
115
- conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes)
116
- targets_weighted = conf_t[(pos+neg).gt(0)]
117
- loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
118
-
119
- # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
120
- N = max(num_pos.data.sum().float(), 1)
121
- loss_l /= N
122
- loss_c /= N
123
- loss_landm /= N1
124
-
125
- return loss_l, loss_c, loss_landm
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Variable
5
+ from videoretalking.third_part.GPEN.face_detect.utils.box_utils import match, log_sum_exp
6
+ from videoretalking.third_part.GPEN.face_detect.data import cfg_mnet
7
+ GPU = cfg_mnet['gpu_train']
8
+
9
+ class MultiBoxLoss(nn.Module):
10
+ """SSD Weighted Loss Function
11
+ Compute Targets:
12
+ 1) Produce Confidence Target Indices by matching ground truth boxes
13
+ with (default) 'priorboxes' that have jaccard index > threshold parameter
14
+ (default threshold: 0.5).
15
+ 2) Produce localization target by 'encoding' variance into offsets of ground
16
+ truth boxes and their matched 'priorboxes'.
17
+ 3) Hard negative mining to filter the excessive number of negative examples
18
+ that comes with using a large number of default bounding boxes.
19
+ (default negative:positive ratio 3:1)
20
+ Objective Loss:
21
+ L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
22
+ Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
23
+ weighted by α which is set to 1 by cross val.
24
+ Args:
25
+ c: class confidences,
26
+ l: predicted boxes,
27
+ g: ground truth boxes
28
+ N: number of matched default boxes
29
+ See: https://arxiv.org/pdf/1512.02325.pdf for more details.
30
+ """
31
+
32
+ def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target):
33
+ super(MultiBoxLoss, self).__init__()
34
+ self.num_classes = num_classes
35
+ self.threshold = overlap_thresh
36
+ self.background_label = bkg_label
37
+ self.encode_target = encode_target
38
+ self.use_prior_for_matching = prior_for_matching
39
+ self.do_neg_mining = neg_mining
40
+ self.negpos_ratio = neg_pos
41
+ self.neg_overlap = neg_overlap
42
+ self.variance = [0.1, 0.2]
43
+
44
+ def forward(self, predictions, priors, targets):
45
+ """Multibox Loss
46
+ Args:
47
+ predictions (tuple): A tuple containing loc preds, conf preds,
48
+ and prior boxes from SSD net.
49
+ conf shape: torch.size(batch_size,num_priors,num_classes)
50
+ loc shape: torch.size(batch_size,num_priors,4)
51
+ priors shape: torch.size(num_priors,4)
52
+
53
+ ground_truth (tensor): Ground truth boxes and labels for a batch,
54
+ shape: [batch_size,num_objs,5] (last idx is the label).
55
+ """
56
+
57
+ loc_data, conf_data, landm_data = predictions
58
+ priors = priors
59
+ num = loc_data.size(0)
60
+ num_priors = (priors.size(0))
61
+
62
+ # match priors (default boxes) and ground truth boxes
63
+ loc_t = torch.Tensor(num, num_priors, 4)
64
+ landm_t = torch.Tensor(num, num_priors, 10)
65
+ conf_t = torch.LongTensor(num, num_priors)
66
+ for idx in range(num):
67
+ truths = targets[idx][:, :4].data
68
+ labels = targets[idx][:, -1].data
69
+ landms = targets[idx][:, 4:14].data
70
+ defaults = priors.data
71
+ match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx)
72
+ if GPU:
73
+ loc_t = loc_t.cuda()
74
+ conf_t = conf_t.cuda()
75
+ landm_t = landm_t.cuda()
76
+
77
+ zeros = torch.tensor(0).cuda()
78
+ # landm Loss (Smooth L1)
79
+ # Shape: [batch,num_priors,10]
80
+ pos1 = conf_t > zeros
81
+ num_pos_landm = pos1.long().sum(1, keepdim=True)
82
+ N1 = max(num_pos_landm.data.sum().float(), 1)
83
+ pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data)
84
+ landm_p = landm_data[pos_idx1].view(-1, 10)
85
+ landm_t = landm_t[pos_idx1].view(-1, 10)
86
+ loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum')
87
+
88
+
89
+ pos = conf_t != zeros
90
+ conf_t[pos] = 1
91
+
92
+ # Localization Loss (Smooth L1)
93
+ # Shape: [batch,num_priors,4]
94
+ pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
95
+ loc_p = loc_data[pos_idx].view(-1, 4)
96
+ loc_t = loc_t[pos_idx].view(-1, 4)
97
+ loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
98
+
99
+ # Compute max conf across batch for hard negative mining
100
+ batch_conf = conf_data.view(-1, self.num_classes)
101
+ loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
102
+
103
+ # Hard Negative Mining
104
+ loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now
105
+ loss_c = loss_c.view(num, -1)
106
+ _, loss_idx = loss_c.sort(1, descending=True)
107
+ _, idx_rank = loss_idx.sort(1)
108
+ num_pos = pos.long().sum(1, keepdim=True)
109
+ num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
110
+ neg = idx_rank < num_neg.expand_as(idx_rank)
111
+
112
+ # Confidence Loss Including Positive and Negative Examples
113
+ pos_idx = pos.unsqueeze(2).expand_as(conf_data)
114
+ neg_idx = neg.unsqueeze(2).expand_as(conf_data)
115
+ conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes)
116
+ targets_weighted = conf_t[(pos+neg).gt(0)]
117
+ loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
118
+
119
+ # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
120
+ N = max(num_pos.data.sum().float(), 1)
121
+ loss_l /= N
122
+ loss_c /= N
123
+ loss_landm /= N1
124
+
125
+ return loss_l, loss_c, loss_landm