File size: 1,905 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
from typing import Callable
import torch
from .time_helper_base import TimeWrapper
def get_cuda_time_wrapper() -> Callable[[], 'TimeWrapper']:
"""
Overview:
Return the ``TimeWrapperCuda`` class, this wrapper aims to ensure compatibility in no cuda device
Returns:
- TimeWrapperCuda(:obj:`class`): See ``TimeWrapperCuda`` class
.. note::
Must use ``torch.cuda.synchronize()``, reference: <https://blog.csdn.net/u013548568/article/details/81368019>
"""
# TODO find a way to autodoc the class within method
class TimeWrapperCuda(TimeWrapper):
"""
Overview:
A class method that inherit from ``TimeWrapper`` class
Notes:
Must use torch.cuda.synchronize(), reference: \
<https://blog.csdn.net/u013548568/article/details/81368019>
Interfaces:
``start_time``, ``end_time``
"""
# cls variable is initialized on loading this class
start_record = torch.cuda.Event(enable_timing=True)
end_record = torch.cuda.Event(enable_timing=True)
# overwrite
@classmethod
def start_time(cls):
"""
Overview:
Implement and overide the ``start_time`` method in ``TimeWrapper`` class
"""
torch.cuda.synchronize()
cls.start = cls.start_record.record()
# overwrite
@classmethod
def end_time(cls):
"""
Overview:
Implement and overide the end_time method in ``TimeWrapper`` class
Returns:
- time(:obj:`float`): The time between ``start_time`` and ``end_time``
"""
cls.end = cls.end_record.record()
torch.cuda.synchronize()
return cls.start_record.elapsed_time(cls.end_record) / 1000
return TimeWrapperCuda
|