|
import torch |
|
import torch.nn as nn |
|
from prettytable import PrettyTable |
|
from torch.nn.modules.activation import Tanh |
|
import copy |
|
import logging |
|
logger = logging.getLogger(__name__) |
|
from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, |
|
RobertaConfig, RobertaModel, RobertaTokenizer) |
|
def whitening_torch_final(embeddings): |
|
mu = torch.mean(embeddings, dim=0, keepdim=True) |
|
cov = torch.mm((embeddings - mu).t(), embeddings - mu) |
|
u, s, vt = torch.svd(cov) |
|
W = torch.mm(u, torch.diag(1/torch.sqrt(s))) |
|
embeddings = torch.mm(embeddings - mu, W) |
|
return embeddings |
|
|
|
class BaseModel(nn.Module): |
|
def __init__(self, ): |
|
super().__init__() |
|
|
|
def model_parameters(self): |
|
table = PrettyTable() |
|
table.field_names = ["Layer Name", "Output Shape", "Param #"] |
|
table.align["Layer Name"] = "l" |
|
table.align["Output Shape"] = "r" |
|
table.align["Param #"] = "r" |
|
for name, parameters in self.named_parameters(): |
|
if parameters.requires_grad: |
|
table.add_row([name, str(list(parameters.shape)), parameters.numel()]) |
|
return table |
|
class Model(BaseModel): |
|
def __init__(self, encoder): |
|
super(Model, self).__init__() |
|
self.encoder = encoder |
|
|
|
def forward(self, code_inputs=None, nl_inputs=None): |
|
|
|
if code_inputs is not None: |
|
outputs = self.encoder(code_inputs,attention_mask=code_inputs.ne(1))[0] |
|
outputs = (outputs*code_inputs.ne(1)[:,:,None]).sum(1)/code_inputs.ne(1).sum(-1)[:,None] |
|
return torch.nn.functional.normalize(outputs, p=2, dim=1) |
|
else: |
|
outputs = self.encoder(nl_inputs,attention_mask=nl_inputs.ne(1))[0] |
|
outputs = (outputs*nl_inputs.ne(1)[:,:,None]).sum(1)/nl_inputs.ne(1).sum(-1)[:,None] |
|
return torch.nn.functional.normalize(outputs, p=2, dim=1) |
|
|
|
|
|
class Multi_Loss_CoCoSoDa( BaseModel): |
|
|
|
def __init__(self, base_encoder, args, mlp=False): |
|
super(Multi_Loss_CoCoSoDa, self).__init__() |
|
|
|
self.K = args.moco_k |
|
self.m = args.moco_m |
|
self.T = args.moco_t |
|
dim= args.moco_dim |
|
|
|
|
|
|
|
self.code_encoder_q = base_encoder |
|
self.code_encoder_k = copy.deepcopy(base_encoder) |
|
self.nl_encoder_q = base_encoder |
|
|
|
self.nl_encoder_k = copy.deepcopy(self.nl_encoder_q) |
|
self.mlp = mlp |
|
self.time_score= args.time_score |
|
self.do_whitening = args.do_whitening |
|
self.do_ineer_loss = args.do_ineer_loss |
|
self.agg_way = args.agg_way |
|
self.args = args |
|
|
|
for param_q, param_k in zip(self.code_encoder_q.parameters(), self.code_encoder_k.parameters()): |
|
param_k.data.copy_(param_q.data) |
|
param_k.requires_grad = False |
|
|
|
for param_q, param_k in zip(self.nl_encoder_q.parameters(), self.nl_encoder_k.parameters()): |
|
param_k.data.copy_(param_q.data) |
|
param_k.requires_grad = False |
|
|
|
|
|
torch.manual_seed(3047) |
|
torch.cuda.manual_seed(3047) |
|
self.register_buffer("code_queue", torch.randn(dim,self.K )) |
|
self.code_queue = nn.functional.normalize(self.code_queue, dim=0) |
|
self.register_buffer("code_queue_ptr", torch.zeros(1, dtype=torch.long)) |
|
|
|
self.register_buffer("masked_code_queue", torch.randn(dim, self.K )) |
|
self.masked_code_queue = nn.functional.normalize(self.masked_code_queue, dim=0) |
|
self.register_buffer("masked_code_queue_ptr", torch.zeros(1, dtype=torch.long)) |
|
|
|
|
|
|
|
self.register_buffer("nl_queue", torch.randn(dim, self.K )) |
|
self.nl_queue = nn.functional.normalize(self.nl_queue, dim=0) |
|
self.register_buffer("nl_queue_ptr", torch.zeros(1, dtype=torch.long)) |
|
|
|
self.register_buffer("masked_nl_queue", torch.randn(dim, self.K )) |
|
self.masked_nl_queue= nn.functional.normalize(self.masked_nl_queue, dim=0) |
|
self.register_buffer("masked_nl_queue_ptr", torch.zeros(1, dtype=torch.long)) |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def _momentum_update_key_encoder(self): |
|
""" |
|
Momentum update of the key encoder |
|
% key encoder的Momentum update |
|
""" |
|
for param_q, param_k in zip(self.code_encoder_q.parameters(), self.code_encoder_k.parameters()): |
|
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) |
|
for param_q, param_k in zip(self.nl_encoder_q.parameters(), self.nl_encoder_k.parameters()): |
|
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) |
|
if self.mlp: |
|
for param_q, param_k in zip(self.code_encoder_q_fc.parameters(), self.code_encoder_k_fc.parameters()): |
|
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) |
|
for param_q, param_k in zip(self.nl_encoder_q_fc.parameters(), self.nl_encoder_k_fc.parameters()): |
|
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) |
|
|
|
@torch.no_grad() |
|
def _dequeue_and_enqueue(self, keys, option='code'): |
|
|
|
|
|
|
|
batch_size = keys.shape[0] |
|
if option == 'code': |
|
code_ptr = int(self.code_queue_ptr) |
|
assert self.K % batch_size == 0 |
|
|
|
|
|
try: |
|
self.code_queue[:, code_ptr:code_ptr + batch_size] = keys.T |
|
except: |
|
print(code_ptr) |
|
print(batch_size) |
|
print(keys.shape) |
|
exit(111) |
|
code_ptr = (code_ptr + batch_size) % self.K |
|
|
|
self.code_queue_ptr[0] = code_ptr |
|
|
|
elif option == 'masked_code': |
|
masked_code_ptr = int(self.masked_code_queue_ptr) |
|
assert self.K % batch_size == 0 |
|
|
|
|
|
try: |
|
self.masked_code_queue[:, masked_code_ptr:masked_code_ptr + batch_size] = keys.T |
|
except: |
|
print(masked_code_ptr) |
|
print(batch_size) |
|
print(keys.shape) |
|
exit(111) |
|
masked_code_ptr = (masked_code_ptr + batch_size) % self.K |
|
|
|
self.masked_code_queue_ptr[0] = masked_code_ptr |
|
|
|
elif option == 'nl': |
|
|
|
nl_ptr = int(self.nl_queue_ptr) |
|
assert self.K % batch_size == 0 |
|
|
|
|
|
self.nl_queue[:, nl_ptr:nl_ptr + batch_size] = keys.T |
|
nl_ptr = (nl_ptr + batch_size) % self.K |
|
|
|
self.nl_queue_ptr[0] = nl_ptr |
|
elif option == 'masked_nl': |
|
|
|
masked_nl_ptr = int(self.masked_nl_queue_ptr) |
|
assert self.K % batch_size == 0 |
|
|
|
|
|
self.masked_nl_queue[:, masked_nl_ptr:masked_nl_ptr + batch_size] = keys.T |
|
masked_nl_ptr = (masked_nl_ptr + batch_size) % self.K |
|
|
|
self.masked_nl_queue_ptr[0] = masked_nl_ptr |
|
|
|
|
|
|
|
def forward(self, source_code_q, source_code_k, nl_q,nl_k): |
|
""" |
|
Input: |
|
im_q: a batch of query images |
|
im_k: a batch of key images |
|
Output: |
|
logits, targets |
|
""" |
|
if not self.args.do_multi_lang_continue_pre_train: |
|
|
|
outputs = self.code_encoder_q(source_code_q, attention_mask=source_code_q.ne(1))[0] |
|
code_q = (outputs*source_code_q.ne(1)[:,:,None]).sum(1)/source_code_q.ne(1).sum(-1)[:,None] |
|
code_q = torch.nn.functional.normalize(code_q, p=2, dim=1) |
|
|
|
outputs= self.nl_encoder_q(nl_q, attention_mask=nl_q.ne(1))[0] |
|
nl_q = (outputs*nl_q.ne(1)[:,:,None]).sum(1)/nl_q.ne(1).sum(-1)[:,None] |
|
nl_q = torch.nn.functional.normalize(nl_q, p=2, dim=1) |
|
code2nl_logits = torch.einsum("ab,cb->ac", code_q,nl_q ) |
|
|
|
code2nl_logits /= self.T |
|
|
|
code2nl_label = torch.arange(code2nl_logits.size(0), device=code2nl_logits.device) |
|
return code2nl_logits,code2nl_label, None, None |
|
if self.agg_way == "avg": |
|
|
|
outputs = self.code_encoder_q(source_code_q, attention_mask=source_code_q.ne(1))[0] |
|
code_q = (outputs*source_code_q.ne(1)[:,:,None]).sum(1)/source_code_q.ne(1).sum(-1)[:,None] |
|
code_q = torch.nn.functional.normalize(code_q, p=2, dim=1) |
|
|
|
outputs= self.nl_encoder_q(nl_q, attention_mask=nl_q.ne(1))[0] |
|
nl_q = (outputs*nl_q.ne(1)[:,:,None]).sum(1)/nl_q.ne(1).sum(-1)[:,None] |
|
nl_q = torch.nn.functional.normalize(nl_q, p=2, dim=1) |
|
|
|
|
|
with torch.no_grad(): |
|
self._momentum_update_key_encoder() |
|
|
|
|
|
|
|
|
|
|
|
outputs = self.code_encoder_k(source_code_k, attention_mask=source_code_k.ne(1))[0] |
|
code_k = (outputs*source_code_k.ne(1)[:,:,None]).sum(1)/source_code_k.ne(1).sum(-1)[:,None] |
|
code_k = torch.nn.functional.normalize( code_k, p=2, dim=1) |
|
|
|
outputs = self.nl_encoder_k(nl_k, attention_mask=nl_k.ne(1))[0] |
|
nl_k = (outputs*nl_k.ne(1)[:,:,None]).sum(1)/nl_k.ne(1).sum(-1)[:,None] |
|
nl_k = torch.nn.functional.normalize(nl_k, p=2, dim=1) |
|
|
|
elif self.agg_way == "cls_pooler": |
|
|
|
|
|
outputs = self.code_encoder_q(source_code_q, attention_mask=source_code_q.ne(1))[1] |
|
code_q = torch.nn.functional.normalize(code_q, p=2, dim=1) |
|
|
|
outputs= self.nl_encoder_q(nl_q, attention_mask=nl_q.ne(1))[1] |
|
nl_q = torch.nn.functional.normalize(nl_q, p=2, dim=1) |
|
|
|
|
|
with torch.no_grad(): |
|
self._momentum_update_key_encoder() |
|
|
|
|
|
|
|
|
|
|
|
outputs = self.code_encoder_k(source_code_k, attention_mask=source_code_k.ne(1))[1] |
|
code_k = torch.nn.functional.normalize( code_k, p=2, dim=1) |
|
|
|
outputs = self.nl_encoder_k(nl_k, attention_mask=nl_k.ne(1))[1] |
|
nl_k = torch.nn.functional.normalize(nl_k, p=2, dim=1) |
|
|
|
elif self.agg_way == "avg_cls_pooler": |
|
|
|
outputs = self.code_encoder_q(source_code_q, attention_mask=source_code_q.ne(1)) |
|
code_q_cls = outputs[1] |
|
outputs = outputs[0] |
|
code_q_avg = (outputs*source_code_q.ne(1)[:,:,None]).sum(1)/source_code_q.ne(1).sum(-1)[:,None] |
|
code_q = code_q_cls + code_q_avg |
|
code_q = torch.nn.functional.normalize(code_q, p=2, dim=1) |
|
|
|
outputs= self.nl_encoder_q(nl_q, attention_mask=nl_q.ne(1)) |
|
nl_q_cls = outputs[1] |
|
outputs= outputs[0] |
|
nl_q_avg = (outputs*nl_q.ne(1)[:,:,None]).sum(1)/nl_q.ne(1).sum(-1)[:,None] |
|
nl_q = nl_q_avg + nl_q_cls |
|
nl_q = torch.nn.functional.normalize(nl_q, p=2, dim=1) |
|
|
|
|
|
with torch.no_grad(): |
|
self._momentum_update_key_encoder() |
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs = self.code_encoder_k(source_code_k, attention_mask=source_code_k.ne(1)) |
|
code_k_cls = outputs[1] |
|
outputs = outputs[0] |
|
code_k_avg = (outputs*source_code_k.ne(1)[:,:,None]).sum(1)/source_code_k.ne(1).sum(-1)[:,None] |
|
code_k = code_k_cls + code_k_avg |
|
code_k = torch.nn.functional.normalize( code_k, p=2, dim=1) |
|
|
|
outputs = self.nl_encoder_k(nl_k, attention_mask=nl_k.ne(1)) |
|
nl_k_cls = outputs[1] |
|
outputs = outputs[0] |
|
nl_k_avg = (outputs*nl_k.ne(1)[:,:,None]).sum(1)/nl_k.ne(1).sum(-1)[:,None] |
|
nl_k = nl_k_cls + nl_k_avg |
|
nl_k = torch.nn.functional.normalize(nl_k, p=2, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code2nl_pos = torch.einsum('nc,bc->nb', [code_q, nl_q]) |
|
|
|
code2nl_neg = torch.einsum('nc,ck->nk', [code_q, self.nl_queue.clone().detach()]) |
|
|
|
code2nl_logits = torch.cat([self.time_score*code2nl_pos, code2nl_neg], dim=1) |
|
|
|
code2nl_logits /= self.T |
|
|
|
code2nl_label = torch.arange(code2nl_logits.size(0), device=code2nl_logits.device) |
|
|
|
|
|
code2maskednl_pos = torch.einsum('nc,bc->nb', [code_q, nl_k]) |
|
|
|
code2maskednl_neg = torch.einsum('nc,ck->nk', [code_q, self.masked_nl_queue.clone().detach()]) |
|
|
|
code2maskednl_logits = torch.cat([self.time_score*code2maskednl_pos, code2maskednl_neg], dim=1) |
|
|
|
code2maskednl_logits /= self.T |
|
|
|
code2maskednl_label = torch.arange(code2maskednl_logits.size(0), device=code2maskednl_logits.device) |
|
|
|
|
|
|
|
nl2code_pos = torch.einsum('nc,bc->nb', [nl_q, code_q]) |
|
|
|
nl2code_neg = torch.einsum('nc,ck->nk', [nl_q, self.code_queue.clone().detach()]) |
|
|
|
nl2code_logits = torch.cat([self.time_score*nl2code_pos, nl2code_neg], dim=1) |
|
|
|
nl2code_logits /= self.T |
|
|
|
nl2code_label = torch.arange(nl2code_logits.size(0), device=nl2code_logits.device) |
|
|
|
|
|
|
|
nl2maskedcode_pos = torch.einsum('nc,bc->nb', [nl_q, code_k]) |
|
|
|
nl2maskedcode_neg = torch.einsum('nc,ck->nk', [nl_q, self.masked_code_queue.clone().detach()]) |
|
|
|
nl2maskedcode_logits = torch.cat([self.time_score*nl2maskedcode_pos, nl2maskedcode_neg], dim=1) |
|
|
|
nl2maskedcode_logits /= self.T |
|
|
|
nl2maskedcode_label = torch.arange(nl2maskedcode_logits.size(0), device=nl2maskedcode_logits.device) |
|
|
|
|
|
inter_logits = torch.cat((code2nl_logits, code2maskednl_logits, nl2code_logits ,nl2maskedcode_logits ), dim=0) |
|
|
|
|
|
|
|
inter_labels = torch.cat((code2nl_label, code2maskednl_label, nl2code_label, nl2maskedcode_label), dim=0) |
|
|
|
if self.do_ineer_loss: |
|
|
|
|
|
code2maskedcode_pos = torch.einsum('nc,bc->nb', [code_q, code_k]) |
|
|
|
code2maskedcode_neg = torch.einsum('nc,ck->nk', [code_q, self.masked_code_queue.clone().detach()]) |
|
|
|
code2maskedcode_logits = torch.cat([self.time_score*code2maskedcode_pos, code2maskedcode_neg], dim=1) |
|
|
|
code2maskedcode_logits /= self.T |
|
|
|
code2maskedcode_label = torch.arange(code2maskedcode_logits.size(0), device=code2maskedcode_logits.device) |
|
|
|
|
|
|
|
|
|
nl2maskednl_pos = torch.einsum('nc,bc->nb', [nl_q, nl_k]) |
|
|
|
nl2maskednl_neg = torch.einsum('nc,ck->nk', [nl_q, self.masked_nl_queue.clone().detach()]) |
|
|
|
nl2maskednl_logits = torch.cat([self.time_score*nl2maskednl_pos, nl2maskednl_neg], dim=1) |
|
|
|
nl2maskednl_logits /= self.T |
|
|
|
nl2maskednl_label = torch.arange(nl2maskednl_logits.size(0), device=nl2maskednl_logits.device) |
|
|
|
|
|
|
|
inter_logits = torch.cat((inter_logits, code2maskedcode_logits, nl2maskednl_logits), dim=0) |
|
|
|
|
|
|
|
inter_labels = torch.cat(( inter_labels, code2maskedcode_label, nl2maskednl_label ), dim=0) |
|
|
|
|
|
|
|
self._dequeue_and_enqueue(code_q, option='code') |
|
self._dequeue_and_enqueue(nl_q, option='nl') |
|
self._dequeue_and_enqueue(code_k, option='masked_code') |
|
self._dequeue_and_enqueue(nl_k, option='masked_nl') |
|
|
|
return inter_logits, inter_labels, code_q, nl_q |
|
|
|
|