import time import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel from tencentpretrain.model_loader import load_model from tencentpretrain.model_saver import save_model from tencentpretrain.model_builder import build_model from tencentpretrain.utils.logging import init_logger from tencentpretrain.utils.optimizers import * from tencentpretrain.utils import * from tencentpretrain.utils.seed import set_seed from tqdm import tqdm def train_and_validate(args): set_seed(args.seed) # Load vocabulary. if args.data_processor == "mt": args.tgt_tokenizer = str2tokenizer[args.tgt_tokenizer](args, is_src=False) args.tgt_vocab = args.tgt_tokenizer.vocab args.tokenizer = str2tokenizer[args.tokenizer](args) args.vocab = args.tokenizer.vocab # Build model. model_for_training = build_model(args) # Load or initialize parameters. if args.pretrained_model_path is not None: # Initialize with pretrained model. model_for_training = load_model(model_for_training, args.pretrained_model_path) else: # Initialize with normal distribution. if args.deep_init: scaled_factor = 1 / math.sqrt(2.0 * args.layers_num) for n, p in list(model_for_training.named_parameters()): if "gamma" not in n and "beta" not in n: if "linear_2.weight" in n or "final_linear.weight" in n: p.data.normal_(0, 0.02 * scaled_factor) elif "linear_2.bias" in n or "final_linear.bias" in n: p.data.zero_() else: p.data.normal_(0, 0.02) else: for n, p in list(model_for_training.named_parameters()): if "gamma" not in n and "beta" not in n: p.data.normal_(0, 0.02) if args.vqgan_model_path is not None: from tencentpretrain.utils.image_tokenizer import build_vqgan_model model_for_dataloader = build_vqgan_model(args) else: model_for_dataloader = None if args.deepspeed: worker(args.local_rank, None, args, model_for_training, model_for_dataloader) elif args.dist_train: # Multiprocessing distributed mode. mp.spawn(worker, nprocs=args.ranks_num, args=(args.gpu_ranks, args, model_for_training, model_for_dataloader), daemon=False) elif args.single_gpu: # Single GPU mode. worker(args.gpu_id, None, args, model_for_training, model_for_dataloader) else: # CPU mode. worker(None, None, args, model_for_training, model_for_dataloader) class Trainer(object): def __init__(self, args): self.current_step = 1 self.total_steps = args.total_steps self.accumulation_steps = args.accumulation_steps self.report_steps = args.report_steps self.save_checkpoint_steps = args.save_checkpoint_steps self.output_model_path = args.output_model_path self.start_time = time.time() self.total_loss = 0.0 self.best_loss = float("inf") self.dist_train = args.dist_train self.batch_size = args.batch_size self.world_size = args.world_size self.logger = args.logger def forward_propagation(self, batch, model): raise NotImplementedError def report_and_reset_stats(self): raise NotImplementedError def train(self, args, gpu_id, rank, loader, model, optimizer, scheduler): model.train() loader_iter = iter(loader) while True: # for step in tqdm(range(self.current_step, self.total_steps + 1)): if self.current_step == self.total_steps + 1: break batch = list(next(loader_iter)) self.seq_length = batch[0].size(1) if gpu_id is not None: for i in range(len(batch)): if torch.is_tensor(batch[i]): batch[i] = batch[i].cuda(gpu_id) # print(batch[0].shape, gpu_id) loss = self.forward_propagation(batch, model) if args.deepspeed: model.backward(loss) else: if args.fp16: with args.amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if self.current_step % self.accumulation_steps == 0: if args.deepspeed: model.step() else: optimizer.step() scheduler.step() model.zero_grad() if self.current_step % self.report_steps == 0 and \ (not self.dist_train or (self.dist_train and rank == 0)): self.report_and_reset_stats() self.start_time = time.time() if args.deepspeed: if self.current_step % self.save_checkpoint_steps == 0: model.save_checkpoint(self.output_model_path, str(self.current_step)) if loss.item() < self.best_loss: self.best_loss = loss.item() # model.save_checkpoint(self.output_model_path, "-best") else: if self.current_step % self.save_checkpoint_steps == 0 and \ (not self.dist_train or (self.dist_train and rank == 0)): save_model(model, self.output_model_path + "-" + str(self.current_step)) if loss.item() < self.best_loss: self.best_loss = loss.item() # print("save best model! loss:" + str(self.best_loss)) # save_model(model, self.output_model_path + "-best") self.current_step += 1 class MlmTrainer(Trainer): def __init__(self, args): super(MlmTrainer, self).__init__(args) self.total_correct = 0.0 self.total_denominator = 0.0 def forward_propagation(self, batch, model): src, tgt, seg = batch loss_info = model(src, tgt, seg) loss, correct, denominator = loss_info self.total_loss += loss.item() self.total_correct += correct.item() self.total_denominator += denominator.item() loss = loss / self.accumulation_steps return loss def report_and_reset_stats(self): done_tokens = self.batch_size * self.seq_length * self.report_steps if self.dist_train: done_tokens *= self.world_size self.logger.info("| {:8d}/{:8d} steps" "| {:8.2f} tokens/s" "| loss {:7.2f}" "| acc: {:3.3f}".format( self.current_step, self.total_steps, done_tokens / (time.time() - self.start_time), self.total_loss / self.report_steps, self.total_correct / self.total_denominator)) self.total_loss = 0.0 self.total_correct = 0.0 self.total_denominator = 0.0 class BertTrainer(Trainer): def __init__(self, args): super(BertTrainer, self).__init__(args) self.total_loss_sp = 0.0 self.total_correct_sp = 0.0 self.total_instances = 0.0 self.total_loss_mlm = 0.0 self.total_correct_mlm = 0.0 self.total_denominator = 0.0 def forward_propagation(self, batch, model): src, tgt_mlm, tgt_sp, seg = batch tgt = {"mlm": tgt_mlm, "sp": tgt_sp} loss_info = model(src, tgt, seg) loss_mlm, correct_mlm, denominator = loss_info["mlm"] loss_sp, correct_sp = loss_info["sp"] loss = loss_mlm + loss_sp self.total_loss += loss.item() self.total_loss_mlm += loss_mlm.item() self.total_loss_sp += loss_sp.item() self.total_correct_mlm += correct_mlm.item() self.total_correct_sp += correct_sp.item() self.total_denominator += denominator.item() self.total_instances += src.size(0) loss = loss / self.accumulation_steps return loss def report_and_reset_stats(self): done_tokens = self.batch_size * self.seq_length * self.report_steps if self.dist_train: done_tokens *= self.world_size self.logger.info("| {:8d}/{:8d} steps" "| {:8.2f} tokens/s" "| loss {:7.2f}" "| loss_mlm: {:3.3f}" "| loss_sp: {:3.3f}" "| acc_mlm: {:3.3f}" "| acc_sp: {:3.3f}".format( self.current_step, self.total_steps, done_tokens / (time.time() - self.start_time), self.total_loss / self.report_steps, self.total_loss_mlm / self.report_steps, self.total_loss_sp / self.report_steps, self.total_correct_mlm / self.total_denominator, self.total_correct_sp / self.total_instances)) self.total_loss, self.total_loss_mlm, self.total_loss_sp = 0.0, 0.0, 0.0 self.total_correct_mlm, self.total_denominator = 0.0, 0.0 self.total_correct_sp, self.total_instances = 0.0, 0.0 class AlbertTrainer(BertTrainer): pass class LmTrainer(MlmTrainer): pass class BilmTrainer(Trainer): def __init__(self, args): super(BilmTrainer, self).__init__(args) self.total_loss_forward, self.total_loss_backward = 0.0, 0.0 self.total_correct_forward, self.total_correct_backward = 0.0, 0.0 self.total_denominator = 0.0 def forward_propagation(self, batch, model): src, tgt_forward, tgt_backward, seg = batch loss_info = model(src, (tgt_forward, tgt_backward), seg) loss_forward, loss_backward, correct_forward, correct_backward, denominator = loss_info loss = loss_forward + loss_backward self.total_loss += loss.item() self.total_loss_forward += loss_forward.item() self.total_loss_backward += loss_backward.item() self.total_correct_forward += correct_forward.item() self.total_correct_backward += correct_backward.item() self.total_denominator += denominator.item() loss = loss / self.accumulation_steps return loss def report_and_reset_stats(self): done_tokens = self.batch_size * self.seq_length * self.report_steps if self.dist_train: done_tokens *= self.world_size self.logger.info("| {:8d}/{:8d} steps" "| {:8.2f} tokens/s" "| loss {:7.2f}" "| loss_forward {:3.3f}" "| loss_backward {:3.3f}" "| acc_forward: {:3.3f}" "| acc_backward: {:3.3f}".format( self.current_step, self.total_steps, done_tokens / (time.time() - self.start_time), self.total_loss / self.report_steps, self.total_loss_forward / self.report_steps, self.total_loss_backward / self.report_steps, self.total_correct_forward / self.total_denominator, self.total_correct_backward / self.total_denominator)) self.total_loss, self.total_loss_forward, self.total_loss_backward = 0.0, 0.0, 0.0 self.total_correct_forward, self.total_correct_backward, self.total_denominator = 0.0, 0.0, 0.0 class ClsTrainer(Trainer): def __init__(self, args): super(ClsTrainer, self).__init__(args) self.total_correct = 0.0 self.total_instances = 0.0 def forward_propagation(self, batch, model): src, tgt, seg = batch loss_info = model(src, tgt, seg) loss, correct = loss_info self.total_loss += loss.item() self.total_correct += correct.item() self.total_instances += src.size(0) loss = loss / self.accumulation_steps return loss def report_and_reset_stats(self): done_tokens = self.batch_size * self.seq_length * self.report_steps if self.dist_train: done_tokens *= self.world_size self.logger.info("| {:8d}/{:8d} steps" "| {:8.2f} tokens/s" "| loss {:7.2f}" "| acc: {:3.3f}".format( self.current_step, self.total_steps, done_tokens / (time.time() - self.start_time), self.total_loss / self.report_steps, self.total_correct / self.total_instances)) self.total_loss = 0.0 self.total_correct = 0.0 self.total_instances = 0.0 class MtTrainer(Trainer): def __init__(self, args): super(MtTrainer, self).__init__(args) self.total_correct = 0.0 self.total_denominator = 0.0 def forward_propagation(self, batch, model): src, tgt_out, seg, tgt_in, tgt_seg = batch loss_info = model(src, tgt_out, seg, tgt_in, tgt_seg) loss, correct, denominator = loss_info self.total_loss += loss.item() self.total_correct += correct.item() self.total_denominator += denominator.item() loss = loss / self.accumulation_steps return loss def report_and_reset_stats(self): done_tokens = self.batch_size * self.seq_length * self.report_steps if self.dist_train: done_tokens *= self.world_size self.logger.info("| {:8d}/{:8d} steps" "| {:8.2f} tokens/s" "| loss {:7.2f}" "| acc: {:3.3f}".format( self.current_step, self.total_steps, done_tokens / (time.time() - self.start_time), self.total_loss / self.report_steps, self.total_correct / self.total_denominator)) self.total_loss = 0.0 self.total_correct = 0.0 self.total_denominator = 0.0 class ClsMlmTrainer(Trainer): def __init__(self, args): super(ClsMlmTrainer, self).__init__(args) self.total_loss_cls = 0.0 self.total_correct_cls = 0.0 self.total_instances = 0.0 self.total_loss_mlm = 0.0 self.total_correct_mlm = 0.0 self.total_denominator = 0.0 def forward_propagation(self, batch, model): src, tgt_mlm, tgt_cls, seg = batch tgt = {"mlm": tgt_mlm, "cls": tgt_cls} loss_info = model(src, tgt, seg) loss_mlm, correct_mlm, denominator = loss_info["mlm"] loss_cls, correct_cls = loss_info["cls"] loss = loss_mlm + loss_cls self.total_loss += loss.item() self.total_loss_mlm += loss_mlm.item() self.total_loss_cls += loss_cls.item() self.total_correct_mlm += correct_mlm.item() self.total_correct_cls += correct_cls.item() self.total_denominator += denominator.item() self.total_instances += src.size(0) loss = loss / self.accumulation_steps return loss def report_and_reset_stats(self): done_tokens = self.batch_size * self.seq_length * self.report_steps if self.dist_train: done_tokens *= self.world_size self.logger.info("| {:8d}/{:8d} steps" "| {:8.2f} tokens/s" "| loss {:7.2f}" "| loss_mlm: {:3.3f}" "| loss_cls: {:3.3f}" "| acc_mlm: {:3.3f}" "| acc_cls: {:3.3f}".format( self.current_step, self.total_steps, done_tokens / (time.time() - self.start_time), self.total_loss / self.report_steps, self.total_loss_mlm / self.report_steps, self.total_loss_cls / self.report_steps, self.total_correct_mlm / self.total_denominator, self.total_correct_cls / self.total_instances)) self.total_loss, self.total_loss_mlm, self.total_loss_cls = 0.0, 0.0, 0.0 self.total_correct_mlm, self.total_denominator = 0.0, 0.0 self.total_correct_cls, self.total_instances = 0.0, 0.0 class T5Trainer(MtTrainer): pass class GsgTrainer(MtTrainer): pass class BartTrainer(MtTrainer): pass class PrefixlmTrainer(MlmTrainer): pass class VitTrainer(ClsTrainer): def report_and_reset_stats(self): done_tokens = self.batch_size * self.seq_length * self.report_steps if self.dist_train: done_tokens *= self.world_size self.logger.info("| {:8d}/{:8d} steps" "| {:8.2f} patches/s" "| loss {:7.2f}" "| acc: {:3.3f}".format( self.current_step, self.total_steps, done_tokens / (time.time() - self.start_time), self.total_loss / self.report_steps, self.total_correct / self.total_instances)) self.total_loss = 0.0 self.total_correct = 0.0 self.total_instances = 0.0 class ViltTrainer(BertTrainer): def forward_propagation(self, batch, model): src_text, src_image, tgt_mlm, tgt_match, seg = batch tgt = {"mlm": tgt_mlm, "sp": tgt_match} loss_info = model((src_text, src_image), tgt, seg) loss_mlm, correct_mlm, denominator = loss_info["mlm"] loss_match, correct_match = loss_info["sp"] loss = loss_mlm + loss_match self.total_loss += loss.item() self.total_loss_mlm += loss_mlm.item() self.total_loss_sp += loss_match.item() self.total_correct_mlm += correct_mlm.item() self.total_correct_sp += correct_match.item() self.total_denominator += denominator.item() self.total_instances += src_text.size(0) loss = loss / self.accumulation_steps return loss def report_and_reset_stats(self): done_tokens = self.batch_size * self.seq_length * self.report_steps if self.dist_train: done_tokens *= self.world_size print("| {:8d}/{:8d} steps" "| {:8.2f} tokens/s" "| loss {:7.2f}" "| loss_mlm: {:3.3f}" "| loss_match: {:3.3f}" "| acc_mlm: {:3.3f}" "| acc_match: {:3.3f}".format( self.current_step, self.total_steps, done_tokens / (time.time() - self.start_time), self.total_loss / self.report_steps, self.total_loss_mlm / self.report_steps, self.total_loss_sp / self.report_steps, self.total_correct_mlm / self.total_denominator, self.total_correct_sp / self.total_denominator)) self.total_loss, self.total_loss_mlm, self.total_loss_sp = 0.0, 0.0, 0.0 self.total_correct_mlm, self.total_denominator = 0.0, 0.0 self.total_correct_sp, self.total_instances = 0.0, 0.0 class ClipTrainer(ClsTrainer): def forward_propagation(self, batch, model): src_text, src_img, seg_text, seg_img = batch loss_info = model((src_text, src_img), None, (seg_text, seg_img)) loss, correct = loss_info self.total_loss += loss.item() self.total_correct += correct.item() self.total_instances += src_text.size(0) loss = loss / self.accumulation_steps return loss class S2tTrainer(MtTrainer): pass class BeitTrainer(MlmTrainer): def forward_propagation(self, batch, model): src, tgt, seg, mask = batch loss_info = model((src, mask), tgt, seg) loss, correct, denominator = loss_info self.total_loss += loss.item() self.total_correct += correct.item() self.total_denominator += denominator.item() loss = loss / self.accumulation_steps return loss class DalleTrainer(LmTrainer): pass str2trainer = {"bert": BertTrainer, "mlm": MlmTrainer, "lm": LmTrainer, "albert": AlbertTrainer, "bilm": BilmTrainer, "cls": ClsTrainer, "mt": MtTrainer, "t5": T5Trainer, "gsg": GsgTrainer, "bart": BartTrainer, "prefixlm": PrefixlmTrainer, "cls_mlm": ClsMlmTrainer, "vit": VitTrainer, "vilt": ViltTrainer, "clip": ClipTrainer, "s2t": S2tTrainer, "beit": BeitTrainer, "dalle": DalleTrainer} def worker(proc_id, gpu_ranks, args, model_for_training, model_for_dataloader=None): """ Args: proc_id: The id of GPU for single GPU mode; The id of process (and GPU) for multiprocessing distributed mode. gpu_ranks: List of ranks of each process. """ set_seed(args.seed) # Get logger args.logger = init_logger(args) if args.deepspeed: import deepspeed deepspeed.init_distributed(dist_backend=args.backend) rank = dist.get_rank() gpu_id = proc_id elif args.dist_train: rank = gpu_ranks[proc_id] gpu_id = proc_id elif args.single_gpu: rank = None gpu_id = proc_id else: rank = None gpu_id = None # Build optimizer. param_optimizer = list(model_for_training.named_parameters()) no_decay = ["bias", "gamma", "beta"] optimizer_grouped_parameters = [ {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01}, {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0} ] if args.optimizer in ["adamw"]: custom_optimizer = str2optimizer[args.optimizer](optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False) else: custom_optimizer = str2optimizer[args.optimizer](optimizer_grouped_parameters, lr=args.learning_rate, scale_parameter=False, relative_step=False) if args.scheduler in ["constant"]: custom_scheduler = str2scheduler[args.scheduler](custom_optimizer) elif args.scheduler in ["constant_with_warmup"]: custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup) elif args.scheduler in ["tri_stage"]: custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps*args.decay, args.total_steps) else: custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps) if args.deepspeed: model_for_training, optimizer, _, scheduler = deepspeed.initialize( model=model_for_training, model_parameters=optimizer_grouped_parameters, args=args, optimizer=custom_optimizer, lr_scheduler=custom_scheduler, mpu=None, dist_init_required=False) else: if gpu_id is not None: model_for_training.cuda(gpu_id) if model_for_dataloader is not None: model_for_dataloader.cuda(gpu_id) optimizer = custom_optimizer scheduler = custom_scheduler if args.fp16: try: from apex import amp except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") model_for_training, optimizer = amp.initialize(model_for_training, optimizer, opt_level=args.fp16_opt_level) args.amp = amp if args.dist_train: # Initialize multiprocessing distributed training environment. dist.init_process_group(backend=args.backend, init_method=args.master_ip, world_size=args.world_size, rank=rank) model_for_training = DistributedDataParallel(model_for_training, device_ids=[gpu_id], find_unused_parameters=True) if model_for_dataloader is not None: model_for_dataloader = DistributedDataParallel(model_for_dataloader, device_ids=[gpu_id], find_unused_parameters=False) args.logger.info("Worker %d is training ... " % rank) else: args.logger.info("Worker is training ...") if args.dist_train: if model_for_dataloader is not None: model_for_dataloader = model_for_dataloader.module train_loader = str2dataloader[args.data_processor](args, args.dataset_path, args.batch_size, rank, args.world_size, gpu_id, True, model_for_dataloader) else: train_loader = str2dataloader[args.data_processor](args, args.dataset_path, args.batch_size, 0, 1, gpu_id, True, model_for_dataloader) trainer = str2trainer[args.data_processor](args) trainer.train(args, gpu_id, rank, train_loader, model_for_training, optimizer, scheduler)