File size: 6,136 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from typing import TYPE_CHECKING

from tensorboardX import SummaryWriter

if TYPE_CHECKING:
    # TYPE_CHECKING is always False at runtime, but mypy will evaluate the contents of this block.
    # So if you import this module within TYPE_CHECKING, you will get code hints and other benefits.
    # Here is a good answer on stackoverflow:
    # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports
    from ding.framework import Parallel


class DistributedWriter(SummaryWriter):
    """
    Overview:
        A simple subclass of SummaryWriter that supports writing to one process in multi-process mode.
        The best way is to use it in conjunction with the ``router`` to take advantage of the message \
            and event components of the router (see ``writer.plugin``).
    Interfaces:
        ``get_instance``, ``plugin``, ``initialize``, ``__del__``
    """
    root = None

    def __init__(self, *args, **kwargs):
        """
        Overview:
            Initialize the DistributedWriter object.
        Arguments:
            - args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \
                SummaryWriter.
            - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \
                SummaryWriter.
        """

        self._default_writer_to_disk = kwargs.get("write_to_disk") if "write_to_disk" in kwargs else True
        # We need to write data to files lazily, so we should not use file writer in __init__,
        # On the contrary, we will initialize the file writer when the user calls the
        # add_* function for the first time
        kwargs["write_to_disk"] = False
        super().__init__(*args, **kwargs)
        self._in_parallel = False
        self._router = None
        self._is_writer = False
        self._lazy_initialized = False

    @classmethod
    def get_instance(cls, *args, **kwargs) -> "DistributedWriter":
        """
        Overview:
            Get instance and set the root level instance on the first called. If args and kwargs is none,
            this method will return root instance.
        Arguments:
            - args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \
                SummaryWriter.
            - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \
                SummaryWriter.
        """
        if args or kwargs:
            ins = cls(*args, **kwargs)
            if cls.root is None:
                cls.root = ins
            return ins
        else:
            return cls.root

    def plugin(self, router: "Parallel", is_writer: bool = False) -> "DistributedWriter":
        """
        Overview:
            Plugin ``router``, so when using this writer with active router, it will automatically send requests\
                to the main writer instead of writing it to the disk. So we can collect data from multiple processes\
                and write them into one file.
        Arguments:
            - router (:obj:`Parallel`): The router to be plugged in.
            - is_writer (:obj:`bool`): Whether this writer is the main writer.
        Examples:
            >>> DistributedWriter().plugin(router, is_writer=True)
        """
        if router.is_active:
            self._in_parallel = True
            self._router = router
            self._is_writer = is_writer
            if is_writer:
                self.initialize()
            self._lazy_initialized = True
            router.on("distributed_writer", self._on_distributed_writer)
        return self

    def _on_distributed_writer(self, fn_name: str, *args, **kwargs):
        """
        Overview:
            This method is called when the router receives a request to write data.
        Arguments:
            - fn_name (:obj:`str`): The name of the function to be called.
            - args (:obj:`Tuple`): The arguments passed to the function to be called.
            - kwargs (:obj:`Dict`): The keyword arguments passed to the function to be called.
        """

        if self._is_writer:
            getattr(self, fn_name)(*args, **kwargs)

    def initialize(self):
        """
        Overview:
            Initialize the file writer.
        """
        self.close()
        self._write_to_disk = self._default_writer_to_disk
        self._get_file_writer()
        self._lazy_initialized = True

    def __del__(self):
        """
        Overview:
            Close the file writer.
        """
        self.close()


def enable_parallel(fn_name, fn):
    """
    Overview:
        Decorator to enable parallel writing.
    Arguments:
        - fn_name (:obj:`str`): The name of the function to be called.
        - fn (:obj:`Callable`): The function to be called.
    """

    def _parallel_fn(self: DistributedWriter, *args, **kwargs):
        if not self._lazy_initialized:
            self.initialize()
        if self._in_parallel and not self._is_writer:
            self._router.emit("distributed_writer", fn_name, *args, **kwargs)
        else:
            fn(self, *args, **kwargs)

    return _parallel_fn


ready_to_parallel_fns = [
    'add_audio',
    'add_custom_scalars',
    'add_custom_scalars_marginchart',
    'add_custom_scalars_multilinechart',
    'add_embedding',
    'add_figure',
    'add_graph',
    'add_graph_deprecated',
    'add_histogram',
    'add_histogram_raw',
    'add_hparams',
    'add_image',
    'add_image_with_boxes',
    'add_images',
    'add_mesh',
    'add_onnx_graph',
    'add_openvino_graph',
    'add_pr_curve',
    'add_pr_curve_raw',
    'add_scalar',
    'add_scalars',
    'add_text',
    'add_video',
]
for fn_name in ready_to_parallel_fns:
    if hasattr(DistributedWriter, fn_name):
        setattr(DistributedWriter, fn_name, enable_parallel(fn_name, getattr(DistributedWriter, fn_name)))

# Examples:
#     In main, `distributed_writer.plugin(task.router, is_writer=True)`,
#     In middleware, `distributed_writer.record()`
distributed_writer = DistributedWriter()