Spaces:
Running
Running
File size: 4,735 Bytes
a80d6bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import os
import torch
import torch.optim as optim
from tqdm import tqdm
from torch.autograd import Variable
from network_v0.model import PointModel
from loss_function import KeypointLoss
class Trainer(object):
def __init__(self, config, train_loader=None):
self.config = config
# data parameters
self.train_loader = train_loader
self.num_train = len(self.train_loader)
# training parameters
self.max_epoch = config.max_epoch
self.start_epoch = config.start_epoch
self.momentum = config.momentum
self.lr = config.init_lr
self.lr_factor = config.lr_factor
self.display = config.display
# misc params
self.use_gpu = config.use_gpu
self.random_seed = config.seed
self.gpu = config.gpu
self.ckpt_dir = config.ckpt_dir
self.ckpt_name = '{}-{}'.format(config.ckpt_name, config.seed)
# build model
self.model = PointModel(is_test=False)
# training on GPU
if self.use_gpu:
torch.cuda.set_device(self.gpu)
self.model.cuda()
print('Number of model parameters: {:,}'.format(sum([p.data.nelement() for p in self.model.parameters()])))
# build loss functional
self.loss_func = KeypointLoss(config)
# build optimizer and scheduler
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
self.lr_scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[4, 8], gamma=self.lr_factor)
# resume
if int(self.config.start_epoch) > 0:
self.config.start_epoch, self.model, self.optimizer, self.lr_scheduler = self.load_checkpoint(int(self.config.start_epoch), self.model, self.optimizer, self.lr_scheduler)
def train(self):
print("\nTrain on {} samples".format(self.num_train))
self.save_checkpoint(0, self.model, self.optimizer, self.lr_scheduler)
for epoch in range(self.start_epoch, self.max_epoch):
print("\nEpoch: {}/{} --lr: {:.6f}".format(epoch+1, self.max_epoch, self.lr))
# train for one epoch
self.train_one_epoch(epoch)
if self.lr_scheduler:
self.lr_scheduler.step()
self.save_checkpoint(epoch+1, self.model, self.optimizer, self.lr_scheduler)
def train_one_epoch(self, epoch):
self.model.train()
for (i, data) in enumerate(tqdm(self.train_loader)):
if self.use_gpu:
source_img = data['image_aug'].cuda()
target_img = data['image'].cuda()
homography = data['homography'].cuda()
source_img = Variable(source_img)
target_img = Variable(target_img)
homography = Variable(homography)
# forward propogation
output = self.model(source_img, target_img, homography)
# compute loss
loss, loc_loss, desc_loss, score_loss, corres_loss = self.loss_func(output)
# compute gradients and update
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# print training info
msg_batch = "Epoch:{} Iter:{} lr:{:.4f} "\
"loc_loss={:.4f} desc_loss={:.4f} score_loss={:.4f} corres_loss={:.4f} "\
"loss={:.4f} "\
.format((epoch + 1), i, self.lr, loc_loss.data, desc_loss.data, score_loss.data, corres_loss.data, loss.data)
if((i % self.display) == 0):
print(msg_batch)
return
def save_checkpoint(self, epoch, model, optimizer, lr_scheduler):
filename = self.ckpt_name + '_' + str(epoch) + '.pth'
torch.save(
{'epoch': epoch,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict()},
os.path.join(self.ckpt_dir, filename))
def load_checkpoint(self, epoch, model, optimizer, lr_scheduler):
filename = self.ckpt_name + '_' + str(epoch) + '.pth'
ckpt = torch.load(os.path.join(self.ckpt_dir, filename))
epoch = ckpt['epoch']
model.load_state_dict(ckpt['model_state'])
optimizer.load_state_dict(ckpt['optimizer_state'])
lr_scheduler.load_state_dict(ckpt['lr_scheduler'])
print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt['epoch']))
return epoch, model, optimizer, lr_scheduler
|