|
import torch |
|
import torch.optim as optim |
|
from tqdm import trange |
|
import os |
|
from tensorboardX import SummaryWriter |
|
import numpy as np |
|
import cv2 |
|
from loss import SGMLoss, SGLoss |
|
from valid import valid, dump_train_vis |
|
|
|
import sys |
|
|
|
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
|
sys.path.insert(0, ROOT_DIR) |
|
|
|
|
|
from utils import train_utils |
|
|
|
|
|
def train_step(optimizer, model, match_loss, data, step, pre_avg_loss): |
|
data["step"] = step |
|
result = model(data, test_mode=False) |
|
loss_res = match_loss.run(data, result) |
|
|
|
optimizer.zero_grad() |
|
loss_res["total_loss"].backward() |
|
|
|
for key in loss_res.keys(): |
|
loss_res[key] = train_utils.reduce_tensor(loss_res[key], "mean") |
|
|
|
if loss_res["total_loss"] < 7 * pre_avg_loss or step < 200 or pre_avg_loss == 0: |
|
optimizer.step() |
|
unusual_loss = False |
|
else: |
|
optimizer.zero_grad() |
|
unusual_loss = True |
|
return loss_res, unusual_loss |
|
|
|
|
|
def train(model, train_loader, valid_loader, config, model_config): |
|
model.train() |
|
optimizer = optim.Adam(model.parameters(), lr=config.train_lr) |
|
|
|
if config.model_name == "SGM": |
|
match_loss = SGMLoss(config, model_config) |
|
elif config.model_name == "SG": |
|
match_loss = SGLoss(config, model_config) |
|
else: |
|
raise NotImplementedError |
|
|
|
checkpoint_path = os.path.join(config.log_base, "checkpoint.pth") |
|
config.resume = os.path.isfile(checkpoint_path) |
|
if config.resume: |
|
if config.local_rank == 0: |
|
print("==> Resuming from checkpoint..") |
|
checkpoint = torch.load( |
|
checkpoint_path, map_location="cuda:{}".format(config.local_rank) |
|
) |
|
model.load_state_dict(checkpoint["state_dict"]) |
|
best_acc = checkpoint["best_acc"] |
|
start_step = checkpoint["step"] |
|
optimizer.load_state_dict(checkpoint["optimizer"]) |
|
else: |
|
best_acc = -1 |
|
start_step = 0 |
|
train_loader_iter = iter(train_loader) |
|
|
|
if config.local_rank == 0: |
|
writer = SummaryWriter(os.path.join(config.log_base, "log_file")) |
|
|
|
train_loader.sampler.set_epoch( |
|
start_step * config.train_batch_size // len(train_loader.dataset) |
|
) |
|
pre_avg_loss = 0 |
|
|
|
progress_bar = ( |
|
trange(start_step, config.train_iter, ncols=config.tqdm_width) |
|
if config.local_rank == 0 |
|
else range(start_step, config.train_iter) |
|
) |
|
for step in progress_bar: |
|
try: |
|
train_data = next(train_loader_iter) |
|
except StopIteration: |
|
if config.local_rank == 0: |
|
print( |
|
"epoch: ", |
|
step * config.train_batch_size // len(train_loader.dataset), |
|
) |
|
train_loader.sampler.set_epoch( |
|
step * config.train_batch_size // len(train_loader.dataset) |
|
) |
|
train_loader_iter = iter(train_loader) |
|
train_data = next(train_loader_iter) |
|
|
|
train_data = train_utils.tocuda(train_data) |
|
lr = min( |
|
config.train_lr * config.decay_rate ** (step - config.decay_iter), |
|
config.train_lr, |
|
) |
|
for param_group in optimizer.param_groups: |
|
param_group["lr"] = lr |
|
|
|
|
|
loss_res, unusual_loss = train_step( |
|
optimizer, model, match_loss, train_data, step - start_step, pre_avg_loss |
|
) |
|
if (step - start_step) <= 200: |
|
pre_avg_loss = loss_res["total_loss"].data |
|
if (step - start_step) > 200 and not unusual_loss: |
|
pre_avg_loss = pre_avg_loss.data * 0.9 + loss_res["total_loss"].data * 0.1 |
|
if unusual_loss and config.local_rank == 0: |
|
print( |
|
"unusual loss! pre_avg_loss: ", |
|
pre_avg_loss, |
|
"cur_loss: ", |
|
loss_res["total_loss"].data, |
|
) |
|
|
|
if config.local_rank == 0 and step % config.log_intv == 0 and not unusual_loss: |
|
writer.add_scalar("TotalLoss", loss_res["total_loss"], step) |
|
writer.add_scalar("CorrLoss", loss_res["loss_corr"], step) |
|
writer.add_scalar("InCorrLoss", loss_res["loss_incorr"], step) |
|
writer.add_scalar("dustbin", model.module.dustbin, step) |
|
|
|
if config.model_name == "SGM": |
|
writer.add_scalar("SeedConfLoss", loss_res["loss_seed_conf"], step) |
|
writer.add_scalar("MidCorrLoss", loss_res["loss_corr_mid"].sum(), step) |
|
writer.add_scalar( |
|
"MidInCorrLoss", loss_res["loss_incorr_mid"].sum(), step |
|
) |
|
|
|
|
|
b_save = ((step + 1) % config.save_intv) == 0 |
|
b_validate = ((step + 1) % config.val_intv) == 0 |
|
if b_validate: |
|
( |
|
total_loss, |
|
acc_corr, |
|
acc_incorr, |
|
seed_precision_tower, |
|
seed_recall_tower, |
|
acc_mid, |
|
) = valid(valid_loader, model, match_loss, config, model_config) |
|
if config.local_rank == 0: |
|
writer.add_scalar("ValidAcc", acc_corr, step) |
|
writer.add_scalar("ValidLoss", total_loss, step) |
|
|
|
if config.model_name == "SGM": |
|
for i in range(len(seed_recall_tower)): |
|
writer.add_scalar( |
|
"seed_conf_pre_%d" % i, seed_precision_tower[i], step |
|
) |
|
writer.add_scalar( |
|
"seed_conf_recall_%d" % i, seed_precision_tower[i], step |
|
) |
|
for i in range(len(acc_mid)): |
|
writer.add_scalar("acc_mid%d" % i, acc_mid[i], step) |
|
print( |
|
"acc_corr: ", |
|
acc_corr.data, |
|
"acc_incorr: ", |
|
acc_incorr.data, |
|
"seed_conf_pre: ", |
|
seed_precision_tower.mean().data, |
|
"seed_conf_recall: ", |
|
seed_recall_tower.mean().data, |
|
"acc_mid: ", |
|
acc_mid.mean().data, |
|
) |
|
else: |
|
print("acc_corr: ", acc_corr.data, "acc_incorr: ", acc_incorr.data) |
|
|
|
|
|
if acc_corr > best_acc: |
|
print("Saving best model with va_res = {}".format(acc_corr)) |
|
best_acc = acc_corr |
|
save_dict = { |
|
"step": step + 1, |
|
"state_dict": model.state_dict(), |
|
"best_acc": best_acc, |
|
"optimizer": optimizer.state_dict(), |
|
} |
|
save_dict.update(save_dict) |
|
torch.save( |
|
save_dict, os.path.join(config.log_base, "model_best.pth") |
|
) |
|
|
|
if b_save: |
|
if config.local_rank == 0: |
|
save_dict = { |
|
"step": step + 1, |
|
"state_dict": model.state_dict(), |
|
"best_acc": best_acc, |
|
"optimizer": optimizer.state_dict(), |
|
} |
|
torch.save(save_dict, checkpoint_path) |
|
|
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
if config.local_rank == 0: |
|
if not os.path.exists( |
|
os.path.join(config.train_vis_folder, "train_vis") |
|
): |
|
os.mkdir(os.path.join(config.train_vis_folder, "train_vis")) |
|
if not os.path.exists( |
|
os.path.join( |
|
config.train_vis_folder, "train_vis", config.log_base |
|
) |
|
): |
|
os.mkdir( |
|
os.path.join( |
|
config.train_vis_folder, "train_vis", config.log_base |
|
) |
|
) |
|
os.mkdir( |
|
os.path.join( |
|
config.train_vis_folder, |
|
"train_vis", |
|
config.log_base, |
|
str(step), |
|
) |
|
) |
|
res = model(train_data) |
|
dump_train_vis(res, train_data, step, config) |
|
model.train() |
|
|
|
if config.local_rank == 0: |
|
writer.close() |
|
|