|
|
|
|
|
|
|
|
|
|
|
from logger import setup_logger |
|
|
from model import BiSeNet |
|
|
from face_dataset import FaceMask |
|
|
from loss import OhemCELoss |
|
|
from evaluate import evaluate |
|
|
from optimizer import Optimizer |
|
|
import cv2 |
|
|
import numpy as np |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader |
|
|
import torch.nn.functional as F |
|
|
import torch.distributed as dist |
|
|
|
|
|
import os |
|
|
import os.path as osp |
|
|
import logging |
|
|
import time |
|
|
import datetime |
|
|
import argparse |
|
|
|
|
|
|
|
|
respth = './res' |
|
|
if not osp.exists(respth): |
|
|
os.makedirs(respth) |
|
|
logger = logging.getLogger() |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parse = argparse.ArgumentParser() |
|
|
parse.add_argument( |
|
|
'--local_rank', |
|
|
dest = 'local_rank', |
|
|
type = int, |
|
|
default = -1, |
|
|
) |
|
|
return parse.parse_args() |
|
|
|
|
|
|
|
|
def train(): |
|
|
args = parse_args() |
|
|
torch.cuda.set_device(args.local_rank) |
|
|
dist.init_process_group( |
|
|
backend = 'nccl', |
|
|
init_method = 'tcp://127.0.0.1:33241', |
|
|
world_size = torch.cuda.device_count(), |
|
|
rank=args.local_rank |
|
|
) |
|
|
setup_logger(respth) |
|
|
|
|
|
|
|
|
n_classes = 19 |
|
|
n_img_per_gpu = 16 |
|
|
n_workers = 8 |
|
|
cropsize = [448, 448] |
|
|
data_root = '/home/zll/data/CelebAMask-HQ/' |
|
|
|
|
|
ds = FaceMask(data_root, cropsize=cropsize, mode='train') |
|
|
sampler = torch.utils.data.distributed.DistributedSampler(ds) |
|
|
dl = DataLoader(ds, |
|
|
batch_size = n_img_per_gpu, |
|
|
shuffle = False, |
|
|
sampler = sampler, |
|
|
num_workers = n_workers, |
|
|
pin_memory = True, |
|
|
drop_last = True) |
|
|
|
|
|
|
|
|
ignore_idx = -100 |
|
|
net = BiSeNet(n_classes=n_classes) |
|
|
net.cuda() |
|
|
net.train() |
|
|
net = nn.parallel.DistributedDataParallel(net, |
|
|
device_ids = [args.local_rank, ], |
|
|
output_device = args.local_rank |
|
|
) |
|
|
score_thres = 0.7 |
|
|
n_min = n_img_per_gpu * cropsize[0] * cropsize[1]//16 |
|
|
LossP = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) |
|
|
Loss2 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) |
|
|
Loss3 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) |
|
|
|
|
|
|
|
|
momentum = 0.9 |
|
|
weight_decay = 5e-4 |
|
|
lr_start = 1e-2 |
|
|
max_iter = 80000 |
|
|
power = 0.9 |
|
|
warmup_steps = 1000 |
|
|
warmup_start_lr = 1e-5 |
|
|
optim = Optimizer( |
|
|
model = net.module, |
|
|
lr0 = lr_start, |
|
|
momentum = momentum, |
|
|
wd = weight_decay, |
|
|
warmup_steps = warmup_steps, |
|
|
warmup_start_lr = warmup_start_lr, |
|
|
max_iter = max_iter, |
|
|
power = power) |
|
|
|
|
|
|
|
|
msg_iter = 50 |
|
|
loss_avg = [] |
|
|
st = glob_st = time.time() |
|
|
diter = iter(dl) |
|
|
epoch = 0 |
|
|
for it in range(max_iter): |
|
|
try: |
|
|
im, lb = next(diter) |
|
|
if not im.size()[0] == n_img_per_gpu: |
|
|
raise StopIteration |
|
|
except StopIteration: |
|
|
epoch += 1 |
|
|
sampler.set_epoch(epoch) |
|
|
diter = iter(dl) |
|
|
im, lb = next(diter) |
|
|
im = im.cuda() |
|
|
lb = lb.cuda() |
|
|
H, W = im.size()[2:] |
|
|
lb = torch.squeeze(lb, 1) |
|
|
|
|
|
optim.zero_grad() |
|
|
out, out16, out32 = net(im) |
|
|
lossp = LossP(out, lb) |
|
|
loss2 = Loss2(out16, lb) |
|
|
loss3 = Loss3(out32, lb) |
|
|
loss = lossp + loss2 + loss3 |
|
|
loss.backward() |
|
|
optim.step() |
|
|
|
|
|
loss_avg.append(loss.item()) |
|
|
|
|
|
|
|
|
if (it+1) % msg_iter == 0: |
|
|
loss_avg = sum(loss_avg) / len(loss_avg) |
|
|
lr = optim.lr |
|
|
ed = time.time() |
|
|
t_intv, glob_t_intv = ed - st, ed - glob_st |
|
|
eta = int((max_iter - it) * (glob_t_intv / it)) |
|
|
eta = str(datetime.timedelta(seconds=eta)) |
|
|
msg = ', '.join([ |
|
|
'it: {it}/{max_it}', |
|
|
'lr: {lr:4f}', |
|
|
'loss: {loss:.4f}', |
|
|
'eta: {eta}', |
|
|
'time: {time:.4f}', |
|
|
]).format( |
|
|
it = it+1, |
|
|
max_it = max_iter, |
|
|
lr = lr, |
|
|
loss = loss_avg, |
|
|
time = t_intv, |
|
|
eta = eta |
|
|
) |
|
|
logger.info(msg) |
|
|
loss_avg = [] |
|
|
st = ed |
|
|
if dist.get_rank() == 0: |
|
|
if (it+1) % 5000 == 0: |
|
|
state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict() |
|
|
if dist.get_rank() == 0: |
|
|
torch.save(state, './res/cp/{}_iter.pth'.format(it)) |
|
|
evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='{}_iter.pth'.format(it)) |
|
|
|
|
|
|
|
|
save_pth = osp.join(respth, 'model_final_diss.pth') |
|
|
|
|
|
state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict() |
|
|
if dist.get_rank() == 0: |
|
|
torch.save(state, save_pth) |
|
|
logger.info('training done, model saved to: {}'.format(save_pth)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
train() |
|
|
|