|
import os |
|
from torch.multiprocessing import Process, Manager, set_start_method, Pool |
|
import functools |
|
import argparse |
|
import yaml |
|
import numpy as np |
|
import sys |
|
import cv2 |
|
from tqdm import trange |
|
|
|
set_start_method("spawn", force=True) |
|
|
|
|
|
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
|
sys.path.insert(0, ROOT_DIR) |
|
|
|
from components import load_component |
|
from utils import evaluation_utils, metrics |
|
|
|
parser = argparse.ArgumentParser(description="dump eval data.") |
|
parser.add_argument( |
|
"--config_path", type=str, default="configs/eval/scannet_eval_sgm.yaml" |
|
) |
|
parser.add_argument("--num_process_match", type=int, default=4) |
|
parser.add_argument("--num_process_eval", type=int, default=4) |
|
parser.add_argument("--vis_folder", type=str, default=None) |
|
args = parser.parse_args() |
|
|
|
|
|
def feed_match(info, matcher): |
|
x1, x2, desc1, desc2, size1, size2 = ( |
|
info["x1"], |
|
info["x2"], |
|
info["desc1"], |
|
info["desc2"], |
|
info["img1"].shape[:2], |
|
info["img2"].shape[:2], |
|
) |
|
test_data = { |
|
"x1": x1, |
|
"x2": x2, |
|
"desc1": desc1, |
|
"desc2": desc2, |
|
"size1": np.flip(np.asarray(size1)), |
|
"size2": np.flip(np.asarray(size2)), |
|
} |
|
corr1, corr2 = matcher.run(test_data) |
|
return [corr1, corr2] |
|
|
|
|
|
def reader_handler(config, read_que): |
|
reader = load_component("reader", config["name"], config) |
|
for index in range(len(reader)): |
|
index += 0 |
|
info = reader.run(index) |
|
read_que.put(info) |
|
read_que.put("over") |
|
|
|
|
|
def match_handler(config, read_que, match_que): |
|
matcher = load_component("matcher", config["name"], config) |
|
match_func = functools.partial(feed_match, matcher=matcher) |
|
pool = Pool(args.num_process_match) |
|
cache = [] |
|
while True: |
|
item = read_que.get() |
|
|
|
if item == "over": |
|
if len(cache) != 0: |
|
results = pool.map(match_func, cache) |
|
for cur_item, cur_result in zip(cache, results): |
|
cur_item["corr1"], cur_item["corr2"] = cur_result[0], cur_result[1] |
|
match_que.put(cur_item) |
|
match_que.put("over") |
|
break |
|
cache.append(item) |
|
|
|
if len(cache) == args.num_process_match: |
|
|
|
results = pool.map(match_func, cache) |
|
for cur_item, cur_result in zip(cache, results): |
|
cur_item["corr1"], cur_item["corr2"] = cur_result[0], cur_result[1] |
|
match_que.put(cur_item) |
|
cache = [] |
|
pool.close() |
|
pool.join() |
|
|
|
|
|
def evaluate_handler(config, match_que): |
|
evaluator = load_component("evaluator", config["name"], config) |
|
pool = Pool(args.num_process_eval) |
|
cache = [] |
|
for _ in trange(config["num_pair"]): |
|
item = match_que.get() |
|
if item == "over": |
|
if len(cache) != 0: |
|
results = pool.map(evaluator.run, cache) |
|
for cur_res in results: |
|
evaluator.res_inqueue(cur_res) |
|
break |
|
cache.append(item) |
|
if len(cache) == args.num_process_eval: |
|
results = pool.map(evaluator.run, cache) |
|
for cur_res in results: |
|
evaluator.res_inqueue(cur_res) |
|
cache = [] |
|
if args.vis_folder is not None: |
|
|
|
corr1_norm, corr2_norm = evaluation_utils.normalize_intrinsic( |
|
item["corr1"], item["K1"] |
|
), evaluation_utils.normalize_intrinsic(item["corr2"], item["K2"]) |
|
inlier_mask = metrics.compute_epi_inlier( |
|
corr1_norm, corr2_norm, item["e"], config["inlier_th"] |
|
) |
|
display = evaluation_utils.draw_match( |
|
item["img1"], item["img2"], item["corr1"], item["corr2"], inlier_mask |
|
) |
|
cv2.imwrite( |
|
os.path.join(args.vis_folder, str(item["index"]) + ".png"), display |
|
) |
|
evaluator.parse() |
|
|
|
|
|
if __name__ == "__main__": |
|
with open(args.config_path, "r") as f: |
|
config = yaml.load(f) |
|
if args.vis_folder is not None and not os.path.exists(args.vis_folder): |
|
os.mkdir(args.vis_folder) |
|
|
|
read_que, match_que, estimate_que = ( |
|
Manager().Queue(maxsize=100), |
|
Manager().Queue(maxsize=100), |
|
Manager().Queue(maxsize=100), |
|
) |
|
|
|
read_process = Process(target=reader_handler, args=(config["reader"], read_que)) |
|
match_process = Process( |
|
target=match_handler, args=(config["matcher"], read_que, match_que) |
|
) |
|
evaluate_process = Process( |
|
target=evaluate_handler, args=(config["evaluator"], match_que) |
|
) |
|
|
|
read_process.start() |
|
match_process.start() |
|
evaluate_process.start() |
|
|
|
read_process.join() |
|
match_process.join() |
|
evaluate_process.join() |
|
|