File size: 5,060 Bytes
10b4a5f
 
 
 
 
 
 
 
 
 
358ab8f
10b4a5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358ab8f
 
10b4a5f
 
358ab8f
10b4a5f
 
 
 
 
358ab8f
 
 
 
 
 
10b4a5f
 
358ab8f
10b4a5f
 
358ab8f
 
 
10b4a5f
 
 
358ab8f
 
 
 
 
 
 
 
 
 
 
 
10b4a5f
 
 
 
358ab8f
 
 
10b4a5f
 
 
 
358ab8f
 
 
 
10b4a5f
 
 
 
 
358ab8f
 
 
 
10b4a5f
 
 
358ab8f
10b4a5f
 
358ab8f
10b4a5f
 
 
 
 
 
 
 
 
358ab8f
 
 
 
 
 
 
 
 
 
 
 
 
 
10b4a5f
358ab8f
10b4a5f
 
 
 
358ab8f
10b4a5f
358ab8f
 
 
 
 
 
 
 
10b4a5f
 
358ab8f
10b4a5f
358ab8f
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import torch
import torch.optim as optim
from tqdm import tqdm

from torch.autograd import Variable

from network_v0.model import PointModel
from loss_function import KeypointLoss


class Trainer(object):
    def __init__(self, config, train_loader=None):
        self.config = config
        # data parameters
        self.train_loader = train_loader
        self.num_train = len(self.train_loader)

        # training parameters
        self.max_epoch = config.max_epoch
        self.start_epoch = config.start_epoch
        self.momentum = config.momentum
        self.lr = config.init_lr
        self.lr_factor = config.lr_factor
        self.display = config.display

        # misc params
        self.use_gpu = config.use_gpu
        self.random_seed = config.seed
        self.gpu = config.gpu
        self.ckpt_dir = config.ckpt_dir
        self.ckpt_name = "{}-{}".format(config.ckpt_name, config.seed)

        # build model
        self.model = PointModel(is_test=False)

        # training on GPU
        if self.use_gpu:
            torch.cuda.set_device(self.gpu)
            self.model.cuda()

        print(
            "Number of model parameters: {:,}".format(
                sum([p.data.nelement() for p in self.model.parameters()])
            )
        )

        # build loss functional
        self.loss_func = KeypointLoss(config)

        # build optimizer and scheduler
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.lr_scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=[4, 8], gamma=self.lr_factor
        )

        # resume
        if int(self.config.start_epoch) > 0:
            (
                self.config.start_epoch,
                self.model,
                self.optimizer,
                self.lr_scheduler,
            ) = self.load_checkpoint(
                int(self.config.start_epoch),
                self.model,
                self.optimizer,
                self.lr_scheduler,
            )

    def train(self):
        print("\nTrain on {} samples".format(self.num_train))
        self.save_checkpoint(0, self.model, self.optimizer, self.lr_scheduler)
        for epoch in range(self.start_epoch, self.max_epoch):
            print(
                "\nEpoch: {}/{} --lr: {:.6f}".format(epoch + 1, self.max_epoch, self.lr)
            )
            # train for one epoch
            self.train_one_epoch(epoch)
            if self.lr_scheduler:
                self.lr_scheduler.step()
            self.save_checkpoint(
                epoch + 1, self.model, self.optimizer, self.lr_scheduler
            )

    def train_one_epoch(self, epoch):
        self.model.train()
        for (i, data) in enumerate(tqdm(self.train_loader)):

            if self.use_gpu:
                source_img = data["image_aug"].cuda()
                target_img = data["image"].cuda()
                homography = data["homography"].cuda()

            source_img = Variable(source_img)
            target_img = Variable(target_img)
            homography = Variable(homography)

            # forward propogation
            output = self.model(source_img, target_img, homography)

            # compute loss
            loss, loc_loss, desc_loss, score_loss, corres_loss = self.loss_func(output)

            # compute gradients and update
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # print training info
            msg_batch = (
                "Epoch:{} Iter:{} lr:{:.4f} "
                "loc_loss={:.4f} desc_loss={:.4f} score_loss={:.4f} corres_loss={:.4f} "
                "loss={:.4f} ".format(
                    (epoch + 1),
                    i,
                    self.lr,
                    loc_loss.data,
                    desc_loss.data,
                    score_loss.data,
                    corres_loss.data,
                    loss.data,
                )
            )

            if (i % self.display) == 0:
                print(msg_batch)
        return

    def save_checkpoint(self, epoch, model, optimizer, lr_scheduler):
        filename = self.ckpt_name + "_" + str(epoch) + ".pth"
        torch.save(
            {
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "lr_scheduler": lr_scheduler.state_dict(),
            },
            os.path.join(self.ckpt_dir, filename),
        )

    def load_checkpoint(self, epoch, model, optimizer, lr_scheduler):
        filename = self.ckpt_name + "_" + str(epoch) + ".pth"
        ckpt = torch.load(os.path.join(self.ckpt_dir, filename))
        epoch = ckpt["epoch"]
        model.load_state_dict(ckpt["model_state"])
        optimizer.load_state_dict(ckpt["optimizer_state"])
        lr_scheduler.load_state_dict(ckpt["lr_scheduler"])

        print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt["epoch"]))

        return epoch, model, optimizer, lr_scheduler