File size: 6,095 Bytes
5019d3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import torch
import math
import torch.nn.functional as F
def log_sum_exp(x, axis=None):
"""
Log sum exp function
Args:
x: Input.
axis: Axis over which to perform sum.
Returns:
torch.Tensor: log sum exp
"""
x_max = torch.max(x, axis)[0]
y = torch.log((torch.exp(x - x_max)).sum(axis)) + x_max
return y
def get_positive_expectation(p_samples, measure='JSD', average=True):
"""
Computes the positive part of a divergence / difference.
Args:
p_samples: Positive samples.
measure: Measure to compute for.
average: Average the result over samples.
Returns:
torch.Tensor
"""
log_2 = math.log(2.)
if measure == 'GAN':
Ep = - F.softplus(-p_samples)
elif measure == 'JSD':
Ep = log_2 - F.softplus(-p_samples)
elif measure == 'X2':
Ep = p_samples ** 2
elif measure == 'KL':
Ep = p_samples + 1.
elif measure == 'RKL':
Ep = -torch.exp(-p_samples)
elif measure == 'DV':
Ep = p_samples
elif measure == 'H2':
Ep = torch.ones_like(p_samples) - torch.exp(-p_samples)
elif measure == 'W1':
Ep = p_samples
else:
raise ValueError('Unknown measurement {}'.format(measure))
if average:
return Ep.mean()
else:
return Ep
def get_negative_expectation(q_samples, measure='JSD', average=True):
"""
Computes the negative part of a divergence / difference.
Args:
q_samples: Negative samples.
measure: Measure to compute for.
average: Average the result over samples.
Returns:
torch.Tensor
"""
log_2 = math.log(2.)
if measure == 'GAN':
Eq = F.softplus(-q_samples) + q_samples
elif measure == 'JSD':
Eq = F.softplus(-q_samples) + q_samples - log_2
elif measure == 'X2':
Eq = -0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2)
elif measure == 'KL':
Eq = torch.exp(q_samples)
elif measure == 'RKL':
Eq = q_samples - 1.
elif measure == 'DV':
Eq = log_sum_exp(q_samples, 0) - math.log(q_samples.size(0))
elif measure == 'H2':
Eq = torch.exp(q_samples) - 1.
elif measure == 'W1':
Eq = q_samples
else:
raise ValueError('Unknown measurement {}'.format(measure))
if average:
return Eq.mean()
else:
return Eq
def batch_video_query_loss(video, query, match_labels, mask, measure='JSD'):
"""
QV-CL module
Computing the Contrastive Loss between the video and query.
:param video: video rep (bsz, Lv, dim)
:param query: query rep (bsz, dim)
:param match_labels: match labels (bsz, Lv)
:param mask: mask (bsz, Lv)
:param measure: estimator of the mutual information
:return: L_{qv}
"""
# generate mask
pos_mask = match_labels.type(torch.float32) # (bsz, Lv)
neg_mask = (torch.ones_like(pos_mask) - pos_mask) * mask # (bsz, Lv)
# compute scores
query = query.unsqueeze(2) # (bsz, dim, 1)
res = torch.matmul(video, query).squeeze(2) # (bsz, Lv)
# computing expectation for the MI between the target moment (positive samples) and query.
E_pos = get_positive_expectation(res * pos_mask, measure, average=False)
E_pos = torch.sum(E_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12) # (bsz, )
# computing expectation for the MI between clips except target moment (negative samples) and query.
E_neg = get_negative_expectation(res * neg_mask, measure, average=False)
E_neg = torch.sum(E_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12) # (bsz, )
E = E_neg - E_pos # (bsz, )
# return torch.mean(E)
return E
def batch_video_video_loss(video, st_ed_indices, match_labels, mask, measure='JSD'):
"""
VV-CL module
Computing the Contrastive loss between the start/end clips and the video
:param video: video rep (bsz, Lv, dim)
:param st_ed_indices: (bsz, 2)
:param match_labels: match labels (bsz, Lv)
:param mask: mask (bsz, Lv)
:param measure: estimator of the mutual information
:return: L_{vv}
"""
# generate mask
pos_mask = match_labels.type(torch.float32) # (bsz, Lv)
neg_mask = (torch.ones_like(pos_mask) - pos_mask) * mask # (bsz, Lv)
# select start and end indices features
st_indices, ed_indices = st_ed_indices[:, 0], st_ed_indices[:, 1] # (bsz, )
batch_indices = torch.arange(0, video.shape[0]).long() # (bsz, )
video_s = video[batch_indices, st_indices, :] # (bsz, dim)
video_e = video[batch_indices, ed_indices, :] # (bsz, dim)
# compute scores
video_s = video_s.unsqueeze(2) # (bsz, dim, 1)
res_s = torch.matmul(video, video_s).squeeze(2) # (bsz, Lv), fusion between the start clips and the video
video_e = video_e.unsqueeze(2) # (bsz, dim, 1)
res_e = torch.matmul(video, video_e).squeeze(2) # (bsz, Lv), fusion between the end clips and the video
# start clips: MI expectation for all positive samples
E_s_pos = get_positive_expectation(res_s * pos_mask, measure, average=False)
E_s_pos = torch.sum(E_s_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12) # (bsz, )
# end clips: MI expectation for all positive samples
E_e_pos = get_positive_expectation(res_e * pos_mask, measure, average=False)
E_e_pos = torch.sum(E_e_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12)
E_pos = E_s_pos + E_e_pos
# start clips: MI expectation for all negative samples
E_s_neg = get_negative_expectation(res_s * neg_mask, measure, average=False)
E_s_neg = torch.sum(E_s_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12)
# end clips: MI expectation for all negative samples
E_e_neg = get_negative_expectation(res_e * neg_mask, measure, average=False)
E_e_neg = torch.sum(E_e_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12)
E_neg = E_s_neg + E_e_neg
E = E_neg - E_pos # (bsz, )
return torch.mean(E)
|