Vincentqyw
fix: roma
358ab8f
raw
history blame
4.07 kB
import torch
import numpy as np
import cv2
import os
from loss import batch_episym
from tqdm import tqdm
import sys
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, ROOT_DIR)
from utils import evaluation_utils, train_utils
def valid(valid_loader, model, match_loss, config, model_config):
model.eval()
loader_iter = iter(valid_loader)
num_pair = 0
total_loss, total_acc_corr, total_acc_incorr = 0, 0, 0
total_precision, total_recall = torch.zeros(
model_config.layer_num, device="cuda"
), torch.zeros(model_config.layer_num, device="cuda")
total_acc_mid = torch.zeros(len(model_config.seedlayer) - 1, device="cuda")
with torch.no_grad():
if config.local_rank == 0:
loader_iter = tqdm(loader_iter)
print("validating...")
for test_data in loader_iter:
num_pair += 1
test_data = train_utils.tocuda(test_data)
res = model(test_data)
loss_res = match_loss.run(test_data, res)
total_acc_corr += loss_res["acc_corr"]
total_acc_incorr += loss_res["acc_incorr"]
total_loss += loss_res["total_loss"]
if config.model_name == "SGM":
total_acc_mid += loss_res["mid_acc_corr"]
total_precision, total_recall = (
total_precision + loss_res["pre_seed_conf"],
total_recall + loss_res["recall_seed_conf"],
)
total_acc_corr /= num_pair
total_acc_incorr /= num_pair
total_precision /= num_pair
total_recall /= num_pair
total_acc_mid /= num_pair
# apply tensor reduction
(
total_loss,
total_acc_corr,
total_acc_incorr,
total_precision,
total_recall,
total_acc_mid,
) = (
train_utils.reduce_tensor(total_loss, "sum"),
train_utils.reduce_tensor(total_acc_corr, "mean"),
train_utils.reduce_tensor(total_acc_incorr, "mean"),
train_utils.reduce_tensor(total_precision, "mean"),
train_utils.reduce_tensor(total_recall, "mean"),
train_utils.reduce_tensor(total_acc_mid, "mean"),
)
model.train()
return (
total_loss,
total_acc_corr,
total_acc_incorr,
total_precision,
total_recall,
total_acc_mid,
)
def dump_train_vis(res, data, step, config):
# batch matching
p = res["p"][:, :-1, :-1]
score, index1 = torch.max(p, dim=-1)
_, index2 = torch.max(p, dim=-2)
mask_th = score > 0.2
mask_mc = index2.gather(index=index1, dim=1) == torch.arange(len(p[0])).cuda()[None]
mask_p = mask_th & mask_mc # B*N
corr1, corr2 = data["x1"], data["x2"].gather(
index=index1[:, :, None].expand(-1, -1, 2), dim=1
)
corr1_kpt, corr2_kpt = data["kpt1"], data["kpt2"].gather(
index=index1[:, :, None].expand(-1, -1, 2), dim=1
)
epi_dis = batch_episym(corr1, corr2, data["e_gt"])
mask_inlier = epi_dis < config.inlier_th # B*N
# dump vis
for cur_mask_p, cur_mask_inlier, cur_corr1, cur_corr2, img_path1, img_path2 in zip(
mask_p, mask_inlier, corr1_kpt, corr2_kpt, data["img_path1"], data["img_path2"]
):
img1, img2 = cv2.imread(img_path1), cv2.imread(img_path2)
dis_play = evaluation_utils.draw_match(
img1,
img2,
cur_corr1[cur_mask_p].cpu().numpy(),
cur_corr2[cur_mask_p].cpu().numpy(),
inlier=cur_mask_inlier,
)
base_name_seq = os.path.join(
img_path1.split("/")[-1]
+ "_"
+ img_path2.split("/")[-1]
+ "_"
+ img_path1.split("/")[-2]
)
save_path = os.path.join(
config.train_vis_folder,
"train_vis",
config.log_base,
str(step),
base_name_seq + ".png",
)
cv2.imwrite(save_path, dis_play)