English
TVR-Ranking / modules /contrastive.py
Liangrj5
init
5019d3f
import torch
import math
import torch.nn.functional as F
def log_sum_exp(x, axis=None):
"""
Log sum exp function
Args:
x: Input.
axis: Axis over which to perform sum.
Returns:
torch.Tensor: log sum exp
"""
x_max = torch.max(x, axis)[0]
y = torch.log((torch.exp(x - x_max)).sum(axis)) + x_max
return y
def get_positive_expectation(p_samples, measure='JSD', average=True):
"""
Computes the positive part of a divergence / difference.
Args:
p_samples: Positive samples.
measure: Measure to compute for.
average: Average the result over samples.
Returns:
torch.Tensor
"""
log_2 = math.log(2.)
if measure == 'GAN':
Ep = - F.softplus(-p_samples)
elif measure == 'JSD':
Ep = log_2 - F.softplus(-p_samples)
elif measure == 'X2':
Ep = p_samples ** 2
elif measure == 'KL':
Ep = p_samples + 1.
elif measure == 'RKL':
Ep = -torch.exp(-p_samples)
elif measure == 'DV':
Ep = p_samples
elif measure == 'H2':
Ep = torch.ones_like(p_samples) - torch.exp(-p_samples)
elif measure == 'W1':
Ep = p_samples
else:
raise ValueError('Unknown measurement {}'.format(measure))
if average:
return Ep.mean()
else:
return Ep
def get_negative_expectation(q_samples, measure='JSD', average=True):
"""
Computes the negative part of a divergence / difference.
Args:
q_samples: Negative samples.
measure: Measure to compute for.
average: Average the result over samples.
Returns:
torch.Tensor
"""
log_2 = math.log(2.)
if measure == 'GAN':
Eq = F.softplus(-q_samples) + q_samples
elif measure == 'JSD':
Eq = F.softplus(-q_samples) + q_samples - log_2
elif measure == 'X2':
Eq = -0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2)
elif measure == 'KL':
Eq = torch.exp(q_samples)
elif measure == 'RKL':
Eq = q_samples - 1.
elif measure == 'DV':
Eq = log_sum_exp(q_samples, 0) - math.log(q_samples.size(0))
elif measure == 'H2':
Eq = torch.exp(q_samples) - 1.
elif measure == 'W1':
Eq = q_samples
else:
raise ValueError('Unknown measurement {}'.format(measure))
if average:
return Eq.mean()
else:
return Eq
def batch_video_query_loss(video, query, match_labels, mask, measure='JSD'):
"""
QV-CL module
Computing the Contrastive Loss between the video and query.
:param video: video rep (bsz, Lv, dim)
:param query: query rep (bsz, dim)
:param match_labels: match labels (bsz, Lv)
:param mask: mask (bsz, Lv)
:param measure: estimator of the mutual information
:return: L_{qv}
"""
# generate mask
pos_mask = match_labels.type(torch.float32) # (bsz, Lv)
neg_mask = (torch.ones_like(pos_mask) - pos_mask) * mask # (bsz, Lv)
# compute scores
query = query.unsqueeze(2) # (bsz, dim, 1)
res = torch.matmul(video, query).squeeze(2) # (bsz, Lv)
# computing expectation for the MI between the target moment (positive samples) and query.
E_pos = get_positive_expectation(res * pos_mask, measure, average=False)
E_pos = torch.sum(E_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12) # (bsz, )
# computing expectation for the MI between clips except target moment (negative samples) and query.
E_neg = get_negative_expectation(res * neg_mask, measure, average=False)
E_neg = torch.sum(E_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12) # (bsz, )
E = E_neg - E_pos # (bsz, )
# return torch.mean(E)
return E
def batch_video_video_loss(video, st_ed_indices, match_labels, mask, measure='JSD'):
"""
VV-CL module
Computing the Contrastive loss between the start/end clips and the video
:param video: video rep (bsz, Lv, dim)
:param st_ed_indices: (bsz, 2)
:param match_labels: match labels (bsz, Lv)
:param mask: mask (bsz, Lv)
:param measure: estimator of the mutual information
:return: L_{vv}
"""
# generate mask
pos_mask = match_labels.type(torch.float32) # (bsz, Lv)
neg_mask = (torch.ones_like(pos_mask) - pos_mask) * mask # (bsz, Lv)
# select start and end indices features
st_indices, ed_indices = st_ed_indices[:, 0], st_ed_indices[:, 1] # (bsz, )
batch_indices = torch.arange(0, video.shape[0]).long() # (bsz, )
video_s = video[batch_indices, st_indices, :] # (bsz, dim)
video_e = video[batch_indices, ed_indices, :] # (bsz, dim)
# compute scores
video_s = video_s.unsqueeze(2) # (bsz, dim, 1)
res_s = torch.matmul(video, video_s).squeeze(2) # (bsz, Lv), fusion between the start clips and the video
video_e = video_e.unsqueeze(2) # (bsz, dim, 1)
res_e = torch.matmul(video, video_e).squeeze(2) # (bsz, Lv), fusion between the end clips and the video
# start clips: MI expectation for all positive samples
E_s_pos = get_positive_expectation(res_s * pos_mask, measure, average=False)
E_s_pos = torch.sum(E_s_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12) # (bsz, )
# end clips: MI expectation for all positive samples
E_e_pos = get_positive_expectation(res_e * pos_mask, measure, average=False)
E_e_pos = torch.sum(E_e_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12)
E_pos = E_s_pos + E_e_pos
# start clips: MI expectation for all negative samples
E_s_neg = get_negative_expectation(res_s * neg_mask, measure, average=False)
E_s_neg = torch.sum(E_s_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12)
# end clips: MI expectation for all negative samples
E_e_neg = get_negative_expectation(res_e * neg_mask, measure, average=False)
E_e_neg = torch.sum(E_e_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12)
E_neg = E_s_neg + E_e_neg
E = E_neg - E_pos # (bsz, )
return torch.mean(E)