LanHarmony's picture
introduce control net
1daafb4
raw
history blame
26.3 kB
# Copyright (c) OpenMMLab. All rights reserved.
import ast
import copy
import os
import os.path as osp
import platform
import shutil
import sys
import tempfile
import uuid
import warnings
from argparse import Action, ArgumentParser
from collections import abc
from importlib import import_module
from addict import Dict
from yapf.yapflib.yapf_api import FormatCode
from .misc import import_modules_from_strings
from .path import check_file_exist
if platform.system() == 'Windows':
import regex as re
else:
import re
BASE_KEY = '_base_'
DELETE_KEY = '_delete_'
DEPRECATION_KEY = '_deprecation_'
RESERVED_KEYS = ['filename', 'text', 'pretty_text']
class ConfigDict(Dict):
def __missing__(self, name):
raise KeyError(name)
def __getattr__(self, name):
try:
value = super(ConfigDict, self).__getattr__(name)
except KeyError:
ex = AttributeError(f"'{self.__class__.__name__}' object has no "
f"attribute '{name}'")
except Exception as e:
ex = e
else:
return value
raise ex
def add_args(parser, cfg, prefix=''):
for k, v in cfg.items():
if isinstance(v, str):
parser.add_argument('--' + prefix + k)
elif isinstance(v, int):
parser.add_argument('--' + prefix + k, type=int)
elif isinstance(v, float):
parser.add_argument('--' + prefix + k, type=float)
elif isinstance(v, bool):
parser.add_argument('--' + prefix + k, action='store_true')
elif isinstance(v, dict):
add_args(parser, v, prefix + k + '.')
elif isinstance(v, abc.Iterable):
parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
else:
print(f'cannot parse key {prefix + k} of type {type(v)}')
return parser
class Config:
"""A facility for config and config files.
It supports common file formats as configs: python/json/yaml. The interface
is the same as a dict object and also allows access config values as
attributes.
Example:
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
>>> cfg.a
1
>>> cfg.b
{'b1': [0, 1]}
>>> cfg.b.b1
[0, 1]
>>> cfg = Config.fromfile('tests/data/config/a.py')
>>> cfg.filename
"/home/kchen/projects/mmcv/tests/data/config/a.py"
>>> cfg.item4
'test'
>>> cfg
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
"""
@staticmethod
def _validate_py_syntax(filename):
with open(filename, 'r', encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
content = f.read()
try:
ast.parse(content)
except SyntaxError as e:
raise SyntaxError('There are syntax errors in config '
f'file {filename}: {e}')
@staticmethod
def _substitute_predefined_vars(filename, temp_config_name):
file_dirname = osp.dirname(filename)
file_basename = osp.basename(filename)
file_basename_no_extension = osp.splitext(file_basename)[0]
file_extname = osp.splitext(filename)[1]
support_templates = dict(
fileDirname=file_dirname,
fileBasename=file_basename,
fileBasenameNoExtension=file_basename_no_extension,
fileExtname=file_extname)
with open(filename, 'r', encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
config_file = f.read()
for key, value in support_templates.items():
regexp = r'\{\{\s*' + str(key) + r'\s*\}\}'
value = value.replace('\\', '/')
config_file = re.sub(regexp, value, config_file)
with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
tmp_config_file.write(config_file)
@staticmethod
def _pre_substitute_base_vars(filename, temp_config_name):
"""Substitute base variable placehoders to string, so that parsing
would work."""
with open(filename, 'r', encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
config_file = f.read()
base_var_dict = {}
regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}'
base_vars = set(re.findall(regexp, config_file))
for base_var in base_vars:
randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}'
base_var_dict[randstr] = base_var
regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}'
config_file = re.sub(regexp, f'"{randstr}"', config_file)
with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
tmp_config_file.write(config_file)
return base_var_dict
@staticmethod
def _substitute_base_vars(cfg, base_var_dict, base_cfg):
"""Substitute variable strings to their actual values."""
cfg = copy.deepcopy(cfg)
if isinstance(cfg, dict):
for k, v in cfg.items():
if isinstance(v, str) and v in base_var_dict:
new_v = base_cfg
for new_k in base_var_dict[v].split('.'):
new_v = new_v[new_k]
cfg[k] = new_v
elif isinstance(v, (list, tuple, dict)):
cfg[k] = Config._substitute_base_vars(
v, base_var_dict, base_cfg)
elif isinstance(cfg, tuple):
cfg = tuple(
Config._substitute_base_vars(c, base_var_dict, base_cfg)
for c in cfg)
elif isinstance(cfg, list):
cfg = [
Config._substitute_base_vars(c, base_var_dict, base_cfg)
for c in cfg
]
elif isinstance(cfg, str) and cfg in base_var_dict:
new_v = base_cfg
for new_k in base_var_dict[cfg].split('.'):
new_v = new_v[new_k]
cfg = new_v
return cfg
@staticmethod
def _file2dict(filename, use_predefined_variables=True):
filename = osp.abspath(osp.expanduser(filename))
check_file_exist(filename)
fileExtname = osp.splitext(filename)[1]
if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
raise IOError('Only py/yml/yaml/json type are supported now!')
with tempfile.TemporaryDirectory() as temp_config_dir:
temp_config_file = tempfile.NamedTemporaryFile(
dir=temp_config_dir, suffix=fileExtname)
if platform.system() == 'Windows':
temp_config_file.close()
temp_config_name = osp.basename(temp_config_file.name)
# Substitute predefined variables
if use_predefined_variables:
Config._substitute_predefined_vars(filename,
temp_config_file.name)
else:
shutil.copyfile(filename, temp_config_file.name)
# Substitute base variables from placeholders to strings
base_var_dict = Config._pre_substitute_base_vars(
temp_config_file.name, temp_config_file.name)
if filename.endswith('.py'):
temp_module_name = osp.splitext(temp_config_name)[0]
sys.path.insert(0, temp_config_dir)
Config._validate_py_syntax(filename)
mod = import_module(temp_module_name)
sys.path.pop(0)
cfg_dict = {
name: value
for name, value in mod.__dict__.items()
if not name.startswith('__')
}
# delete imported module
del sys.modules[temp_module_name]
elif filename.endswith(('.yml', '.yaml', '.json')):
import annotator.uniformer.mmcv as mmcv
cfg_dict = mmcv.load(temp_config_file.name)
# close temp file
temp_config_file.close()
# check deprecation information
if DEPRECATION_KEY in cfg_dict:
deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
warning_msg = f'The config file {filename} will be deprecated ' \
'in the future.'
if 'expected' in deprecation_info:
warning_msg += f' Please use {deprecation_info["expected"]} ' \
'instead.'
if 'reference' in deprecation_info:
warning_msg += ' More information can be found at ' \
f'{deprecation_info["reference"]}'
warnings.warn(warning_msg)
cfg_text = filename + '\n'
with open(filename, 'r', encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
cfg_text += f.read()
if BASE_KEY in cfg_dict:
cfg_dir = osp.dirname(filename)
base_filename = cfg_dict.pop(BASE_KEY)
base_filename = base_filename if isinstance(
base_filename, list) else [base_filename]
cfg_dict_list = list()
cfg_text_list = list()
for f in base_filename:
_cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
cfg_dict_list.append(_cfg_dict)
cfg_text_list.append(_cfg_text)
base_cfg_dict = dict()
for c in cfg_dict_list:
duplicate_keys = base_cfg_dict.keys() & c.keys()
if len(duplicate_keys) > 0:
raise KeyError('Duplicate key is not allowed among bases. '
f'Duplicate keys: {duplicate_keys}')
base_cfg_dict.update(c)
# Substitute base variables from strings to their actual values
cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
base_cfg_dict)
base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = base_cfg_dict
# merge cfg_text
cfg_text_list.append(cfg_text)
cfg_text = '\n'.join(cfg_text_list)
return cfg_dict, cfg_text
@staticmethod
def _merge_a_into_b(a, b, allow_list_keys=False):
"""merge dict ``a`` into dict ``b`` (non-inplace).
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
in-place modifications.
Args:
a (dict): The source dict to be merged into ``b``.
b (dict): The origin dict to be fetch keys from ``a``.
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
are allowed in source ``a`` and will replace the element of the
corresponding index in b if b is a list. Default: False.
Returns:
dict: The modified dict of ``b`` using ``a``.
Examples:
# Normally merge a into b.
>>> Config._merge_a_into_b(
... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
{'obj': {'a': 2}}
# Delete b first and merge a into b.
>>> Config._merge_a_into_b(
... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
{'obj': {'a': 2}}
# b is a list
>>> Config._merge_a_into_b(
... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
[{'a': 2}, {'b': 2}]
"""
b = b.copy()
for k, v in a.items():
if allow_list_keys and k.isdigit() and isinstance(b, list):
k = int(k)
if len(b) <= k:
raise KeyError(f'Index {k} exceeds the length of list {b}')
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
elif isinstance(v,
dict) and k in b and not v.pop(DELETE_KEY, False):
allowed_types = (dict, list) if allow_list_keys else dict
if not isinstance(b[k], allowed_types):
raise TypeError(
f'{k}={v} in child config cannot inherit from base '
f'because {k} is a dict in the child config but is of '
f'type {type(b[k])} in base config. You may set '
f'`{DELETE_KEY}=True` to ignore the base config')
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
else:
b[k] = v
return b
@staticmethod
def fromfile(filename,
use_predefined_variables=True,
import_custom_modules=True):
cfg_dict, cfg_text = Config._file2dict(filename,
use_predefined_variables)
if import_custom_modules and cfg_dict.get('custom_imports', None):
import_modules_from_strings(**cfg_dict['custom_imports'])
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
@staticmethod
def fromstring(cfg_str, file_format):
"""Generate config from config str.
Args:
cfg_str (str): Config str.
file_format (str): Config file format corresponding to the
config str. Only py/yml/yaml/json type are supported now!
Returns:
obj:`Config`: Config obj.
"""
if file_format not in ['.py', '.json', '.yaml', '.yml']:
raise IOError('Only py/yml/yaml/json type are supported now!')
if file_format != '.py' and 'dict(' in cfg_str:
# check if users specify a wrong suffix for python
warnings.warn(
'Please check "file_format", the file format may be .py')
with tempfile.NamedTemporaryFile(
'w', encoding='utf-8', suffix=file_format,
delete=False) as temp_file:
temp_file.write(cfg_str)
# on windows, previous implementation cause error
# see PR 1077 for details
cfg = Config.fromfile(temp_file.name)
os.remove(temp_file.name)
return cfg
@staticmethod
def auto_argparser(description=None):
"""Generate argparser from config file automatically (experimental)"""
partial_parser = ArgumentParser(description=description)
partial_parser.add_argument('config', help='config file path')
cfg_file = partial_parser.parse_known_args()[0].config
cfg = Config.fromfile(cfg_file)
parser = ArgumentParser(description=description)
parser.add_argument('config', help='config file path')
add_args(parser, cfg)
return parser, cfg
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
if cfg_dict is None:
cfg_dict = dict()
elif not isinstance(cfg_dict, dict):
raise TypeError('cfg_dict must be a dict, but '
f'got {type(cfg_dict)}')
for key in cfg_dict:
if key in RESERVED_KEYS:
raise KeyError(f'{key} is reserved for config file')
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
super(Config, self).__setattr__('_filename', filename)
if cfg_text:
text = cfg_text
elif filename:
with open(filename, 'r') as f:
text = f.read()
else:
text = ''
super(Config, self).__setattr__('_text', text)
@property
def filename(self):
return self._filename
@property
def text(self):
return self._text
@property
def pretty_text(self):
indent = 4
def _indent(s_, num_spaces):
s = s_.split('\n')
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s
def _format_basic_types(k, v, use_mapping=False):
if isinstance(v, str):
v_str = f"'{v}'"
else:
v_str = str(v)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f'{k_str}: {v_str}'
else:
attr_str = f'{str(k)}={v_str}'
attr_str = _indent(attr_str, indent)
return attr_str
def _format_list(k, v, use_mapping=False):
# check if all items in the list are dict
if all(isinstance(_, dict) for _ in v):
v_str = '[\n'
v_str += '\n'.join(
f'dict({_indent(_format_dict(v_), indent)}),'
for v_ in v).rstrip(',')
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f'{k_str}: {v_str}'
else:
attr_str = f'{str(k)}={v_str}'
attr_str = _indent(attr_str, indent) + ']'
else:
attr_str = _format_basic_types(k, v, use_mapping)
return attr_str
def _contain_invalid_identifier(dict_str):
contain_invalid_identifier = False
for key_name in dict_str:
contain_invalid_identifier |= \
(not str(key_name).isidentifier())
return contain_invalid_identifier
def _format_dict(input_dict, outest_level=False):
r = ''
s = []
use_mapping = _contain_invalid_identifier(input_dict)
if use_mapping:
r += '{'
for idx, (k, v) in enumerate(input_dict.items()):
is_last = idx >= len(input_dict) - 1
end = '' if outest_level or is_last else ','
if isinstance(v, dict):
v_str = '\n' + _format_dict(v)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f'{k_str}: dict({v_str}'
else:
attr_str = f'{str(k)}=dict({v_str}'
attr_str = _indent(attr_str, indent) + ')' + end
elif isinstance(v, list):
attr_str = _format_list(k, v, use_mapping) + end
else:
attr_str = _format_basic_types(k, v, use_mapping) + end
s.append(attr_str)
r += '\n'.join(s)
if use_mapping:
r += '}'
return r
cfg_dict = self._cfg_dict.to_dict()
text = _format_dict(cfg_dict, outest_level=True)
# copied from setup.cfg
yapf_style = dict(
based_on_style='pep8',
blank_line_before_nested_class_or_def=True,
split_before_expression_after_opening_paren=True)
text, _ = FormatCode(text, style_config=yapf_style, verify=True)
return text
def __repr__(self):
return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
def __len__(self):
return len(self._cfg_dict)
def __getattr__(self, name):
return getattr(self._cfg_dict, name)
def __getitem__(self, name):
return self._cfg_dict.__getitem__(name)
def __setattr__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setattr__(name, value)
def __setitem__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setitem__(name, value)
def __iter__(self):
return iter(self._cfg_dict)
def __getstate__(self):
return (self._cfg_dict, self._filename, self._text)
def __setstate__(self, state):
_cfg_dict, _filename, _text = state
super(Config, self).__setattr__('_cfg_dict', _cfg_dict)
super(Config, self).__setattr__('_filename', _filename)
super(Config, self).__setattr__('_text', _text)
def dump(self, file=None):
cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
if self.filename.endswith('.py'):
if file is None:
return self.pretty_text
else:
with open(file, 'w', encoding='utf-8') as f:
f.write(self.pretty_text)
else:
import annotator.uniformer.mmcv as mmcv
if file is None:
file_format = self.filename.split('.')[-1]
return mmcv.dump(cfg_dict, file_format=file_format)
else:
mmcv.dump(cfg_dict, file)
def merge_from_dict(self, options, allow_list_keys=True):
"""Merge list into cfg_dict.
Merge the dict parsed by MultipleKVAction into this cfg.
Examples:
>>> options = {'model.backbone.depth': 50,
... 'model.backbone.with_cp':True}
>>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
>>> cfg.merge_from_dict(options)
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
>>> assert cfg_dict == dict(
... model=dict(backbone=dict(depth=50, with_cp=True)))
# Merge list element
>>> cfg = Config(dict(pipeline=[
... dict(type='LoadImage'), dict(type='LoadAnnotations')]))
>>> options = dict(pipeline={'0': dict(type='SelfLoadImage')})
>>> cfg.merge_from_dict(options, allow_list_keys=True)
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
>>> assert cfg_dict == dict(pipeline=[
... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')])
Args:
options (dict): dict of configs to merge from.
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
are allowed in ``options`` and will replace the element of the
corresponding index in the config if the config is a list.
Default: True.
"""
option_cfg_dict = {}
for full_key, v in options.items():
d = option_cfg_dict
key_list = full_key.split('.')
for subkey in key_list[:-1]:
d.setdefault(subkey, ConfigDict())
d = d[subkey]
subkey = key_list[-1]
d[subkey] = v
cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
super(Config, self).__setattr__(
'_cfg_dict',
Config._merge_a_into_b(
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
class DictAction(Action):
"""
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options can
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
"""
@staticmethod
def _parse_int_float_bool(val):
try:
return int(val)
except ValueError:
pass
try:
return float(val)
except ValueError:
pass
if val.lower() in ['true', 'false']:
return True if val.lower() == 'true' else False
return val
@staticmethod
def _parse_iterable(val):
"""Parse iterable values in the string.
All elements inside '()' or '[]' are treated as iterable values.
Args:
val (str): Value string.
Returns:
list | tuple: The expanded list or tuple from the string.
Examples:
>>> DictAction._parse_iterable('1,2,3')
[1, 2, 3]
>>> DictAction._parse_iterable('[a, b, c]')
['a', 'b', 'c']
>>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
[(1, 2, 3), ['a', 'b'], 'c']
"""
def find_next_comma(string):
"""Find the position of next comma in the string.
If no ',' is found in the string, return the string length. All
chars inside '()' and '[]' are treated as one element and thus ','
inside these brackets are ignored.
"""
assert (string.count('(') == string.count(')')) and (
string.count('[') == string.count(']')), \
f'Imbalanced brackets exist in {string}'
end = len(string)
for idx, char in enumerate(string):
pre = string[:idx]
# The string before this ',' is balanced
if ((char == ',') and (pre.count('(') == pre.count(')'))
and (pre.count('[') == pre.count(']'))):
end = idx
break
return end
# Strip ' and " characters and replace whitespace.
val = val.strip('\'\"').replace(' ', '')
is_tuple = False
if val.startswith('(') and val.endswith(')'):
is_tuple = True
val = val[1:-1]
elif val.startswith('[') and val.endswith(']'):
val = val[1:-1]
elif ',' not in val:
# val is a single value
return DictAction._parse_int_float_bool(val)
values = []
while len(val) > 0:
comma_idx = find_next_comma(val)
element = DictAction._parse_iterable(val[:comma_idx])
values.append(element)
val = val[comma_idx + 1:]
if is_tuple:
values = tuple(values)
return values
def __call__(self, parser, namespace, values, option_string=None):
options = {}
for kv in values:
key, val = kv.split('=', maxsplit=1)
options[key] = self._parse_iterable(val)
setattr(namespace, self.dest, options)