File size: 6,136 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
from typing import TYPE_CHECKING
from tensorboardX import SummaryWriter
if TYPE_CHECKING:
# TYPE_CHECKING is always False at runtime, but mypy will evaluate the contents of this block.
# So if you import this module within TYPE_CHECKING, you will get code hints and other benefits.
# Here is a good answer on stackoverflow:
# https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports
from ding.framework import Parallel
class DistributedWriter(SummaryWriter):
"""
Overview:
A simple subclass of SummaryWriter that supports writing to one process in multi-process mode.
The best way is to use it in conjunction with the ``router`` to take advantage of the message \
and event components of the router (see ``writer.plugin``).
Interfaces:
``get_instance``, ``plugin``, ``initialize``, ``__del__``
"""
root = None
def __init__(self, *args, **kwargs):
"""
Overview:
Initialize the DistributedWriter object.
Arguments:
- args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \
SummaryWriter.
- kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \
SummaryWriter.
"""
self._default_writer_to_disk = kwargs.get("write_to_disk") if "write_to_disk" in kwargs else True
# We need to write data to files lazily, so we should not use file writer in __init__,
# On the contrary, we will initialize the file writer when the user calls the
# add_* function for the first time
kwargs["write_to_disk"] = False
super().__init__(*args, **kwargs)
self._in_parallel = False
self._router = None
self._is_writer = False
self._lazy_initialized = False
@classmethod
def get_instance(cls, *args, **kwargs) -> "DistributedWriter":
"""
Overview:
Get instance and set the root level instance on the first called. If args and kwargs is none,
this method will return root instance.
Arguments:
- args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \
SummaryWriter.
- kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \
SummaryWriter.
"""
if args or kwargs:
ins = cls(*args, **kwargs)
if cls.root is None:
cls.root = ins
return ins
else:
return cls.root
def plugin(self, router: "Parallel", is_writer: bool = False) -> "DistributedWriter":
"""
Overview:
Plugin ``router``, so when using this writer with active router, it will automatically send requests\
to the main writer instead of writing it to the disk. So we can collect data from multiple processes\
and write them into one file.
Arguments:
- router (:obj:`Parallel`): The router to be plugged in.
- is_writer (:obj:`bool`): Whether this writer is the main writer.
Examples:
>>> DistributedWriter().plugin(router, is_writer=True)
"""
if router.is_active:
self._in_parallel = True
self._router = router
self._is_writer = is_writer
if is_writer:
self.initialize()
self._lazy_initialized = True
router.on("distributed_writer", self._on_distributed_writer)
return self
def _on_distributed_writer(self, fn_name: str, *args, **kwargs):
"""
Overview:
This method is called when the router receives a request to write data.
Arguments:
- fn_name (:obj:`str`): The name of the function to be called.
- args (:obj:`Tuple`): The arguments passed to the function to be called.
- kwargs (:obj:`Dict`): The keyword arguments passed to the function to be called.
"""
if self._is_writer:
getattr(self, fn_name)(*args, **kwargs)
def initialize(self):
"""
Overview:
Initialize the file writer.
"""
self.close()
self._write_to_disk = self._default_writer_to_disk
self._get_file_writer()
self._lazy_initialized = True
def __del__(self):
"""
Overview:
Close the file writer.
"""
self.close()
def enable_parallel(fn_name, fn):
"""
Overview:
Decorator to enable parallel writing.
Arguments:
- fn_name (:obj:`str`): The name of the function to be called.
- fn (:obj:`Callable`): The function to be called.
"""
def _parallel_fn(self: DistributedWriter, *args, **kwargs):
if not self._lazy_initialized:
self.initialize()
if self._in_parallel and not self._is_writer:
self._router.emit("distributed_writer", fn_name, *args, **kwargs)
else:
fn(self, *args, **kwargs)
return _parallel_fn
ready_to_parallel_fns = [
'add_audio',
'add_custom_scalars',
'add_custom_scalars_marginchart',
'add_custom_scalars_multilinechart',
'add_embedding',
'add_figure',
'add_graph',
'add_graph_deprecated',
'add_histogram',
'add_histogram_raw',
'add_hparams',
'add_image',
'add_image_with_boxes',
'add_images',
'add_mesh',
'add_onnx_graph',
'add_openvino_graph',
'add_pr_curve',
'add_pr_curve_raw',
'add_scalar',
'add_scalars',
'add_text',
'add_video',
]
for fn_name in ready_to_parallel_fns:
if hasattr(DistributedWriter, fn_name):
setattr(DistributedWriter, fn_name, enable_parallel(fn_name, getattr(DistributedWriter, fn_name)))
# Examples:
# In main, `distributed_writer.plugin(task.router, is_writer=True)`,
# In middleware, `distributed_writer.record()`
distributed_writer = DistributedWriter()
|