Vincentqyw
fix: roma
358ab8f
raw
history blame
5.13 kB
# Copyright 2019-present NAVER Corp.
# CC BY-NC-SA 3.0
# Available only for non-commercial use
import os, pdb
import torch
import torch.optim as optim
from tools import common, trainer
from tools.dataloader import *
from nets.patchnet import *
from nets.losses import *
default_net = "Quad_L2Net_ConfCFS()"
toy_db_debug = """SyntheticPairDataset(
ImgFolder('imgs'),
'RandomScale(256,1024,can_upscale=True)',
'RandomTilting(0.5), PixelNoise(25)')"""
db_web_images = """SyntheticPairDataset(
web_images,
'RandomScale(256,1024,can_upscale=True)',
'RandomTilting(0.5), PixelNoise(25)')"""
db_aachen_images = """SyntheticPairDataset(
aachen_db_images,
'RandomScale(256,1024,can_upscale=True)',
'RandomTilting(0.5), PixelNoise(25)')"""
db_aachen_style_transfer = """TransformedPairs(
aachen_style_transfer_pairs,
'RandomScale(256,1024,can_upscale=True), RandomTilting(0.5), PixelNoise(25)')"""
db_aachen_flow = "aachen_flow_pairs"
data_sources = dict(
D=toy_db_debug,
W=db_web_images,
A=db_aachen_images,
F=db_aachen_flow,
S=db_aachen_style_transfer,
)
default_dataloader = """PairLoader(CatPairDataset(`data`),
scale = 'RandomScale(256,1024,can_upscale=True)',
distort = 'ColorJitter(0.2,0.2,0.2,0.1)',
crop = 'RandomCrop(192)')"""
default_sampler = """NghSampler2(ngh=7, subq=-8, subd=1, pos_d=3, neg_d=5, border=16,
subd_neg=-8,maxpool_pos=True)"""
default_loss = """MultiLoss(
1, ReliabilityLoss(`sampler`, base=0.5, nq=20),
1, CosimLoss(N=`N`),
1, PeakyLoss(N=`N`))"""
class MyTrainer(trainer.Trainer):
"""This class implements the network training.
Below is the function I need to overload to explain how to do the backprop.
"""
def forward_backward(self, inputs):
output = self.net(imgs=[inputs.pop("img1"), inputs.pop("img2")])
allvars = dict(inputs, **output)
loss, details = self.loss_func(**allvars)
if torch.is_grad_enabled():
loss.backward()
return loss, details
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Train R2D2")
parser.add_argument("--data-loader", type=str, default=default_dataloader)
parser.add_argument(
"--train-data",
type=str,
default=list("WASF"),
nargs="+",
choices=set(data_sources.keys()),
)
parser.add_argument(
"--net", type=str, default=default_net, help="network architecture"
)
parser.add_argument(
"--pretrained", type=str, default="", help="pretrained model path"
)
parser.add_argument(
"--save-path", type=str, required=True, help="model save_path path"
)
parser.add_argument("--loss", type=str, default=default_loss, help="loss function")
parser.add_argument(
"--sampler", type=str, default=default_sampler, help="AP sampler"
)
parser.add_argument(
"--N", type=int, default=16, help="patch size for repeatability"
)
parser.add_argument(
"--epochs", type=int, default=25, help="number of training epochs"
)
parser.add_argument("--batch-size", "--bs", type=int, default=8, help="batch size")
parser.add_argument("--learning-rate", "--lr", type=str, default=1e-4)
parser.add_argument("--weight-decay", "--wd", type=float, default=5e-4)
parser.add_argument(
"--threads", type=int, default=8, help="number of worker threads"
)
parser.add_argument("--gpu", type=int, nargs="+", default=[0], help="-1 for CPU")
args = parser.parse_args()
iscuda = common.torch_set_gpu(args.gpu)
common.mkdir_for(args.save_path)
# Create data loader
from datasets import *
db = [data_sources[key] for key in args.train_data]
db = eval(args.data_loader.replace("`data`", ",".join(db)).replace("\n", ""))
print("Training image database =", db)
loader = threaded_loader(db, iscuda, args.threads, args.batch_size, shuffle=True)
# create network
print("\n>> Creating net = " + args.net)
net = eval(args.net)
print(f" ( Model size: {common.model_size(net)/1000:.0f}K parameters )")
# initialization
if args.pretrained:
checkpoint = torch.load(args.pretrained, lambda a, b: a)
net.load_pretrained(checkpoint["state_dict"])
# create losses
loss = args.loss.replace("`sampler`", args.sampler).replace("`N`", str(args.N))
print("\n>> Creating loss = " + loss)
loss = eval(loss.replace("\n", ""))
# create optimizer
optimizer = optim.Adam(
[p for p in net.parameters() if p.requires_grad],
lr=args.learning_rate,
weight_decay=args.weight_decay,
)
train = MyTrainer(net, loader, loss, optimizer)
if iscuda:
train = train.cuda()
# Training loop #
for epoch in range(args.epochs):
print(f"\n>> Starting epoch {epoch}...")
train()
print(f"\n>> Saving model to {args.save_path}")
torch.save({"net": args.net, "state_dict": net.state_dict()}, args.save_path)