File size: 4,735 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import torch
import torch.optim as optim
from tqdm import tqdm

from torch.autograd import Variable

from network_v0.model import PointModel
from loss_function import KeypointLoss

class Trainer(object):
    def __init__(self, config, train_loader=None):
        self.config = config
        # data parameters
        self.train_loader = train_loader
        self.num_train = len(self.train_loader)

        # training parameters
        self.max_epoch = config.max_epoch
        self.start_epoch = config.start_epoch
        self.momentum = config.momentum
        self.lr = config.init_lr
        self.lr_factor = config.lr_factor
        self.display = config.display

        # misc params
        self.use_gpu = config.use_gpu
        self.random_seed = config.seed
        self.gpu = config.gpu
        self.ckpt_dir = config.ckpt_dir
        self.ckpt_name = '{}-{}'.format(config.ckpt_name, config.seed)
		
        # build model
        self.model = PointModel(is_test=False)
        
        # training on GPU
        if self.use_gpu:
            torch.cuda.set_device(self.gpu)
            self.model.cuda()

        print('Number of model parameters: {:,}'.format(sum([p.data.nelement() for p in self.model.parameters()])))	
        
        # build loss functional
        self.loss_func = KeypointLoss(config)
        
        # build optimizer and scheduler
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.lr_scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[4, 8], gamma=self.lr_factor)

        # resume
        if int(self.config.start_epoch) > 0:
            self.config.start_epoch, self.model, self.optimizer, self.lr_scheduler = self.load_checkpoint(int(self.config.start_epoch), self.model, self.optimizer, self.lr_scheduler)    
    
    def train(self):
        print("\nTrain on {} samples".format(self.num_train))
        self.save_checkpoint(0, self.model, self.optimizer, self.lr_scheduler)
        for epoch in range(self.start_epoch, self.max_epoch):
            print("\nEpoch: {}/{} --lr: {:.6f}".format(epoch+1, self.max_epoch, self.lr))
            # train for one epoch
            self.train_one_epoch(epoch)
            if self.lr_scheduler:
                self.lr_scheduler.step()
            self.save_checkpoint(epoch+1, self.model, self.optimizer, self.lr_scheduler)
            
    def train_one_epoch(self, epoch):
        self.model.train()
        for (i, data) in enumerate(tqdm(self.train_loader)):

            if self.use_gpu:
                source_img = data['image_aug'].cuda()
                target_img = data['image'].cuda()
                homography = data['homography'].cuda()
            
            source_img = Variable(source_img)
            target_img = Variable(target_img)
            homography = Variable(homography)
            
            # forward propogation
            output = self.model(source_img, target_img, homography)
            
            # compute loss
            loss, loc_loss, desc_loss, score_loss, corres_loss = self.loss_func(output)

            # compute gradients and update
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # print training info
            msg_batch = "Epoch:{} Iter:{} lr:{:.4f} "\
                        "loc_loss={:.4f} desc_loss={:.4f} score_loss={:.4f} corres_loss={:.4f} "\
                        "loss={:.4f} "\
                        .format((epoch + 1), i, self.lr, loc_loss.data, desc_loss.data, score_loss.data, corres_loss.data, loss.data)

            if((i % self.display) == 0):
                print(msg_batch)
        return

    def save_checkpoint(self, epoch, model, optimizer, lr_scheduler):
        filename = self.ckpt_name + '_' + str(epoch) + '.pth'
        torch.save(
            {'epoch': epoch,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict()},
            os.path.join(self.ckpt_dir, filename))

    def load_checkpoint(self, epoch, model, optimizer, lr_scheduler):
        filename = self.ckpt_name + '_' + str(epoch) + '.pth'
        ckpt = torch.load(os.path.join(self.ckpt_dir, filename))
        epoch = ckpt['epoch']
        model.load_state_dict(ckpt['model_state'])
        optimizer.load_state_dict(ckpt['optimizer_state'])
        lr_scheduler.load_state_dict(ckpt['lr_scheduler'])

        print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt['epoch']))

        return epoch, model, optimizer, lr_scheduler