File size: 4,969 Bytes
e8b0040
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import time
import datetime
import torch
import sys

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from torch.utils.tensorboard import SummaryWriter
from core.dsproc_mcls import MultiClassificationProcessor
from core.mengine import TrainEngine
from toolkit.dtransform import create_transforms_inference, transforms_imagenet_train
from toolkit.yacs import CfgNode as CN
from timm.utils import ModelEmaV3

import warnings

warnings.filterwarnings("ignore")

# check
print(torch.__version__)
print(torch.cuda.is_available())

# init
cfg = CN(new_allowed=True)

# dataset dir
ctg_list = './dataset/label.txt'
train_list = './dataset/train.txt'
val_list = './dataset/val.txt'

# : network
cfg.network = CN(new_allowed=True)
cfg.network.name = 'replknet'
cfg.network.class_num = 2
cfg.network.input_size = 384

# : train params
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

cfg.train = CN(new_allowed=True)
cfg.train.resume = False
cfg.train.resume_path = ''
cfg.train.params_path = ''
cfg.train.batch_size = 16
cfg.train.epoch_num = 20
cfg.train.epoch_start = 0
cfg.train.worker_num = 8

# : optimizer params
cfg.optimizer = CN(new_allowed=True)
cfg.optimizer.lr = 1e-4 * 1
cfg.optimizer.weight_decay = 1e-2
cfg.optimizer.momentum = 0.9
cfg.optimizer.beta1 = 0.9
cfg.optimizer.beta2 = 0.999
cfg.optimizer.eps = 1e-8

# : scheduler params
cfg.scheduler = CN(new_allowed=True)
cfg.scheduler.min_lr = 1e-6

# init path
task = 'competition'
log_root = 'output/' + datetime.datetime.now().strftime("%Y-%m-%d") + '-' + time.strftime(
    "%H-%M-%S") + '_' + cfg.network.name + '_' + f"to_{task}_BinClass"

if not os.path.exists(log_root):
    os.makedirs(log_root)
writer = SummaryWriter(log_root)

# create engine
train_engine = TrainEngine(0, 0, DDP=False, SyncBatchNorm=False)
train_engine.create_env(cfg)

# create transforms
transforms_dict = {
    0: transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size)),
    1: transforms_imagenet_train(img_size=(cfg.network.input_size, cfg.network.input_size), jpeg_compression=1),
}

transforms_dict_test = {
    0: create_transforms_inference(h=512, w=512),
    1: create_transforms_inference(h=512, w=512),
}

transform = transforms_dict
transform_test = transforms_dict_test

# create dataset
trainset = MultiClassificationProcessor(transform)
trainset.load_data_from_txt(train_list, ctg_list)

valset = MultiClassificationProcessor(transform_test)
valset.load_data_from_txt(val_list, ctg_list)

# create dataloader
train_loader = torch.utils.data.DataLoader(dataset=trainset,
                                           batch_size=cfg.train.batch_size,
                                           num_workers=cfg.train.worker_num,
                                           shuffle=True,
                                           pin_memory=True,
                                           drop_last=True)

val_loader = torch.utils.data.DataLoader(dataset=valset,
                                         batch_size=cfg.train.batch_size,
                                         num_workers=cfg.train.worker_num,
                                         shuffle=False,
                                         pin_memory=True,
                                         drop_last=False)

train_log_txtFile = log_root + "/" + "train_log.txt"
f_open = open(train_log_txtFile, "w")

# train & Val & Test
best_test_mAP = 0.0
best_test_idx = 0.0
ema_start = True
train_engine.ema_model = ModelEmaV3(train_engine.netloc_).cuda()
for epoch_idx in range(cfg.train.epoch_start, cfg.train.epoch_num):
    # train
    train_top1, train_loss, train_lr = train_engine.train_multi_class(train_loader=train_loader, epoch_idx=epoch_idx,
                                                                      ema_start=ema_start)
    # val
    val_top1, val_loss, val_auc = train_engine.val_multi_class(val_loader=val_loader, epoch_idx=epoch_idx)
    # ema_val
    if ema_start:
        ema_val_top1, ema_val_loss, ema_val_auc = train_engine.val_ema(val_loader=val_loader, epoch_idx=epoch_idx)

    train_engine.save_checkpoint(log_root, epoch_idx, train_top1, val_top1, ema_start)

    if ema_start:
        outInfo = f"epoch_idx = {epoch_idx},  train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc}, ema_val_top1={ema_val_top1}, ema_val_loss={ema_val_loss}, ema_val_auc={ema_val_auc} \n"
    else:
        outInfo = f"epoch_idx = {epoch_idx},  train_top1={train_top1}, train_loss={train_loss},val_top1={val_top1},val_loss={val_loss}, val_auc={val_auc} \n"

    print(outInfo)

    f_open.write(outInfo)
    # 刷新文件
    f_open.flush()

    # curve all mAP & mLoss
    writer.add_scalars('top1', {'train': train_top1, 'valid': val_top1}, epoch_idx)
    writer.add_scalars('loss', {'train': train_loss, 'valid': val_loss}, epoch_idx)

    # curve lr
    writer.add_scalar('train_lr', train_lr, epoch_idx)