Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
from functools import lru_cache | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from models.utils import allgather_wgrad | |
logger = logging.getLogger(__name__) | |
def get_sim( | |
vision_proj: torch.Tensor, | |
text_proj: torch.Tensor, | |
temp=1.0, | |
agg_method="mean", | |
): | |
"""calculate pair-wise video-text similarity. | |
Args: | |
vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C]. | |
text_proj (torch.Tensor): The text representation. Shape: [B,C]. | |
temp (torch.Tensor): The temperature. Shape: []. | |
Returns: The similarity between video and text. Shape: [B,B]. | |
""" | |
vision_proj = F.normalize(vision_proj, dim=-1) | |
text_proj = F.normalize(text_proj, dim=-1) | |
if vision_proj.ndim == 3: | |
sim_v2t = torch.einsum("mld,nd->mln", vision_proj, text_proj) / temp # [B, L, B] | |
sim_t2v = torch.einsum("nd,mld->nlm", text_proj, vision_proj) / temp # [B, L, B] | |
if agg_method == "mean": | |
sim_v2t = sim_v2t.mean(1) | |
sim_t2v = sim_t2v.mean(1) | |
elif agg_method == "max": | |
sim_v2t = sim_v2t.max(1)[0] | |
sim_t2v = sim_t2v.max(1)[0] | |
elif text_proj.ndim == 3: | |
sim_v2t = torch.einsum("nd,mld->nlm", vision_proj, text_proj) / temp # [B, L, B] | |
sim_t2v = torch.einsum("nld,md->nlm", text_proj, vision_proj) / temp # [B, L, B] | |
if agg_method == "mean": | |
sim_v2t = sim_v2t.mean(1) | |
sim_t2v = sim_t2v.mean(1) | |
elif agg_method == "max": | |
sim_v2t = sim_v2t.max(1)[0] | |
sim_t2v = sim_t2v.max(1)[0] | |
else: | |
sim_v2t = vision_proj @ text_proj.T / temp | |
sim_t2v = sim_v2t.T | |
return sim_v2t, sim_t2v | |
class VTC_VTM_Loss(nn.Module): | |
"""video-text contrastive and matching losses.""" | |
def __init__(self, vtm_hard_neg): | |
super().__init__() | |
self.vtm_hard_neg = vtm_hard_neg | |
def vtc_loss( | |
self, | |
vision_proj: torch.Tensor, | |
text_proj: torch.Tensor, | |
idx: torch.Tensor, | |
temp=1.0, | |
all_gather=True, | |
agg_method="mean", | |
): | |
"""forward to calculate the loss | |
Args: | |
vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C]. | |
text_proj (torch.Tensor): The text representation. Shape: [B,C]. | |
idx (torch.Tensor): The index for each example. Shape: [B,]. | |
temp (torch.Tensor): The temperature. Shape: []. | |
all_gather (bool): If true, will gather samples across all the GPUs and calculate loss across the gathered samples. | |
Returns: loss_vtc (torch.Tensor): The video-text contrastive loss. Shape: []. | |
""" | |
if all_gather: | |
gather_args = self.get_gather_args() | |
vision_proj = allgather_wgrad(vision_proj, gather_args) | |
text_proj = allgather_wgrad(text_proj, gather_args) | |
if idx is not None: | |
idx = allgather_wgrad(idx, gather_args) | |
sim_v2t, sim_t2v = get_sim(vision_proj, text_proj, temp, agg_method=agg_method) | |
with torch.no_grad(): | |
sim_v2t_targets = self.get_mask(sim_v2t, idx=idx, normalize=True) | |
sim_t2v_targets = sim_v2t_targets | |
loss_i2t = -torch.sum(F.log_softmax(sim_v2t, dim=1) * sim_v2t_targets, dim=1).mean() | |
loss_t2i = -torch.sum(F.log_softmax(sim_t2v, dim=1) * sim_t2v_targets, dim=1).mean() | |
loss_vtc = (loss_i2t + loss_t2i) / 2 | |
return loss_vtc | |
def vtm_loss( | |
self, | |
multimodal_encoder, | |
vtm_head: nn.Module, | |
temp, | |
vision_embeds: torch.Tensor, | |
text_embeds: torch.Tensor, | |
vision_proj: torch.Tensor, | |
text_proj: torch.Tensor, | |
text_atts: torch.Tensor, | |
idx: torch.Tensor, | |
): | |
"""video-text matching loss. | |
Args: | |
multinomial_encoder (nn.Module): The multimodal_encoder. | |
vtm_head (nn.Module): The head to produce the video-text matching score. | |
temp (torch.Tensor): temporature for similarity calculation. | |
vision_embeds (torch.Tensor): The features of all patches in the video. Shape: [B,T,L,C]. | |
text_embeds (torch.Tensor): The features of all tokens in the text. Shape: [B,L,C]. | |
vision_proj (torch.Tensor): The vision representation. Shape: [B,T,C]. | |
text_proj (torch.Tensor): The text representation. Shape: [B,C]. | |
text_atts (torch.Tensor): The padded mask for text tokens. 0 is padded. Shape: [B,L]. | |
idx (torch.Tensor): The index for each example. Shape: [B,]. | |
Returns: TODO | |
""" | |
with torch.no_grad(): | |
sim_v2t, sim_t2v = get_sim(vision_proj, text_proj, temp) | |
vision_atts = torch.ones( | |
vision_embeds.size()[:-1], dtype=torch.long, device=vision_embeds.device | |
) | |
weights_v2t = F.softmax(sim_v2t + 1e-4, dim=1) # (N, N) | |
weights_t2v = F.softmax(sim_t2v + 1e-4, dim=1) | |
mask = self.get_mask(sim_v2t, idx=idx).bool() | |
weights_v2t.masked_fill_(mask, 0) | |
weights_t2v.masked_fill_(mask, 0) | |
weights_v2t = torch.nan_to_num_(weights_v2t, nan=1e-2, posinf=1e-2, neginf=1e-2) | |
weights_t2v = torch.nan_to_num_(weights_t2v, nan=1e-2, posinf=1e-2, neginf=1e-2) | |
# select a negative image for each text | |
if self.vtm_hard_neg: | |
vision_neg_indices = torch.multinomial(weights_t2v, 1).squeeze() # NOTE bs != 1 | |
txt_neg_indices = torch.multinomial(weights_v2t, 1).squeeze() | |
else: | |
vision_neg_indices = self.get_rand_indices(mask, 1).squeeze() | |
txt_neg_indices = self.get_rand_indices(mask, 1).squeeze() | |
vision_embeds_neg = vision_embeds[vision_neg_indices] # [B, T*L, c] | |
text_embeds_neg = text_embeds[txt_neg_indices] # [B, L, d] | |
text_atts_neg = text_atts[txt_neg_indices] | |
# concat embeddings | |
vision_embeds_all = torch.cat([vision_embeds, vision_embeds_neg, vision_embeds], dim=0) | |
text_embeds_all = torch.cat([text_embeds, text_embeds, text_embeds_neg], dim=0) | |
vision_atts_all = torch.cat([vision_atts, vision_atts, vision_atts], dim=0) | |
text_atts_all = torch.cat([text_atts, text_atts, text_atts_neg], dim=0) | |
output = multimodal_encoder( | |
encoder_embeds=text_embeds_all, | |
attention_mask=text_atts_all, | |
encoder_hidden_states=vision_embeds_all, | |
encoder_attention_mask=vision_atts_all, | |
return_dict=True, | |
mode="fusion", | |
) | |
vtm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d) | |
vtm_logits = vtm_head(vtm_embeds) # [3*B, 2] | |
bs = vtm_logits.shape[0] // 3 | |
vtm_labels = vtm_logits.new_ones(3 * bs, dtype=torch.long) | |
vtm_labels[bs:] = 0 | |
loss_vtm = F.cross_entropy(vtm_logits, vtm_labels) | |
return loss_vtm | |
def get_rand_indices(self, mask, k): | |
"""get rand indices according to mask. | |
Args: | |
mask (torch.Tensor): Shape: (N, L) 0 indicates the positions that we can sample, 1 otherwise | |
k (int): the number indices to sample at each row. | |
Returns: | |
The sampled indices. Shape: [N,k]. | |
(N, k) indices | |
""" | |
mask = mask.float() | |
mask = mask - 10000 * mask | |
mask += torch.randn_like(mask) | |
_, indices = torch.sort(mask, dim=1, descending=True) | |
indices = indices[:, :k].contiguous() | |
return indices | |
def get_mask(self, sim, idx=None, normalize=False): | |
""" | |
Args: | |
sim (torch.Tensor): The similarity between videos and texts. shape: (B, B). | |
idx (torch.Tensor): The index for each video. Shape: [B]. | |
normalize (bool): If true, make row sum equal to 1 | |
""" | |
if idx is not None: | |
idx = idx.view(-1, 1) | |
mask = torch.eq(idx, idx.T).to(sim.dtype) | |
if normalize: | |
mask = mask / mask.sum(1, keepdim=True) | |
else: | |
mask = torch.zeros_like(sim) | |
mask.fill_diagonal_(1) | |
return mask # `1` mark valid/matched location | |
def get_gather_args(self): | |
"""obtain the args for all_gather | |
Returns: dict. | |
""" | |
from utils.distributed import get_rank, get_world_size | |
from utils.easydict import EasyDict | |
return EasyDict({"world_size": get_world_size(), "rank": get_rank()}) | |
class MLMLoss(nn.Module): | |
"""masked language modeling loss.""" | |
def __init__(self, masking_prob, tokenizer): | |
super(MLMLoss, self).__init__() | |
self.tokenizer = tokenizer | |
self.masking_prob = masking_prob | |
def mlm_loss( | |
self, | |
text_encoder, | |
text, | |
vision_embeds, | |
vision_atts, | |
): | |
input_ids = text.input_ids.clone() | |
labels = input_ids.clone() | |
probability_matrix = torch.full(labels.shape, self.masking_prob) | |
input_ids, labels = self.mask( | |
input_ids, | |
text_encoder.config.vocab_size, | |
input_ids.device, | |
targets=labels, | |
probability_matrix=probability_matrix, | |
) | |
intermediate_mlm_output = text_encoder.bert( | |
input_ids, | |
attention_mask=text.attention_mask, | |
encoder_hidden_states=vision_embeds, | |
encoder_attention_mask=vision_atts, | |
return_dict=True, | |
mode="text", | |
) | |
text_embeds = intermediate_mlm_output.last_hidden_state | |
mlm_output = text_encoder( | |
encoder_embeds=text_embeds, | |
attention_mask=text.attention_mask, | |
encoder_hidden_states=vision_embeds, | |
encoder_attention_mask=vision_atts, | |
return_dict=True, | |
labels=labels, | |
soft_labels=None, | |
mode="fusion", | |
) | |
return mlm_output.loss | |
def simple_mlm_loss( | |
self, | |
text_encoder, | |
text, | |
text_embeds, | |
vision_embeds, | |
vision_atts, | |
labels | |
): | |
mlm_output = text_encoder( | |
encoder_embeds=text_embeds, | |
attention_mask=text.attention_mask, | |
encoder_hidden_states=vision_embeds, | |
encoder_attention_mask=vision_atts, | |
return_dict=True, | |
labels=labels, | |
soft_labels=None, | |
mode="fusion", | |
) | |
return mlm_output.loss | |
def mask( | |
self, | |
input_ids, | |
vocab_size, | |
device, | |
targets=None, | |
masked_indices=None, | |
probability_matrix=None, | |
): | |
if masked_indices is None: | |
masked_indices = torch.bernoulli(probability_matrix).bool() | |
masked_indices[input_ids == self.tokenizer.pad_token_id] = False | |
masked_indices[input_ids == self.tokenizer.cls_token_id] = False | |
"""make deepspeed happy!""" | |
# _pad_mask = (input_ids == self.tokenizer.pad_token_id).to(masked_indices.device, non_blocking=True) # 0 | |
# # print(_pad_mask.device) | |
# masked_indices[_pad_mask] = False | |
# _cls_mask = (input_ids == self.tokenizer.cls_token_id).to(masked_indices.device, non_blocking=True) # 101 | |
# masked_indices[_cls_mask] = False | |
if targets is not None: | |
# We only compute loss on masked tokens | |
targets[~masked_indices] = -100 | |
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) | |
indices_replaced = ( | |
torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices | |
) | |
input_ids[indices_replaced] = self.tokenizer.mask_token_id | |
# 10% of the time, we replace masked input tokens with random word | |
indices_random = ( | |
torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() | |
& masked_indices | |
& ~indices_replaced | |
) | |
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device) | |
input_ids[indices_random] = random_words[indices_random] | |
# The rest of the time (10% of the time) we keep the masked input tokens unchanged | |
if targets is not None: | |
return input_ids, targets | |
else: | |
return input_ids | |
class UTA_Loss(nn.Module): | |
"""mask align clip loss.""" | |
def __init__(self, uta_norm_type='l2', uta_loss_type='l2'): | |
super().__init__() | |
self.norm_type = uta_norm_type | |
self.loss_type = uta_loss_type | |
logger.info(f'Norm type: {uta_norm_type}') | |
logger.info(f'Loss type: {uta_loss_type}') | |
if uta_loss_type == 'mse': | |
self.loss_func = nn.MSELoss() | |
elif uta_loss_type == 'smooth_l1': | |
self.loss_func = nn.SmoothL1Loss() | |
def uta_loss(self, student_output, clip_output): | |
"""forward to calculate the loss | |
Args: | |
student_output (torch.Tensor): The student output. Shape: [K,B,N,C]. | |
clip_output (torch.Tensor): The teacher representation. Shape: [K,B,N,C]. | |
Returns: loss_uta (torch.Tensor): The mask clip alignment loss. Shape: []. | |
""" | |
if self.norm_type == 'l2': | |
student_output = student_output / student_output.norm(dim=-1, keepdim=True) | |
clip_output = clip_output / clip_output.norm(dim=-1, keepdim=True) | |
elif self.norm_type == 'none': | |
pass | |
else: | |
raise NotImplementedError | |
if self.loss_type == 'l2': | |
loss_uta = (2 - 2 * (student_output * clip_output).sum(dim=-1)).mean() | |
elif self.loss_type in ['mse', 'smooth_l1']: | |
loss_uta = self.loss_func(input=student_output, target=clip_output) | |
else: | |
raise NotImplementedError | |
return loss_uta | |
def uta_vision_loss(self, student_v_output, clip_v_output): | |
"""forward to calculate the loss | |
Args: | |
student_v_output (torch.Tensor): The student output. Shape: [B,T,C]. | |
clip_v_output (torch.Tensor): The teacher representation. Shape: [B,T,C]. | |
Returns: loss_uta (torch.Tensor): The mask clip alignment loss. Shape: []. | |
""" | |
if student_v_output.shape[1] != clip_v_output.shape[1]: | |
student_v_output = student_v_output.mean(1, keepdim=True) | |
clip_v_output = clip_v_output.mean(1, keepdim=True) | |
if self.norm_type == 'l2': | |
student_v_output = student_v_output / student_v_output.norm(dim=-1, keepdim=True) | |
clip_v_output = clip_v_output / clip_v_output.norm(dim=-1, keepdim=True) | |
elif self.norm_type == 'none': | |
pass | |
else: | |
raise NotImplementedError | |
if self.loss_type == 'l2': | |
loss_uta = (2 - 2 * (student_v_output * clip_v_output).sum(dim=-1)).mean() | |
elif self.loss_type in ['mse', 'smooth_l1']: | |
loss_uta = self.loss_func(input=student_v_output, target=clip_v_output) | |
else: | |
raise NotImplementedError | |
return loss_uta | |
def uta_all_loss( | |
self, | |
student_v_output, clip_v_output, | |
student_t_output, clip_t_output, | |
): | |
"""forward to calculate the loss | |
Args: | |
student_v_output (torch.Tensor): The student output. Shape: [B,T,C]. | |
clip_v_output (torch.Tensor): The teacher representation. Shape: [B,T,C]. | |
student_t_output (torch.Tensor): The student output. Shape: [B,1,C]. | |
clip_t_output (torch.Tensor): The teacher representation. Shape: [B,1,C]. | |
Returns: loss_uta (torch.Tensor): The mask clip alignment loss. Shape: []. | |
""" | |
if student_v_output.shape[1] != clip_v_output.shape[1]: | |
student_v_output = student_v_output.mean(1, keepdim=True) | |
clip_v_output = clip_v_output.mean(1, keepdim=True) | |
if self.norm_type == 'l2': | |
student_v_output = student_v_output / student_v_output.norm(dim=-1, keepdim=True) | |
clip_v_output = clip_v_output / clip_v_output.norm(dim=-1, keepdim=True) | |
student_t_output = student_t_output / student_t_output.norm(dim=-1, keepdim=True) | |
clip_t_output = clip_t_output / clip_t_output.norm(dim=-1, keepdim=True) | |
elif self.norm_type == 'none': | |
pass | |
else: | |
raise NotImplementedError | |
if self.loss_type == 'l2': | |
loss_uta_v = (2 - 2 * (student_v_output * clip_v_output).sum(dim=-1)).mean() | |
loss_uta_t = (2 - 2 * (student_t_output * clip_t_output).sum(dim=-1)).mean() | |
elif self.loss_type in ['mse', 'smooth_l1']: | |
loss_uta_v = self.loss_func(input=student_v_output, target=clip_v_output) | |
loss_uta_t = self.loss_func(input=student_t_output, target=clip_t_output) | |
else: | |
raise NotImplementedError | |
return (loss_uta_v + loss_uta_t) / 2. | |
class new_UTA_Loss(nn.Module): | |
"""mask align clip loss.""" | |
def __init__(self, distill_final_features=True, clip_loss_ratio=[1., 1.]): | |
super().__init__() | |
self.distill_final_features = distill_final_features | |
self.clip_loss_ratio = clip_loss_ratio | |
logger.info(f'distill_final_features: {distill_final_features}') | |
logger.info(f'clip_loss_ratio: {clip_loss_ratio}') | |
def uta_loss(self, student_output, student_output_final, | |
targets_clip_middle_vis, targets_clip_final_vis): | |
"""forward to calculate the loss | |
Args: | |
student_output (torch.Tensor): The student output. Shape: [K,B,N,C]. | |
clip_output (torch.Tensor): The teacher representation. Shape: [K,B,N,C]. | |
Returns: loss_uta (torch.Tensor): The mask clip alignment loss. Shape: []. | |
""" | |
loss_clip_middle = (2 - 2 * (student_output * targets_clip_middle_vis).sum(dim=-1)).mean() | |
if self.distill_final_features and self.clip_loss_ratio[1] > 0: | |
loss_clip_final = (2 - 2 * (student_output_final * targets_clip_final_vis).sum(dim=-1)).mean() | |
else: | |
loss_clip_final = torch.zeros(1).type_as(loss_clip_middle).to(loss_clip_middle.device) | |
loss_uta = loss_clip_middle * self.clip_loss_ratio[0] + loss_clip_final * self.clip_loss_ratio[1] | |
return loss_uta | |