rawalkhirodkar's picture
Add initial commit
28c256d
raw
history blame
12.8 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import inspect
import logging
from typing import TYPE_CHECKING, Any, Optional, Union
from mmengine.config import Config, ConfigDict
from mmengine.utils import ManagerMixin
from .registry import Registry
if TYPE_CHECKING:
import torch.nn as nn
from mmengine.optim.scheduler import _ParamScheduler
from mmengine.runner import Runner
def build_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
"""Build a module from config dict when it is a class configuration, or
call a function from config dict when it is a function configuration.
If the global variable default scope (:obj:`DefaultScope`) exists,
:meth:`build` will firstly get the responding registry and then call
its own :meth:`build`.
At least one of the ``cfg`` and ``default_args`` contains the key "type",
which should be either str or class. If they all contain it, the key
in ``cfg`` will be used because ``cfg`` has a high priority than
``default_args`` that means if a key exists in both of them, the value of
the key will be ``cfg[key]``. They will be merged first and the key "type"
will be popped up and the remaining keys will be used as initialization
arguments.
Examples:
>>> from mmengine import Registry, build_from_cfg
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> def __init__(self, depth, stages=4):
>>> self.depth = depth
>>> self.stages = stages
>>> cfg = dict(type='ResNet', depth=50)
>>> model = build_from_cfg(cfg, MODELS)
>>> # Returns an instantiated object
>>> @MODELS.register_module()
>>> def resnet50():
>>> pass
>>> resnet = build_from_cfg(dict(type='resnet50'), MODELS)
>>> # Return a result of the calling function
Args:
cfg (dict or ConfigDict or Config): Config dict. It should at least
contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict or ConfigDict or Config, optional): Default
initialization arguments. Defaults to None.
Returns:
object: The constructed object.
"""
# Avoid circular import
from ..logging import print_log
if not isinstance(cfg, (dict, ConfigDict, Config)):
raise TypeError(
f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}')
if 'type' not in cfg:
if default_args is None or 'type' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry):
raise TypeError('registry must be a mmengine.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args,
(dict, ConfigDict, Config)) or default_args is None):
raise TypeError(
'default_args should be a dict, ConfigDict, Config or None, '
f'but got {type(default_args)}')
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
# Instance should be built under target scope, if `_scope_` is defined
# in cfg, current default scope should switch to specified scope
# temporarily.
scope = args.pop('_scope_', None)
with registry.switch_scope_and_registry(scope) as registry:
obj_type = args.pop('type')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.scope}::{registry.name} registry. ' # noqa: E501
f'Please check whether the value of `{obj_type}` is '
'correct or it was registered as expected. More details '
'can be found at '
'https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501
)
# this will include classes, functions, partial functions and more
elif callable(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
# If `obj_cls` inherits from `ManagerMixin`, it should be
# instantiated by `ManagerMixin.get_instance` to ensure that it
# can be accessed globally.
if inspect.isclass(obj_cls) and \
issubclass(obj_cls, ManagerMixin): # type: ignore
obj = obj_cls.get_instance(**args) # type: ignore
else:
obj = obj_cls(**args) # type: ignore
if (inspect.isclass(obj_cls) or inspect.isfunction(obj_cls)
or inspect.ismethod(obj_cls)):
print_log(
f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
'registry, and its implementation can be found in '
f'{obj_cls.__module__}', # type: ignore
logger='current',
level=logging.DEBUG)
else:
print_log(
'An instance is built from registry, and its constructor '
f'is {obj_cls}',
logger='current',
level=logging.DEBUG)
return obj
def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
registry: Registry) -> 'Runner':
"""Build a Runner object.
Examples:
>>> from mmengine.registry import Registry, build_runner_from_cfg
>>> RUNNERS = Registry('runners', build_func=build_runner_from_cfg)
>>> @RUNNERS.register_module()
>>> class CustomRunner(Runner):
>>> def setup_env(env_cfg):
>>> pass
>>> cfg = dict(runner_type='CustomRunner', ...)
>>> custom_runner = RUNNERS.build(cfg)
Args:
cfg (dict or ConfigDict or Config): Config dict. If "runner_type" key
exists, it will be used to build a custom runner. Otherwise, it
will be used to build a default runner.
registry (:obj:`Registry`): The registry to search the type from.
Returns:
object: The constructed runner object.
"""
from ..config import Config, ConfigDict
from ..logging import print_log
assert isinstance(
cfg,
(dict, ConfigDict, Config
)), f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}'
assert isinstance(
registry, Registry), ('registry should be a mmengine.Registry object',
f'but got {type(registry)}')
args = cfg.copy()
# Runner should be built under target scope, if `_scope_` is defined
# in cfg, current default scope should switch to specified scope
# temporarily.
scope = args.pop('_scope_', None)
with registry.switch_scope_and_registry(scope) as registry:
obj_type = args.get('runner_type', 'Runner')
if isinstance(obj_type, str):
runner_cls = registry.get(obj_type)
if runner_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry. '
f'Please check whether the value of `{obj_type}` is '
'correct or it was registered as expected. More details '
'can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501
)
elif inspect.isclass(obj_type):
runner_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
runner = runner_cls.from_cfg(args) # type: ignore
print_log(
f'An `{runner_cls.__name__}` instance is built from ' # type: ignore # noqa: E501
'registry, its implementation can be found in'
f'{runner_cls.__module__}', # type: ignore
logger='current',
level=logging.DEBUG)
return runner
def build_model_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
default_args: Optional[Union[dict, 'ConfigDict', 'Config']] = None
) -> 'nn.Module':
"""Build a PyTorch model from config dict(s). Different from
``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
Args:
cfg (dict, list[dict]): The config of modules, which is either a config
dict or a list of config dicts. If cfg is a list, the built
modules will be wrapped with ``nn.Sequential``.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn.Module.
"""
from ..model import Sequential
if isinstance(cfg, list):
modules = [
build_from_cfg(_cfg, registry, default_args) for _cfg in cfg
]
return Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
def build_scheduler_from_cfg(
cfg: Union[dict, ConfigDict, Config],
registry: Registry,
default_args: Optional[Union[dict, ConfigDict, Config]] = None
) -> '_ParamScheduler':
"""Builds a ``ParamScheduler`` instance from config.
``ParamScheduler`` supports building instance by its constructor or
method ``build_iter_from_epoch``. Therefore, its registry needs a build
function to handle both cases.
Args:
cfg (dict or ConfigDict or Config): Config dictionary. If it contains
the key ``convert_to_iter_based``, instance will be built by method
``convert_to_iter_based``, otherwise instance will be built by its
constructor.
registry (:obj:`Registry`): The ``PARAM_SCHEDULERS`` registry.
default_args (dict or ConfigDict or Config, optional): Default
initialization arguments. It must contain key ``optimizer``. If
``convert_to_iter_based`` is defined in ``cfg``, it must
additionally contain key ``epoch_length``. Defaults to None.
Returns:
object: The constructed ``ParamScheduler``.
"""
assert isinstance(
cfg,
(dict, ConfigDict, Config
)), f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}'
assert isinstance(
registry, Registry), ('registry should be a mmengine.Registry object',
f'but got {type(registry)}')
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
scope = args.pop('_scope_', None)
with registry.switch_scope_and_registry(scope) as registry:
convert_to_iter = args.pop('convert_to_iter_based', False)
if convert_to_iter:
scheduler_type = args.pop('type')
assert 'epoch_length' in args and args.get('by_epoch', True), (
'Only epoch-based parameter scheduler can be converted to '
'iter-based, and `epoch_length` should be set')
if isinstance(scheduler_type, str):
scheduler_cls = registry.get(scheduler_type)
if scheduler_cls is None:
raise KeyError(
f'{scheduler_type} is not in the {registry.name} '
'registry. Please check whether the value of '
f'`{scheduler_type}` is correct or it was registered '
'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501
)
elif inspect.isclass(scheduler_type):
scheduler_cls = scheduler_type
else:
raise TypeError('type must be a str or valid type, but got '
f'{type(scheduler_type)}')
return scheduler_cls.build_iter_from_epoch( # type: ignore
**args)
else:
args.pop('epoch_length', None)
return build_from_cfg(args, registry)