|
from . import matchers |
|
from . import readers |
|
from . import evaluators |
|
from . import extractors |
|
|
|
def load_component(compo_name,model_name,config): |
|
if compo_name=='extractor': |
|
component=load_extractor(model_name,config) |
|
elif compo_name=='reader': |
|
component=load_reader(model_name,config) |
|
elif compo_name=='matcher': |
|
component=load_matcher(model_name,config) |
|
elif compo_name=='evaluator': |
|
component=load_evaluator(model_name,config) |
|
else: |
|
raise NotImplementedError |
|
return component |
|
|
|
|
|
def load_extractor(model_name,config): |
|
if model_name=='root': |
|
extractor =extractors.ExtractSIFT(config) |
|
elif model_name=='sp': |
|
extractor=extractors.ExtractSuperpoint(config) |
|
else: |
|
raise NotImplementedError |
|
return extractor |
|
|
|
def load_matcher(model_name,config): |
|
if model_name=='SGM': |
|
matcher=matchers.GNN_Matcher(config,'SGM') |
|
elif model_name=='SG': |
|
matcher=matchers.GNN_Matcher(config,'SG') |
|
elif model_name=='NN': |
|
matcher=matchers.NN_Matcher(config) |
|
else: |
|
raise NotImplementedError |
|
return matcher |
|
|
|
def load_reader(model_name,config): |
|
if model_name=='standard': |
|
reader=readers.standard_reader(config) |
|
else: |
|
raise NotImplementedError |
|
return reader |
|
|
|
def load_evaluator(model_name,config): |
|
if model_name=='AUC': |
|
evaluator=evaluators.auc_eval(config) |
|
elif model_name=='FM': |
|
evaluator=evaluators.FMbench_eval(config) |
|
return evaluator |
|
|