|
import os |
|
import torch |
|
import torch.optim as optim |
|
from tqdm import tqdm |
|
|
|
from torch.autograd import Variable |
|
|
|
from network_v0.model import PointModel |
|
from loss_function import KeypointLoss |
|
|
|
class Trainer(object): |
|
def __init__(self, config, train_loader=None): |
|
self.config = config |
|
|
|
self.train_loader = train_loader |
|
self.num_train = len(self.train_loader) |
|
|
|
|
|
self.max_epoch = config.max_epoch |
|
self.start_epoch = config.start_epoch |
|
self.momentum = config.momentum |
|
self.lr = config.init_lr |
|
self.lr_factor = config.lr_factor |
|
self.display = config.display |
|
|
|
|
|
self.use_gpu = config.use_gpu |
|
self.random_seed = config.seed |
|
self.gpu = config.gpu |
|
self.ckpt_dir = config.ckpt_dir |
|
self.ckpt_name = '{}-{}'.format(config.ckpt_name, config.seed) |
|
|
|
|
|
self.model = PointModel(is_test=False) |
|
|
|
|
|
if self.use_gpu: |
|
torch.cuda.set_device(self.gpu) |
|
self.model.cuda() |
|
|
|
print('Number of model parameters: {:,}'.format(sum([p.data.nelement() for p in self.model.parameters()]))) |
|
|
|
|
|
self.loss_func = KeypointLoss(config) |
|
|
|
|
|
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) |
|
self.lr_scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[4, 8], gamma=self.lr_factor) |
|
|
|
|
|
if int(self.config.start_epoch) > 0: |
|
self.config.start_epoch, self.model, self.optimizer, self.lr_scheduler = self.load_checkpoint(int(self.config.start_epoch), self.model, self.optimizer, self.lr_scheduler) |
|
|
|
def train(self): |
|
print("\nTrain on {} samples".format(self.num_train)) |
|
self.save_checkpoint(0, self.model, self.optimizer, self.lr_scheduler) |
|
for epoch in range(self.start_epoch, self.max_epoch): |
|
print("\nEpoch: {}/{} --lr: {:.6f}".format(epoch+1, self.max_epoch, self.lr)) |
|
|
|
self.train_one_epoch(epoch) |
|
if self.lr_scheduler: |
|
self.lr_scheduler.step() |
|
self.save_checkpoint(epoch+1, self.model, self.optimizer, self.lr_scheduler) |
|
|
|
def train_one_epoch(self, epoch): |
|
self.model.train() |
|
for (i, data) in enumerate(tqdm(self.train_loader)): |
|
|
|
if self.use_gpu: |
|
source_img = data['image_aug'].cuda() |
|
target_img = data['image'].cuda() |
|
homography = data['homography'].cuda() |
|
|
|
source_img = Variable(source_img) |
|
target_img = Variable(target_img) |
|
homography = Variable(homography) |
|
|
|
|
|
output = self.model(source_img, target_img, homography) |
|
|
|
|
|
loss, loc_loss, desc_loss, score_loss, corres_loss = self.loss_func(output) |
|
|
|
|
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
self.optimizer.step() |
|
|
|
|
|
msg_batch = "Epoch:{} Iter:{} lr:{:.4f} "\ |
|
"loc_loss={:.4f} desc_loss={:.4f} score_loss={:.4f} corres_loss={:.4f} "\ |
|
"loss={:.4f} "\ |
|
.format((epoch + 1), i, self.lr, loc_loss.data, desc_loss.data, score_loss.data, corres_loss.data, loss.data) |
|
|
|
if((i % self.display) == 0): |
|
print(msg_batch) |
|
return |
|
|
|
def save_checkpoint(self, epoch, model, optimizer, lr_scheduler): |
|
filename = self.ckpt_name + '_' + str(epoch) + '.pth' |
|
torch.save( |
|
{'epoch': epoch, |
|
'model_state': model.state_dict(), |
|
'optimizer_state': optimizer.state_dict(), |
|
'lr_scheduler': lr_scheduler.state_dict()}, |
|
os.path.join(self.ckpt_dir, filename)) |
|
|
|
def load_checkpoint(self, epoch, model, optimizer, lr_scheduler): |
|
filename = self.ckpt_name + '_' + str(epoch) + '.pth' |
|
ckpt = torch.load(os.path.join(self.ckpt_dir, filename)) |
|
epoch = ckpt['epoch'] |
|
model.load_state_dict(ckpt['model_state']) |
|
optimizer.load_state_dict(ckpt['optimizer_state']) |
|
lr_scheduler.load_state_dict(ckpt['lr_scheduler']) |
|
|
|
print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt['epoch'])) |
|
|
|
return epoch, model, optimizer, lr_scheduler |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|