|
import os |
|
import torch |
|
import torch.optim as optim |
|
from tqdm import tqdm |
|
|
|
from torch.autograd import Variable |
|
|
|
from network_v0.model import PointModel |
|
from loss_function import KeypointLoss |
|
|
|
|
|
class Trainer(object): |
|
def __init__(self, config, train_loader=None): |
|
self.config = config |
|
|
|
self.train_loader = train_loader |
|
self.num_train = len(self.train_loader) |
|
|
|
|
|
self.max_epoch = config.max_epoch |
|
self.start_epoch = config.start_epoch |
|
self.momentum = config.momentum |
|
self.lr = config.init_lr |
|
self.lr_factor = config.lr_factor |
|
self.display = config.display |
|
|
|
|
|
self.use_gpu = config.use_gpu |
|
self.random_seed = config.seed |
|
self.gpu = config.gpu |
|
self.ckpt_dir = config.ckpt_dir |
|
self.ckpt_name = "{}-{}".format(config.ckpt_name, config.seed) |
|
|
|
|
|
self.model = PointModel(is_test=False) |
|
|
|
|
|
if self.use_gpu: |
|
torch.cuda.set_device(self.gpu) |
|
self.model.cuda() |
|
|
|
print( |
|
"Number of model parameters: {:,}".format( |
|
sum([p.data.nelement() for p in self.model.parameters()]) |
|
) |
|
) |
|
|
|
|
|
self.loss_func = KeypointLoss(config) |
|
|
|
|
|
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) |
|
self.lr_scheduler = optim.lr_scheduler.MultiStepLR( |
|
self.optimizer, milestones=[4, 8], gamma=self.lr_factor |
|
) |
|
|
|
|
|
if int(self.config.start_epoch) > 0: |
|
( |
|
self.config.start_epoch, |
|
self.model, |
|
self.optimizer, |
|
self.lr_scheduler, |
|
) = self.load_checkpoint( |
|
int(self.config.start_epoch), |
|
self.model, |
|
self.optimizer, |
|
self.lr_scheduler, |
|
) |
|
|
|
def train(self): |
|
print("\nTrain on {} samples".format(self.num_train)) |
|
self.save_checkpoint(0, self.model, self.optimizer, self.lr_scheduler) |
|
for epoch in range(self.start_epoch, self.max_epoch): |
|
print( |
|
"\nEpoch: {}/{} --lr: {:.6f}".format(epoch + 1, self.max_epoch, self.lr) |
|
) |
|
|
|
self.train_one_epoch(epoch) |
|
if self.lr_scheduler: |
|
self.lr_scheduler.step() |
|
self.save_checkpoint( |
|
epoch + 1, self.model, self.optimizer, self.lr_scheduler |
|
) |
|
|
|
def train_one_epoch(self, epoch): |
|
self.model.train() |
|
for (i, data) in enumerate(tqdm(self.train_loader)): |
|
|
|
if self.use_gpu: |
|
source_img = data["image_aug"].cuda() |
|
target_img = data["image"].cuda() |
|
homography = data["homography"].cuda() |
|
|
|
source_img = Variable(source_img) |
|
target_img = Variable(target_img) |
|
homography = Variable(homography) |
|
|
|
|
|
output = self.model(source_img, target_img, homography) |
|
|
|
|
|
loss, loc_loss, desc_loss, score_loss, corres_loss = self.loss_func(output) |
|
|
|
|
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
self.optimizer.step() |
|
|
|
|
|
msg_batch = ( |
|
"Epoch:{} Iter:{} lr:{:.4f} " |
|
"loc_loss={:.4f} desc_loss={:.4f} score_loss={:.4f} corres_loss={:.4f} " |
|
"loss={:.4f} ".format( |
|
(epoch + 1), |
|
i, |
|
self.lr, |
|
loc_loss.data, |
|
desc_loss.data, |
|
score_loss.data, |
|
corres_loss.data, |
|
loss.data, |
|
) |
|
) |
|
|
|
if (i % self.display) == 0: |
|
print(msg_batch) |
|
return |
|
|
|
def save_checkpoint(self, epoch, model, optimizer, lr_scheduler): |
|
filename = self.ckpt_name + "_" + str(epoch) + ".pth" |
|
torch.save( |
|
{ |
|
"epoch": epoch, |
|
"model_state": model.state_dict(), |
|
"optimizer_state": optimizer.state_dict(), |
|
"lr_scheduler": lr_scheduler.state_dict(), |
|
}, |
|
os.path.join(self.ckpt_dir, filename), |
|
) |
|
|
|
def load_checkpoint(self, epoch, model, optimizer, lr_scheduler): |
|
filename = self.ckpt_name + "_" + str(epoch) + ".pth" |
|
ckpt = torch.load(os.path.join(self.ckpt_dir, filename)) |
|
epoch = ckpt["epoch"] |
|
model.load_state_dict(ckpt["model_state"]) |
|
optimizer.load_state_dict(ckpt["optimizer_state"]) |
|
lr_scheduler.load_state_dict(ckpt["lr_scheduler"]) |
|
|
|
print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt["epoch"])) |
|
|
|
return epoch, model, optimizer, lr_scheduler |
|
|