|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa). |
|
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned |
|
using a masked language modeling (MLM) loss. |
|
""" |
|
|
|
from unittest import removeResult |
|
import torch.nn.functional as F |
|
import argparse |
|
import logging |
|
import os |
|
|
|
import pickle |
|
import random |
|
import torch |
|
import json |
|
from random import choice |
|
import numpy as np |
|
from itertools import cycle |
|
from model import Model,Multi_Loss_CoCoSoDa |
|
from torch.nn import CrossEntropyLoss |
|
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler |
|
from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, |
|
RobertaConfig, RobertaModel, RobertaTokenizer) |
|
|
|
logger = logging.getLogger(__name__) |
|
from tqdm import tqdm |
|
import multiprocessing |
|
cpu_cont = 16 |
|
|
|
from parser import DFG_python,DFG_java,DFG_ruby,DFG_go,DFG_php,DFG_javascript |
|
from parser import (remove_comments_and_docstrings, |
|
tree_to_token_index, |
|
index_to_code_token, |
|
tree_to_variable_index) |
|
from tree_sitter import Language, Parser |
|
import sys |
|
sys.path.append("dataset") |
|
torch.cuda.set_per_process_memory_fraction(0.8) |
|
from utils import save_json_data, save_pickle_data |
|
dfg_function={ |
|
'python':DFG_python, |
|
'java':DFG_java, |
|
'ruby':DFG_ruby, |
|
'go':DFG_go, |
|
'php':DFG_php, |
|
'javascript':DFG_javascript |
|
} |
|
|
|
parsers={} |
|
for lang in dfg_function: |
|
LANGUAGE = Language('parser/my-languages.so', lang) |
|
parser = Parser() |
|
parser.set_language(LANGUAGE) |
|
parser = [parser,dfg_function[lang]] |
|
parsers[lang]= parser |
|
|
|
|
|
ruby_special_token = ['keyword', 'identifier', 'separators', 'simple_symbol', 'constant', 'instance_variable', |
|
'operator', 'string_content', 'integer', 'escape_sequence', 'comment', 'hash_key_symbol', |
|
'global_variable', 'heredoc_beginning', 'heredoc_content', 'heredoc_end', 'class_variable',] |
|
|
|
java_special_token = ['keyword', 'identifier', 'type_identifier', 'separators', 'operator', 'decimal_integer_literal', |
|
'void_type', 'string_literal', 'decimal_floating_point_literal', |
|
'boolean_type', 'null_literal', 'comment', 'hex_integer_literal', 'character_literal'] |
|
|
|
go_special_token = ['keyword', 'identifier', 'separators', 'type_identifier', 'int_literal', 'operator', |
|
'field_identifier', 'package_identifier', 'comment', 'escape_sequence', 'raw_string_literal', |
|
'rune_literal', 'label_name', 'float_literal'] |
|
|
|
javascript_special_token =['keyword', 'separators', 'identifier', 'property_identifier', 'operator', |
|
'number', 'string_fragment', 'comment', 'regex_pattern', 'shorthand_property_identifier_pattern', |
|
'shorthand_property_identifier', 'regex_flags', 'escape_sequence', 'statement_identifier'] |
|
|
|
php_special_token =['text', 'php_tag', 'name', 'operator', 'keyword', 'string', 'integer', 'separators', 'comment', |
|
'escape_sequence', 'ERROR', 'boolean', 'namespace', 'class', 'extends'] |
|
|
|
python_special_token =['keyword', 'identifier', 'separators', 'operator', '"', 'integer', |
|
'comment', 'none', 'escape_sequence'] |
|
|
|
|
|
special_token={ |
|
'python':python_special_token, |
|
'java':java_special_token, |
|
'ruby':ruby_special_token, |
|
'go':go_special_token, |
|
'php':php_special_token, |
|
'javascript':javascript_special_token |
|
} |
|
|
|
all_special_token = [] |
|
for key, value in special_token.items(): |
|
all_special_token = list(set(all_special_token ).union(set(value))) |
|
|
|
def lalign(x, y, alpha=2): |
|
x = torch.tensor(x) |
|
y= torch.tensor(y) |
|
return (x - y).norm(dim=1).pow(alpha).mean() |
|
|
|
|
|
|
|
|
|
def lunif(x, t=2): |
|
x = torch.tensor(x) |
|
sq_pdist = torch.pdist(x, p=2).pow(2) |
|
return sq_pdist.mul(-t).exp().mean().log() |
|
|
|
|
|
|
|
def cal_r1_r5_r10(ranks): |
|
r1,r5,r10= 0,0,0 |
|
data_len= len(ranks) |
|
for item in ranks: |
|
if item >=1: |
|
r1 +=1 |
|
r5 += 1 |
|
r10 += 1 |
|
elif item >=0.2: |
|
r5+= 1 |
|
r10+=1 |
|
elif item >=0.1: |
|
r10 +=1 |
|
result = {"R@1":round(r1/data_len,3), "R@5": round(r5/data_len,3), "R@10": round(r10/data_len,3)} |
|
return result |
|
|
|
|
|
def extract_dataflow(code, parser,lang): |
|
|
|
try: |
|
code=remove_comments_and_docstrings(code,lang) |
|
except: |
|
pass |
|
|
|
if lang=="php": |
|
code="<?php"+code+"?>" |
|
try: |
|
tree = parser[0].parse(bytes(code,'utf8')) |
|
root_node = tree.root_node |
|
tokens_index=tree_to_token_index(root_node) |
|
code=code.split('\n') |
|
code_tokens=[index_to_code_token(x,code) for x in tokens_index] |
|
index_to_code={} |
|
for idx,(index,code) in enumerate(zip(tokens_index,code_tokens)): |
|
index_to_code[index]=(idx,code) |
|
try: |
|
DFG,_=parser[1](root_node,index_to_code,{}) |
|
except: |
|
DFG=[] |
|
DFG=sorted(DFG,key=lambda x:x[1]) |
|
indexs=set() |
|
for d in DFG: |
|
if len(d[-1])!=0: |
|
indexs.add(d[1]) |
|
for x in d[-1]: |
|
indexs.add(x) |
|
new_DFG=[] |
|
for d in DFG: |
|
if d[1] in indexs: |
|
new_DFG.append(d) |
|
dfg=new_DFG |
|
except: |
|
dfg=[] |
|
return code_tokens,dfg |
|
|
|
|
|
def tokenizer_source_code(code, parser,lang): |
|
|
|
try: |
|
code=remove_comments_and_docstrings(code,lang) |
|
except: |
|
pass |
|
|
|
if lang=="php": |
|
code="<?php"+code+"?>" |
|
try: |
|
tree = parser[0].parse(bytes(code,'utf8')) |
|
root_node = tree.root_node |
|
tokens_index=tree_to_token_index(root_node) |
|
code=code.split('\n') |
|
code_tokens=[index_to_code_token(x,code) for x in tokens_index] |
|
except: |
|
dfg=[] |
|
return code_tokens |
|
|
|
class InputFeatures(object): |
|
"""A single training/test features for a example.""" |
|
def __init__(self, |
|
code_tokens, |
|
code_ids, |
|
|
|
|
|
|
|
nl_tokens, |
|
nl_ids, |
|
url, |
|
|
|
): |
|
self.code_tokens = code_tokens |
|
self.code_ids = code_ids |
|
|
|
|
|
|
|
self.nl_tokens = nl_tokens |
|
self.nl_ids = nl_ids |
|
self.url=url |
|
|
|
|
|
class TypeAugInputFeatures(object): |
|
"""A single training/test features for a example.""" |
|
def __init__(self, |
|
code_tokens, |
|
code_ids, |
|
|
|
code_type, |
|
code_type_ids, |
|
nl_tokens, |
|
nl_ids, |
|
url, |
|
|
|
): |
|
self.code_tokens = code_tokens |
|
self.code_ids = code_ids |
|
|
|
self.code_type=code_type |
|
self.code_type_ids=code_type_ids |
|
self.nl_tokens = nl_tokens |
|
self.nl_ids = nl_ids |
|
self.url=url |
|
|
|
def convert_examples_to_features(js): |
|
js,tokenizer,args=js |
|
|
|
if args.lang == "java_mini": |
|
parser=parsers["java"] |
|
else: |
|
parser=parsers[js["language"]] |
|
|
|
code_tokens=tokenizer_source_code(js['original_string'],parser,args.lang) |
|
code_tokens=" ".join(code_tokens[:args.code_length-2]) |
|
code_tokens=tokenizer.tokenize(code_tokens)[:args.code_length-2] |
|
code_tokens =[tokenizer.cls_token]+code_tokens+[tokenizer.sep_token] |
|
code_ids = tokenizer.convert_tokens_to_ids(code_tokens) |
|
padding_length = args.code_length - len(code_ids) |
|
code_ids+=[tokenizer.pad_token_id]*padding_length |
|
|
|
|
|
nl=' '.join(js['docstring_tokens']) |
|
nl_tokens=tokenizer.tokenize(nl)[:args.nl_length-2] |
|
nl_tokens =[tokenizer.cls_token]+nl_tokens+[tokenizer.sep_token] |
|
nl_ids = tokenizer.convert_tokens_to_ids(nl_tokens) |
|
padding_length = args.nl_length - len(nl_ids) |
|
nl_ids+=[tokenizer.pad_token_id]*padding_length |
|
|
|
return InputFeatures(code_tokens,code_ids,nl_tokens,nl_ids,js['url']) |
|
|
|
|
|
def convert_examples_to_features_aug_type(js): |
|
js,tokenizer,args=js |
|
|
|
if args.lang == "java_mini": |
|
parser=parsers["java"] |
|
else: |
|
parser=parsers[js["language"]] |
|
|
|
token_type_role = js[ 'bpe_token_type_role'] |
|
code_token = [item[0] for item in token_type_role] |
|
|
|
|
|
code_tokens = code_token[:args.code_length-4] |
|
code_tokens =[tokenizer.cls_token,"<encoder-only>",tokenizer.sep_token]+code_tokens+[tokenizer.sep_token] |
|
code_ids = tokenizer.convert_tokens_to_ids(code_tokens) |
|
padding_length = args.code_length - len(code_ids) |
|
code_ids += [tokenizer.pad_token_id]*padding_length |
|
|
|
|
|
code_type_token = [item[-1] for item in token_type_role] |
|
|
|
|
|
code_type_tokens = code_type_token[:args.code_length-4] |
|
code_type_tokens =[tokenizer.cls_token,"<encoder-only>",tokenizer.sep_token]+code_type_tokens+[tokenizer.sep_token] |
|
code_type_ids = tokenizer.convert_tokens_to_ids(code_type_tokens) |
|
padding_length = args.code_length - len(code_type_ids) |
|
code_type_ids += [tokenizer.pad_token_id]*padding_length |
|
|
|
|
|
nl=' '.join(js['docstring_tokens']) |
|
nl_tokens = tokenizer.tokenize(nl)[:args.nl_length-4] |
|
nl_tokens = [tokenizer.cls_token,"<encoder-only>",tokenizer.sep_token]+nl_tokens+[tokenizer.sep_token] |
|
nl_ids = tokenizer.convert_tokens_to_ids(nl_tokens) |
|
padding_length = args.nl_length - len(nl_ids) |
|
nl_ids += [tokenizer.pad_token_id]*padding_length |
|
|
|
return TypeAugInputFeatures(code_tokens,code_ids,code_type_tokens,code_type_ids,nl_tokens,nl_ids,js['url']) |
|
|
|
|
|
class TextDataset(Dataset): |
|
def __init__(self, tokenizer, args, file_path=None,pool=None): |
|
self.args=args |
|
prefix=file_path.split('/')[-1][:-6] |
|
cache_file=args.output_dir+'/'+prefix+'.pkl' |
|
n_debug_samples = args.n_debug_samples |
|
|
|
|
|
if 'train' in file_path: |
|
self.split = "train" |
|
else: |
|
self.split = "other" |
|
if os.path.exists(cache_file): |
|
self.examples=pickle.load(open(cache_file,'rb')) |
|
if args.debug: |
|
self.examples= self.examples[:n_debug_samples] |
|
else: |
|
self.examples = [] |
|
data=[] |
|
if args.debug: |
|
with open(file_path, encoding="utf-8") as f: |
|
for line in f: |
|
line=line.strip() |
|
js=json.loads(line) |
|
data.append((js,tokenizer,args)) |
|
if len(data) >= n_debug_samples: |
|
break |
|
else: |
|
with open(file_path, encoding="utf-8") as f: |
|
for line in f: |
|
line=line.strip() |
|
js=json.loads(line) |
|
data.append((js,tokenizer,args)) |
|
|
|
if self.args.data_aug_type == "replace_type": |
|
self.examples=pool.map(convert_examples_to_features_aug_type, tqdm(data,total=len(data))) |
|
else: |
|
self.examples=pool.map(convert_examples_to_features, tqdm(data,total=len(data))) |
|
|
|
if 'train' in file_path: |
|
for idx, example in enumerate(self.examples[:3]): |
|
logger.info("*** Example ***") |
|
logger.info("idx: {}".format(idx)) |
|
logger.info("code_tokens: {}".format([x.replace('\u0120','_') for x in example.code_tokens])) |
|
logger.info("code_ids: {}".format(' '.join(map(str, example.code_ids)))) |
|
logger.info("nl_tokens: {}".format([x.replace('\u0120','_') for x in example.nl_tokens])) |
|
logger.info("nl_ids: {}".format(' '.join(map(str, example.nl_ids)))) |
|
|
|
def __len__(self): |
|
return len(self.examples) |
|
|
|
def __getitem__(self, item): |
|
if self.args.data_aug_type == "replace_type": |
|
return (torch.tensor(self.examples[item].code_ids), |
|
torch.tensor(self.examples[item].code_type_ids), |
|
torch.tensor(self.examples[item].nl_ids)) |
|
else: |
|
return (torch.tensor(self.examples[item].code_ids), |
|
torch.tensor(self.examples[item].nl_ids), |
|
torch.tensor(self.examples[item].code_tokens), |
|
torch.tensor(self.examples[item].nl_tokens)) |
|
|
|
|
|
def convert_examples_to_features_unixcoder(js,tokenizer,args): |
|
"""convert examples to token ids""" |
|
code = ' '.join(js['code_tokens']) if type(js['code_tokens']) is list else ' '.join(js['code_tokens'].split()) |
|
code_tokens = tokenizer.tokenize(code)[:args.code_length-4] |
|
code_tokens =[tokenizer.cls_token,"<encoder-only>",tokenizer.sep_token]+code_tokens+[tokenizer.sep_token] |
|
code_ids = tokenizer.convert_tokens_to_ids(code_tokens) |
|
padding_length = args.code_length - len(code_ids) |
|
code_ids += [tokenizer.pad_token_id]*padding_length |
|
|
|
nl = ' '.join(js['docstring_tokens']) if type(js['docstring_tokens']) is list else ' '.join(js['doc'].split()) |
|
nl_tokens = tokenizer.tokenize(nl)[:args.nl_length-4] |
|
nl_tokens = [tokenizer.cls_token,"<encoder-only>",tokenizer.sep_token]+nl_tokens+[tokenizer.sep_token] |
|
nl_ids = tokenizer.convert_tokens_to_ids(nl_tokens) |
|
padding_length = args.nl_length - len(nl_ids) |
|
nl_ids += [tokenizer.pad_token_id]*padding_length |
|
|
|
return InputFeatures(code_tokens,code_ids,nl_tokens,nl_ids,js['url'] if "url" in js else js["retrieval_idx"]) |
|
|
|
class TextDataset_unixcoder(Dataset): |
|
def __init__(self, tokenizer, args, file_path=None, pooler=None): |
|
self.examples = [] |
|
data = [] |
|
n_debug_samples = args.n_debug_samples |
|
with open(file_path) as f: |
|
if "jsonl" in file_path: |
|
for line in f: |
|
line = line.strip() |
|
js = json.loads(line) |
|
if 'function_tokens' in js: |
|
js['code_tokens'] = js['function_tokens'] |
|
data.append(js) |
|
if args.debug and len(data) >= n_debug_samples: |
|
break |
|
elif "codebase"in file_path or "code_idx_map" in file_path: |
|
js = json.load(f) |
|
for key in js: |
|
temp = {} |
|
temp['code_tokens'] = key.split() |
|
temp["retrieval_idx"] = js[key] |
|
temp['doc'] = "" |
|
temp['docstring_tokens'] = "" |
|
data.append(temp) |
|
if args.debug and len(data) >= n_debug_samples: |
|
break |
|
elif "json" in file_path: |
|
for js in json.load(f): |
|
data.append(js) |
|
if args.debug and len(data) >= n_debug_samples: |
|
break |
|
|
|
|
|
for js in data: |
|
self.examples.append(convert_examples_to_features_unixcoder(js,tokenizer,args)) |
|
|
|
if "train" in file_path: |
|
|
|
for idx, example in enumerate(self.examples[:3]): |
|
logger.info("*** Example ***") |
|
logger.info("idx: {}".format(idx)) |
|
logger.info("code_tokens: {}".format([x.replace('\u0120','_') for x in example.code_tokens])) |
|
logger.info("code_ids: {}".format(' '.join(map(str, example.code_ids)))) |
|
logger.info("nl_tokens: {}".format([x.replace('\u0120','_') for x in example.nl_tokens])) |
|
logger.info("nl_ids: {}".format(' '.join(map(str, example.nl_ids)))) |
|
|
|
def __len__(self): |
|
return len(self.examples) |
|
|
|
def __getitem__(self, i): |
|
return (torch.tensor(self.examples[i].code_ids),torch.tensor(self.examples[i].nl_ids)) |
|
|
|
|
|
|
|
|
|
|
|
def set_seed(seed=42): |
|
random.seed(seed) |
|
os.environ['PYHTONHASHSEED'] = str(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
def mask_tokens(inputs,tokenizer,mlm_probability): |
|
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ |
|
labels = inputs.clone() |
|
|
|
probability_matrix = torch.full(labels.shape, mlm_probability).to(inputs.device) |
|
special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in |
|
labels.tolist()] |
|
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool).to(inputs.device), value=0.0) |
|
if tokenizer._pad_token is not None: |
|
padding_mask = labels.eq(tokenizer.pad_token_id) |
|
probability_matrix.masked_fill_(padding_mask, value=0.0) |
|
|
|
masked_indices = torch.bernoulli(probability_matrix).bool() |
|
labels[~masked_indices] = -100 |
|
|
|
|
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool().to(inputs.device) & masked_indices |
|
inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) |
|
|
|
|
|
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool().to(inputs.device) & masked_indices & ~indices_replaced |
|
random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long).to(inputs.device) |
|
inputs[indices_random] = random_words[indices_random] |
|
|
|
|
|
return inputs, labels |
|
|
|
|
|
def replace_with_type_tokens(inputs,replaces,tokenizer,mlm_probability): |
|
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ |
|
labels = inputs.clone() |
|
|
|
probability_matrix = torch.full(labels.shape, mlm_probability).to(inputs.device) |
|
special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in |
|
labels.tolist()] |
|
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool).to(inputs.device), value=0.0) |
|
if tokenizer._pad_token is not None: |
|
padding_mask = labels.eq(tokenizer.pad_token_id) |
|
probability_matrix.masked_fill_(padding_mask, value=0.0) |
|
|
|
masked_indices = torch.bernoulli(probability_matrix).bool() |
|
labels[~masked_indices] = -100 |
|
|
|
|
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool().to(inputs.device) & masked_indices |
|
inputs[indices_replaced] = replaces[indices_replaced] |
|
|
|
return inputs, labels |
|
|
|
def replace_special_token_with_type_tokens(inputs, speical_token_ids, tokenizer, mlm_probability): |
|
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ |
|
labels = inputs.clone() |
|
probability_matrix = torch.full(labels.shape,0.0).to(inputs.device) |
|
probability_matrix.masked_fill_(labels.eq(speical_token_ids).to(inputs.device), value=mlm_probability) |
|
masked_indices = torch.bernoulli(probability_matrix).bool() |
|
labels[~masked_indices] = -100 |
|
|
|
|
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool().to(inputs.device) & masked_indices |
|
inputs[indices_replaced] = speical_token_ids |
|
|
|
return inputs, labels |
|
|
|
def replace_special_token_with_mask(inputs, speical_token_ids, tokenizer, mlm_probability): |
|
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ |
|
labels = inputs.clone() |
|
probability_matrix = torch.full(labels.shape,0.0).to(inputs.device) |
|
probability_matrix.masked_fill_(labels.eq(speical_token_ids).to(inputs.device), value=mlm_probability) |
|
masked_indices = torch.bernoulli(probability_matrix).bool() |
|
labels[~masked_indices] = -100 |
|
|
|
|
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool().to(inputs.device) & masked_indices |
|
inputs[indices_replaced] =tokenizer.convert_tokens_to_ids(tokenizer.mask_token) |
|
|
|
return inputs, labels |
|
|
|
def train(args, model, tokenizer,pool): |
|
|
|
""" Train the model """ |
|
if args.data_aug_type == "replace_type" : |
|
train_dataset=TextDataset(tokenizer, args, args.train_data_file, pool) |
|
else: |
|
|
|
train_dataset=TextDataset_unixcoder(tokenizer, args, args.train_data_file, pool) |
|
|
|
|
|
train_sampler = RandomSampler(train_dataset) |
|
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size,num_workers=4,drop_last=True) |
|
|
|
model.to(args.device) |
|
if args.local_rank not in [-1, 0]: |
|
torch.distributed.barrier() |
|
no_decay = ['bias', 'LayerNorm.weight'] |
|
optimizer_grouped_parameters = [ |
|
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
|
'weight_decay': args.weight_decay}, |
|
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} |
|
] |
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=1e-8) |
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)*args.num_train_epochs) |
|
|
|
|
|
if args.n_gpu > 1: |
|
model = torch.nn.DataParallel(model) |
|
|
|
|
|
logger.info("***** Running training *****") |
|
logger.info(" Num examples = %d", len(train_dataset)) |
|
logger.info(" Num Epochs = %d", args.num_train_epochs) |
|
logger.info(" Num quene = %d", args.moco_k) |
|
logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size//args.n_gpu) |
|
logger.info(" Total train batch size = %d", args.train_batch_size) |
|
logger.info(" Total optimization steps = %d", len(train_dataloader)*args.num_train_epochs) |
|
|
|
model.zero_grad() |
|
model.train() |
|
tr_num,tr_loss,best_mrr=0,0,-1 |
|
loss_fct = CrossEntropyLoss() |
|
|
|
if args.model_type in ["no_aug_cocosoda", "multi-loss-cocosoda"] : |
|
if args.do_continue_pre_trained: |
|
logger.info("do_continue_pre_trained") |
|
elif args.do_fine_tune: |
|
logger.info("do_fine_tune") |
|
special_token_list = special_token[args.lang] |
|
special_token_id_list = tokenizer.convert_tokens_to_ids(special_token_list) |
|
model_eval = model.module if hasattr(model,'module') else model |
|
for idx in range(args.num_train_epochs): |
|
print(idx) |
|
for step,batch in enumerate(train_dataloader): |
|
|
|
|
|
code_inputs = batch[0].to(args.device) |
|
nl_inputs = batch[1].to(args.device) |
|
|
|
nl_outputs = model_eval.nl_encoder_q(nl_inputs, attention_mask=nl_inputs.ne(1)) |
|
nl_vec =nl_outputs [1] |
|
code_outputs = model_eval.code_encoder_q(code_inputs, attention_mask=code_inputs.ne(1)) |
|
code_vec =code_outputs [1] |
|
|
|
|
|
torch.cuda.empty_cache() |
|
tr_num+=1 |
|
|
|
scores = torch.einsum("ab,cb->ac",nl_vec,code_vec) |
|
|
|
loss = loss_fct(scores*20, torch.arange(code_inputs.size(0), device=scores.device)) |
|
|
|
tr_loss += loss.item() |
|
|
|
if (step+1)% args.eval_frequency==0: |
|
logger.info("epoch {} step {} loss {}".format(idx,step+1,round(tr_loss/tr_num,5))) |
|
tr_loss=0 |
|
tr_num=0 |
|
|
|
|
|
loss.backward() |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
scheduler.step() |
|
torch.cuda.empty_cache() |
|
|
|
results = evaluate(args, model, tokenizer,args.eval_data_file, pool, eval_when_training=True) |
|
for key, value in results.items(): |
|
logger.info(" %s = %s", key, round(value,4)) |
|
|
|
|
|
if results['eval_mrr']>best_mrr: |
|
best_mrr=results['eval_mrr'] |
|
logger.info(" "+"*"*20) |
|
logger.info(" Best mrr:%s",round(best_mrr,4)) |
|
logger.info(" "+"*"*20) |
|
|
|
checkpoint_prefix = 'checkpoint-best-mrr' |
|
output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix)) |
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
model_to_save = model.module if hasattr(model,'module') else model |
|
output_dir = os.path.join(output_dir, '{}'.format('model.bin')) |
|
torch.save(model_to_save.state_dict(), output_dir) |
|
logger.info("Saving model checkpoint to %s", output_dir) |
|
|
|
output_dir_epoch = os.path.join(args.output_dir, '{}'.format(idx)) |
|
if not os.path.exists(output_dir_epoch): |
|
os.makedirs(output_dir_epoch) |
|
model_to_save = model.module if hasattr(model,'module') else model |
|
output_dir_epoch = os.path.join(output_dir_epoch, '{}'.format('model.bin')) |
|
torch.save(model_to_save.state_dict(), output_dir_epoch) |
|
logger.info("Saving model checkpoint to %s", output_dir_epoch) |
|
|
|
def multi_lang_continue_pre_train(args, model, tokenizer,pool): |
|
""" Train the model """ |
|
|
|
if "unixcoder" in args.model_name_or_path: |
|
train_datasets = [] |
|
for train_data_file in args.couninue_pre_train_data_files: |
|
train_dataset=TextDataset_unixcoder(tokenizer, args, train_data_file, pool) |
|
train_datasets.append(train_dataset) |
|
else: |
|
train_datasets = [] |
|
for train_data_file in args.couninue_pre_train_data_files: |
|
train_dataset=TextDataset(tokenizer, args, train_data_file, pool) |
|
train_datasets.append(train_dataset) |
|
|
|
train_samplers = [RandomSampler(train_dataset) for train_dataset in train_datasets] |
|
|
|
train_dataloaders = [cycle(DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size,drop_last=True)) for train_dataset,train_sampler in zip(train_datasets,train_samplers)] |
|
t_total = args.max_steps |
|
|
|
|
|
|
|
model.to(args.device) |
|
if args.local_rank not in [-1, 0]: |
|
torch.distributed.barrier() |
|
no_decay = ['bias', 'LayerNorm.weight'] |
|
optimizer_grouped_parameters = [ |
|
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
|
'weight_decay': 0.01}, |
|
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} |
|
] |
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=1e-8) |
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.num_warmup_steps,num_training_steps=t_total) |
|
|
|
|
|
training_data_length = sum ([len(item) for item in train_datasets]) |
|
logger.info("***** Running training *****") |
|
logger.info(" Num examples = %d", training_data_length) |
|
logger.info(" Num Epochs = %d", args.num_train_epochs) |
|
logger.info(" Num quene = %d", args.moco_k) |
|
logger.info(" Instantaneous batch size per GPU = %d", args.train_batch_size//args.n_gpu) |
|
logger.info(" Total train batch size = %d", args.train_batch_size) |
|
|
|
checkpoint_last = os.path.join(args.output_dir, 'checkpoint-last') |
|
scheduler_last = os.path.join(checkpoint_last, 'scheduler.pt') |
|
optimizer_last = os.path.join(checkpoint_last, 'optimizer.pt') |
|
if os.path.exists(scheduler_last): |
|
scheduler.load_state_dict(torch.load(scheduler_last, map_location="cpu")) |
|
if os.path.exists(optimizer_last): |
|
optimizer.load_state_dict(torch.load(optimizer_last, map_location="cpu")) |
|
if args.local_rank == 0: |
|
torch.distributed.barrier() |
|
if args.fp16: |
|
try: |
|
from apex import amp |
|
except ImportError: |
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") |
|
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) |
|
|
|
|
|
if args.n_gpu > 1: |
|
model = torch.nn.DataParallel(model) |
|
|
|
|
|
if args.local_rank != -1: |
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank%args.gpu_per_node], |
|
output_device=args.local_rank%args.gpu_per_node, |
|
find_unused_parameters=True) |
|
|
|
loss_fct = CrossEntropyLoss() |
|
set_seed(args.seed) |
|
probs=[len(x) for x in train_datasets] |
|
probs=[x/sum(probs) for x in probs] |
|
probs=[x**0.7 for x in probs] |
|
probs=[x/sum(probs) for x in probs] |
|
|
|
model.zero_grad() |
|
model.train() |
|
|
|
global_step = args.start_step |
|
step=0 |
|
tr_loss, logging_loss,avg_loss,tr_nb, best_mrr = 0.0, 0.0,0.0,0,-1 |
|
tr_num=0 |
|
special_token_list = all_special_token |
|
special_token_id_list = tokenizer.convert_tokens_to_ids(special_token_list) |
|
while True: |
|
|
|
train_dataloader=np.random.choice(train_dataloaders, 1, p=probs)[0] |
|
|
|
step+=1 |
|
batch=next(train_dataloader) |
|
|
|
model.train() |
|
|
|
code_inputs = batch[0].to(args.device) |
|
code_transformations_ids = code_inputs.clone() |
|
nl_inputs = batch[1].to(args.device) |
|
nl_transformations_ids= nl_inputs.clone() |
|
|
|
if step%4 == 0: |
|
code_transformations_ids[:, 3:], _ = mask_tokens(code_inputs.clone()[:, 3:] ,tokenizer,args.mlm_probability) |
|
nl_transformations_ids[:, 3:], _ = mask_tokens(nl_inputs.clone()[:, 3:] ,tokenizer,args.mlm_probability) |
|
elif step%4 == 1: |
|
code_types = code_inputs.clone() |
|
code_transformations_ids[:, 3:], _ = replace_with_type_tokens(code_inputs.clone()[:, 3:] ,code_types.clone()[:, 3:],tokenizer,args.mlm_probability) |
|
elif step%4 == 2: |
|
random.seed( step) |
|
choice_token_id = choice(special_token_id_list) |
|
code_transformations_ids[:, 3:], _ = replace_special_token_with_type_tokens(code_inputs.clone()[:, 3:], choice_token_id, tokenizer,args.mlm_probability) |
|
elif step%4 == 3: |
|
random.seed( step) |
|
choice_token_id = choice(special_token_id_list) |
|
code_transformations_ids[:, 3:], _ = replace_special_token_with_mask(code_inputs.clone()[:, 3:], choice_token_id, tokenizer,args.mlm_probability) |
|
|
|
|
|
tr_num+=1 |
|
inter_output, inter_target, _, _= model(source_code_q=code_inputs, source_code_k=code_transformations_ids, |
|
nl_q=nl_inputs , nl_k=nl_transformations_ids ) |
|
|
|
|
|
|
|
|
|
loss = loss_fct(20*inter_output, inter_target) |
|
|
|
if args.n_gpu > 1: |
|
loss = loss.mean() |
|
|
|
|
|
if args.gradient_accumulation_steps > 1: |
|
loss = loss / args.gradient_accumulation_steps |
|
|
|
if args.fp16: |
|
with amp.scale_loss(loss, optimizer) as scaled_loss: |
|
scaled_loss.backward() |
|
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) |
|
else: |
|
loss.backward() |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) |
|
|
|
tr_loss += loss.item() |
|
if (step+1)% args.eval_frequency==0: |
|
logger.info("step {} loss {}".format(step+1,round(tr_loss/tr_num,5))) |
|
tr_loss=0 |
|
tr_num=0 |
|
|
|
if (step + 1) % args.gradient_accumulation_steps == 0: |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
scheduler.step() |
|
global_step += 1 |
|
output_flag=True |
|
avg_loss=round((tr_loss - logging_loss) /(global_step- tr_nb),6) |
|
|
|
if global_step %100 == 0: |
|
logger.info(" global steps (step*gradient_accumulation_steps ): %s loss: %s", global_step, round(avg_loss,6)) |
|
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: |
|
logging_loss = tr_loss |
|
tr_nb=global_step |
|
|
|
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: |
|
checkpoint_prefix = 'checkpoint-mrr' |
|
|
|
results = evaluate(args, model, tokenizer,args.eval_data_file, pool, eval_when_training=True) |
|
|
|
|
|
|
|
logger.info(" %s = %s", 'eval_mrr', round(results['eval_mrr'],6)) |
|
|
|
if results['eval_mrr']>best_mrr: |
|
best_mrr=results['eval_mrr'] |
|
logger.info(" "+"*"*20) |
|
logger.info(" Best mrr:%s",round(best_mrr,4)) |
|
logger.info(" "+"*"*20) |
|
|
|
output_dir = os.path.join(args.output_dir, '{}'.format('checkpoint-best-mrr')) |
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
model_to_save = model.module if hasattr(model,'module') else model |
|
output_dir = os.path.join(output_dir, '{}'.format('model.bin')) |
|
torch.save(model_to_save.state_dict(), output_dir) |
|
logger.info("Saving model checkpoint to %s", output_dir) |
|
|
|
|
|
|
|
|
|
output_dir = os.path.join(args.output_dir, '{}-{}-{}'.format(checkpoint_prefix, global_step,round(results['eval_mrr'],6))) |
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
model_to_save = model.module.code_encoder_q if hasattr(model,'module') else model.code_encoder_q |
|
model_to_save.save_pretrained(output_dir) |
|
torch.save(args, os.path.join(output_dir, 'training_args.bin')) |
|
logger.info("Saving model checkpoint to %s", output_dir) |
|
|
|
|
|
|
|
last_output_dir = os.path.join(args.output_dir, 'checkpoint-last') |
|
if not os.path.exists(last_output_dir): |
|
os.makedirs(last_output_dir) |
|
model_to_save.save_pretrained(last_output_dir) |
|
idx_file = os.path.join(last_output_dir, 'idx_file.txt') |
|
with open(idx_file, 'w', encoding='utf-8') as idxf: |
|
idxf.write(str(0) + '\n') |
|
|
|
torch.save(optimizer.state_dict(), os.path.join(last_output_dir, "optimizer.pt")) |
|
torch.save(scheduler.state_dict(), os.path.join(last_output_dir, "scheduler.pt")) |
|
logger.info("Saving optimizer and scheduler states to %s", last_output_dir) |
|
|
|
step_file = os.path.join(last_output_dir, 'step_file.txt') |
|
with open(step_file, 'w', encoding='utf-8') as stepf: |
|
stepf.write(str(global_step) + '\n') |
|
|
|
if args.max_steps > 0 and global_step > args.max_steps: |
|
break |
|
|
|
|
|
def evaluate(args, model, tokenizer,file_name,pool, eval_when_training=False): |
|
|
|
dataset_class = TextDataset_unixcoder |
|
|
|
|
|
query_dataset = dataset_class(tokenizer, args, file_name, pool) |
|
query_sampler = SequentialSampler(query_dataset) |
|
query_dataloader = DataLoader(query_dataset, sampler=query_sampler, batch_size=args.eval_batch_size,num_workers=4) |
|
|
|
code_dataset = dataset_class(tokenizer, args, args.codebase_file, pool) |
|
code_sampler = SequentialSampler(code_dataset) |
|
code_dataloader = DataLoader(code_dataset, sampler=code_sampler, batch_size=args.eval_batch_size,num_workers=4) |
|
|
|
|
|
if args.n_gpu > 1 and eval_when_training is False: |
|
model = torch.nn.DataParallel(model) |
|
|
|
|
|
logger.info("***** Running evaluation on %s *****"%args.lang) |
|
logger.info(" Num queries = %d", len(query_dataset)) |
|
logger.info(" Num codes = %d", len(code_dataset)) |
|
logger.info(" Batch size = %d", args.eval_batch_size) |
|
|
|
|
|
model.eval() |
|
model_eval = model.module if hasattr(model,'module') else model |
|
code_vecs=[] |
|
nl_vecs=[] |
|
for batch in query_dataloader: |
|
nl_inputs = batch[-1].to(args.device) |
|
with torch.no_grad(): |
|
if args.model_type == "base" : |
|
nl_vec = model(nl_inputs=nl_inputs) |
|
|
|
elif args.model_type in ["cocosoda" ,"no_aug_cocosoda", "multi-loss-cocosoda"]: |
|
outputs = model_eval.nl_encoder_q(nl_inputs, attention_mask=nl_inputs.ne(1)) |
|
if args.agg_way == "avg": |
|
outputs = outputs [0] |
|
nl_vec = (outputs*nl_inputs.ne(1)[:,:,None]).sum(1)/nl_inputs.ne(1).sum(-1)[:,None] |
|
elif args.agg_way == "cls_pooler": |
|
nl_vec =outputs [1] |
|
elif args.agg_way == "avg_cls_pooler": |
|
nl_vec =outputs [1] + (outputs[0]*nl_inputs.ne(1)[:,:,None]).sum(1)/nl_inputs.ne(1).sum(-1)[:,None] |
|
nl_vec = torch.nn.functional.normalize( nl_vec, p=2, dim=1) |
|
if args.do_whitening: |
|
nl_vec=whitening_torch_final(nl_vec) |
|
|
|
|
|
|
|
nl_vecs.append(nl_vec.cpu().numpy()) |
|
|
|
for batch in code_dataloader: |
|
with torch.no_grad(): |
|
code_inputs = batch[0].to(args.device) |
|
if args.model_type == "base" : |
|
code_vec = model(code_inputs=code_inputs) |
|
elif args.model_type in ["cocosoda" ,"no_aug_cocosoda", "multi-loss-cocosoda"]: |
|
|
|
outputs = model_eval.code_encoder_q(code_inputs, attention_mask=code_inputs.ne(1)) |
|
if args.agg_way == "avg": |
|
outputs = outputs [0] |
|
code_vec = (outputs*code_inputs.ne(1)[:,:,None]).sum(1)/code_inputs.ne(1).sum(-1)[:,None] |
|
elif args.agg_way == "cls_pooler": |
|
code_vec=outputs [1] |
|
elif args.agg_way == "avg_cls_pooler": |
|
code_vec=outputs [1] + (outputs[0]*code_inputs.ne(1)[:,:,None]).sum(1)/code_inputs.ne(1).sum(-1)[:,None] |
|
code_vec = torch.nn.functional.normalize(code_vec, p=2, dim=1) |
|
if args.do_whitening: |
|
code_vec=whitening_torch_final(code_vec) |
|
|
|
|
|
|
|
code_vecs.append(code_vec.cpu().numpy()) |
|
|
|
model.train() |
|
code_vecs=np.concatenate(code_vecs,0) |
|
nl_vecs=np.concatenate(nl_vecs,0) |
|
|
|
scores=np.matmul(nl_vecs,code_vecs.T) |
|
|
|
sort_ids=np.argsort(scores, axis=-1, kind='quicksort', order=None)[:,::-1] |
|
|
|
nl_urls=[] |
|
code_urls=[] |
|
for example in query_dataset.examples: |
|
nl_urls.append(example.url) |
|
|
|
for example in code_dataset.examples: |
|
code_urls.append(example.url) |
|
|
|
ranks=[] |
|
for url, sort_id in zip(nl_urls,sort_ids): |
|
rank=0 |
|
find=False |
|
for idx in sort_id[:1000]: |
|
if find is False: |
|
rank+=1 |
|
if code_urls[idx]==url: |
|
find=True |
|
if find: |
|
ranks.append(1/rank) |
|
else: |
|
ranks.append(0) |
|
if args.save_evaluation_reuslt: |
|
evaluation_result = {"nl_urls":nl_urls, "code_urls":code_urls,"sort_ids":sort_ids[:,:10],"ranks":ranks} |
|
save_pickle_data(args.save_evaluation_reuslt_dir, "evaluation_result.pkl",evaluation_result) |
|
result = cal_r1_r5_r10(ranks) |
|
result["eval_mrr"] = round(float(np.mean(ranks)),3) |
|
return result |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument('--data_aug_type',default="replace_type",choices=["replace_type", "random_mask" ,"other"], help="the ways of soda",required=False) |
|
parser.add_argument('--aug_type_way',default="random_replace_type",choices=["random_replace_type", "replace_special_type" ,"replace_special_type_with_mask"], help="the ways of soda",required=False) |
|
parser.add_argument('--print_align_unif_loss', action='store_true', help='print_align_unif_loss', required=False) |
|
parser.add_argument('--do_ineer_loss', action='store_true', help='print_align_unif_loss', required=False) |
|
parser.add_argument('--only_save_the_nl_code_vec', action='store_true', help='print_align_unif_loss', required=False) |
|
parser.add_argument('--do_zero_short', action='store_true', help='print_align_unif_loss', required=False) |
|
parser.add_argument('--agg_way',default="cls_pooler",choices=["avg", "cls_pooler","avg_cls_pooler" ], help="base is codebert/graphcoder/unixcoder",required=False) |
|
parser.add_argument('--weight_decay',default=0.01, type=float,required=False) |
|
parser.add_argument('--do_single_lang_continue_pre_train', action='store_true', help='do_single_lang_continue_pre_train', required=False) |
|
parser.add_argument('--save_evaluation_reuslt', action='store_true', help='save_evaluation_reuslt', required=False) |
|
parser.add_argument('--save_evaluation_reuslt_dir', type=str, help='save_evaluation_reuslt', required=False) |
|
parser.add_argument('--epoch', type=int, default=50, |
|
help="random seed for initialization") |
|
|
|
parser.add_argument('--fp16', action='store_true', |
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") |
|
parser.add_argument("--local_rank", type=int, default=-1, |
|
help="For distributed training: local_rank") |
|
parser.add_argument("--loaded_model_filename", type=str, required=False, |
|
help="loaded_model_filename") |
|
parser.add_argument("--loaded_codebert_model_filename", type=str, required=False, |
|
help="loaded_model_filename") |
|
parser.add_argument('--do_multi_lang_continue_pre_train', action='store_true', help='do_multi_lang_continue_pre_train', required=False) |
|
parser.add_argument("--couninue_pre_train_data_files", default=["dataset/ruby/train.jsonl", "dataset/java/train.jsonl",], type=str, nargs='+', required=False, |
|
help="The input training data files (some json files).") |
|
|
|
|
|
|
|
parser.add_argument('--do_continue_pre_trained', action='store_true', help='debug mode', required=False) |
|
parser.add_argument('--do_fine_tune', action='store_true', help='debug mode', required=False) |
|
parser.add_argument('--do_whitening', action='store_true', help='do_whitening https://github.com/Jun-jie-Huang/WhiteningBERT', required=False) |
|
parser.add_argument("--time_score", default=1, type=int,help="cosine value * time_score") |
|
parser.add_argument("--max_steps", default=100, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.") |
|
parser.add_argument("--num_warmup_steps", default=0, type=int, help="num_warmup_steps") |
|
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, |
|
help="Number of updates steps to accumulate before performing a backward/update pass.") |
|
parser.add_argument('--logging_steps', type=int, default=50, |
|
help="Log every X updates steps.") |
|
parser.add_argument('--save_steps', type=int, default=50, |
|
help="Save checkpoint every X updates steps.") |
|
|
|
parser.add_argument('--moco_type',default="encoder_queue",choices=["encoder_queue","encoder_momentum_encoder_queue" ], help="base is codebert/graphcoder/unixcoder",required=False) |
|
|
|
|
|
|
|
parser.add_argument('--use_best_mrr_model', action='store_true', help='cosine_space', required=False) |
|
parser.add_argument('--debug', action='store_true', help='debug mode', required=False) |
|
parser.add_argument('--n_debug_samples', type=int, default=100, required=False) |
|
parser.add_argument("--max_codeblock_num", default=10, type=int, |
|
help="Optional NL input sequence length after tokenization.") |
|
parser.add_argument('--hidden_size', type=int, default=768, required=False) |
|
parser.add_argument("--eval_frequency", default=1, type=int, required=False) |
|
parser.add_argument("--mlm_probability", default=0.1, type=float, required=False) |
|
|
|
|
|
parser.add_argument('--do_avg', action='store_true', help='avrage hidden status', required=False) |
|
parser.add_argument('--model_type',default="base",choices=["base", "cocosoda","multi-loss-cocosoda","no_aug_cocosoda"], help="base is codebert/graphcoder/unixcoder",required=False) |
|
|
|
|
|
parser.add_argument('--moco_dim', default=768, type=int, |
|
help='feature dimension (default: 768)') |
|
parser.add_argument('--moco_k', default=32, type=int, |
|
help='queue size; number of negative keys (default: 65536), which is divided by 32, etc.') |
|
parser.add_argument('--moco_m', default=0.999, type=float, |
|
help='moco momentum of updating key encoder (default: 0.999)') |
|
parser.add_argument('--moco_t', default=0.07, type=float, |
|
help='softmax temperature (default: 0.07)') |
|
|
|
|
|
parser.add_argument('--mlp', action='store_true',help='use mlp head') |
|
|
|
|
|
parser.add_argument("--train_data_file", default="dataset/java/train.jsonl", type=str, required=False, |
|
help="The input training data file (a json file).") |
|
parser.add_argument("--output_dir", default="saved_models/pre-train", type=str, required=False, |
|
help="The output directory where the model predictions and checkpoints will be written.") |
|
parser.add_argument("--eval_data_file", default="dataset/java/valid.jsonl", type=str, |
|
help="An optional input evaluation data file to evaluate the MRR(a jsonl file).") |
|
parser.add_argument("--test_data_file", default="dataset/java/test.jsonl", type=str, |
|
help="An optional input test data file to test the MRR(a josnl file).") |
|
parser.add_argument("--codebase_file", default="dataset/java/codebase.jsonl", type=str, |
|
help="An optional input test data file to codebase (a jsonl file).") |
|
|
|
parser.add_argument("--lang", default="java", type=str, |
|
help="language.") |
|
|
|
parser.add_argument("--model_name_or_path", default="DeepSoftwareAnalytics/CoCoSoDa", type=str, |
|
help="The model checkpoint for weights initialization.") |
|
parser.add_argument("--config_name", default="DeepSoftwareAnalytics/CoCoSoDa", type=str, |
|
help="Optional pretrained config name or path if not the same as model_name_or_path") |
|
parser.add_argument("--tokenizer_name", default="DeepSoftwareAnalytics/CoCoSoDa", type=str, |
|
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path") |
|
|
|
parser.add_argument("--nl_length", default=50, type=int, |
|
help="Optional NL input sequence length after tokenization.") |
|
parser.add_argument("--code_length", default=100, type=int, |
|
help="Optional Code input sequence length after tokenization.") |
|
parser.add_argument("--data_flow_length", default=0, type=int, |
|
help="Optional Data Flow input sequence length after tokenization.",required=False) |
|
|
|
parser.add_argument("--do_train", action='store_true', |
|
help="Whether to run training.") |
|
parser.add_argument("--do_eval", action='store_true', |
|
help="Whether to run eval on the dev set.") |
|
parser.add_argument("--do_test", action='store_true', |
|
help="Whether to run eval on the test set.") |
|
|
|
parser.add_argument("--train_batch_size", default=4, type=int, |
|
help="Batch size for training.") |
|
parser.add_argument("--eval_batch_size", default=4, type=int, |
|
help="Batch size for evaluation.") |
|
parser.add_argument("--learning_rate", default=2e-5, type=float, |
|
help="The initial learning rate for Adam.") |
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, |
|
help="Max gradient norm.") |
|
parser.add_argument("--num_train_epochs", default=4, type=int, |
|
help="Total number of training epochs to perform.") |
|
|
|
parser.add_argument('--seed', type=int, default=3407, |
|
help="random seed for initialization") |
|
|
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
def create_model(args,model,tokenizer, config=None): |
|
|
|
|
|
if args.data_aug_type in ["replace_type" , "other"] and not args.only_save_the_nl_code_vec: |
|
special_tokens_dict = {'additional_special_tokens': all_special_token} |
|
logger.info(" new token %s"%(str(special_tokens_dict))) |
|
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
if (args.loaded_model_filename) and ("pytorch_model.bin" in args.loaded_model_filename): |
|
logger.info("reload pytorch model from {}".format(args.loaded_model_filename)) |
|
model.load_state_dict(torch.load(args.loaded_model_filename),strict=False) |
|
|
|
if args.model_type == "base" : |
|
model= Model(model) |
|
elif args.model_type == "multi-loss-cocosoda": |
|
model= Multi_Loss_CoCoSoDa(model,args, args.mlp) |
|
if (args.loaded_model_filename) and ("pytorch_model.bin" not in args.loaded_model_filename) : |
|
logger.info("reload model from {}".format(args.loaded_model_filename)) |
|
model.load_state_dict(torch.load(args.loaded_model_filename)) |
|
|
|
|
|
if (args.loaded_codebert_model_filename) : |
|
logger.info("reload pytorch model from {}".format(args.loaded_codebert_model_filename)) |
|
model.load_state_dict(torch.load(args.loaded_codebert_model_filename),strict=False) |
|
logger.info(model.model_parameters()) |
|
|
|
|
|
return model |
|
|
|
def main(): |
|
|
|
args = parse_args() |
|
|
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', |
|
datefmt='%m/%d/%Y %H:%M:%S',level=logging.INFO ) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
args.n_gpu = torch.cuda.device_count() |
|
args.device = device |
|
logger.info("device: %s, n_gpu: %s",device, args.n_gpu) |
|
|
|
pool = multiprocessing.Pool(cpu_cont) |
|
|
|
|
|
set_seed(args.seed) |
|
|
|
|
|
|
|
if "codet5" in args.model_name_or_path: |
|
config = T5Config.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) |
|
tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name) |
|
model = T5ForConditionalGeneration.from_pretrained(args.model_name_or_path) |
|
model = model.encoder |
|
else: |
|
config = RobertaConfig.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) |
|
tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name) |
|
model = RobertaModel.from_pretrained(args.model_name_or_path) |
|
model=create_model(args,model,tokenizer,config) |
|
|
|
logger.info("Training/evaluation parameters %s", args) |
|
args.start_step = 0 |
|
|
|
model.to(args.device) |
|
|
|
|
|
if args.do_multi_lang_continue_pre_train: |
|
multi_lang_continue_pre_train(args, model, tokenizer, pool) |
|
output_tokenizer_dir = os.path.join(args.output_dir,"tokenzier") |
|
if not os.path.exists(output_tokenizer_dir): |
|
os.makedirs( output_tokenizer_dir) |
|
tokenizer.save_pretrained( output_tokenizer_dir) |
|
if args.do_train: |
|
train(args, model, tokenizer, pool) |
|
|
|
|
|
|
|
results = {} |
|
|
|
if args.do_eval: |
|
checkpoint_prefix = 'checkpoint-best-mrr/model.bin' |
|
output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix)) |
|
if (not args.only_save_the_nl_code_vec) and (not args.do_zero_short) : |
|
model.load_state_dict(torch.load(output_dir),strict=False) |
|
model.to(args.device) |
|
result=evaluate(args, model, tokenizer,args.eval_data_file, pool) |
|
logger.info("***** Eval valid results *****") |
|
for key in sorted(result.keys()): |
|
logger.info(" %s = %s", key, str(round(result[key],4))) |
|
|
|
if args.do_test: |
|
|
|
logger.info("runnning test") |
|
checkpoint_prefix = 'checkpoint-best-mrr/model.bin' |
|
output_dir = os.path.join(args.output_dir, '{}'.format(checkpoint_prefix)) |
|
if (not args.only_save_the_nl_code_vec) and (not args.do_zero_short) : |
|
model.load_state_dict(torch.load(output_dir),strict=False) |
|
model.to(args.device) |
|
result=evaluate(args, model, tokenizer,args.test_data_file, pool) |
|
logger.info("***** Eval test results *****") |
|
for key in sorted(result.keys()): |
|
logger.info(" %s = %s", key, str(round(result[key],4))) |
|
save_json_data(args.output_dir, "result.jsonl", result) |
|
return results |
|
|
|
|
|
def gen_vector(): |
|
|
|
args = parse_args() |
|
|
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', |
|
datefmt='%m/%d/%Y %H:%M:%S',level=logging.INFO ) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
args.n_gpu = torch.cuda.device_count() |
|
args.device = device |
|
logger.info("device: %s, n_gpu: %s",device, args.n_gpu) |
|
|
|
pool = multiprocessing.Pool(cpu_cont) |
|
|
|
|
|
set_seed(args.seed) |
|
if "codet5" in args.model_name_or_path: |
|
config = T5Config.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) |
|
tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name) |
|
model = T5ForConditionalGeneration.from_pretrained(args.model_name_or_path) |
|
model = model.encoder |
|
else: |
|
config = RobertaConfig.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) |
|
tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name) |
|
model = RobertaModel.from_pretrained(args.model_name_or_path) |
|
model=create_model(args,model,tokenizer,config) |
|
|
|
if args.data_aug_type == "replace_type" : |
|
train_dataset=TextDataset(tokenizer, args, args.train_data_file, pool) |
|
else: |
|
|
|
train_dataset=TextDataset_unixcoder(tokenizer, args, args.train_data_file, pool) |
|
|
|
|
|
train_sampler = SequentialSampler(train_dataset) |
|
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size,num_workers=4,drop_last=False) |
|
|
|
for idx in range(args.num_train_epochs): |
|
output_dir_epoch = os.path.join(args.output_dir, '{}'.format(idx)) |
|
output_dir_epoch = os.path.join(output_dir_epoch, '{}'.format('model.bin')) |
|
|
|
model.load_state_dict(torch.load(output_dir_epoch),strict=False) |
|
model.to(args.device) |
|
|
|
model_eval = model.module if hasattr(model,'module') else model |
|
|
|
all_nl_vec = [] |
|
all_code_vec = [] |
|
|
|
for step,batch in enumerate(train_dataloader): |
|
code_inputs = batch[0].to(args.device) |
|
nl_inputs = batch[1].to(args.device) |
|
|
|
nl_outputs = model_eval.nl_encoder_q(nl_inputs, attention_mask=nl_inputs.ne(1)) |
|
nl_vec =nl_outputs [1] |
|
code_outputs = model_eval.code_encoder_q(code_inputs, attention_mask=code_inputs.ne(1)) |
|
code_vec =code_outputs [1] |
|
all_nl_vec.append(nl_vec.detach().cpu().numpy()) |
|
all_code_vec.append(code_vec.detach().cpu().numpy()) |
|
all_nl_vec = np.concatenate(all_nl_vec, axis=0) |
|
all_code_vec = np.concatenate(all_code_vec, axis=0) |
|
print(all_nl_vec.shape, all_code_vec.shape) |
|
np.save("/home/yiming/cocosoda/CoCoSoDa/saved_models/fine_tune/ruby/" + str(idx) + "/all_nl_vec.npy", all_nl_vec) |
|
np.save("/home/yiming/cocosoda/CoCoSoDa/saved_models/fine_tune/ruby/" + str(idx) + "/all_code_vec.npy", all_code_vec) |
|
idxs = [i for i in range(len(all_nl_vec))] |
|
for epoch in range(1,2): |
|
idxs_dir_path = "/home/yiming/cocosoda/CoCoSoDa/saved_models/codesearch_contrastive_learning/Model/Epoch_" + str(epoch) |
|
if os.path.exists(idxs_dir_path): |
|
pass |
|
else: |
|
os.mkdir(idxs_dir_path) |
|
idxs_path = idxs_dir_path + "/index.json" |
|
json_file = open(idxs_path, mode='w') |
|
json.dump(idxs, json_file, indent=4) |
|
|
|
if args.data_aug_type == "replace_type" : |
|
test_dataset=TextDataset(tokenizer, args, args.test_data_file, pool) |
|
else: |
|
|
|
test_dataset=TextDataset_unixcoder(tokenizer, args, args.test_data_file, pool) |
|
|
|
|
|
test_sampler = SequentialSampler(test_dataset) |
|
test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=args.train_batch_size,num_workers=4,drop_last=False) |
|
|
|
for idx in range(args.num_train_epochs): |
|
output_dir_epoch = os.path.join(args.output_dir, '{}'.format(idx)) |
|
output_dir_epoch = os.path.join(output_dir_epoch, '{}'.format('model.bin')) |
|
|
|
model.load_state_dict(torch.load(output_dir_epoch),strict=False) |
|
model.to(args.device) |
|
|
|
model_eval = model.module if hasattr(model,'module') else model |
|
|
|
all_nl_vec = [] |
|
all_code_vec = [] |
|
|
|
for step,batch in enumerate(test_dataloader): |
|
code_inputs = batch[0].to(args.device) |
|
nl_inputs = batch[1].to(args.device) |
|
|
|
nl_outputs = model_eval.nl_encoder_q(nl_inputs, attention_mask=nl_inputs.ne(1)) |
|
nl_vec =nl_outputs [1] |
|
code_outputs = model_eval.code_encoder_q(code_inputs, attention_mask=code_inputs.ne(1)) |
|
code_vec =code_outputs [1] |
|
all_nl_vec.append(nl_vec.detach().cpu().numpy()) |
|
all_code_vec.append(code_vec.detach().cpu().numpy()) |
|
all_nl_vec = np.concatenate(all_nl_vec, axis=0) |
|
all_code_vec = np.concatenate(all_code_vec, axis=0) |
|
print(all_nl_vec.shape, all_code_vec.shape) |
|
np.save("/home/yiming/cocosoda/CoCoSoDa/saved_models/fine_tune/ruby/" + str(idx) + "/test_all_nl_vec.npy", all_nl_vec) |
|
np.save("/home/yiming/cocosoda/CoCoSoDa/saved_models/fine_tune/ruby/" + str(idx) + "/test_all_code_vec.npy", all_code_vec) |
|
|
|
def gen_label(): |
|
|
|
args = parse_args() |
|
|
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', |
|
datefmt='%m/%d/%Y %H:%M:%S',level=logging.INFO ) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
args.n_gpu = torch.cuda.device_count() |
|
args.device = device |
|
logger.info("device: %s, n_gpu: %s",device, args.n_gpu) |
|
|
|
pool = multiprocessing.Pool(cpu_cont) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
code_list = [] |
|
docstring_list = [] |
|
|
|
with open(args.train_data_file, 'rt') as gz_file: |
|
for line in gz_file: |
|
data = json.loads(line) |
|
code = data['code'] |
|
docstring = data['docstring'] |
|
|
|
|
|
code_list.append(code) |
|
docstring_list.append(docstring) |
|
|
|
print(len(code_list)) |
|
print(len(docstring_list)) |
|
|
|
|
|
|
|
|
|
code_output_file = '/home/yiming/cocosoda/CoCoSoDa/saved_models/fine_tune/ruby/code_list.json' |
|
docstring_output_file = '/home/yiming/cocosoda/CoCoSoDa/saved_models/fine_tune/ruby/docstring_list.json' |
|
|
|
|
|
with open(code_output_file, 'w') as file: |
|
json.dump(code_list, file) |
|
|
|
|
|
with open(docstring_output_file, 'w') as file: |
|
json.dump(docstring_list, file) |
|
|
|
code_list = [] |
|
docstring_list = [] |
|
|
|
with open(args.test_data_file, 'rt') as gz_file: |
|
for line in gz_file: |
|
data = json.loads(line) |
|
code = data['code'] |
|
docstring = data['docstring'] |
|
|
|
|
|
code_list.append(code) |
|
docstring_list.append(docstring) |
|
|
|
print(len(code_list)) |
|
print(len(docstring_list)) |
|
|
|
code_output_file = '/home/yiming/cocosoda/CoCoSoDa/saved_models/fine_tune/ruby/test_code_list.json' |
|
docstring_output_file = '/home/yiming/cocosoda/CoCoSoDa/saved_models/fine_tune/ruby/test_docstring_list.json' |
|
|
|
|
|
with open(code_output_file, 'w') as file: |
|
json.dump(code_list, file) |
|
|
|
|
|
with open(docstring_output_file, 'w') as file: |
|
json.dump(docstring_list, file) |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
gen_label() |
|
|
|
|