File size: 5,903 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import inspect
import os
from collections import OrderedDict
from typing import Optional, Iterable, Callable

_innest_error = True

_DI_ENGINE_REG_TRACE_IS_ON = os.environ.get('DIENGINEREGTRACE', 'OFF').upper() == 'ON'


class Registry(dict):
    """
    Overview:
        A helper class for managing registering modules, it extends a dictionary
        and provides a register functions.
    Interfaces:
        ``__init__``, ``register``, ``get``, ``build``, ``query``, ``query_details``
    Examples:
        creeting a registry:
        >>> some_registry = Registry({"default": default_module})

        There're two ways of registering new modules:
        1): normal way is just calling register function:
        >>> def foo():
        >>>     ...
            some_registry.register("foo_module", foo)
        2): used as decorator when declaring the module:
        >>> @some_registry.register("foo_module")
        >>> @some_registry.register("foo_modeul_nickname")
        >>> def foo():
        >>>     ...

        Access of module is just like using a dictionary, eg:
        >>> f = some_registry["foo_module"]
    """

    def __init__(self, *args, **kwargs) -> None:
        """
        Overview:
            Initialize the Registry object.
        Arguments:
            - args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \
                dict.
            - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \
                dict.
        """

        super(Registry, self).__init__(*args, **kwargs)
        self.__trace__ = dict()

    def register(
            self,
            module_name: Optional[str] = None,
            module: Optional[Callable] = None,
            force_overwrite: bool = False
    ) -> Callable:
        """
        Overview:
            Register the module.
        Arguments:
            - module_name (:obj:`Optional[str]`): The name of the module.
            - module (:obj:`Optional[Callable]`): The module to be registered.
            - force_overwrite (:obj:`bool`): Whether to overwrite the module with the same name.
        """

        if _DI_ENGINE_REG_TRACE_IS_ON:
            frame = inspect.stack()[1][0]
            info = inspect.getframeinfo(frame)
            filename = info.filename
            lineno = info.lineno
        # used as function call
        if module is not None:
            assert module_name is not None
            Registry._register_generic(self, module_name, module, force_overwrite)
            if _DI_ENGINE_REG_TRACE_IS_ON:
                self.__trace__[module_name] = (filename, lineno)
            return

        # used as decorator
        def register_fn(fn: Callable) -> Callable:
            if module_name is None:
                name = fn.__name__
            else:
                name = module_name
            Registry._register_generic(self, name, fn, force_overwrite)
            if _DI_ENGINE_REG_TRACE_IS_ON:
                self.__trace__[name] = (filename, lineno)
            return fn

        return register_fn

    @staticmethod
    def _register_generic(module_dict: dict, module_name: str, module: Callable, force_overwrite: bool = False) -> None:
        """
        Overview:
            Register the module.
        Arguments:
            - module_dict (:obj:`dict`): The dict to store the module.
            - module_name (:obj:`str`): The name of the module.
            - module (:obj:`Callable`): The module to be registered.
            - force_overwrite (:obj:`bool`): Whether to overwrite the module with the same name.
        """

        if not force_overwrite:
            assert module_name not in module_dict, module_name
        module_dict[module_name] = module

    def get(self, module_name: str) -> Callable:
        """
        Overview:
            Get the module.
        Arguments:
            - module_name (:obj:`str`): The name of the module.
        """

        return self[module_name]

    def build(self, obj_type: str, *obj_args, **obj_kwargs) -> object:
        """
        Overview:
            Build the object.
        Arguments:
            - obj_type (:obj:`str`): The type of the object.
            - obj_args (:obj:`Tuple`): The arguments passed to the object.
            - obj_kwargs (:obj:`Dict`): The keyword arguments passed to the object.
        """

        try:
            build_fn = self[obj_type]
            return build_fn(*obj_args, **obj_kwargs)
        except Exception as e:
            # get build_fn fail
            if isinstance(e, KeyError):
                raise KeyError("not support buildable-object type: {}".format(obj_type))
            # build_fn execution fail
            global _innest_error
            if _innest_error:
                argspec = inspect.getfullargspec(build_fn)
                message = 'for {}(alias={})'.format(build_fn, obj_type)
                message += '\nExpected args are:{}'.format(argspec)
                message += '\nGiven args are:{}/{}'.format(argspec, obj_kwargs.keys())
                message += '\nGiven args details are:{}/{}'.format(argspec, obj_kwargs)
                _innest_error = False
            raise e

    def query(self) -> Iterable:
        """
        Overview:
            all registered module names.
        """

        return self.keys()

    def query_details(self, aliases: Optional[Iterable] = None) -> OrderedDict:
        """
        Overview:
            Get the details of the registered modules.
        Arguments:
            - aliases (:obj:`Optional[Iterable]`): The aliases of the modules.
        """

        assert _DI_ENGINE_REG_TRACE_IS_ON, "please exec 'export DIENGINEREGTRACE=ON' first"
        if aliases is None:
            aliases = self.keys()
        return OrderedDict((alias, self.__trace__[alias]) for alias in aliases)