import torch import torch.optim as optim from tqdm import trange import os from tensorboardX import SummaryWriter import numpy as np import cv2 from loss import SGMLoss,SGLoss from valid import valid,dump_train_vis import sys ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT_DIR) from utils import train_utils def train_step(optimizer, model, match_loss, data,step,pre_avg_loss): data['step']=step result=model(data,test_mode=False) loss_res=match_loss.run(data,result) optimizer.zero_grad() loss_res['total_loss'].backward() #apply reduce on all record tensor for key in loss_res.keys(): loss_res[key]=train_utils.reduce_tensor(loss_res[key],'mean') if loss_res['total_loss']<7*pre_avg_loss or step<200 or pre_avg_loss==0: optimizer.step() unusual_loss=False else: optimizer.zero_grad() unusual_loss=True return loss_res,unusual_loss def train(model, train_loader, valid_loader, config,model_config): model.train() optimizer = optim.Adam(model.parameters(), lr=config.train_lr) if config.model_name=='SGM': match_loss = SGMLoss(config,model_config) elif config.model_name=='SG': match_loss= SGLoss(config,model_config) else: raise NotImplementedError checkpoint_path = os.path.join(config.log_base, 'checkpoint.pth') config.resume = os.path.isfile(checkpoint_path) if config.resume: if config.local_rank==0: print('==> Resuming from checkpoint..') checkpoint = torch.load(checkpoint_path,map_location='cuda:{}'.format(config.local_rank)) model.load_state_dict(checkpoint['state_dict']) best_acc = checkpoint['best_acc'] start_step = checkpoint['step'] optimizer.load_state_dict(checkpoint['optimizer']) else: best_acc = -1 start_step = 0 train_loader_iter = iter(train_loader) if config.local_rank==0: writer=SummaryWriter(os.path.join(config.log_base,'log_file')) train_loader.sampler.set_epoch(start_step*config.train_batch_size//len(train_loader.dataset)) pre_avg_loss=0 progress_bar=trange(start_step, config.train_iter,ncols=config.tqdm_width) if config.local_rank==0 else range(start_step, config.train_iter) for step in progress_bar: try: train_data = next(train_loader_iter) except StopIteration: if config.local_rank==0: print('epoch: ',step*config.train_batch_size//len(train_loader.dataset)) train_loader.sampler.set_epoch(step*config.train_batch_size//len(train_loader.dataset)) train_loader_iter = iter(train_loader) train_data = next(train_loader_iter) train_data = train_utils.tocuda(train_data) lr=min(config.train_lr*config.decay_rate**(step-config.decay_iter),config.train_lr) for param_group in optimizer.param_groups: param_group['lr'] = lr # run training loss_res,unusual_loss = train_step(optimizer, model, match_loss, train_data,step-start_step,pre_avg_loss) if (step-start_step)<=200: pre_avg_loss=loss_res['total_loss'].data if (step-start_step)>200 and not unusual_loss: pre_avg_loss=pre_avg_loss.data*0.9+loss_res['total_loss'].data*0.1 if unusual_loss and config.local_rank==0: print('unusual loss! pre_avg_loss: ',pre_avg_loss,'cur_loss: ',loss_res['total_loss'].data) #log if config.local_rank==0 and step%config.log_intv==0 and not unusual_loss: writer.add_scalar('TotalLoss',loss_res['total_loss'],step) writer.add_scalar('CorrLoss',loss_res['loss_corr'],step) writer.add_scalar('InCorrLoss', loss_res['loss_incorr'], step) writer.add_scalar('dustbin', model.module.dustbin, step) if config.model_name=='SGM': writer.add_scalar('SeedConfLoss', loss_res['loss_seed_conf'], step) writer.add_scalar('MidCorrLoss', loss_res['loss_corr_mid'].sum(), step) writer.add_scalar('MidInCorrLoss', loss_res['loss_incorr_mid'].sum(), step) # valid ans save b_save = ((step + 1) % config.save_intv) == 0 b_validate = ((step + 1) % config.val_intv) == 0 if b_validate: total_loss,acc_corr,acc_incorr,seed_precision_tower,seed_recall_tower,acc_mid=valid(valid_loader, model, match_loss, config,model_config) if config.local_rank==0: writer.add_scalar('ValidAcc', acc_corr, step) writer.add_scalar('ValidLoss', total_loss, step) if config.model_name=='SGM': for i in range(len(seed_recall_tower)): writer.add_scalar('seed_conf_pre_%d'%i,seed_precision_tower[i],step) writer.add_scalar('seed_conf_recall_%d' % i, seed_precision_tower[i], step) for i in range(len(acc_mid)): writer.add_scalar('acc_mid%d'%i,acc_mid[i],step) print('acc_corr: ',acc_corr.data,'acc_incorr: ',acc_incorr.data,'seed_conf_pre: ',seed_precision_tower.mean().data, 'seed_conf_recall: ',seed_recall_tower.mean().data,'acc_mid: ',acc_mid.mean().data) else: print('acc_corr: ',acc_corr.data,'acc_incorr: ',acc_incorr.data) #saving best if acc_corr > best_acc: print("Saving best model with va_res = {}".format(acc_corr)) best_acc = acc_corr save_dict={'step': step + 1, 'state_dict': model.state_dict(), 'best_acc': best_acc, 'optimizer' : optimizer.state_dict()} save_dict.update(save_dict) torch.save(save_dict, os.path.join(config.log_base, 'model_best.pth')) if b_save: if config.local_rank==0: save_dict={'step': step + 1, 'state_dict': model.state_dict(), 'best_acc': best_acc, 'optimizer' : optimizer.state_dict()} torch.save(save_dict, checkpoint_path) #draw match results model.eval() with torch.no_grad(): if config.local_rank==0: if not os.path.exists(os.path.join(config.train_vis_folder,'train_vis')): os.mkdir(os.path.join(config.train_vis_folder,'train_vis')) if not os.path.exists(os.path.join(config.train_vis_folder,'train_vis',config.log_base)): os.mkdir(os.path.join(config.train_vis_folder,'train_vis',config.log_base)) os.mkdir(os.path.join(config.train_vis_folder,'train_vis',config.log_base,str(step))) res=model(train_data) dump_train_vis(res,train_data,step,config) model.train() if config.local_rank==0: writer.close()