from typing import TYPE_CHECKING from tensorboardX import SummaryWriter if TYPE_CHECKING: # TYPE_CHECKING is always False at runtime, but mypy will evaluate the contents of this block. # So if you import this module within TYPE_CHECKING, you will get code hints and other benefits. # Here is a good answer on stackoverflow: # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports from ding.framework import Parallel class DistributedWriter(SummaryWriter): """ Overview: A simple subclass of SummaryWriter that supports writing to one process in multi-process mode. The best way is to use it in conjunction with the ``router`` to take advantage of the message \ and event components of the router (see ``writer.plugin``). Interfaces: ``get_instance``, ``plugin``, ``initialize``, ``__del__`` """ root = None def __init__(self, *args, **kwargs): """ Overview: Initialize the DistributedWriter object. Arguments: - args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \ SummaryWriter. - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \ SummaryWriter. """ self._default_writer_to_disk = kwargs.get("write_to_disk") if "write_to_disk" in kwargs else True # We need to write data to files lazily, so we should not use file writer in __init__, # On the contrary, we will initialize the file writer when the user calls the # add_* function for the first time kwargs["write_to_disk"] = False super().__init__(*args, **kwargs) self._in_parallel = False self._router = None self._is_writer = False self._lazy_initialized = False @classmethod def get_instance(cls, *args, **kwargs) -> "DistributedWriter": """ Overview: Get instance and set the root level instance on the first called. If args and kwargs is none, this method will return root instance. Arguments: - args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \ SummaryWriter. - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \ SummaryWriter. """ if args or kwargs: ins = cls(*args, **kwargs) if cls.root is None: cls.root = ins return ins else: return cls.root def plugin(self, router: "Parallel", is_writer: bool = False) -> "DistributedWriter": """ Overview: Plugin ``router``, so when using this writer with active router, it will automatically send requests\ to the main writer instead of writing it to the disk. So we can collect data from multiple processes\ and write them into one file. Arguments: - router (:obj:`Parallel`): The router to be plugged in. - is_writer (:obj:`bool`): Whether this writer is the main writer. Examples: >>> DistributedWriter().plugin(router, is_writer=True) """ if router.is_active: self._in_parallel = True self._router = router self._is_writer = is_writer if is_writer: self.initialize() self._lazy_initialized = True router.on("distributed_writer", self._on_distributed_writer) return self def _on_distributed_writer(self, fn_name: str, *args, **kwargs): """ Overview: This method is called when the router receives a request to write data. Arguments: - fn_name (:obj:`str`): The name of the function to be called. - args (:obj:`Tuple`): The arguments passed to the function to be called. - kwargs (:obj:`Dict`): The keyword arguments passed to the function to be called. """ if self._is_writer: getattr(self, fn_name)(*args, **kwargs) def initialize(self): """ Overview: Initialize the file writer. """ self.close() self._write_to_disk = self._default_writer_to_disk self._get_file_writer() self._lazy_initialized = True def __del__(self): """ Overview: Close the file writer. """ self.close() def enable_parallel(fn_name, fn): """ Overview: Decorator to enable parallel writing. Arguments: - fn_name (:obj:`str`): The name of the function to be called. - fn (:obj:`Callable`): The function to be called. """ def _parallel_fn(self: DistributedWriter, *args, **kwargs): if not self._lazy_initialized: self.initialize() if self._in_parallel and not self._is_writer: self._router.emit("distributed_writer", fn_name, *args, **kwargs) else: fn(self, *args, **kwargs) return _parallel_fn ready_to_parallel_fns = [ 'add_audio', 'add_custom_scalars', 'add_custom_scalars_marginchart', 'add_custom_scalars_multilinechart', 'add_embedding', 'add_figure', 'add_graph', 'add_graph_deprecated', 'add_histogram', 'add_histogram_raw', 'add_hparams', 'add_image', 'add_image_with_boxes', 'add_images', 'add_mesh', 'add_onnx_graph', 'add_openvino_graph', 'add_pr_curve', 'add_pr_curve_raw', 'add_scalar', 'add_scalars', 'add_text', 'add_video', ] for fn_name in ready_to_parallel_fns: if hasattr(DistributedWriter, fn_name): setattr(DistributedWriter, fn_name, enable_parallel(fn_name, getattr(DistributedWriter, fn_name))) # Examples: # In main, `distributed_writer.plugin(task.router, is_writer=True)`, # In middleware, `distributed_writer.record()` distributed_writer = DistributedWriter()