|
import torch |
|
import math |
|
import torch.nn.functional as F |
|
|
|
|
|
def log_sum_exp(x, axis=None): |
|
""" |
|
Log sum exp function |
|
Args: |
|
x: Input. |
|
axis: Axis over which to perform sum. |
|
Returns: |
|
torch.Tensor: log sum exp |
|
""" |
|
x_max = torch.max(x, axis)[0] |
|
y = torch.log((torch.exp(x - x_max)).sum(axis)) + x_max |
|
return y |
|
|
|
|
|
def get_positive_expectation(p_samples, measure='JSD', average=True): |
|
""" |
|
Computes the positive part of a divergence / difference. |
|
Args: |
|
p_samples: Positive samples. |
|
measure: Measure to compute for. |
|
average: Average the result over samples. |
|
Returns: |
|
torch.Tensor |
|
""" |
|
log_2 = math.log(2.) |
|
if measure == 'GAN': |
|
Ep = - F.softplus(-p_samples) |
|
elif measure == 'JSD': |
|
Ep = log_2 - F.softplus(-p_samples) |
|
elif measure == 'X2': |
|
Ep = p_samples ** 2 |
|
elif measure == 'KL': |
|
Ep = p_samples + 1. |
|
elif measure == 'RKL': |
|
Ep = -torch.exp(-p_samples) |
|
elif measure == 'DV': |
|
Ep = p_samples |
|
elif measure == 'H2': |
|
Ep = torch.ones_like(p_samples) - torch.exp(-p_samples) |
|
elif measure == 'W1': |
|
Ep = p_samples |
|
else: |
|
raise ValueError('Unknown measurement {}'.format(measure)) |
|
if average: |
|
return Ep.mean() |
|
else: |
|
return Ep |
|
|
|
|
|
def get_negative_expectation(q_samples, measure='JSD', average=True): |
|
""" |
|
Computes the negative part of a divergence / difference. |
|
Args: |
|
q_samples: Negative samples. |
|
measure: Measure to compute for. |
|
average: Average the result over samples. |
|
Returns: |
|
torch.Tensor |
|
""" |
|
log_2 = math.log(2.) |
|
if measure == 'GAN': |
|
Eq = F.softplus(-q_samples) + q_samples |
|
elif measure == 'JSD': |
|
Eq = F.softplus(-q_samples) + q_samples - log_2 |
|
elif measure == 'X2': |
|
Eq = -0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2) |
|
elif measure == 'KL': |
|
Eq = torch.exp(q_samples) |
|
elif measure == 'RKL': |
|
Eq = q_samples - 1. |
|
elif measure == 'DV': |
|
Eq = log_sum_exp(q_samples, 0) - math.log(q_samples.size(0)) |
|
elif measure == 'H2': |
|
Eq = torch.exp(q_samples) - 1. |
|
elif measure == 'W1': |
|
Eq = q_samples |
|
else: |
|
raise ValueError('Unknown measurement {}'.format(measure)) |
|
if average: |
|
return Eq.mean() |
|
else: |
|
return Eq |
|
|
|
|
|
def batch_video_query_loss(video, query, match_labels, mask, measure='JSD'): |
|
""" |
|
QV-CL module |
|
Computing the Contrastive Loss between the video and query. |
|
:param video: video rep (bsz, Lv, dim) |
|
:param query: query rep (bsz, dim) |
|
:param match_labels: match labels (bsz, Lv) |
|
:param mask: mask (bsz, Lv) |
|
:param measure: estimator of the mutual information |
|
:return: L_{qv} |
|
""" |
|
|
|
pos_mask = match_labels.type(torch.float32) |
|
neg_mask = (torch.ones_like(pos_mask) - pos_mask) * mask |
|
|
|
|
|
query = query.unsqueeze(2) |
|
res = torch.matmul(video, query).squeeze(2) |
|
|
|
|
|
E_pos = get_positive_expectation(res * pos_mask, measure, average=False) |
|
E_pos = torch.sum(E_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12) |
|
|
|
|
|
E_neg = get_negative_expectation(res * neg_mask, measure, average=False) |
|
E_neg = torch.sum(E_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12) |
|
|
|
E = E_neg - E_pos |
|
|
|
return E |
|
|
|
|
|
def batch_video_video_loss(video, st_ed_indices, match_labels, mask, measure='JSD'): |
|
""" |
|
VV-CL module |
|
Computing the Contrastive loss between the start/end clips and the video |
|
:param video: video rep (bsz, Lv, dim) |
|
:param st_ed_indices: (bsz, 2) |
|
:param match_labels: match labels (bsz, Lv) |
|
:param mask: mask (bsz, Lv) |
|
:param measure: estimator of the mutual information |
|
:return: L_{vv} |
|
""" |
|
|
|
pos_mask = match_labels.type(torch.float32) |
|
neg_mask = (torch.ones_like(pos_mask) - pos_mask) * mask |
|
|
|
|
|
st_indices, ed_indices = st_ed_indices[:, 0], st_ed_indices[:, 1] |
|
batch_indices = torch.arange(0, video.shape[0]).long() |
|
video_s = video[batch_indices, st_indices, :] |
|
video_e = video[batch_indices, ed_indices, :] |
|
|
|
|
|
video_s = video_s.unsqueeze(2) |
|
res_s = torch.matmul(video, video_s).squeeze(2) |
|
video_e = video_e.unsqueeze(2) |
|
res_e = torch.matmul(video, video_e).squeeze(2) |
|
|
|
|
|
E_s_pos = get_positive_expectation(res_s * pos_mask, measure, average=False) |
|
E_s_pos = torch.sum(E_s_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12) |
|
|
|
E_e_pos = get_positive_expectation(res_e * pos_mask, measure, average=False) |
|
E_e_pos = torch.sum(E_e_pos * pos_mask, dim=1) / (torch.sum(pos_mask, dim=1) + 1e-12) |
|
E_pos = E_s_pos + E_e_pos |
|
|
|
|
|
E_s_neg = get_negative_expectation(res_s * neg_mask, measure, average=False) |
|
E_s_neg = torch.sum(E_s_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12) |
|
|
|
|
|
E_e_neg = get_negative_expectation(res_e * neg_mask, measure, average=False) |
|
E_e_neg = torch.sum(E_e_neg * neg_mask, dim=1) / (torch.sum(neg_mask, dim=1) + 1e-12) |
|
E_neg = E_s_neg + E_e_neg |
|
|
|
E = E_neg - E_pos |
|
return torch.mean(E) |
|
|