File size: 4,969 Bytes
e8b0040 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import os
import time
import datetime
import torch
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from torch.utils.tensorboard import SummaryWriter
from core.dsproc_mcls import MultiClassificationProcessor
from core.mengine import TrainEngine
from toolkit.dtransform import create_transforms_inference, transforms_imagenet_train
from toolkit.yacs import CfgNode as CN
from timm.utils import ModelEmaV3
import warnings
warnings.filterwarnings("ignore")
# check
print(torch.__version__)
print(torch.cuda.is_available())
# init
cfg = CN(new_allowed=True)
# dataset dir
ctg_list = './dataset/label.txt'
train_list = './dataset/train.txt'
val_list = './dataset/val.txt'
# : network
cfg.network = CN(new_allowed=True)
cfg.network.name = 'replknet'
cfg.network.class_num = 2
cfg.network.input_size = 384
# : train params
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
cfg.train = CN(new_allowed=True)
cfg.train.resume = False
cfg.train.resume_path = ''
cfg.train.params_path = ''
cfg.train.batch_size = 16
cfg.train.epoch_num = 20
cfg.train.epoch_start = 0
cfg.train.worker_num = 8
# : optimizer params
cfg.optimizer = CN(new_allowed=True)
cfg.optimizer.lr = 1e-4 * 1
cfg.optimizer.weight_decay = 1e-2
cfg.optimizer.momentum = 0.9
cfg.optimizer.beta1 = 0.9
cfg.optimizer.beta2 = 0.999
cfg.optimizer.eps = 1e-8
# : scheduler params
cfg.scheduler = CN(new_allowed=True)
cfg.scheduler.min_lr = 1e-6
# init path
task = 'competition'
log_root = 'output/' + datetime.datetime.now().strftime("%Y-%m-%d") + '-' + time.strftime(
"%H-%M-%S") + '_' + cfg.network.name + '_' + f"to_{task}_BinClass"
if not os.path.exists(log_root):
os.makedirs(log_root)
writer = SummaryWriter(log_root)
# create engine
train_engine = TrainEngine(0, 0, DDP=False, SyncBatchNorm=False)
train_engine.create_env(cfg)
# create transforms
transforms_dict = {
0: transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size)),
1: transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size), jpeg_compression=1),
}
transforms_dict_test = {
0: create_transforms_inference(h=512, w=512),
1: create_transforms_inference(h=512, w=512),
}
transform = transforms_dict
transform_test = transforms_dict_test
# create dataset
trainset = MultiClassificationProcessor(transform)
trainset.load_data_from_txt(train_list, ctg_list)
valset = MultiClassificationProcessor(transform_test)
valset.load_data_from_txt(val_list, ctg_list)
# create dataloader
train_loader = torch.utils.data.DataLoader(dataset=trainset,
batch_size=cfg.train.batch_size,
num_workers=cfg.train.worker_num,
shuffle=True,
pin_memory=True,
drop_last=True)
val_loader = torch.utils.data.DataLoader(dataset=valset,
batch_size=cfg.train.batch_size,
num_workers=cfg.train.worker_num,
shuffle=False,
pin_memory=True,
drop_last=False)
train_log_txtFile = log_root + "/" + "train_log.txt"
f_open = open(train_log_txtFile, "w")
# train & Val & Test
best_test_mAP = 0.0
best_test_idx = 0.0
ema_start = True
train_engine.ema_model = ModelEmaV3(train_engine.netloc_).cuda()
for epoch_idx in range(cfg.train.epoch_start, cfg.train.epoch_num):
# train
train_top1, train_loss, train_lr = train_engine.train_multi_class(train_loader=train_loader, epoch_idx=epoch_idx,
ema_start=ema_start)
# val
val_top1, val_loss, val_auc = train_engine.val_multi_class(val_loader=val_loader, epoch_idx=epoch_idx)
# ema_val
if ema_start:
ema_val_top1, ema_val_loss, ema_val_auc = train_engine.val_ema(val_loader=val_loader, epoch_idx=epoch_idx)
train_engine.save_checkpoint(log_root, epoch_idx, train_top1, val_top1, ema_start)
if ema_start:
outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc}, ema_val_top1={ema_val_top1}, ema_val_loss={ema_val_loss}, ema_val_auc={ema_val_auc} \n"
else:
outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc} \n"
print(outInfo)
f_open.write(outInfo)
# 刷新文件
f_open.flush()
# curve all mAP & mLoss
writer.add_scalars('top1', {'train': train_top1, 'valid': val_top1}, epoch_idx)
writer.add_scalars('loss', {'train': train_loss, 'valid': val_loss}, epoch_idx)
# curve lr
writer.add_scalar('train_lr', train_lr, epoch_idx)
|