Spaces:
Running
Running
import os | |
from argparse import ArgumentParser, RawDescriptionHelpFormatter | |
from collections.abc import Mapping | |
import yaml | |
__all__ = ['Config'] | |
class ArgsParser(ArgumentParser): | |
def __init__(self): | |
super(ArgsParser, | |
self).__init__(formatter_class=RawDescriptionHelpFormatter) | |
self.add_argument('-o', | |
'--opt', | |
nargs='*', | |
help='set configuration options') | |
self.add_argument('--local_rank') | |
def parse_args(self, argv=None): | |
args = super(ArgsParser, self).parse_args(argv) | |
assert args.config is not None, 'Please specify --config=configure_file_path.' | |
args.opt = self._parse_opt(args.opt) | |
return args | |
def _parse_opt(self, opts): | |
config = {} | |
if not opts: | |
return config | |
for s in opts: | |
s = s.strip() | |
k, v = s.split('=', 1) | |
if '.' not in k: | |
config[k] = yaml.load(v, Loader=yaml.Loader) | |
else: | |
keys = k.split('.') | |
if keys[0] not in config: | |
config[keys[0]] = {} | |
cur = config[keys[0]] | |
for idx, key in enumerate(keys[1:]): | |
if idx == len(keys) - 2: | |
cur[key] = yaml.load(v, Loader=yaml.Loader) | |
else: | |
cur[key] = {} | |
cur = cur[key] | |
return config | |
class AttrDict(dict): | |
"""Single level attribute dict, NOT recursive.""" | |
def __init__(self, **kwargs): | |
super(AttrDict, self).__init__() | |
super(AttrDict, self).update(kwargs) | |
def __getattr__(self, key): | |
if key in self: | |
return self[key] | |
raise AttributeError("object has no attribute '{}'".format(key)) | |
def _merge_dict(config, merge_dct): | |
"""Recursive dict merge. Inspired by :meth:``dict.update()``, instead of | |
updating only top-level keys, dict_merge recurses down into dicts nested to | |
an arbitrary depth, updating keys. The ``merge_dct`` is merged into | |
``dct``. | |
Args: | |
config: dict onto which the merge is executed | |
merge_dct: dct merged into config | |
Returns: dct | |
""" | |
for key, value in merge_dct.items(): | |
sub_keys = key.split('.') | |
key = sub_keys[0] | |
if key in config and len(sub_keys) > 1: | |
_merge_dict(config[key], {'.'.join(sub_keys[1:]): value}) | |
elif key in config and isinstance(config[key], dict) and isinstance( | |
value, Mapping): | |
_merge_dict(config[key], value) | |
else: | |
config[key] = value | |
return config | |
def print_dict(cfg, print_func=print, delimiter=0): | |
"""Recursively visualize a dict and indenting acrrording by the | |
relationship of keys.""" | |
for k, v in sorted(cfg.items()): | |
if isinstance(v, dict): | |
print_func('{}{} : '.format(delimiter * ' ', str(k))) | |
print_dict(v, print_func, delimiter + 4) | |
elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict): | |
print_func('{}{} : '.format(delimiter * ' ', str(k))) | |
for value in v: | |
print_dict(value, print_func, delimiter + 4) | |
else: | |
print_func('{}{} : {}'.format(delimiter * ' ', k, v)) | |
class Config(object): | |
def __init__(self, config_path, BASE_KEY='_BASE_'): | |
self.BASE_KEY = BASE_KEY | |
self.cfg = self._load_config_with_base(config_path) | |
def _load_config_with_base(self, file_path): | |
"""Load config from file. | |
Args: | |
file_path (str): Path of the config file to be loaded. | |
Returns: global config | |
""" | |
_, ext = os.path.splitext(file_path) | |
assert ext in ['.yml', '.yaml'], 'only support yaml files for now' | |
with open(file_path) as f: | |
file_cfg = yaml.load(f, Loader=yaml.Loader) | |
# NOTE: cfgs outside have higher priority than cfgs in _BASE_ | |
if self.BASE_KEY in file_cfg: | |
all_base_cfg = AttrDict() | |
base_ymls = list(file_cfg[self.BASE_KEY]) | |
for base_yml in base_ymls: | |
if base_yml.startswith('~'): | |
base_yml = os.path.expanduser(base_yml) | |
if not base_yml.startswith('/'): | |
base_yml = os.path.join(os.path.dirname(file_path), | |
base_yml) | |
with open(base_yml) as f: | |
base_cfg = self._load_config_with_base(base_yml) | |
all_base_cfg = _merge_dict(all_base_cfg, base_cfg) | |
del file_cfg[self.BASE_KEY] | |
file_cfg = _merge_dict(all_base_cfg, file_cfg) | |
file_cfg['filename'] = os.path.splitext( | |
os.path.split(file_path)[-1])[0] | |
return file_cfg | |
def merge_dict(self, args): | |
self.cfg = _merge_dict(self.cfg, args) | |
def print_cfg(self, print_func=print): | |
"""Recursively visualize a dict and indenting acrrording by the | |
relationship of keys.""" | |
print_func('----------- Config -----------') | |
print_dict(self.cfg, print_func) | |
print_func('---------------------------------------------') | |
def save(self, p, cfg=None): | |
if cfg is None: | |
cfg = self.cfg | |
with open(p, 'w') as f: | |
yaml.dump(dict(cfg), f, default_flow_style=False, sort_keys=False) | |