Spaces:
Running
Running
import os | |
import torch | |
import torch.optim as optim | |
from tqdm import tqdm | |
from torch.autograd import Variable | |
from network_v0.model import PointModel | |
from loss_function import KeypointLoss | |
class Trainer(object): | |
def __init__(self, config, train_loader=None): | |
self.config = config | |
# data parameters | |
self.train_loader = train_loader | |
self.num_train = len(self.train_loader) | |
# training parameters | |
self.max_epoch = config.max_epoch | |
self.start_epoch = config.start_epoch | |
self.momentum = config.momentum | |
self.lr = config.init_lr | |
self.lr_factor = config.lr_factor | |
self.display = config.display | |
# misc params | |
self.use_gpu = config.use_gpu | |
self.random_seed = config.seed | |
self.gpu = config.gpu | |
self.ckpt_dir = config.ckpt_dir | |
self.ckpt_name = '{}-{}'.format(config.ckpt_name, config.seed) | |
# build model | |
self.model = PointModel(is_test=False) | |
# training on GPU | |
if self.use_gpu: | |
torch.cuda.set_device(self.gpu) | |
self.model.cuda() | |
print('Number of model parameters: {:,}'.format(sum([p.data.nelement() for p in self.model.parameters()]))) | |
# build loss functional | |
self.loss_func = KeypointLoss(config) | |
# build optimizer and scheduler | |
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) | |
self.lr_scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[4, 8], gamma=self.lr_factor) | |
# resume | |
if int(self.config.start_epoch) > 0: | |
self.config.start_epoch, self.model, self.optimizer, self.lr_scheduler = self.load_checkpoint(int(self.config.start_epoch), self.model, self.optimizer, self.lr_scheduler) | |
def train(self): | |
print("\nTrain on {} samples".format(self.num_train)) | |
self.save_checkpoint(0, self.model, self.optimizer, self.lr_scheduler) | |
for epoch in range(self.start_epoch, self.max_epoch): | |
print("\nEpoch: {}/{} --lr: {:.6f}".format(epoch+1, self.max_epoch, self.lr)) | |
# train for one epoch | |
self.train_one_epoch(epoch) | |
if self.lr_scheduler: | |
self.lr_scheduler.step() | |
self.save_checkpoint(epoch+1, self.model, self.optimizer, self.lr_scheduler) | |
def train_one_epoch(self, epoch): | |
self.model.train() | |
for (i, data) in enumerate(tqdm(self.train_loader)): | |
if self.use_gpu: | |
source_img = data['image_aug'].cuda() | |
target_img = data['image'].cuda() | |
homography = data['homography'].cuda() | |
source_img = Variable(source_img) | |
target_img = Variable(target_img) | |
homography = Variable(homography) | |
# forward propogation | |
output = self.model(source_img, target_img, homography) | |
# compute loss | |
loss, loc_loss, desc_loss, score_loss, corres_loss = self.loss_func(output) | |
# compute gradients and update | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
# print training info | |
msg_batch = "Epoch:{} Iter:{} lr:{:.4f} "\ | |
"loc_loss={:.4f} desc_loss={:.4f} score_loss={:.4f} corres_loss={:.4f} "\ | |
"loss={:.4f} "\ | |
.format((epoch + 1), i, self.lr, loc_loss.data, desc_loss.data, score_loss.data, corres_loss.data, loss.data) | |
if((i % self.display) == 0): | |
print(msg_batch) | |
return | |
def save_checkpoint(self, epoch, model, optimizer, lr_scheduler): | |
filename = self.ckpt_name + '_' + str(epoch) + '.pth' | |
torch.save( | |
{'epoch': epoch, | |
'model_state': model.state_dict(), | |
'optimizer_state': optimizer.state_dict(), | |
'lr_scheduler': lr_scheduler.state_dict()}, | |
os.path.join(self.ckpt_dir, filename)) | |
def load_checkpoint(self, epoch, model, optimizer, lr_scheduler): | |
filename = self.ckpt_name + '_' + str(epoch) + '.pth' | |
ckpt = torch.load(os.path.join(self.ckpt_dir, filename)) | |
epoch = ckpt['epoch'] | |
model.load_state_dict(ckpt['model_state']) | |
optimizer.load_state_dict(ckpt['optimizer_state']) | |
lr_scheduler.load_state_dict(ckpt['lr_scheduler']) | |
print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt['epoch'])) | |
return epoch, model, optimizer, lr_scheduler | |