File size: 832 Bytes
37aeb5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
from omegaconf import DictConfig, OmegaConf
def parse_structured(fields, cfg) -> DictConfig:
scfg = OmegaConf.structured(fields(**cfg))
return scfg
def load_config(fields, config, extras=None):
if extras is not None:
print("Warning! extra parameter in cli is not verified, may cause erros.")
if isinstance(config, str):
cfg = OmegaConf.load(config)
elif isinstance(config, dict):
cfg = OmegaConf.create(config)
elif isinstance(config, DictConfig):
cfg = config
else:
raise NotImplementedError(f"Unsupported config type {type(config)}")
if extras is not None:
cli_conf = OmegaConf.from_cli(extras)
cfg = OmegaConf.merge(cfg, cli_conf)
OmegaConf.resolve(cfg)
assert isinstance(cfg, DictConfig)
return parse_structured(fields, cfg) |