File size: 1,905 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from typing import Callable
import torch
from .time_helper_base import TimeWrapper


def get_cuda_time_wrapper() -> Callable[[], 'TimeWrapper']:
    """
    Overview:
        Return the ``TimeWrapperCuda`` class, this wrapper aims to ensure compatibility in no cuda device

    Returns:
        - TimeWrapperCuda(:obj:`class`): See ``TimeWrapperCuda`` class

    .. note::
        Must use ``torch.cuda.synchronize()``, reference: <https://blog.csdn.net/u013548568/article/details/81368019>

    """

    # TODO find a way to autodoc the class within method
    class TimeWrapperCuda(TimeWrapper):
        """
        Overview:
            A class method that inherit from ``TimeWrapper`` class

            Notes:
                Must use torch.cuda.synchronize(), reference: \
                <https://blog.csdn.net/u013548568/article/details/81368019>

        Interfaces:
            ``start_time``, ``end_time``
        """
        # cls variable is initialized on loading this class
        start_record = torch.cuda.Event(enable_timing=True)
        end_record = torch.cuda.Event(enable_timing=True)

        # overwrite
        @classmethod
        def start_time(cls):
            """
            Overview:
                Implement and overide the ``start_time`` method in ``TimeWrapper`` class
            """
            torch.cuda.synchronize()
            cls.start = cls.start_record.record()

        # overwrite
        @classmethod
        def end_time(cls):
            """
            Overview:
                Implement and overide the end_time method in ``TimeWrapper`` class
            Returns:
                - time(:obj:`float`): The time between ``start_time`` and ``end_time``
            """
            cls.end = cls.end_record.record()
            torch.cuda.synchronize()
            return cls.start_record.elapsed_time(cls.end_record) / 1000

    return TimeWrapperCuda