# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import inspect import logging from typing import TYPE_CHECKING, Any, Optional, Union from mmengine.config import Config, ConfigDict from mmengine.utils import ManagerMixin from .registry import Registry if TYPE_CHECKING: import torch.nn as nn from mmengine.optim.scheduler import _ParamScheduler from mmengine.runner import Runner def build_from_cfg( cfg: Union[dict, ConfigDict, Config], registry: Registry, default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any: """Build a module from config dict when it is a class configuration, or call a function from config dict when it is a function configuration. If the global variable default scope (:obj:`DefaultScope`) exists, :meth:`build` will firstly get the responding registry and then call its own :meth:`build`. At least one of the ``cfg`` and ``default_args`` contains the key "type", which should be either str or class. If they all contain it, the key in ``cfg`` will be used because ``cfg`` has a high priority than ``default_args`` that means if a key exists in both of them, the value of the key will be ``cfg[key]``. They will be merged first and the key "type" will be popped up and the remaining keys will be used as initialization arguments. Examples: >>> from mmengine import Registry, build_from_cfg >>> MODELS = Registry('models') >>> @MODELS.register_module() >>> class ResNet: >>> def __init__(self, depth, stages=4): >>> self.depth = depth >>> self.stages = stages >>> cfg = dict(type='ResNet', depth=50) >>> model = build_from_cfg(cfg, MODELS) >>> # Returns an instantiated object >>> @MODELS.register_module() >>> def resnet50(): >>> pass >>> resnet = build_from_cfg(dict(type='resnet50'), MODELS) >>> # Return a result of the calling function Args: cfg (dict or ConfigDict or Config): Config dict. It should at least contain the key "type". registry (:obj:`Registry`): The registry to search the type from. default_args (dict or ConfigDict or Config, optional): Default initialization arguments. Defaults to None. Returns: object: The constructed object. """ # Avoid circular import from ..logging import print_log if not isinstance(cfg, (dict, ConfigDict, Config)): raise TypeError( f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}') if 'type' not in cfg: if default_args is None or 'type' not in default_args: raise KeyError( '`cfg` or `default_args` must contain the key "type", ' f'but got {cfg}\n{default_args}') if not isinstance(registry, Registry): raise TypeError('registry must be a mmengine.Registry object, ' f'but got {type(registry)}') if not (isinstance(default_args, (dict, ConfigDict, Config)) or default_args is None): raise TypeError( 'default_args should be a dict, ConfigDict, Config or None, ' f'but got {type(default_args)}') args = cfg.copy() if default_args is not None: for name, value in default_args.items(): args.setdefault(name, value) # Instance should be built under target scope, if `_scope_` is defined # in cfg, current default scope should switch to specified scope # temporarily. scope = args.pop('_scope_', None) with registry.switch_scope_and_registry(scope) as registry: obj_type = args.pop('type') if isinstance(obj_type, str): obj_cls = registry.get(obj_type) if obj_cls is None: raise KeyError( f'{obj_type} is not in the {registry.scope}::{registry.name} registry. ' # noqa: E501 f'Please check whether the value of `{obj_type}` is ' 'correct or it was registered as expected. More details ' 'can be found at ' 'https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501 ) # this will include classes, functions, partial functions and more elif callable(obj_type): obj_cls = obj_type else: raise TypeError( f'type must be a str or valid type, but got {type(obj_type)}') # If `obj_cls` inherits from `ManagerMixin`, it should be # instantiated by `ManagerMixin.get_instance` to ensure that it # can be accessed globally. if inspect.isclass(obj_cls) and \ issubclass(obj_cls, ManagerMixin): # type: ignore obj = obj_cls.get_instance(**args) # type: ignore else: obj = obj_cls(**args) # type: ignore if (inspect.isclass(obj_cls) or inspect.isfunction(obj_cls) or inspect.ismethod(obj_cls)): print_log( f'An `{obj_cls.__name__}` instance is built from ' # type: ignore # noqa: E501 'registry, and its implementation can be found in ' f'{obj_cls.__module__}', # type: ignore logger='current', level=logging.DEBUG) else: print_log( 'An instance is built from registry, and its constructor ' f'is {obj_cls}', logger='current', level=logging.DEBUG) return obj def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config], registry: Registry) -> 'Runner': """Build a Runner object. Examples: >>> from mmengine.registry import Registry, build_runner_from_cfg >>> RUNNERS = Registry('runners', build_func=build_runner_from_cfg) >>> @RUNNERS.register_module() >>> class CustomRunner(Runner): >>> def setup_env(env_cfg): >>> pass >>> cfg = dict(runner_type='CustomRunner', ...) >>> custom_runner = RUNNERS.build(cfg) Args: cfg (dict or ConfigDict or Config): Config dict. If "runner_type" key exists, it will be used to build a custom runner. Otherwise, it will be used to build a default runner. registry (:obj:`Registry`): The registry to search the type from. Returns: object: The constructed runner object. """ from ..config import Config, ConfigDict from ..logging import print_log assert isinstance( cfg, (dict, ConfigDict, Config )), f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}' assert isinstance( registry, Registry), ('registry should be a mmengine.Registry object', f'but got {type(registry)}') args = cfg.copy() # Runner should be built under target scope, if `_scope_` is defined # in cfg, current default scope should switch to specified scope # temporarily. scope = args.pop('_scope_', None) with registry.switch_scope_and_registry(scope) as registry: obj_type = args.get('runner_type', 'Runner') if isinstance(obj_type, str): runner_cls = registry.get(obj_type) if runner_cls is None: raise KeyError( f'{obj_type} is not in the {registry.name} registry. ' f'Please check whether the value of `{obj_type}` is ' 'correct or it was registered as expected. More details ' 'can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501 ) elif inspect.isclass(obj_type): runner_cls = obj_type else: raise TypeError( f'type must be a str or valid type, but got {type(obj_type)}') runner = runner_cls.from_cfg(args) # type: ignore print_log( f'An `{runner_cls.__name__}` instance is built from ' # type: ignore # noqa: E501 'registry, its implementation can be found in' f'{runner_cls.__module__}', # type: ignore logger='current', level=logging.DEBUG) return runner def build_model_from_cfg( cfg: Union[dict, ConfigDict, Config], registry: Registry, default_args: Optional[Union[dict, 'ConfigDict', 'Config']] = None ) -> 'nn.Module': """Build a PyTorch model from config dict(s). Different from ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built. Args: cfg (dict, list[dict]): The config of modules, which is either a config dict or a list of config dicts. If cfg is a list, the built modules will be wrapped with ``nn.Sequential``. registry (:obj:`Registry`): A registry the module belongs to. default_args (dict, optional): Default arguments to build the module. Defaults to None. Returns: nn.Module: A built nn.Module. """ from ..model import Sequential if isinstance(cfg, list): modules = [ build_from_cfg(_cfg, registry, default_args) for _cfg in cfg ] return Sequential(*modules) else: return build_from_cfg(cfg, registry, default_args) def build_scheduler_from_cfg( cfg: Union[dict, ConfigDict, Config], registry: Registry, default_args: Optional[Union[dict, ConfigDict, Config]] = None ) -> '_ParamScheduler': """Builds a ``ParamScheduler`` instance from config. ``ParamScheduler`` supports building instance by its constructor or method ``build_iter_from_epoch``. Therefore, its registry needs a build function to handle both cases. Args: cfg (dict or ConfigDict or Config): Config dictionary. If it contains the key ``convert_to_iter_based``, instance will be built by method ``convert_to_iter_based``, otherwise instance will be built by its constructor. registry (:obj:`Registry`): The ``PARAM_SCHEDULERS`` registry. default_args (dict or ConfigDict or Config, optional): Default initialization arguments. It must contain key ``optimizer``. If ``convert_to_iter_based`` is defined in ``cfg``, it must additionally contain key ``epoch_length``. Defaults to None. Returns: object: The constructed ``ParamScheduler``. """ assert isinstance( cfg, (dict, ConfigDict, Config )), f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}' assert isinstance( registry, Registry), ('registry should be a mmengine.Registry object', f'but got {type(registry)}') args = cfg.copy() if default_args is not None: for name, value in default_args.items(): args.setdefault(name, value) scope = args.pop('_scope_', None) with registry.switch_scope_and_registry(scope) as registry: convert_to_iter = args.pop('convert_to_iter_based', False) if convert_to_iter: scheduler_type = args.pop('type') assert 'epoch_length' in args and args.get('by_epoch', True), ( 'Only epoch-based parameter scheduler can be converted to ' 'iter-based, and `epoch_length` should be set') if isinstance(scheduler_type, str): scheduler_cls = registry.get(scheduler_type) if scheduler_cls is None: raise KeyError( f'{scheduler_type} is not in the {registry.name} ' 'registry. Please check whether the value of ' f'`{scheduler_type}` is correct or it was registered ' 'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501 ) elif inspect.isclass(scheduler_type): scheduler_cls = scheduler_type else: raise TypeError('type must be a str or valid type, but got ' f'{type(scheduler_type)}') return scheduler_cls.build_iter_from_epoch( # type: ignore **args) else: args.pop('epoch_length', None) return build_from_cfg(args, registry)