from . import matchers
from . import readers
from . import evaluators
from . import extractors


def load_component(compo_name, model_name, config):
    if compo_name == "extractor":
        component = load_extractor(model_name, config)
    elif compo_name == "reader":
        component = load_reader(model_name, config)
    elif compo_name == "matcher":
        component = load_matcher(model_name, config)
    elif compo_name == "evaluator":
        component = load_evaluator(model_name, config)
    else:
        raise NotImplementedError
    return component


def load_extractor(model_name, config):
    if model_name == "root":
        extractor = extractors.ExtractSIFT(config)
    elif model_name == "sp":
        extractor = extractors.ExtractSuperpoint(config)
    else:
        raise NotImplementedError
    return extractor


def load_matcher(model_name, config):
    if model_name == "SGM":
        matcher = matchers.GNN_Matcher(config, "SGM")
    elif model_name == "SG":
        matcher = matchers.GNN_Matcher(config, "SG")
    elif model_name == "NN":
        matcher = matchers.NN_Matcher(config)
    else:
        raise NotImplementedError
    return matcher


def load_reader(model_name, config):
    if model_name == "standard":
        reader = readers.standard_reader(config)
    else:
        raise NotImplementedError
    return reader


def load_evaluator(model_name, config):
    if model_name == "AUC":
        evaluator = evaluators.auc_eval(config)
    elif model_name == "FM":
        evaluator = evaluators.FMbench_eval(config)
    return evaluator