Spaces:
Running
Running
import argparse | |
from pathlib import Path | |
from pprint import pprint | |
from typing import Optional | |
import pkg_resources | |
from omegaconf import OmegaConf | |
from ..models import get_model | |
from ..settings import TRAINING_PATH | |
from ..utils.experiments import load_experiment | |
def parse_config_path(name_or_path: Optional[str], defaults: str) -> Path: | |
default_configs = {} | |
for c in pkg_resources.resource_listdir("gluefactory", str(defaults)): | |
if c.endswith(".yaml"): | |
default_configs[Path(c).stem] = Path( | |
pkg_resources.resource_filename("gluefactory", defaults + c) | |
) | |
if name_or_path is None: | |
return None | |
if name_or_path in default_configs: | |
return default_configs[name_or_path] | |
path = Path(name_or_path) | |
if not path.exists(): | |
raise FileNotFoundError( | |
f"Cannot find the config file: {name_or_path}. " | |
f"Not in the default configs {list(default_configs.keys())} " | |
"and not an existing path." | |
) | |
return Path(path) | |
def extract_benchmark_conf(conf, benchmark): | |
mconf = OmegaConf.create( | |
{ | |
"model": conf.get("model", {}), | |
} | |
) | |
if "benchmarks" in conf.keys(): | |
return OmegaConf.merge(mconf, conf.benchmarks.get(benchmark, {})) | |
else: | |
return mconf | |
def parse_eval_args(benchmark, args, configs_path, default=None): | |
conf = {"data": {}, "model": {}, "eval": {}} | |
if args.conf: | |
conf_path = parse_config_path(args.conf, configs_path) | |
custom_conf = OmegaConf.load(conf_path) | |
conf = extract_benchmark_conf(OmegaConf.merge(conf, custom_conf), benchmark) | |
args.tag = ( | |
args.tag if args.tag is not None else conf_path.name.replace(".yaml", "") | |
) | |
cli_conf = OmegaConf.from_cli(args.dotlist) | |
conf = OmegaConf.merge(conf, cli_conf) | |
conf.checkpoint = args.checkpoint if args.checkpoint else conf.get("checkpoint") | |
if conf.checkpoint and not conf.checkpoint.endswith(".tar"): | |
checkpoint_conf = OmegaConf.load( | |
TRAINING_PATH / conf.checkpoint / "config.yaml" | |
) | |
conf = OmegaConf.merge(extract_benchmark_conf(checkpoint_conf, benchmark), conf) | |
if default: | |
conf = OmegaConf.merge(default, conf) | |
if args.tag is not None: | |
name = args.tag | |
elif args.conf and conf.checkpoint: | |
name = f"{args.conf}_{conf.checkpoint}" | |
elif args.conf: | |
name = args.conf | |
elif conf.checkpoint: | |
name = conf.checkpoint | |
if len(args.dotlist) > 0 and not args.tag: | |
name = name + "_" + ":".join(args.dotlist) | |
print("Running benchmark:", benchmark) | |
print("Experiment tag:", name) | |
print("Config:") | |
pprint(OmegaConf.to_container(conf)) | |
return name, conf | |
def load_model(model_conf, checkpoint): | |
if checkpoint: | |
model = load_experiment(checkpoint, conf=model_conf).eval() | |
else: | |
model = get_model("two_view_pipeline")(model_conf).eval() | |
if not model.is_initialized(): | |
raise ValueError( | |
"The provided model has non-initialized parameters. " | |
+ "Try to load a checkpoint instead." | |
) | |
return model | |
def get_eval_parser(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--tag", type=str, default=None) | |
parser.add_argument("--checkpoint", type=str, default=None) | |
parser.add_argument("--conf", type=str, default=None) | |
parser.add_argument("--overwrite", action="store_true") | |
parser.add_argument("--overwrite_eval", action="store_true") | |
parser.add_argument("--plot", action="store_true") | |
parser.add_argument("dotlist", nargs="*") | |
return parser | |