|
import os |
|
import time |
|
import datetime |
|
import torch |
|
import sys |
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
from torch.utils.tensorboard import SummaryWriter |
|
from core.dsproc_mcls import MultiClassificationProcessor |
|
from core.mengine import TrainEngine |
|
from toolkit.dtransform import create_transforms_inference, transforms_imagenet_train |
|
from toolkit.yacs import CfgNode as CN |
|
from timm.utils import ModelEmaV3 |
|
|
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
print(torch.__version__) |
|
print(torch.cuda.is_available()) |
|
|
|
|
|
cfg = CN(new_allowed=True) |
|
|
|
|
|
ctg_list = './dataset/label.txt' |
|
train_list = './dataset/train.txt' |
|
val_list = './dataset/val.txt' |
|
|
|
|
|
cfg.network = CN(new_allowed=True) |
|
cfg.network.name = 'replknet' |
|
cfg.network.class_num = 2 |
|
cfg.network.input_size = 384 |
|
|
|
|
|
mean = (0.485, 0.456, 0.406) |
|
std = (0.229, 0.224, 0.225) |
|
|
|
cfg.train = CN(new_allowed=True) |
|
cfg.train.resume = False |
|
cfg.train.resume_path = '' |
|
cfg.train.params_path = '' |
|
cfg.train.batch_size = 16 |
|
cfg.train.epoch_num = 20 |
|
cfg.train.epoch_start = 0 |
|
cfg.train.worker_num = 8 |
|
|
|
|
|
cfg.optimizer = CN(new_allowed=True) |
|
cfg.optimizer.lr = 1e-4 * 1 |
|
cfg.optimizer.weight_decay = 1e-2 |
|
cfg.optimizer.momentum = 0.9 |
|
cfg.optimizer.beta1 = 0.9 |
|
cfg.optimizer.beta2 = 0.999 |
|
cfg.optimizer.eps = 1e-8 |
|
|
|
|
|
cfg.scheduler = CN(new_allowed=True) |
|
cfg.scheduler.min_lr = 1e-6 |
|
|
|
|
|
task = 'competition' |
|
log_root = 'output/' + datetime.datetime.now().strftime("%Y-%m-%d") + '-' + time.strftime( |
|
"%H-%M-%S") + '_' + cfg.network.name + '_' + f"to_{task}_BinClass" |
|
|
|
if not os.path.exists(log_root): |
|
os.makedirs(log_root) |
|
writer = SummaryWriter(log_root) |
|
|
|
|
|
train_engine = TrainEngine(0, 0, DDP=False, SyncBatchNorm=False) |
|
train_engine.create_env(cfg) |
|
|
|
|
|
transforms_dict = { |
|
0: transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size)), |
|
1: transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size), jpeg_compression=1), |
|
} |
|
|
|
transforms_dict_test = { |
|
0: create_transforms_inference(h=512, w=512), |
|
1: create_transforms_inference(h=512, w=512), |
|
} |
|
|
|
transform = transforms_dict |
|
transform_test = transforms_dict_test |
|
|
|
|
|
trainset = MultiClassificationProcessor(transform) |
|
trainset.load_data_from_txt(train_list, ctg_list) |
|
|
|
valset = MultiClassificationProcessor(transform_test) |
|
valset.load_data_from_txt(val_list, ctg_list) |
|
|
|
|
|
train_loader = torch.utils.data.DataLoader(dataset=trainset, |
|
batch_size=cfg.train.batch_size, |
|
num_workers=cfg.train.worker_num, |
|
shuffle=True, |
|
pin_memory=True, |
|
drop_last=True) |
|
|
|
val_loader = torch.utils.data.DataLoader(dataset=valset, |
|
batch_size=cfg.train.batch_size, |
|
num_workers=cfg.train.worker_num, |
|
shuffle=False, |
|
pin_memory=True, |
|
drop_last=False) |
|
|
|
train_log_txtFile = log_root + "/" + "train_log.txt" |
|
f_open = open(train_log_txtFile, "w") |
|
|
|
|
|
best_test_mAP = 0.0 |
|
best_test_idx = 0.0 |
|
ema_start = True |
|
train_engine.ema_model = ModelEmaV3(train_engine.netloc_).cuda() |
|
for epoch_idx in range(cfg.train.epoch_start, cfg.train.epoch_num): |
|
|
|
train_top1, train_loss, train_lr = train_engine.train_multi_class(train_loader=train_loader, epoch_idx=epoch_idx, |
|
ema_start=ema_start) |
|
|
|
val_top1, val_loss, val_auc = train_engine.val_multi_class(val_loader=val_loader, epoch_idx=epoch_idx) |
|
|
|
if ema_start: |
|
ema_val_top1, ema_val_loss, ema_val_auc = train_engine.val_ema(val_loader=val_loader, epoch_idx=epoch_idx) |
|
|
|
train_engine.save_checkpoint(log_root, epoch_idx, train_top1, val_top1, ema_start) |
|
|
|
if ema_start: |
|
outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc}, ema_val_top1={ema_val_top1}, ema_val_loss={ema_val_loss}, ema_val_auc={ema_val_auc} \n" |
|
else: |
|
outInfo = f"epoch_idx = {epoch_idx}, train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc} \n" |
|
|
|
print(outInfo) |
|
|
|
f_open.write(outInfo) |
|
|
|
f_open.flush() |
|
|
|
|
|
writer.add_scalars('top1', {'train': train_top1, 'valid': val_top1}, epoch_idx) |
|
writer.add_scalars('loss', {'train': train_loss, 'valid': val_loss}, epoch_idx) |
|
|
|
|
|
writer.add_scalar('train_lr', train_lr, epoch_idx) |
|
|
|
|