from typing import Callable import torch from .time_helper_base import TimeWrapper def get_cuda_time_wrapper() -> Callable[[], 'TimeWrapper']: """ Overview: Return the ``TimeWrapperCuda`` class, this wrapper aims to ensure compatibility in no cuda device Returns: - TimeWrapperCuda(:obj:`class`): See ``TimeWrapperCuda`` class .. note:: Must use ``torch.cuda.synchronize()``, reference: """ # TODO find a way to autodoc the class within method class TimeWrapperCuda(TimeWrapper): """ Overview: A class method that inherit from ``TimeWrapper`` class Notes: Must use torch.cuda.synchronize(), reference: \ Interfaces: ``start_time``, ``end_time`` """ # cls variable is initialized on loading this class start_record = torch.cuda.Event(enable_timing=True) end_record = torch.cuda.Event(enable_timing=True) # overwrite @classmethod def start_time(cls): """ Overview: Implement and overide the ``start_time`` method in ``TimeWrapper`` class """ torch.cuda.synchronize() cls.start = cls.start_record.record() # overwrite @classmethod def end_time(cls): """ Overview: Implement and overide the end_time method in ``TimeWrapper`` class Returns: - time(:obj:`float`): The time between ``start_time`` and ``end_time`` """ cls.end = cls.end_record.record() torch.cuda.synchronize() return cls.start_record.elapsed_time(cls.end_record) / 1000 return TimeWrapperCuda