|
import torch |
|
import numpy as np |
|
import cv2 |
|
import os |
|
from loss import batch_episym |
|
from tqdm import tqdm |
|
|
|
import sys |
|
|
|
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
|
sys.path.insert(0, ROOT_DIR) |
|
|
|
from utils import evaluation_utils, train_utils |
|
|
|
|
|
def valid(valid_loader, model, match_loss, config, model_config): |
|
model.eval() |
|
loader_iter = iter(valid_loader) |
|
num_pair = 0 |
|
total_loss, total_acc_corr, total_acc_incorr = 0, 0, 0 |
|
total_precision, total_recall = torch.zeros( |
|
model_config.layer_num, device="cuda" |
|
), torch.zeros(model_config.layer_num, device="cuda") |
|
total_acc_mid = torch.zeros(len(model_config.seedlayer) - 1, device="cuda") |
|
|
|
with torch.no_grad(): |
|
if config.local_rank == 0: |
|
loader_iter = tqdm(loader_iter) |
|
print("validating...") |
|
for test_data in loader_iter: |
|
num_pair += 1 |
|
test_data = train_utils.tocuda(test_data) |
|
res = model(test_data) |
|
loss_res = match_loss.run(test_data, res) |
|
|
|
total_acc_corr += loss_res["acc_corr"] |
|
total_acc_incorr += loss_res["acc_incorr"] |
|
total_loss += loss_res["total_loss"] |
|
|
|
if config.model_name == "SGM": |
|
total_acc_mid += loss_res["mid_acc_corr"] |
|
total_precision, total_recall = ( |
|
total_precision + loss_res["pre_seed_conf"], |
|
total_recall + loss_res["recall_seed_conf"], |
|
) |
|
|
|
total_acc_corr /= num_pair |
|
total_acc_incorr /= num_pair |
|
total_precision /= num_pair |
|
total_recall /= num_pair |
|
total_acc_mid /= num_pair |
|
|
|
|
|
( |
|
total_loss, |
|
total_acc_corr, |
|
total_acc_incorr, |
|
total_precision, |
|
total_recall, |
|
total_acc_mid, |
|
) = ( |
|
train_utils.reduce_tensor(total_loss, "sum"), |
|
train_utils.reduce_tensor(total_acc_corr, "mean"), |
|
train_utils.reduce_tensor(total_acc_incorr, "mean"), |
|
train_utils.reduce_tensor(total_precision, "mean"), |
|
train_utils.reduce_tensor(total_recall, "mean"), |
|
train_utils.reduce_tensor(total_acc_mid, "mean"), |
|
) |
|
model.train() |
|
return ( |
|
total_loss, |
|
total_acc_corr, |
|
total_acc_incorr, |
|
total_precision, |
|
total_recall, |
|
total_acc_mid, |
|
) |
|
|
|
|
|
def dump_train_vis(res, data, step, config): |
|
|
|
p = res["p"][:, :-1, :-1] |
|
score, index1 = torch.max(p, dim=-1) |
|
_, index2 = torch.max(p, dim=-2) |
|
mask_th = score > 0.2 |
|
mask_mc = index2.gather(index=index1, dim=1) == torch.arange(len(p[0])).cuda()[None] |
|
mask_p = mask_th & mask_mc |
|
|
|
corr1, corr2 = data["x1"], data["x2"].gather( |
|
index=index1[:, :, None].expand(-1, -1, 2), dim=1 |
|
) |
|
corr1_kpt, corr2_kpt = data["kpt1"], data["kpt2"].gather( |
|
index=index1[:, :, None].expand(-1, -1, 2), dim=1 |
|
) |
|
epi_dis = batch_episym(corr1, corr2, data["e_gt"]) |
|
mask_inlier = epi_dis < config.inlier_th |
|
|
|
|
|
for cur_mask_p, cur_mask_inlier, cur_corr1, cur_corr2, img_path1, img_path2 in zip( |
|
mask_p, mask_inlier, corr1_kpt, corr2_kpt, data["img_path1"], data["img_path2"] |
|
): |
|
img1, img2 = cv2.imread(img_path1), cv2.imread(img_path2) |
|
dis_play = evaluation_utils.draw_match( |
|
img1, |
|
img2, |
|
cur_corr1[cur_mask_p].cpu().numpy(), |
|
cur_corr2[cur_mask_p].cpu().numpy(), |
|
inlier=cur_mask_inlier, |
|
) |
|
base_name_seq = os.path.join( |
|
img_path1.split("/")[-1] |
|
+ "_" |
|
+ img_path2.split("/")[-1] |
|
+ "_" |
|
+ img_path1.split("/")[-2] |
|
) |
|
save_path = os.path.join( |
|
config.train_vis_folder, |
|
"train_vis", |
|
config.log_base, |
|
str(step), |
|
base_name_seq + ".png", |
|
) |
|
cv2.imwrite(save_path, dis_play) |
|
|