gomoku / DI-engine /ding /utils /registry.py
zjowowen's picture
init space
079c32c
import inspect
import os
from collections import OrderedDict
from typing import Optional, Iterable, Callable
_innest_error = True
_DI_ENGINE_REG_TRACE_IS_ON = os.environ.get('DIENGINEREGTRACE', 'OFF').upper() == 'ON'
class Registry(dict):
"""
Overview:
A helper class for managing registering modules, it extends a dictionary
and provides a register functions.
Interfaces:
``__init__``, ``register``, ``get``, ``build``, ``query``, ``query_details``
Examples:
creeting a registry:
>>> some_registry = Registry({"default": default_module})
There're two ways of registering new modules:
1): normal way is just calling register function:
>>> def foo():
>>> ...
some_registry.register("foo_module", foo)
2): used as decorator when declaring the module:
>>> @some_registry.register("foo_module")
>>> @some_registry.register("foo_modeul_nickname")
>>> def foo():
>>> ...
Access of module is just like using a dictionary, eg:
>>> f = some_registry["foo_module"]
"""
def __init__(self, *args, **kwargs) -> None:
"""
Overview:
Initialize the Registry object.
Arguments:
- args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \
dict.
- kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \
dict.
"""
super(Registry, self).__init__(*args, **kwargs)
self.__trace__ = dict()
def register(
self,
module_name: Optional[str] = None,
module: Optional[Callable] = None,
force_overwrite: bool = False
) -> Callable:
"""
Overview:
Register the module.
Arguments:
- module_name (:obj:`Optional[str]`): The name of the module.
- module (:obj:`Optional[Callable]`): The module to be registered.
- force_overwrite (:obj:`bool`): Whether to overwrite the module with the same name.
"""
if _DI_ENGINE_REG_TRACE_IS_ON:
frame = inspect.stack()[1][0]
info = inspect.getframeinfo(frame)
filename = info.filename
lineno = info.lineno
# used as function call
if module is not None:
assert module_name is not None
Registry._register_generic(self, module_name, module, force_overwrite)
if _DI_ENGINE_REG_TRACE_IS_ON:
self.__trace__[module_name] = (filename, lineno)
return
# used as decorator
def register_fn(fn: Callable) -> Callable:
if module_name is None:
name = fn.__name__
else:
name = module_name
Registry._register_generic(self, name, fn, force_overwrite)
if _DI_ENGINE_REG_TRACE_IS_ON:
self.__trace__[name] = (filename, lineno)
return fn
return register_fn
@staticmethod
def _register_generic(module_dict: dict, module_name: str, module: Callable, force_overwrite: bool = False) -> None:
"""
Overview:
Register the module.
Arguments:
- module_dict (:obj:`dict`): The dict to store the module.
- module_name (:obj:`str`): The name of the module.
- module (:obj:`Callable`): The module to be registered.
- force_overwrite (:obj:`bool`): Whether to overwrite the module with the same name.
"""
if not force_overwrite:
assert module_name not in module_dict, module_name
module_dict[module_name] = module
def get(self, module_name: str) -> Callable:
"""
Overview:
Get the module.
Arguments:
- module_name (:obj:`str`): The name of the module.
"""
return self[module_name]
def build(self, obj_type: str, *obj_args, **obj_kwargs) -> object:
"""
Overview:
Build the object.
Arguments:
- obj_type (:obj:`str`): The type of the object.
- obj_args (:obj:`Tuple`): The arguments passed to the object.
- obj_kwargs (:obj:`Dict`): The keyword arguments passed to the object.
"""
try:
build_fn = self[obj_type]
return build_fn(*obj_args, **obj_kwargs)
except Exception as e:
# get build_fn fail
if isinstance(e, KeyError):
raise KeyError("not support buildable-object type: {}".format(obj_type))
# build_fn execution fail
global _innest_error
if _innest_error:
argspec = inspect.getfullargspec(build_fn)
message = 'for {}(alias={})'.format(build_fn, obj_type)
message += '\nExpected args are:{}'.format(argspec)
message += '\nGiven args are:{}/{}'.format(argspec, obj_kwargs.keys())
message += '\nGiven args details are:{}/{}'.format(argspec, obj_kwargs)
_innest_error = False
raise e
def query(self) -> Iterable:
"""
Overview:
all registered module names.
"""
return self.keys()
def query_details(self, aliases: Optional[Iterable] = None) -> OrderedDict:
"""
Overview:
Get the details of the registered modules.
Arguments:
- aliases (:obj:`Optional[Iterable]`): The aliases of the modules.
"""
assert _DI_ENGINE_REG_TRACE_IS_ON, "please exec 'export DIENGINEREGTRACE=ON' first"
if aliases is None:
aliases = self.keys()
return OrderedDict((alias, self.__trace__[alias]) for alias in aliases)