Vincentqyw
fix: roma
358ab8f
raw
history blame
21.1 kB
import os
import cv2
import time
import yaml
import torch
import datetime
from tensorboardX import SummaryWriter
import torchvision.transforms as tvf
import torch.nn as nn
import torch.nn.functional as F
from nets.geom import getK, getWarp, _grid_positions, getWarpNoValidate
from nets.loss import make_detector_loss, make_noise_score_map_loss
from nets.score import extract_kpts
from nets.multi_sampler import MultiSampler
from nets.noise_reliability_loss import MultiPixelAPLoss
from datasets.noise_simulator import NoiseSimulator
from nets.l2net import Quad_L2Net
class Trainer:
def __init__(self, config, device, loader, job_name, start_cnt):
self.config = config
self.device = device
self.loader = loader
# tensorboard writer construction
os.makedirs("./runs/", exist_ok=True)
if job_name != "":
self.log_dir = f"runs/{job_name}"
else:
self.log_dir = f'runs/{datetime.datetime.now().strftime("%m-%d-%H%M%S")}'
self.writer = SummaryWriter(self.log_dir)
with open(f"{self.log_dir}/config.yaml", "w") as f:
yaml.dump(config, f)
if config["network"]["input_type"] == "gray":
self.model = eval(f'{config["network"]["model"]}(inchan=1)').to(device)
elif (
config["network"]["input_type"] == "rgb"
or config["network"]["input_type"] == "raw-demosaic"
):
self.model = eval(f'{config["network"]["model"]}(inchan=3)').to(device)
elif config["network"]["input_type"] == "raw":
self.model = eval(f'{config["network"]["model"]}(inchan=4)').to(device)
else:
raise NotImplementedError()
# noise maker
self.noise_maker = NoiseSimulator(device)
# reliability map conv
self.model.clf = nn.Conv2d(128, 2, kernel_size=1).cuda()
# load model
self.cnt = 0
if start_cnt != 0:
self.model.load_state_dict(
torch.load(
f"{self.log_dir}/model_{start_cnt:06d}.pth", map_location=device
)
)
self.cnt = start_cnt + 1
# sampler
sampler = MultiSampler(
ngh=7,
subq=-8,
subd=1,
pos_d=3,
neg_d=5,
border=16,
subd_neg=-8,
maxpool_pos=True,
).to(device)
self.reliability_relitive_loss = MultiPixelAPLoss(sampler, nq=20).to(device)
# optimizer and scheduler
if self.config["training"]["optimizer"] == "SGD":
self.optimizer = torch.optim.SGD(
[
{
"params": self.model.parameters(),
"initial_lr": self.config["training"]["lr"],
}
],
lr=self.config["training"]["lr"],
momentum=self.config["training"]["momentum"],
weight_decay=self.config["training"]["weight_decay"],
)
elif self.config["training"]["optimizer"] == "Adam":
self.optimizer = torch.optim.Adam(
[
{
"params": self.model.parameters(),
"initial_lr": self.config["training"]["lr"],
}
],
lr=self.config["training"]["lr"],
weight_decay=self.config["training"]["weight_decay"],
)
else:
raise NotImplementedError()
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer,
step_size=self.config["training"]["lr_step"],
gamma=self.config["training"]["lr_gamma"],
last_epoch=start_cnt,
)
for param_tensor in self.model.state_dict():
print(param_tensor, "\t", self.model.state_dict()[param_tensor].size())
def save(self, iter_num):
torch.save(self.model.state_dict(), f"{self.log_dir}/model_{iter_num:06d}.pth")
def load(self, path):
self.model.load_state_dict(torch.load(path))
def train(self):
self.model.train()
for epoch in range(2):
for batch_idx, inputs in enumerate(self.loader):
self.optimizer.zero_grad()
t = time.time()
# preprocess and add noise
img0_ori, noise_img0_ori = self.preprocess_noise_pair(
inputs["img0"], self.cnt
)
img1_ori, noise_img1_ori = self.preprocess_noise_pair(
inputs["img1"], self.cnt
)
img0 = img0_ori.permute(0, 3, 1, 2).float().to(self.device)
img1 = img1_ori.permute(0, 3, 1, 2).float().to(self.device)
noise_img0 = noise_img0_ori.permute(0, 3, 1, 2).float().to(self.device)
noise_img1 = noise_img1_ori.permute(0, 3, 1, 2).float().to(self.device)
if self.config["network"]["input_type"] == "rgb":
# 3-channel rgb
RGB_mean = [0.485, 0.456, 0.406]
RGB_std = [0.229, 0.224, 0.225]
norm_RGB = tvf.Normalize(mean=RGB_mean, std=RGB_std)
img0 = norm_RGB(img0)
img1 = norm_RGB(img1)
noise_img0 = norm_RGB(noise_img0)
noise_img1 = norm_RGB(noise_img1)
elif self.config["network"]["input_type"] == "gray":
# 1-channel
img0 = torch.mean(img0, dim=1, keepdim=True)
img1 = torch.mean(img1, dim=1, keepdim=True)
noise_img0 = torch.mean(noise_img0, dim=1, keepdim=True)
noise_img1 = torch.mean(noise_img1, dim=1, keepdim=True)
norm_gray0 = tvf.Normalize(mean=img0.mean(), std=img0.std())
norm_gray1 = tvf.Normalize(mean=img1.mean(), std=img1.std())
img0 = norm_gray0(img0)
img1 = norm_gray1(img1)
noise_img0 = norm_gray0(noise_img0)
noise_img1 = norm_gray1(noise_img1)
elif self.config["network"]["input_type"] == "raw":
# 4-channel
pass
elif self.config["network"]["input_type"] == "raw-demosaic":
# 3-channel
pass
else:
raise NotImplementedError()
desc0, score_map0, _, _ = self.model(img0)
desc1, score_map1, _, _ = self.model(img1)
conf0 = F.softmax(self.model.clf(torch.abs(desc0) ** 2.0), dim=1)[
:, 1:2
]
conf1 = F.softmax(self.model.clf(torch.abs(desc1) ** 2.0), dim=1)[
:, 1:2
]
noise_desc0, noise_score_map0, noise_at0, noise_att0 = self.model(
noise_img0
)
noise_desc1, noise_score_map1, noise_at1, noise_att1 = self.model(
noise_img1
)
noise_conf0 = F.softmax(
self.model.clf(torch.abs(noise_desc0) ** 2.0), dim=1
)[:, 1:2]
noise_conf1 = F.softmax(
self.model.clf(torch.abs(noise_desc1) ** 2.0), dim=1
)[:, 1:2]
cur_feat_size0 = torch.tensor(score_map0.shape[2:])
cur_feat_size1 = torch.tensor(score_map1.shape[2:])
desc0 = desc0.permute(0, 2, 3, 1)
desc1 = desc1.permute(0, 2, 3, 1)
score_map0 = score_map0.permute(0, 2, 3, 1)
score_map1 = score_map1.permute(0, 2, 3, 1)
noise_desc0 = noise_desc0.permute(0, 2, 3, 1)
noise_desc1 = noise_desc1.permute(0, 2, 3, 1)
noise_score_map0 = noise_score_map0.permute(0, 2, 3, 1)
noise_score_map1 = noise_score_map1.permute(0, 2, 3, 1)
conf0 = conf0.permute(0, 2, 3, 1)
conf1 = conf1.permute(0, 2, 3, 1)
noise_conf0 = noise_conf0.permute(0, 2, 3, 1)
noise_conf1 = noise_conf1.permute(0, 2, 3, 1)
r_K0 = getK(inputs["ori_img_size0"], cur_feat_size0, inputs["K0"]).to(
self.device
)
r_K1 = getK(inputs["ori_img_size1"], cur_feat_size1, inputs["K1"]).to(
self.device
)
pos0 = _grid_positions(
cur_feat_size0[0], cur_feat_size0[1], img0.shape[0]
).to(self.device)
pos0_for_rel, pos1_for_rel, _ = getWarpNoValidate(
pos0,
inputs["rel_pose"].to(self.device),
inputs["depth0"].to(self.device),
r_K0,
inputs["depth1"].to(self.device),
r_K1,
img0.shape[0],
)
pos0, pos1, _ = getWarp(
pos0,
inputs["rel_pose"].to(self.device),
inputs["depth0"].to(self.device),
r_K0,
inputs["depth1"].to(self.device),
r_K1,
img0.shape[0],
)
reliab_loss_relative = self.reliability_relitive_loss(
desc0,
desc1,
noise_desc0,
noise_desc1,
conf0,
conf1,
noise_conf0,
noise_conf1,
pos0_for_rel,
pos1_for_rel,
img0.shape[0],
img0.shape[2],
img0.shape[3],
)
det_structured_loss, det_accuracy = make_detector_loss(
pos0,
pos1,
desc0,
desc1,
score_map0,
score_map1,
img0.shape[0],
self.config["network"]["use_corr_n"],
self.config["network"]["loss_type"],
self.config,
)
det_structured_loss_noise, det_accuracy_noise = make_detector_loss(
pos0,
pos1,
noise_desc0,
noise_desc1,
noise_score_map0,
noise_score_map1,
img0.shape[0],
self.config["network"]["use_corr_n"],
self.config["network"]["loss_type"],
self.config,
)
indices0, scores0 = extract_kpts(
score_map0.permute(0, 3, 1, 2),
k=self.config["network"]["det"]["kpt_n"],
score_thld=self.config["network"]["det"]["score_thld"],
nms_size=self.config["network"]["det"]["nms_size"],
eof_size=self.config["network"]["det"]["eof_size"],
edge_thld=self.config["network"]["det"]["edge_thld"],
)
indices1, scores1 = extract_kpts(
score_map1.permute(0, 3, 1, 2),
k=self.config["network"]["det"]["kpt_n"],
score_thld=self.config["network"]["det"]["score_thld"],
nms_size=self.config["network"]["det"]["nms_size"],
eof_size=self.config["network"]["det"]["eof_size"],
edge_thld=self.config["network"]["det"]["edge_thld"],
)
noise_score_loss0, mask0 = make_noise_score_map_loss(
score_map0, noise_score_map0, indices0, img0.shape[0], thld=0.1
)
noise_score_loss1, mask1 = make_noise_score_map_loss(
score_map1, noise_score_map1, indices1, img1.shape[0], thld=0.1
)
total_loss = det_structured_loss + det_structured_loss_noise
total_loss += noise_score_loss0 / 2.0 * 1.0
total_loss += noise_score_loss1 / 2.0 * 1.0
total_loss += reliab_loss_relative[0] / 2.0 * 0.5
total_loss += reliab_loss_relative[1] / 2.0 * 0.5
self.writer.add_scalar("acc/normal_acc", det_accuracy, self.cnt)
self.writer.add_scalar("acc/noise_acc", det_accuracy_noise, self.cnt)
self.writer.add_scalar("loss/total_loss", total_loss, self.cnt)
self.writer.add_scalar(
"loss/noise_score_loss",
(noise_score_loss0 + noise_score_loss1) / 2.0,
self.cnt,
)
self.writer.add_scalar(
"loss/det_loss_normal", det_structured_loss, self.cnt
)
self.writer.add_scalar(
"loss/det_loss_noise", det_structured_loss_noise, self.cnt
)
print(
"iter={},\tloss={:.4f},\tacc={:.4f},\t{:.4f}s/iter".format(
self.cnt, total_loss, det_accuracy, time.time() - t
)
)
# print(f'normal_loss: {det_structured_loss}, noise_loss: {det_structured_loss_noise}, reliab_loss: {reliab_loss_relative[0]}, {reliab_loss_relative[1]}')
if det_structured_loss != 0:
total_loss.backward()
self.optimizer.step()
self.lr_scheduler.step()
if self.cnt % 100 == 0:
noise_indices0, noise_scores0 = extract_kpts(
noise_score_map0.permute(0, 3, 1, 2),
k=self.config["network"]["det"]["kpt_n"],
score_thld=self.config["network"]["det"]["score_thld"],
nms_size=self.config["network"]["det"]["nms_size"],
eof_size=self.config["network"]["det"]["eof_size"],
edge_thld=self.config["network"]["det"]["edge_thld"],
)
noise_indices1, noise_scores1 = extract_kpts(
noise_score_map1.permute(0, 3, 1, 2),
k=self.config["network"]["det"]["kpt_n"],
score_thld=self.config["network"]["det"]["score_thld"],
nms_size=self.config["network"]["det"]["nms_size"],
eof_size=self.config["network"]["det"]["eof_size"],
edge_thld=self.config["network"]["det"]["edge_thld"],
)
if self.config["network"]["input_type"] == "raw":
kpt_img0 = self.showKeyPoints(
img0_ori[0][..., :3] * 255.0, indices0[0]
)
kpt_img1 = self.showKeyPoints(
img1_ori[0][..., :3] * 255.0, indices1[0]
)
noise_kpt_img0 = self.showKeyPoints(
noise_img0_ori[0][..., :3] * 255.0, noise_indices0[0]
)
noise_kpt_img1 = self.showKeyPoints(
noise_img1_ori[0][..., :3] * 255.0, noise_indices1[0]
)
else:
kpt_img0 = self.showKeyPoints(img0_ori[0] * 255.0, indices0[0])
kpt_img1 = self.showKeyPoints(img1_ori[0] * 255.0, indices1[0])
noise_kpt_img0 = self.showKeyPoints(
noise_img0_ori[0] * 255.0, noise_indices0[0]
)
noise_kpt_img1 = self.showKeyPoints(
noise_img1_ori[0] * 255.0, noise_indices1[0]
)
self.writer.add_image(
"img0/kpts", kpt_img0, self.cnt, dataformats="HWC"
)
self.writer.add_image(
"img1/kpts", kpt_img1, self.cnt, dataformats="HWC"
)
self.writer.add_image(
"img0/noise_kpts", noise_kpt_img0, self.cnt, dataformats="HWC"
)
self.writer.add_image(
"img1/noise_kpts", noise_kpt_img1, self.cnt, dataformats="HWC"
)
self.writer.add_image(
"img0/score_map", score_map0[0], self.cnt, dataformats="HWC"
)
self.writer.add_image(
"img1/score_map", score_map1[0], self.cnt, dataformats="HWC"
)
self.writer.add_image(
"img0/noise_score_map",
noise_score_map0[0],
self.cnt,
dataformats="HWC",
)
self.writer.add_image(
"img1/noise_score_map",
noise_score_map1[0],
self.cnt,
dataformats="HWC",
)
self.writer.add_image(
"img0/kpt_mask", mask0.unsqueeze(2), self.cnt, dataformats="HWC"
)
self.writer.add_image(
"img1/kpt_mask", mask1.unsqueeze(2), self.cnt, dataformats="HWC"
)
self.writer.add_image(
"img0/conf", conf0[0], self.cnt, dataformats="HWC"
)
self.writer.add_image(
"img1/conf", conf1[0], self.cnt, dataformats="HWC"
)
self.writer.add_image(
"img0/noise_conf", noise_conf0[0], self.cnt, dataformats="HWC"
)
self.writer.add_image(
"img1/noise_conf", noise_conf1[0], self.cnt, dataformats="HWC"
)
if self.cnt % 5000 == 0:
self.save(self.cnt)
self.cnt += 1
def showKeyPoints(self, img, indices):
key_points = cv2.KeyPoint_convert(indices.cpu().float().numpy()[:, ::-1])
img = img.numpy().astype("uint8")
img = cv2.drawKeypoints(img, key_points, None, color=(0, 255, 0))
return img
def preprocess(self, img, iter_idx):
if (
not self.config["network"]["noise"]
and "raw" not in self.config["network"]["input_type"]
):
return img
raw = self.noise_maker.rgb2raw(img, batched=True)
if self.config["network"]["noise"]:
ratio_dec = (
min(self.config["network"]["noise_maxstep"], iter_idx)
/ self.config["network"]["noise_maxstep"]
)
raw = self.noise_maker.raw2noisyRaw(raw, ratio_dec=ratio_dec, batched=True)
if self.config["network"]["input_type"] == "raw":
return torch.tensor(self.noise_maker.raw2packedRaw(raw, batched=True))
if self.config["network"]["input_type"] == "raw-demosaic":
return torch.tensor(self.noise_maker.raw2demosaicRaw(raw, batched=True))
rgb = self.noise_maker.raw2rgb(raw, batched=True)
if (
self.config["network"]["input_type"] == "rgb"
or self.config["network"]["input_type"] == "gray"
):
return torch.tensor(rgb)
raise NotImplementedError()
def preprocess_noise_pair(self, img, iter_idx):
assert self.config["network"]["noise"]
raw = self.noise_maker.rgb2raw(img, batched=True)
ratio_dec = (
min(self.config["network"]["noise_maxstep"], iter_idx)
/ self.config["network"]["noise_maxstep"]
)
noise_raw = self.noise_maker.raw2noisyRaw(
raw, ratio_dec=ratio_dec, batched=True
)
if self.config["network"]["input_type"] == "raw":
return torch.tensor(
self.noise_maker.raw2packedRaw(raw, batched=True)
), torch.tensor(self.noise_maker.raw2packedRaw(noise_raw, batched=True))
if self.config["network"]["input_type"] == "raw-demosaic":
return torch.tensor(
self.noise_maker.raw2demosaicRaw(raw, batched=True)
), torch.tensor(self.noise_maker.raw2demosaicRaw(noise_raw, batched=True))
noise_rgb = self.noise_maker.raw2rgb(noise_raw, batched=True)
if (
self.config["network"]["input_type"] == "rgb"
or self.config["network"]["input_type"] == "gray"
):
return img, torch.tensor(noise_rgb)
raise NotImplementedError()