mazpie's picture
Initial commit
2d9a728
raw
history blame
18.4 kB
import logging
from functools import lru_cache
import torch
import torch.nn.functional as F
from torch import nn
from models.utils import allgather_wgrad
logger = logging.getLogger(__name__)
def get_sim(
vision_proj: torch.Tensor,
text_proj: torch.Tensor,
temp=1.0,
agg_method="mean",
):
"""calculate pair-wise video-text similarity.
Args:
vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C].
text_proj (torch.Tensor): The text representation. Shape: [B,C].
temp (torch.Tensor): The temperature. Shape: [].
Returns: The similarity between video and text. Shape: [B,B].
"""
vision_proj = F.normalize(vision_proj, dim=-1)
text_proj = F.normalize(text_proj, dim=-1)
if vision_proj.ndim == 3:
sim_v2t = torch.einsum("mld,nd->mln", vision_proj, text_proj) / temp # [B, L, B]
sim_t2v = torch.einsum("nd,mld->nlm", text_proj, vision_proj) / temp # [B, L, B]
if agg_method == "mean":
sim_v2t = sim_v2t.mean(1)
sim_t2v = sim_t2v.mean(1)
elif agg_method == "max":
sim_v2t = sim_v2t.max(1)[0]
sim_t2v = sim_t2v.max(1)[0]
elif text_proj.ndim == 3:
sim_v2t = torch.einsum("nd,mld->nlm", vision_proj, text_proj) / temp # [B, L, B]
sim_t2v = torch.einsum("nld,md->nlm", text_proj, vision_proj) / temp # [B, L, B]
if agg_method == "mean":
sim_v2t = sim_v2t.mean(1)
sim_t2v = sim_t2v.mean(1)
elif agg_method == "max":
sim_v2t = sim_v2t.max(1)[0]
sim_t2v = sim_t2v.max(1)[0]
else:
sim_v2t = vision_proj @ text_proj.T / temp
sim_t2v = sim_v2t.T
return sim_v2t, sim_t2v
class VTC_VTM_Loss(nn.Module):
"""video-text contrastive and matching losses."""
def __init__(self, vtm_hard_neg):
super().__init__()
self.vtm_hard_neg = vtm_hard_neg
def vtc_loss(
self,
vision_proj: torch.Tensor,
text_proj: torch.Tensor,
idx: torch.Tensor,
temp=1.0,
all_gather=True,
agg_method="mean",
):
"""forward to calculate the loss
Args:
vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C].
text_proj (torch.Tensor): The text representation. Shape: [B,C].
idx (torch.Tensor): The index for each example. Shape: [B,].
temp (torch.Tensor): The temperature. Shape: [].
all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples.
Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: [].
"""
if all_gather:
gather_args = self.get_gather_args()
vision_proj = allgather_wgrad(vision_proj, gather_args)
text_proj = allgather_wgrad(text_proj, gather_args)
if idx is not None:
idx = allgather_wgrad(idx, gather_args)
sim_v2t, sim_t2v = get_sim(vision_proj, text_proj, temp, agg_method=agg_method)
with torch.no_grad():
sim_v2t_targets = self.get_mask(sim_v2t, idx=idx, normalize=True)
sim_t2v_targets = sim_v2t_targets
loss_i2t = -torch.sum(F.log_softmax(sim_v2t, dim=1) * sim_v2t_targets, dim=1).mean()
loss_t2i = -torch.sum(F.log_softmax(sim_t2v, dim=1) * sim_t2v_targets, dim=1).mean()
loss_vtc = (loss_i2t + loss_t2i) / 2
return loss_vtc
def vtm_loss(
self,
multimodal_encoder,
vtm_head: nn.Module,
temp,
vision_embeds: torch.Tensor,
text_embeds: torch.Tensor,
vision_proj: torch.Tensor,
text_proj: torch.Tensor,
text_atts: torch.Tensor,
idx: torch.Tensor,
):
"""video-text matching loss.
Args:
multinomial_encoder (nn.Module): The multimodal_encoder.
vtm_head (nn.Module): The head to produce the video-text matching score.
temp (torch.Tensor): temporature for similarity calculation.
vision_embeds (torch.Tensor): The features of all patches in the video. Shape: [B,T,L,C].
text_embeds (torch.Tensor): The features of all tokens in the text. Shape: [B,L,C].
vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C].
text_proj (torch.Tensor): The text representation. Shape: [B,C].
text_atts (torch.Tensor): The padded mask for text tokens. 0 is padded. Shape: [B,L].
idx (torch.Tensor): The index for each example. Shape: [B,].
Returns: TODO
"""
with torch.no_grad():
sim_v2t, sim_t2v = get_sim(vision_proj, text_proj, temp)
vision_atts = torch.ones(
vision_embeds.size()[:-1], dtype=torch.long, device=vision_embeds.device
)
weights_v2t = F.softmax(sim_v2t + 1e-4, dim=1) # (N, N)
weights_t2v = F.softmax(sim_t2v + 1e-4, dim=1)
mask = self.get_mask(sim_v2t, idx=idx).bool()
weights_v2t.masked_fill_(mask, 0)
weights_t2v.masked_fill_(mask, 0)
weights_v2t = torch.nan_to_num_(weights_v2t, nan=1e-2, posinf=1e-2, neginf=1e-2)
weights_t2v = torch.nan_to_num_(weights_t2v, nan=1e-2, posinf=1e-2, neginf=1e-2)
# select a negative image for each text
if self.vtm_hard_neg:
vision_neg_indices = torch.multinomial(weights_t2v, 1).squeeze() # NOTE bs != 1
txt_neg_indices = torch.multinomial(weights_v2t, 1).squeeze()
else:
vision_neg_indices = self.get_rand_indices(mask, 1).squeeze()
txt_neg_indices = self.get_rand_indices(mask, 1).squeeze()
vision_embeds_neg = vision_embeds[vision_neg_indices] # [B, T*L, c]
text_embeds_neg = text_embeds[txt_neg_indices] # [B, L, d]
text_atts_neg = text_atts[txt_neg_indices]
# concat embeddings
vision_embeds_all = torch.cat([vision_embeds, vision_embeds_neg, vision_embeds], dim=0)
text_embeds_all = torch.cat([text_embeds, text_embeds, text_embeds_neg], dim=0)
vision_atts_all = torch.cat([vision_atts, vision_atts, vision_atts], dim=0)
text_atts_all = torch.cat([text_atts, text_atts, text_atts_neg], dim=0)
output = multimodal_encoder(
encoder_embeds=text_embeds_all,
attention_mask=text_atts_all,
encoder_hidden_states=vision_embeds_all,
encoder_attention_mask=vision_atts_all,
return_dict=True,
mode="fusion",
)
vtm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d)
vtm_logits = vtm_head(vtm_embeds) # [3*B, 2]
bs = vtm_logits.shape[0] // 3
vtm_labels = vtm_logits.new_ones(3 * bs, dtype=torch.long)
vtm_labels[bs:] = 0
loss_vtm = F.cross_entropy(vtm_logits, vtm_labels)
return loss_vtm
def get_rand_indices(self, mask, k):
"""get rand indices according to mask.
Args:
mask (torch.Tensor): Shape: (N, L) 0 indicates the positions that we can sample, 1 otherwise
k (int): the number indices to sample at each row.
Returns:
The sampled indices. Shape: [N,k].
(N, k) indices
"""
mask = mask.float()
mask = mask - 10000 * mask
mask += torch.randn_like(mask)
_, indices = torch.sort(mask, dim=1, descending=True)
indices = indices[:, :k].contiguous()
return indices
@torch.no_grad()
def get_mask(self, sim, idx=None, normalize=False):
"""
Args:
sim (torch.Tensor): The similarity between videos and texts. shape: (B, B).
idx (torch.Tensor): The index for each video. Shape: [B].
normalize (bool): If true, make row sum equal to 1
"""
if idx is not None:
idx = idx.view(-1, 1)
mask = torch.eq(idx, idx.T).to(sim.dtype)
if normalize:
mask = mask / mask.sum(1, keepdim=True)
else:
mask = torch.zeros_like(sim)
mask.fill_diagonal_(1)
return mask # `1` mark valid/matched location
@lru_cache(maxsize=16)
def get_gather_args(self):
"""obtain the args for all_gather
Returns: dict.
"""
from utils.distributed import get_rank, get_world_size
from utils.easydict import EasyDict
return EasyDict({"world_size": get_world_size(), "rank": get_rank()})
class MLMLoss(nn.Module):
"""masked language modeling loss."""
def __init__(self, masking_prob, tokenizer):
super(MLMLoss, self).__init__()
self.tokenizer = tokenizer
self.masking_prob = masking_prob
def mlm_loss(
self,
text_encoder,
text,
vision_embeds,
vision_atts,
):
input_ids = text.input_ids.clone()
labels = input_ids.clone()
probability_matrix = torch.full(labels.shape, self.masking_prob)
input_ids, labels = self.mask(
input_ids,
text_encoder.config.vocab_size,
input_ids.device,
targets=labels,
probability_matrix=probability_matrix,
)
intermediate_mlm_output = text_encoder.bert(
input_ids,
attention_mask=text.attention_mask,
encoder_hidden_states=vision_embeds,
encoder_attention_mask=vision_atts,
return_dict=True,
mode="text",
)
text_embeds = intermediate_mlm_output.last_hidden_state
mlm_output = text_encoder(
encoder_embeds=text_embeds,
attention_mask=text.attention_mask,
encoder_hidden_states=vision_embeds,
encoder_attention_mask=vision_atts,
return_dict=True,
labels=labels,
soft_labels=None,
mode="fusion",
)
return mlm_output.loss
def simple_mlm_loss(
self,
text_encoder,
text,
text_embeds,
vision_embeds,
vision_atts,
labels
):
mlm_output = text_encoder(
encoder_embeds=text_embeds,
attention_mask=text.attention_mask,
encoder_hidden_states=vision_embeds,
encoder_attention_mask=vision_atts,
return_dict=True,
labels=labels,
soft_labels=None,
mode="fusion",
)
return mlm_output.loss
def mask(
self,
input_ids,
vocab_size,
device,
targets=None,
masked_indices=None,
probability_matrix=None,
):
if masked_indices is None:
masked_indices = torch.bernoulli(probability_matrix).bool()
masked_indices[input_ids == self.tokenizer.pad_token_id] = False
masked_indices[input_ids == self.tokenizer.cls_token_id] = False
"""make deepspeed happy!"""
# _pad_mask = (input_ids == self.tokenizer.pad_token_id).to(masked_indices.device, non_blocking=True) # 0
# # print(_pad_mask.device)
# masked_indices[_pad_mask] = False
# _cls_mask = (input_ids == self.tokenizer.cls_token_id).to(masked_indices.device, non_blocking=True) # 101
# masked_indices[_cls_mask] = False
if targets is not None:
# We only compute loss on masked tokens
targets[~masked_indices] = -100
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = (
torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
)
input_ids[indices_replaced] = self.tokenizer.mask_token_id
# 10% of the time, we replace masked input tokens with random word
indices_random = (
torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool()
& masked_indices
& ~indices_replaced
)
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
input_ids[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
if targets is not None:
return input_ids, targets
else:
return input_ids
class UTA_Loss(nn.Module):
"""mask align clip loss."""
def __init__(self, uta_norm_type='l2', uta_loss_type='l2'):
super().__init__()
self.norm_type = uta_norm_type
self.loss_type = uta_loss_type
logger.info(f'Norm type: {uta_norm_type}')
logger.info(f'Loss type: {uta_loss_type}')
if uta_loss_type == 'mse':
self.loss_func = nn.MSELoss()
elif uta_loss_type == 'smooth_l1':
self.loss_func = nn.SmoothL1Loss()
def uta_loss(self, student_output, clip_output):
"""forward to calculate the loss
Args:
student_output (torch.Tensor): The student output. Shape: [K,B,N,C].
clip_output (torch.Tensor): The teacher representation. Shape: [K,B,N,C].
Returns: loss_uta (torch.Tensor): The mask clip alignment loss. Shape: [].
"""
if self.norm_type == 'l2':
student_output = student_output / student_output.norm(dim=-1, keepdim=True)
clip_output = clip_output / clip_output.norm(dim=-1, keepdim=True)
elif self.norm_type == 'none':
pass
else:
raise NotImplementedError
if self.loss_type == 'l2':
loss_uta = (2 - 2 * (student_output * clip_output).sum(dim=-1)).mean()
elif self.loss_type in ['mse', 'smooth_l1']:
loss_uta = self.loss_func(input=student_output, target=clip_output)
else:
raise NotImplementedError
return loss_uta
def uta_vision_loss(self, student_v_output, clip_v_output):
"""forward to calculate the loss
Args:
student_v_output (torch.Tensor): The student output. Shape: [B,T,C].
clip_v_output (torch.Tensor): The teacher representation. Shape: [B,T,C].
Returns: loss_uta (torch.Tensor): The mask clip alignment loss. Shape: [].
"""
if student_v_output.shape[1] != clip_v_output.shape[1]:
student_v_output = student_v_output.mean(1, keepdim=True)
clip_v_output = clip_v_output.mean(1, keepdim=True)
if self.norm_type == 'l2':
student_v_output = student_v_output / student_v_output.norm(dim=-1, keepdim=True)
clip_v_output = clip_v_output / clip_v_output.norm(dim=-1, keepdim=True)
elif self.norm_type == 'none':
pass
else:
raise NotImplementedError
if self.loss_type == 'l2':
loss_uta = (2 - 2 * (student_v_output * clip_v_output).sum(dim=-1)).mean()
elif self.loss_type in ['mse', 'smooth_l1']:
loss_uta = self.loss_func(input=student_v_output, target=clip_v_output)
else:
raise NotImplementedError
return loss_uta
def uta_all_loss(
self,
student_v_output, clip_v_output,
student_t_output, clip_t_output,
):
"""forward to calculate the loss
Args:
student_v_output (torch.Tensor): The student output. Shape: [B,T,C].
clip_v_output (torch.Tensor): The teacher representation. Shape: [B,T,C].
student_t_output (torch.Tensor): The student output. Shape: [B,1,C].
clip_t_output (torch.Tensor): The teacher representation. Shape: [B,1,C].
Returns: loss_uta (torch.Tensor): The mask clip alignment loss. Shape: [].
"""
if student_v_output.shape[1] != clip_v_output.shape[1]:
student_v_output = student_v_output.mean(1, keepdim=True)
clip_v_output = clip_v_output.mean(1, keepdim=True)
if self.norm_type == 'l2':
student_v_output = student_v_output / student_v_output.norm(dim=-1, keepdim=True)
clip_v_output = clip_v_output / clip_v_output.norm(dim=-1, keepdim=True)
student_t_output = student_t_output / student_t_output.norm(dim=-1, keepdim=True)
clip_t_output = clip_t_output / clip_t_output.norm(dim=-1, keepdim=True)
elif self.norm_type == 'none':
pass
else:
raise NotImplementedError
if self.loss_type == 'l2':
loss_uta_v = (2 - 2 * (student_v_output * clip_v_output).sum(dim=-1)).mean()
loss_uta_t = (2 - 2 * (student_t_output * clip_t_output).sum(dim=-1)).mean()
elif self.loss_type in ['mse', 'smooth_l1']:
loss_uta_v = self.loss_func(input=student_v_output, target=clip_v_output)
loss_uta_t = self.loss_func(input=student_t_output, target=clip_t_output)
else:
raise NotImplementedError
return (loss_uta_v + loss_uta_t) / 2.
class new_UTA_Loss(nn.Module):
"""mask align clip loss."""
def __init__(self, distill_final_features=True, clip_loss_ratio=[1., 1.]):
super().__init__()
self.distill_final_features = distill_final_features
self.clip_loss_ratio = clip_loss_ratio
logger.info(f'distill_final_features: {distill_final_features}')
logger.info(f'clip_loss_ratio: {clip_loss_ratio}')
def uta_loss(self, student_output, student_output_final,
targets_clip_middle_vis, targets_clip_final_vis):
"""forward to calculate the loss
Args:
student_output (torch.Tensor): The student output. Shape: [K,B,N,C].
clip_output (torch.Tensor): The teacher representation. Shape: [K,B,N,C].
Returns: loss_uta (torch.Tensor): The mask clip alignment loss. Shape: [].
"""
loss_clip_middle = (2 - 2 * (student_output * targets_clip_middle_vis).sum(dim=-1)).mean()
if self.distill_final_features and self.clip_loss_ratio[1] > 0:
loss_clip_final = (2 - 2 * (student_output_final * targets_clip_final_vis).sum(dim=-1)).mean()
else:
loss_clip_final = torch.zeros(1).type_as(loss_clip_middle).to(loss_clip_middle.device)
loss_uta = loss_clip_middle * self.clip_loss_ratio[0] + loss_clip_final * self.clip_loss_ratio[1]
return loss_uta