import logging from functools import lru_cache import torch import torch.nn.functional as F from torch import nn from models.utils import allgather_wgrad logger = logging.getLogger(__name__) def get_sim( vision_proj: torch.Tensor, text_proj: torch.Tensor, temp=1.0, agg_method="mean", ): """calculate pair-wise video-text similarity. Args: vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C]. text_proj (torch.Tensor): The text representation. Shape: [B,C]. temp (torch.Tensor): The temperature. Shape: []. Returns: The similarity between video and text. Shape: [B,B]. """ vision_proj = F.normalize(vision_proj, dim=-1) text_proj = F.normalize(text_proj, dim=-1) if vision_proj.ndim == 3: sim_v2t = torch.einsum("mld,nd->mln", vision_proj, text_proj) / temp # [B, L, B] sim_t2v = torch.einsum("nd,mld->nlm", text_proj, vision_proj) / temp # [B, L, B] if agg_method == "mean": sim_v2t = sim_v2t.mean(1) sim_t2v = sim_t2v.mean(1) elif agg_method == "max": sim_v2t = sim_v2t.max(1)[0] sim_t2v = sim_t2v.max(1)[0] elif text_proj.ndim == 3: sim_v2t = torch.einsum("nd,mld->nlm", vision_proj, text_proj) / temp # [B, L, B] sim_t2v = torch.einsum("nld,md->nlm", text_proj, vision_proj) / temp # [B, L, B] if agg_method == "mean": sim_v2t = sim_v2t.mean(1) sim_t2v = sim_t2v.mean(1) elif agg_method == "max": sim_v2t = sim_v2t.max(1)[0] sim_t2v = sim_t2v.max(1)[0] else: sim_v2t = vision_proj @ text_proj.T / temp sim_t2v = sim_v2t.T return sim_v2t, sim_t2v class VTC_VTM_Loss(nn.Module): """video-text contrastive and matching losses.""" def __init__(self, vtm_hard_neg): super().__init__() self.vtm_hard_neg = vtm_hard_neg def vtc_loss( self, vision_proj: torch.Tensor, text_proj: torch.Tensor, idx: torch.Tensor, temp=1.0, all_gather=True, agg_method="mean", ): """forward to calculate the loss Args: vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C]. text_proj (torch.Tensor): The text representation. Shape: [B,C]. idx (torch.Tensor): The index for each example. Shape: [B,]. temp (torch.Tensor): The temperature. Shape: []. all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples. Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: []. """ if all_gather: gather_args = self.get_gather_args() vision_proj = allgather_wgrad(vision_proj, gather_args) text_proj = allgather_wgrad(text_proj, gather_args) if idx is not None: idx = allgather_wgrad(idx, gather_args) sim_v2t, sim_t2v = get_sim(vision_proj, text_proj, temp, agg_method=agg_method) with torch.no_grad(): sim_v2t_targets = self.get_mask(sim_v2t, idx=idx, normalize=True) sim_t2v_targets = sim_v2t_targets loss_i2t = -torch.sum(F.log_softmax(sim_v2t, dim=1) * sim_v2t_targets, dim=1).mean() loss_t2i = -torch.sum(F.log_softmax(sim_t2v, dim=1) * sim_t2v_targets, dim=1).mean() loss_vtc = (loss_i2t + loss_t2i) / 2 return loss_vtc def vtm_loss( self, multimodal_encoder, vtm_head: nn.Module, temp, vision_embeds: torch.Tensor, text_embeds: torch.Tensor, vision_proj: torch.Tensor, text_proj: torch.Tensor, text_atts: torch.Tensor, idx: torch.Tensor, ): """video-text matching loss. Args: multinomial_encoder (nn.Module): The multimodal_encoder. vtm_head (nn.Module): The head to produce the video-text matching score. temp (torch.Tensor): temporature for similarity calculation. vision_embeds (torch.Tensor): The features of all patches in the video. Shape: [B,T,L,C]. text_embeds (torch.Tensor): The features of all tokens in the text. Shape: [B,L,C]. vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C]. text_proj (torch.Tensor): The text representation. Shape: [B,C]. text_atts (torch.Tensor): The padded mask for text tokens. 0 is padded. Shape: [B,L]. idx (torch.Tensor): The index for each example. Shape: [B,]. Returns: TODO """ with torch.no_grad(): sim_v2t, sim_t2v = get_sim(vision_proj, text_proj, temp) vision_atts = torch.ones( vision_embeds.size()[:-1], dtype=torch.long, device=vision_embeds.device ) weights_v2t = F.softmax(sim_v2t + 1e-4, dim=1) # (N, N) weights_t2v = F.softmax(sim_t2v + 1e-4, dim=1) mask = self.get_mask(sim_v2t, idx=idx).bool() weights_v2t.masked_fill_(mask, 0) weights_t2v.masked_fill_(mask, 0) weights_v2t = torch.nan_to_num_(weights_v2t, nan=1e-2, posinf=1e-2, neginf=1e-2) weights_t2v = torch.nan_to_num_(weights_t2v, nan=1e-2, posinf=1e-2, neginf=1e-2) # select a negative image for each text if self.vtm_hard_neg: vision_neg_indices = torch.multinomial(weights_t2v, 1).squeeze() # NOTE bs != 1 txt_neg_indices = torch.multinomial(weights_v2t, 1).squeeze() else: vision_neg_indices = self.get_rand_indices(mask, 1).squeeze() txt_neg_indices = self.get_rand_indices(mask, 1).squeeze() vision_embeds_neg = vision_embeds[vision_neg_indices] # [B, T*L, c] text_embeds_neg = text_embeds[txt_neg_indices] # [B, L, d] text_atts_neg = text_atts[txt_neg_indices] # concat embeddings vision_embeds_all = torch.cat([vision_embeds, vision_embeds_neg, vision_embeds], dim=0) text_embeds_all = torch.cat([text_embeds, text_embeds, text_embeds_neg], dim=0) vision_atts_all = torch.cat([vision_atts, vision_atts, vision_atts], dim=0) text_atts_all = torch.cat([text_atts, text_atts, text_atts_neg], dim=0) output = multimodal_encoder( encoder_embeds=text_embeds_all, attention_mask=text_atts_all, encoder_hidden_states=vision_embeds_all, encoder_attention_mask=vision_atts_all, return_dict=True, mode="fusion", ) vtm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d) vtm_logits = vtm_head(vtm_embeds) # [3*B, 2] bs = vtm_logits.shape[0] // 3 vtm_labels = vtm_logits.new_ones(3 * bs, dtype=torch.long) vtm_labels[bs:] = 0 loss_vtm = F.cross_entropy(vtm_logits, vtm_labels) return loss_vtm def get_rand_indices(self, mask, k): """get rand indices according to mask. Args: mask (torch.Tensor): Shape: (N, L) 0 indicates the positions that we can sample, 1 otherwise k (int): the number indices to sample at each row. Returns: The sampled indices. Shape: [N,k]. (N, k) indices """ mask = mask.float() mask = mask - 10000 * mask mask += torch.randn_like(mask) _, indices = torch.sort(mask, dim=1, descending=True) indices = indices[:, :k].contiguous() return indices @torch.no_grad() def get_mask(self, sim, idx=None, normalize=False): """ Args: sim (torch.Tensor): The similarity between videos and texts. shape: (B, B). idx (torch.Tensor): The index for each video. Shape: [B]. normalize (bool): If true, make row sum equal to 1 """ if idx is not None: idx = idx.view(-1, 1) mask = torch.eq(idx, idx.T).to(sim.dtype) if normalize: mask = mask / mask.sum(1, keepdim=True) else: mask = torch.zeros_like(sim) mask.fill_diagonal_(1) return mask # `1` mark valid/matched location @lru_cache(maxsize=16) def get_gather_args(self): """obtain the args for all_gather Returns: dict. """ from utils.distributed import get_rank, get_world_size from utils.easydict import EasyDict return EasyDict({"world_size": get_world_size(), "rank": get_rank()}) class MLMLoss(nn.Module): """masked language modeling loss.""" def __init__(self, masking_prob, tokenizer): super(MLMLoss, self).__init__() self.tokenizer = tokenizer self.masking_prob = masking_prob def mlm_loss( self, text_encoder, text, vision_embeds, vision_atts, ): input_ids = text.input_ids.clone() labels = input_ids.clone() probability_matrix = torch.full(labels.shape, self.masking_prob) input_ids, labels = self.mask( input_ids, text_encoder.config.vocab_size, input_ids.device, targets=labels, probability_matrix=probability_matrix, ) intermediate_mlm_output = text_encoder.bert( input_ids, attention_mask=text.attention_mask, encoder_hidden_states=vision_embeds, encoder_attention_mask=vision_atts, return_dict=True, mode="text", ) text_embeds = intermediate_mlm_output.last_hidden_state mlm_output = text_encoder( encoder_embeds=text_embeds, attention_mask=text.attention_mask, encoder_hidden_states=vision_embeds, encoder_attention_mask=vision_atts, return_dict=True, labels=labels, soft_labels=None, mode="fusion", ) return mlm_output.loss def simple_mlm_loss( self, text_encoder, text, text_embeds, vision_embeds, vision_atts, labels ): mlm_output = text_encoder( encoder_embeds=text_embeds, attention_mask=text.attention_mask, encoder_hidden_states=vision_embeds, encoder_attention_mask=vision_atts, return_dict=True, labels=labels, soft_labels=None, mode="fusion", ) return mlm_output.loss def mask( self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None, ): if masked_indices is None: masked_indices = torch.bernoulli(probability_matrix).bool() masked_indices[input_ids == self.tokenizer.pad_token_id] = False masked_indices[input_ids == self.tokenizer.cls_token_id] = False """make deepspeed happy!""" # _pad_mask = (input_ids == self.tokenizer.pad_token_id).to(masked_indices.device, non_blocking=True) # 0 # # print(_pad_mask.device) # masked_indices[_pad_mask] = False # _cls_mask = (input_ids == self.tokenizer.cls_token_id).to(masked_indices.device, non_blocking=True) # 101 # masked_indices[_cls_mask] = False if targets is not None: # We only compute loss on masked tokens targets[~masked_indices] = -100 # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) indices_replaced = ( torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices ) input_ids[indices_replaced] = self.tokenizer.mask_token_id # 10% of the time, we replace masked input tokens with random word indices_random = ( torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced ) random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device) input_ids[indices_random] = random_words[indices_random] # The rest of the time (10% of the time) we keep the masked input tokens unchanged if targets is not None: return input_ids, targets else: return input_ids class UTA_Loss(nn.Module): """mask align clip loss.""" def __init__(self, uta_norm_type='l2', uta_loss_type='l2'): super().__init__() self.norm_type = uta_norm_type self.loss_type = uta_loss_type logger.info(f'Norm type: {uta_norm_type}') logger.info(f'Loss type: {uta_loss_type}') if uta_loss_type == 'mse': self.loss_func = nn.MSELoss() elif uta_loss_type == 'smooth_l1': self.loss_func = nn.SmoothL1Loss() def uta_loss(self, student_output, clip_output): """forward to calculate the loss Args: student_output (torch.Tensor): The student output. Shape: [K,B,N,C]. clip_output (torch.Tensor): The teacher representation. Shape: [K,B,N,C]. Returns: loss_uta (torch.Tensor): The mask clip alignment loss. Shape: []. """ if self.norm_type == 'l2': student_output = student_output / student_output.norm(dim=-1, keepdim=True) clip_output = clip_output / clip_output.norm(dim=-1, keepdim=True) elif self.norm_type == 'none': pass else: raise NotImplementedError if self.loss_type == 'l2': loss_uta = (2 - 2 * (student_output * clip_output).sum(dim=-1)).mean() elif self.loss_type in ['mse', 'smooth_l1']: loss_uta = self.loss_func(input=student_output, target=clip_output) else: raise NotImplementedError return loss_uta def uta_vision_loss(self, student_v_output, clip_v_output): """forward to calculate the loss Args: student_v_output (torch.Tensor): The student output. Shape: [B,T,C]. clip_v_output (torch.Tensor): The teacher representation. Shape: [B,T,C]. Returns: loss_uta (torch.Tensor): The mask clip alignment loss. Shape: []. """ if student_v_output.shape[1] != clip_v_output.shape[1]: student_v_output = student_v_output.mean(1, keepdim=True) clip_v_output = clip_v_output.mean(1, keepdim=True) if self.norm_type == 'l2': student_v_output = student_v_output / student_v_output.norm(dim=-1, keepdim=True) clip_v_output = clip_v_output / clip_v_output.norm(dim=-1, keepdim=True) elif self.norm_type == 'none': pass else: raise NotImplementedError if self.loss_type == 'l2': loss_uta = (2 - 2 * (student_v_output * clip_v_output).sum(dim=-1)).mean() elif self.loss_type in ['mse', 'smooth_l1']: loss_uta = self.loss_func(input=student_v_output, target=clip_v_output) else: raise NotImplementedError return loss_uta def uta_all_loss( self, student_v_output, clip_v_output, student_t_output, clip_t_output, ): """forward to calculate the loss Args: student_v_output (torch.Tensor): The student output. Shape: [B,T,C]. clip_v_output (torch.Tensor): The teacher representation. Shape: [B,T,C]. student_t_output (torch.Tensor): The student output. Shape: [B,1,C]. clip_t_output (torch.Tensor): The teacher representation. Shape: [B,1,C]. Returns: loss_uta (torch.Tensor): The mask clip alignment loss. Shape: []. """ if student_v_output.shape[1] != clip_v_output.shape[1]: student_v_output = student_v_output.mean(1, keepdim=True) clip_v_output = clip_v_output.mean(1, keepdim=True) if self.norm_type == 'l2': student_v_output = student_v_output / student_v_output.norm(dim=-1, keepdim=True) clip_v_output = clip_v_output / clip_v_output.norm(dim=-1, keepdim=True) student_t_output = student_t_output / student_t_output.norm(dim=-1, keepdim=True) clip_t_output = clip_t_output / clip_t_output.norm(dim=-1, keepdim=True) elif self.norm_type == 'none': pass else: raise NotImplementedError if self.loss_type == 'l2': loss_uta_v = (2 - 2 * (student_v_output * clip_v_output).sum(dim=-1)).mean() loss_uta_t = (2 - 2 * (student_t_output * clip_t_output).sum(dim=-1)).mean() elif self.loss_type in ['mse', 'smooth_l1']: loss_uta_v = self.loss_func(input=student_v_output, target=clip_v_output) loss_uta_t = self.loss_func(input=student_t_output, target=clip_t_output) else: raise NotImplementedError return (loss_uta_v + loss_uta_t) / 2. class new_UTA_Loss(nn.Module): """mask align clip loss.""" def __init__(self, distill_final_features=True, clip_loss_ratio=[1., 1.]): super().__init__() self.distill_final_features = distill_final_features self.clip_loss_ratio = clip_loss_ratio logger.info(f'distill_final_features: {distill_final_features}') logger.info(f'clip_loss_ratio: {clip_loss_ratio}') def uta_loss(self, student_output, student_output_final, targets_clip_middle_vis, targets_clip_final_vis): """forward to calculate the loss Args: student_output (torch.Tensor): The student output. Shape: [K,B,N,C]. clip_output (torch.Tensor): The teacher representation. Shape: [K,B,N,C]. Returns: loss_uta (torch.Tensor): The mask clip alignment loss. Shape: []. """ loss_clip_middle = (2 - 2 * (student_output * targets_clip_middle_vis).sum(dim=-1)).mean() if self.distill_final_features and self.clip_loss_ratio[1] > 0: loss_clip_final = (2 - 2 * (student_output_final * targets_clip_final_vis).sum(dim=-1)).mean() else: loss_clip_final = torch.zeros(1).type_as(loss_clip_middle).to(loss_clip_middle.device) loss_uta = loss_clip_middle * self.clip_loss_ratio[0] + loss_clip_final * self.clip_loss_ratio[1] return loss_uta