File size: 6,338 Bytes
3ad39d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import argparse
import subprocess
import pprint

import numpy as np
import torch
# import torch.multiprocessing
# torch.multiprocessing.set_sharing_strategy('file_system')
from torch.utils.data import DataLoader

from COTR.models import build_model
from COTR.utils import debug_utils, utils
from COTR.datasets import cotr_dataset
from COTR.trainers.cotr_trainer import COTRTrainer
from COTR.global_configs import general_config
from COTR.options.options import *
from COTR.options.options_utils import *


utils.fix_randomness(0)


def train(opt):
    pprint.pprint(dict(os.environ), width=1)
    result = subprocess.Popen(["nvidia-smi"], stdout=subprocess.PIPE)
    print(result.stdout.read().decode())
    device = torch.cuda.current_device()
    print(f'can see {torch.cuda.device_count()} gpus')
    print(f'current using gpu at {device} -- {torch.cuda.get_device_name(device)}')
    # dummy = torch.rand(3758725612).to(device)
    # del dummy
    torch.cuda.empty_cache()
    model = build_model(opt)
    model = model.to(device)
    if opt.enable_zoom:
        train_dset = cotr_dataset.COTRZoomDataset(opt, 'train')
        val_dset = cotr_dataset.COTRZoomDataset(opt, 'val')
    else:
        train_dset = cotr_dataset.COTRDataset(opt, 'train')
        val_dset = cotr_dataset.COTRDataset(opt, 'val')

    train_loader = DataLoader(train_dset, batch_size=opt.batch_size,
                              shuffle=opt.shuffle_data, num_workers=opt.workers,
                              worker_init_fn=utils.worker_init_fn, pin_memory=True)
    val_loader = DataLoader(val_dset, batch_size=opt.batch_size,
                            shuffle=opt.shuffle_data, num_workers=opt.workers,
                            drop_last=True, worker_init_fn=utils.worker_init_fn, pin_memory=True)

    optim_list = [{"params": model.transformer.parameters(), "lr": opt.learning_rate},
                  {"params": model.corr_embed.parameters(), "lr": opt.learning_rate},
                  {"params": model.query_proj.parameters(), "lr": opt.learning_rate},
                  {"params": model.input_proj.parameters(), "lr": opt.learning_rate},
                  ]
    if opt.lr_backbone > 0:
        optim_list.append({"params": model.backbone.parameters(), "lr": opt.lr_backbone})
    
    optim = torch.optim.Adam(optim_list)
    trainer = COTRTrainer(opt, model, optim, None, train_loader, val_loader)
    trainer.train()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    set_general_arguments(parser)
    set_dataset_arguments(parser)
    set_nn_arguments(parser)
    set_COTR_arguments(parser)
    parser.add_argument('--num_kp', type=int,
                        default=100)
    parser.add_argument('--kp_pool', type=int,
                        default=100)
    parser.add_argument('--enable_zoom', type=str2bool,
                        default=False)
    parser.add_argument('--zoom_start', type=float,
                        default=1.0)
    parser.add_argument('--zoom_end', type=float,
                        default=0.1)
    parser.add_argument('--zoom_levels', type=int,
                        default=10)
    parser.add_argument('--zoom_jitter', type=float,
                        default=0.5)

    parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
    parser.add_argument('--tb_dir', type=str, default=general_config['tb_out'], help='tensorboard runs directory')

    parser.add_argument('--learning_rate', type=float,
                        default=1e-4, help='learning rate')
    parser.add_argument('--lr_backbone', type=float,
                        default=1e-5, help='backbone learning rate')
    parser.add_argument('--batch_size', type=int,
                        default=32, help='batch size for training')
    parser.add_argument('--cycle_consis', type=str2bool, default=True,
                        help='cycle consistency')
    parser.add_argument('--bidirectional', type=str2bool, default=True,
                        help='left2right and right2left')
    parser.add_argument('--max_iter', type=int,
                        default=200000, help='total training iterations')
    parser.add_argument('--valid_iter', type=int,
                        default=1000, help='iterval of validation')
    parser.add_argument('--resume', type=str2bool, default=False,
                        help='resume training with same model name')
    parser.add_argument('--cc_resume', type=str2bool, default=False,
                        help='resume from last run if possible')
    parser.add_argument('--need_rotation', type=str2bool, default=False,
                        help='rotation augmentation')
    parser.add_argument('--max_rotation', type=float, default=0,
                        help='max rotation for data augmentation')
    parser.add_argument('--rotation_chance', type=float, default=0,
                        help='the probability of being rotated')
    parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
    parser.add_argument('--suffix', type=str, default='', help='model suffix')

    opt = parser.parse_args()
    opt.command = ' '.join(sys.argv)

    layer_2_channels = {'layer1': 256,
                        'layer2': 512,
                        'layer3': 1024,
                        'layer4': 2048, }
    opt.dim_feedforward = layer_2_channels[opt.layer]
    opt.num_queries = opt.num_kp

    opt.name = get_compact_naming_cotr(opt)
    opt.out = os.path.join(opt.out_dir, opt.name)
    opt.tb_out = os.path.join(opt.tb_dir, opt.name)

    if opt.cc_resume:
        if os.path.isfile(os.path.join(opt.out, 'checkpoint.pth.tar')):
            print('resuming from last run')
            opt.load_weights = None
            opt.resume = True
        else:
            opt.resume = False
    assert (bool(opt.load_weights) and opt.resume) == False
    if opt.load_weights:
        opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
    if opt.resume:
        opt.load_weights_path = os.path.join(opt.out, 'checkpoint.pth.tar')

    opt.scenes_name_list = build_scenes_name_list_from_opt(opt)

    if opt.confirm:
        confirm_opt(opt)
    else:
        print_opt(opt)

    save_opt(opt)
    train(opt)