deo / main_train_single_gpu.py
jinyin_chen
test
e8b0040
import os
import time
import datetime
import torch
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from torch.utils.tensorboard import SummaryWriter
from core.dsproc_mcls import MultiClassificationProcessor
from core.mengine import TrainEngine
from toolkit.dtransform import create_transforms_inference, transforms_imagenet_train
from toolkit.yacs import CfgNode as CN
from timm.utils import ModelEmaV3
import warnings
warnings.filterwarnings("ignore")
# check
print(torch.__version__)
print(torch.cuda.is_available())
# init
cfg = CN(new_allowed=True)
# dataset dir
ctg_list = './dataset/label.txt'
train_list = './dataset/train.txt'
val_list = './dataset/val.txt'
# : network
cfg.network = CN(new_allowed=True)
cfg.network.name = 'replknet'
cfg.network.class_num = 2
cfg.network.input_size = 384
# : train params
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
cfg.train = CN(new_allowed=True)
cfg.train.resume = False
cfg.train.resume_path = ''
cfg.train.params_path = ''
cfg.train.batch_size = 16
cfg.train.epoch_num = 20
cfg.train.epoch_start = 0
cfg.train.worker_num = 8
# : optimizer params
cfg.optimizer = CN(new_allowed=True)
cfg.optimizer.lr = 1e-4 * 1
cfg.optimizer.weight_decay = 1e-2
cfg.optimizer.momentum = 0.9
cfg.optimizer.beta1 = 0.9
cfg.optimizer.beta2 = 0.999
cfg.optimizer.eps = 1e-8
# : scheduler params
cfg.scheduler = CN(new_allowed=True)
cfg.scheduler.min_lr = 1e-6
# init path
task = 'competition'
log_root = 'output/' + datetime.datetime.now().strftime("%Y-%m-%d") + '-' + time.strftime(
"%H-%M-%S") + '_' + cfg.network.name + '_' + f"to_{task}_BinClass"
if not os.path.exists(log_root):
os.makedirs(log_root)
writer = SummaryWriter(log_root)
# create engine
train_engine = TrainEngine(0, 0, DDP=False, SyncBatchNorm=False)
train_engine.create_env(cfg)
# create transforms
transforms_dict = {
0: transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size)),
1: transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size), jpeg_compression=1),
}
transforms_dict_test = {
0: create_transforms_inference(h=512, w=512),
1: create_transforms_inference(h=512, w=512),
}
transform = transforms_dict
transform_test = transforms_dict_test
# create dataset
trainset = MultiClassificationProcessor(transform)
trainset.load_data_from_txt(train_list, ctg_list)
valset = MultiClassificationProcessor(transform_test)
valset.load_data_from_txt(val_list, ctg_list)
# create dataloader
train_loader = torch.utils.data.DataLoader(dataset=trainset,
batch_size=cfg.train.batch_size,
num_workers=cfg.train.worker_num,
shuffle=True,
pin_memory=True,
drop_last=True)
val_loader = torch.utils.data.DataLoader(dataset=valset,
batch_size=cfg.train.batch_size,
num_workers=cfg.train.worker_num,
shuffle=False,
pin_memory=True,
drop_last=False)
train_log_txtFile = log_root + "/" + "train_log.txt"
f_open = open(train_log_txtFile, "w")
# train & Val & Test
best_test_mAP = 0.0
best_test_idx = 0.0
ema_start = True
train_engine.ema_model = ModelEmaV3(train_engine.netloc_).cuda()
for epoch_idx in range(cfg.train.epoch_start, cfg.train.epoch_num):
# train
train_top1, train_loss, train_lr = train_engine.train_multi_class(train_loader=train_loader, epoch_idx=epoch_idx,
ema_start=ema_start)
# val
val_top1, val_loss, val_auc = train_engine.val_multi_class(val_loader=val_loader, epoch_idx=epoch_idx)
# ema_val
if ema_start:
ema_val_top1, ema_val_loss, ema_val_auc = train_engine.val_ema(val_loader=val_loader, epoch_idx=epoch_idx)
train_engine.save_checkpoint(log_root, epoch_idx, train_top1, val_top1, ema_start)
if ema_start:
outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc}, ema_val_top1={ema_val_top1}, ema_val_loss={ema_val_loss}, ema_val_auc={ema_val_auc} \n"
else:
outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc} \n"
print(outInfo)
f_open.write(outInfo)
# 刷新文件
f_open.flush()
# curve all mAP & mLoss
writer.add_scalars('top1', {'train': train_top1, 'valid': val_top1}, epoch_idx)
writer.add_scalars('loss', {'train': train_loss, 'valid': val_loss}, epoch_idx)
# curve lr
writer.add_scalar('train_lr', train_lr, epoch_idx)