File size: 5,903 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import inspect
import os
from collections import OrderedDict
from typing import Optional, Iterable, Callable
_innest_error = True
_DI_ENGINE_REG_TRACE_IS_ON = os.environ.get('DIENGINEREGTRACE', 'OFF').upper() == 'ON'
class Registry(dict):
"""
Overview:
A helper class for managing registering modules, it extends a dictionary
and provides a register functions.
Interfaces:
``__init__``, ``register``, ``get``, ``build``, ``query``, ``query_details``
Examples:
creeting a registry:
>>> some_registry = Registry({"default": default_module})
There're two ways of registering new modules:
1): normal way is just calling register function:
>>> def foo():
>>> ...
some_registry.register("foo_module", foo)
2): used as decorator when declaring the module:
>>> @some_registry.register("foo_module")
>>> @some_registry.register("foo_modeul_nickname")
>>> def foo():
>>> ...
Access of module is just like using a dictionary, eg:
>>> f = some_registry["foo_module"]
"""
def __init__(self, *args, **kwargs) -> None:
"""
Overview:
Initialize the Registry object.
Arguments:
- args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \
dict.
- kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \
dict.
"""
super(Registry, self).__init__(*args, **kwargs)
self.__trace__ = dict()
def register(
self,
module_name: Optional[str] = None,
module: Optional[Callable] = None,
force_overwrite: bool = False
) -> Callable:
"""
Overview:
Register the module.
Arguments:
- module_name (:obj:`Optional[str]`): The name of the module.
- module (:obj:`Optional[Callable]`): The module to be registered.
- force_overwrite (:obj:`bool`): Whether to overwrite the module with the same name.
"""
if _DI_ENGINE_REG_TRACE_IS_ON:
frame = inspect.stack()[1][0]
info = inspect.getframeinfo(frame)
filename = info.filename
lineno = info.lineno
# used as function call
if module is not None:
assert module_name is not None
Registry._register_generic(self, module_name, module, force_overwrite)
if _DI_ENGINE_REG_TRACE_IS_ON:
self.__trace__[module_name] = (filename, lineno)
return
# used as decorator
def register_fn(fn: Callable) -> Callable:
if module_name is None:
name = fn.__name__
else:
name = module_name
Registry._register_generic(self, name, fn, force_overwrite)
if _DI_ENGINE_REG_TRACE_IS_ON:
self.__trace__[name] = (filename, lineno)
return fn
return register_fn
@staticmethod
def _register_generic(module_dict: dict, module_name: str, module: Callable, force_overwrite: bool = False) -> None:
"""
Overview:
Register the module.
Arguments:
- module_dict (:obj:`dict`): The dict to store the module.
- module_name (:obj:`str`): The name of the module.
- module (:obj:`Callable`): The module to be registered.
- force_overwrite (:obj:`bool`): Whether to overwrite the module with the same name.
"""
if not force_overwrite:
assert module_name not in module_dict, module_name
module_dict[module_name] = module
def get(self, module_name: str) -> Callable:
"""
Overview:
Get the module.
Arguments:
- module_name (:obj:`str`): The name of the module.
"""
return self[module_name]
def build(self, obj_type: str, *obj_args, **obj_kwargs) -> object:
"""
Overview:
Build the object.
Arguments:
- obj_type (:obj:`str`): The type of the object.
- obj_args (:obj:`Tuple`): The arguments passed to the object.
- obj_kwargs (:obj:`Dict`): The keyword arguments passed to the object.
"""
try:
build_fn = self[obj_type]
return build_fn(*obj_args, **obj_kwargs)
except Exception as e:
# get build_fn fail
if isinstance(e, KeyError):
raise KeyError("not support buildable-object type: {}".format(obj_type))
# build_fn execution fail
global _innest_error
if _innest_error:
argspec = inspect.getfullargspec(build_fn)
message = 'for {}(alias={})'.format(build_fn, obj_type)
message += '\nExpected args are:{}'.format(argspec)
message += '\nGiven args are:{}/{}'.format(argspec, obj_kwargs.keys())
message += '\nGiven args details are:{}/{}'.format(argspec, obj_kwargs)
_innest_error = False
raise e
def query(self) -> Iterable:
"""
Overview:
all registered module names.
"""
return self.keys()
def query_details(self, aliases: Optional[Iterable] = None) -> OrderedDict:
"""
Overview:
Get the details of the registered modules.
Arguments:
- aliases (:obj:`Optional[Iterable]`): The aliases of the modules.
"""
assert _DI_ENGINE_REG_TRACE_IS_ON, "please exec 'export DIENGINEREGTRACE=ON' first"
if aliases is None:
aliases = self.keys()
return OrderedDict((alias, self.__trace__[alias]) for alias in aliases)
|