File size: 1,584 Bytes
404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from . import matchers
from . import readers
from . import evaluators
from . import extractors
def load_component(compo_name, model_name, config):
if compo_name == "extractor":
component = load_extractor(model_name, config)
elif compo_name == "reader":
component = load_reader(model_name, config)
elif compo_name == "matcher":
component = load_matcher(model_name, config)
elif compo_name == "evaluator":
component = load_evaluator(model_name, config)
else:
raise NotImplementedError
return component
def load_extractor(model_name, config):
if model_name == "root":
extractor = extractors.ExtractSIFT(config)
elif model_name == "sp":
extractor = extractors.ExtractSuperpoint(config)
else:
raise NotImplementedError
return extractor
def load_matcher(model_name, config):
if model_name == "SGM":
matcher = matchers.GNN_Matcher(config, "SGM")
elif model_name == "SG":
matcher = matchers.GNN_Matcher(config, "SG")
elif model_name == "NN":
matcher = matchers.NN_Matcher(config)
else:
raise NotImplementedError
return matcher
def load_reader(model_name, config):
if model_name == "standard":
reader = readers.standard_reader(config)
else:
raise NotImplementedError
return reader
def load_evaluator(model_name, config):
if model_name == "AUC":
evaluator = evaluators.auc_eval(config)
elif model_name == "FM":
evaluator = evaluators.FMbench_eval(config)
return evaluator
|