cocosoda_ruby / model.py
SalazarPevelll
model
51c57f8
raw
history blame
19.2 kB
import torch
import torch.nn as nn
from prettytable import PrettyTable
from torch.nn.modules.activation import Tanh
import copy
import logging
logger = logging.getLogger(__name__)
from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
RobertaConfig, RobertaModel, RobertaTokenizer)
def whitening_torch_final(embeddings):
mu = torch.mean(embeddings, dim=0, keepdim=True)
cov = torch.mm((embeddings - mu).t(), embeddings - mu)
u, s, vt = torch.svd(cov)
W = torch.mm(u, torch.diag(1/torch.sqrt(s)))
embeddings = torch.mm(embeddings - mu, W)
return embeddings
class BaseModel(nn.Module):
def __init__(self, ):
super().__init__()
def model_parameters(self):
table = PrettyTable()
table.field_names = ["Layer Name", "Output Shape", "Param #"]
table.align["Layer Name"] = "l"
table.align["Output Shape"] = "r"
table.align["Param #"] = "r"
for name, parameters in self.named_parameters():
if parameters.requires_grad:
table.add_row([name, str(list(parameters.shape)), parameters.numel()])
return table
class Model(BaseModel):
def __init__(self, encoder):
super(Model, self).__init__()
self.encoder = encoder
def forward(self, code_inputs=None, nl_inputs=None):
# code_inputs [bs, seq]
if code_inputs is not None:
outputs = self.encoder(code_inputs,attention_mask=code_inputs.ne(1))[0] #[bs, seq_len, dim]
outputs = (outputs*code_inputs.ne(1)[:,:,None]).sum(1)/code_inputs.ne(1).sum(-1)[:,None] # None作为ndarray或tensor的索引作用是增加维度,
return torch.nn.functional.normalize(outputs, p=2, dim=1)
else:
outputs = self.encoder(nl_inputs,attention_mask=nl_inputs.ne(1))[0]
outputs = (outputs*nl_inputs.ne(1)[:,:,None]).sum(1)/nl_inputs.ne(1).sum(-1)[:,None]
return torch.nn.functional.normalize(outputs, p=2, dim=1)
class Multi_Loss_CoCoSoDa( BaseModel):
def __init__(self, base_encoder, args, mlp=False):
super(Multi_Loss_CoCoSoDa, self).__init__()
self.K = args.moco_k
self.m = args.moco_m
self.T = args.moco_t
dim= args.moco_dim
# create the encoders
# num_classes is the output fc dimension
self.code_encoder_q = base_encoder
self.code_encoder_k = copy.deepcopy(base_encoder)
self.nl_encoder_q = base_encoder
# self.nl_encoder_q = RobertaModel.from_pretrained("roberta-base")
self.nl_encoder_k = copy.deepcopy(self.nl_encoder_q)
self.mlp = mlp
self.time_score= args.time_score
self.do_whitening = args.do_whitening
self.do_ineer_loss = args.do_ineer_loss
self.agg_way = args.agg_way
self.args = args
for param_q, param_k in zip(self.code_encoder_q.parameters(), self.code_encoder_k.parameters()):
param_k.data.copy_(param_q.data) # initialize
param_k.requires_grad = False # not update by gradient
for param_q, param_k in zip(self.nl_encoder_q.parameters(), self.nl_encoder_k.parameters()):
param_k.data.copy_(param_q.data) # initialize
param_k.requires_grad = False # not update by gradient
# create the code queue
torch.manual_seed(3047)
torch.cuda.manual_seed(3047)
self.register_buffer("code_queue", torch.randn(dim,self.K ))
self.code_queue = nn.functional.normalize(self.code_queue, dim=0)
self.register_buffer("code_queue_ptr", torch.zeros(1, dtype=torch.long))
# create the masked code queue
self.register_buffer("masked_code_queue", torch.randn(dim, self.K ))
self.masked_code_queue = nn.functional.normalize(self.masked_code_queue, dim=0)
self.register_buffer("masked_code_queue_ptr", torch.zeros(1, dtype=torch.long))
# create the nl queue
self.register_buffer("nl_queue", torch.randn(dim, self.K ))
self.nl_queue = nn.functional.normalize(self.nl_queue, dim=0)
self.register_buffer("nl_queue_ptr", torch.zeros(1, dtype=torch.long))
# create the masked nl queue
self.register_buffer("masked_nl_queue", torch.randn(dim, self.K ))
self.masked_nl_queue= nn.functional.normalize(self.masked_nl_queue, dim=0)
self.register_buffer("masked_nl_queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""
Momentum update of the key encoder
% key encoder的Momentum update
"""
for param_q, param_k in zip(self.code_encoder_q.parameters(), self.code_encoder_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
for param_q, param_k in zip(self.nl_encoder_q.parameters(), self.nl_encoder_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
if self.mlp:
for param_q, param_k in zip(self.code_encoder_q_fc.parameters(), self.code_encoder_k_fc.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
for param_q, param_k in zip(self.nl_encoder_q_fc.parameters(), self.nl_encoder_k_fc.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys, option='code'):
# gather keys before updating queue
# keys = concat_all_gather(keys)
batch_size = keys.shape[0]
if option == 'code':
code_ptr = int(self.code_queue_ptr)
assert self.K % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
try:
self.code_queue[:, code_ptr:code_ptr + batch_size] = keys.T
except:
print(code_ptr)
print(batch_size)
print(keys.shape)
exit(111)
code_ptr = (code_ptr + batch_size) % self.K # move pointer ptr->pointer
self.code_queue_ptr[0] = code_ptr
elif option == 'masked_code':
masked_code_ptr = int(self.masked_code_queue_ptr)
assert self.K % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
try:
self.masked_code_queue[:, masked_code_ptr:masked_code_ptr + batch_size] = keys.T
except:
print(masked_code_ptr)
print(batch_size)
print(keys.shape)
exit(111)
masked_code_ptr = (masked_code_ptr + batch_size) % self.K # move pointer ptr->pointer
self.masked_code_queue_ptr[0] = masked_code_ptr
elif option == 'nl':
nl_ptr = int(self.nl_queue_ptr)
assert self.K % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.nl_queue[:, nl_ptr:nl_ptr + batch_size] = keys.T
nl_ptr = (nl_ptr + batch_size) % self.K # move pointer ptr->pointer
self.nl_queue_ptr[0] = nl_ptr
elif option == 'masked_nl':
masked_nl_ptr = int(self.masked_nl_queue_ptr)
assert self.K % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.masked_nl_queue[:, masked_nl_ptr:masked_nl_ptr + batch_size] = keys.T
masked_nl_ptr = (masked_nl_ptr + batch_size) % self.K # move pointer ptr->pointer
self.masked_nl_queue_ptr[0] = masked_nl_ptr
def forward(self, source_code_q, source_code_k, nl_q,nl_k):
"""
Input:
im_q: a batch of query images
im_k: a batch of key images
Output:
logits, targets
"""
if not self.args.do_multi_lang_continue_pre_train:
# logger.info(".do_multi_lang_continue_pre_train")
outputs = self.code_encoder_q(source_code_q, attention_mask=source_code_q.ne(1))[0]
code_q = (outputs*source_code_q.ne(1)[:,:,None]).sum(1)/source_code_q.ne(1).sum(-1)[:,None] # None作为ndarray或tensor的索引作用是增加维度,
code_q = torch.nn.functional.normalize(code_q, p=2, dim=1)
# compute query features for nl
outputs= self.nl_encoder_q(nl_q, attention_mask=nl_q.ne(1))[0] # queries: NxC bs*feature_dim
nl_q = (outputs*nl_q.ne(1)[:,:,None]).sum(1)/nl_q.ne(1).sum(-1)[:,None]
nl_q = torch.nn.functional.normalize(nl_q, p=2, dim=1)
code2nl_logits = torch.einsum("ab,cb->ac", code_q,nl_q )
# loss = self.loss_fct(scores*20, torch.arange(code_inputs.size(0), device=scores.device))
code2nl_logits /= self.T
# label
code2nl_label = torch.arange(code2nl_logits.size(0), device=code2nl_logits.device)
return code2nl_logits,code2nl_label, None, None
if self.agg_way == "avg":
# compute query features for source code
outputs = self.code_encoder_q(source_code_q, attention_mask=source_code_q.ne(1))[0]
code_q = (outputs*source_code_q.ne(1)[:,:,None]).sum(1)/source_code_q.ne(1).sum(-1)[:,None] # None作为ndarray或tensor的索引作用是增加维度,
code_q = torch.nn.functional.normalize(code_q, p=2, dim=1)
# compute query features for nl
outputs= self.nl_encoder_q(nl_q, attention_mask=nl_q.ne(1))[0] # queries: NxC bs*feature_dim
nl_q = (outputs*nl_q.ne(1)[:,:,None]).sum(1)/nl_q.ne(1).sum(-1)[:,None]
nl_q = torch.nn.functional.normalize(nl_q, p=2, dim=1)
# compute key features
with torch.no_grad(): # no gradient to keys
self._momentum_update_key_encoder() # update the key encoder
# shuffle for making use of BN
# im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
# masked code
outputs = self.code_encoder_k(source_code_k, attention_mask=source_code_k.ne(1))[0] # keys: NxC
code_k = (outputs*source_code_k.ne(1)[:,:,None]).sum(1)/source_code_k.ne(1).sum(-1)[:,None] # None作为ndarray或tensor的索引作用是增加维度,
code_k = torch.nn.functional.normalize( code_k, p=2, dim=1)
# masked nl
outputs = self.nl_encoder_k(nl_k, attention_mask=nl_k.ne(1))[0] # keys: bs*dim
nl_k = (outputs*nl_k.ne(1)[:,:,None]).sum(1)/nl_k.ne(1).sum(-1)[:,None]
nl_k = torch.nn.functional.normalize(nl_k, p=2, dim=1)
elif self.agg_way == "cls_pooler":
# logger.info(self.agg_way )
# compute query features for source code
outputs = self.code_encoder_q(source_code_q, attention_mask=source_code_q.ne(1))[1]
code_q = torch.nn.functional.normalize(code_q, p=2, dim=1)
# compute query features for nl
outputs= self.nl_encoder_q(nl_q, attention_mask=nl_q.ne(1))[1] # queries: NxC bs*feature_dim
nl_q = torch.nn.functional.normalize(nl_q, p=2, dim=1)
# compute key features
with torch.no_grad(): # no gradient to keys
self._momentum_update_key_encoder() # update the key encoder
# shuffle for making use of BN
# im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
# masked code
outputs = self.code_encoder_k(source_code_k, attention_mask=source_code_k.ne(1))[1] # keys: NxC
code_k = torch.nn.functional.normalize( code_k, p=2, dim=1)
# masked nl
outputs = self.nl_encoder_k(nl_k, attention_mask=nl_k.ne(1))[1] # keys: bs*dim
nl_k = torch.nn.functional.normalize(nl_k, p=2, dim=1)
elif self.agg_way == "avg_cls_pooler":
# logger.info(self.agg_way )
outputs = self.code_encoder_q(source_code_q, attention_mask=source_code_q.ne(1))
code_q_cls = outputs[1]
outputs = outputs[0]
code_q_avg = (outputs*source_code_q.ne(1)[:,:,None]).sum(1)/source_code_q.ne(1).sum(-1)[:,None] # None作为ndarray或tensor的索引作用是增加维度,
code_q = code_q_cls + code_q_avg
code_q = torch.nn.functional.normalize(code_q, p=2, dim=1)
# compute query features for nl
outputs= self.nl_encoder_q(nl_q, attention_mask=nl_q.ne(1))
nl_q_cls = outputs[1]
outputs= outputs[0] # queries: NxC bs*feature_dim
nl_q_avg = (outputs*nl_q.ne(1)[:,:,None]).sum(1)/nl_q.ne(1).sum(-1)[:,None]
nl_q = nl_q_avg + nl_q_cls
nl_q = torch.nn.functional.normalize(nl_q, p=2, dim=1)
# compute key features
with torch.no_grad(): # no gradient to keys
self._momentum_update_key_encoder() # update the key encoder
# shuffle for making use of BN
# im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
# masked code
outputs = self.code_encoder_k(source_code_k, attention_mask=source_code_k.ne(1))
code_k_cls = outputs[1] # keys: NxC
outputs = outputs[0]
code_k_avg = (outputs*source_code_k.ne(1)[:,:,None]).sum(1)/source_code_k.ne(1).sum(-1)[:,None] # None作为ndarray或tensor的索引作用是增加维度,
code_k = code_k_cls + code_k_avg
code_k = torch.nn.functional.normalize( code_k, p=2, dim=1)
# masked nl
outputs = self.nl_encoder_k(nl_k, attention_mask=nl_k.ne(1))
nl_k_cls = outputs[1] # keys: bs*dim
outputs = outputs[0]
nl_k_avg = (outputs*nl_k.ne(1)[:,:,None]).sum(1)/nl_k.ne(1).sum(-1)[:,None]
nl_k = nl_k_cls + nl_k_avg
nl_k = torch.nn.functional.normalize(nl_k, p=2, dim=1)
# ## do_whitening
# if self.do_whitening:
# code_q = whitening_torch_final(code_q)
# code_k = whitening_torch_final(code_k)
# nl_q = whitening_torch_final(nl_q)
# nl_k = whitening_torch_final(nl_k)
## code vs nl
code2nl_pos = torch.einsum('nc,bc->nb', [code_q, nl_q])
# negative logits: NxK
code2nl_neg = torch.einsum('nc,ck->nk', [code_q, self.nl_queue.clone().detach()])
# logits: Nx(n+K)
code2nl_logits = torch.cat([self.time_score*code2nl_pos, code2nl_neg], dim=1)
# apply temperature
code2nl_logits /= self.T
# label
code2nl_label = torch.arange(code2nl_logits.size(0), device=code2nl_logits.device)
## code vs masked nl
code2maskednl_pos = torch.einsum('nc,bc->nb', [code_q, nl_k])
# negative logits: NxK
code2maskednl_neg = torch.einsum('nc,ck->nk', [code_q, self.masked_nl_queue.clone().detach()])
# logits: Nx(n+K)
code2maskednl_logits = torch.cat([self.time_score*code2maskednl_pos, code2maskednl_neg], dim=1)
# apply temperature
code2maskednl_logits /= self.T
# label
code2maskednl_label = torch.arange(code2maskednl_logits.size(0), device=code2maskednl_logits.device)
## nl vs code
# nl2code_pos = torch.einsum('nc,nc->n', [nl_q, code_k]).unsqueeze(-1)
nl2code_pos = torch.einsum('nc,bc->nb', [nl_q, code_q])
# negative logits: bsxK
nl2code_neg = torch.einsum('nc,ck->nk', [nl_q, self.code_queue.clone().detach()])
# nl2code_logits: bsx(n+K)
nl2code_logits = torch.cat([self.time_score*nl2code_pos, nl2code_neg], dim=1)
# apply temperature
nl2code_logits /= self.T
# label
nl2code_label = torch.arange(nl2code_logits.size(0), device=nl2code_logits.device)
## nl vs masked code
# nl2code_pos = torch.einsum('nc,nc->n', [nl_q, code_k]).unsqueeze(-1)
nl2maskedcode_pos = torch.einsum('nc,bc->nb', [nl_q, code_k])
# negative logits: bsxK
nl2maskedcode_neg = torch.einsum('nc,ck->nk', [nl_q, self.masked_code_queue.clone().detach()])
# nl2code_logits: bsx(n+K)
nl2maskedcode_logits = torch.cat([self.time_score*nl2maskedcode_pos, nl2maskedcode_neg], dim=1)
# apply temperature
nl2maskedcode_logits /= self.T
# label
nl2maskedcode_label = torch.arange(nl2maskedcode_logits.size(0), device=nl2maskedcode_logits.device)
#logit 4*bsx(1+K)
inter_logits = torch.cat((code2nl_logits, code2maskednl_logits, nl2code_logits ,nl2maskedcode_logits ), dim=0)
# labels: positive key indicators
# inter_labels = torch.zeros(inter_logits.shape[0], dtype=torch.long).cuda()
inter_labels = torch.cat((code2nl_label, code2maskednl_label, nl2code_label, nl2maskedcode_label), dim=0)
if self.do_ineer_loss:
# logger.info("do_ineer_loss")
## code vs masked code
code2maskedcode_pos = torch.einsum('nc,bc->nb', [code_q, code_k])
# negative logits: NxK
code2maskedcode_neg = torch.einsum('nc,ck->nk', [code_q, self.masked_code_queue.clone().detach()])
# logits: Nx(n+K)
code2maskedcode_logits = torch.cat([self.time_score*code2maskedcode_pos, code2maskedcode_neg], dim=1)
# apply temperature
code2maskedcode_logits /= self.T
# label
code2maskedcode_label = torch.arange(code2maskedcode_logits.size(0), device=code2maskedcode_logits.device)
## nl vs masked nl
# nl2code_pos = torch.einsum('nc,nc->n', [nl_q, code_k]).unsqueeze(-1)
nl2maskednl_pos = torch.einsum('nc,bc->nb', [nl_q, nl_k])
# negative logits: bsxK
nl2maskednl_neg = torch.einsum('nc,ck->nk', [nl_q, self.masked_nl_queue.clone().detach()])
# nl2code_logits: bsx(n+K)
nl2maskednl_logits = torch.cat([self.time_score*nl2maskednl_pos, nl2maskednl_neg], dim=1)
# apply temperature
nl2maskednl_logits /= self.T
# label
nl2maskednl_label = torch.arange(nl2maskednl_logits.size(0), device=nl2maskednl_logits.device)
#logit 6*bsx(1+K)
inter_logits = torch.cat((inter_logits, code2maskedcode_logits, nl2maskednl_logits), dim=0)
# labels: positive key indicators
# inter_labels = torch.zeros(inter_logits.shape[0], dtype=torch.long).cuda()
inter_labels = torch.cat(( inter_labels, code2maskedcode_label, nl2maskednl_label ), dim=0)
# dequeue and enqueue
self._dequeue_and_enqueue(code_q, option='code')
self._dequeue_and_enqueue(nl_q, option='nl')
self._dequeue_and_enqueue(code_k, option='masked_code')
self._dequeue_and_enqueue(nl_k, option='masked_nl')
return inter_logits, inter_labels, code_q, nl_q