Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,561 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import time
from typing import Optional, Union
import torch
from mmengine.dist.utils import master_only
from mmengine.logging import MMLogger, print_log
class TimeCounter:
"""A tool that counts the average running time of a function or a method.
Users can use it as a decorator or context manager to calculate the average
running time of code blocks.
Args:
log_interval (int): The interval of logging. Defaults to 1.
warmup_interval (int): The interval of warmup. Defaults to 1.
with_sync (bool): Whether to synchronize cuda. Defaults to True.
tag (str, optional): Function tag. Used to distinguish between
different functions or methods being called. Defaults to None.
logger (MMLogger, optional): Formatted logger used to record messages.
Defaults to None.
Examples:
>>> import time
>>> from mmengine.utils.dl_utils import TimeCounter
>>> @TimeCounter()
... def fun1():
... time.sleep(0.1)
... fun1()
[fun1]-time per run averaged in the past 1 runs: 100.0 ms
>>> @@TimeCounter(log_interval=2, tag='fun')
... def fun2():
... time.sleep(0.2)
>>> for _ in range(3):
... fun2()
[fun]-time per run averaged in the past 2 runs: 200.0 ms
>>> with TimeCounter(tag='fun3'):
... time.sleep(0.3)
[fun3]-time per run averaged in the past 1 runs: 300.0 ms
"""
instance_dict: dict = dict()
log_interval: int
warmup_interval: int
logger: Optional[MMLogger]
__count: int
__pure_inf_time: float
def __new__(cls,
log_interval: int = 1,
warmup_interval: int = 1,
with_sync: bool = True,
tag: Optional[str] = None,
logger: Optional[MMLogger] = None):
assert warmup_interval >= 1
if tag is not None and tag in cls.instance_dict:
return cls.instance_dict[tag]
instance = super().__new__(cls)
cls.instance_dict[tag] = instance
instance.log_interval = log_interval
instance.warmup_interval = warmup_interval
instance.with_sync = with_sync
instance.tag = tag
instance.logger = logger
instance.__count = 0
instance.__pure_inf_time = 0.
instance.__start_time = 0.
return instance
@master_only
def __call__(self, fn):
if self.tag is None:
self.tag = fn.__name__
def wrapper(*args, **kwargs):
self.__count += 1
if self.with_sync and torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.perf_counter()
result = fn(*args, **kwargs)
if self.with_sync and torch.cuda.is_available():
torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time
self.print_time(elapsed)
return result
return wrapper
@master_only
def __enter__(self):
assert self.tag is not None, 'In order to clearly distinguish ' \
'printing information in different ' \
'contexts, please specify the ' \
'tag parameter'
self.__count += 1
if self.with_sync and torch.cuda.is_available():
torch.cuda.synchronize()
self.__start_time = time.perf_counter()
@master_only
def __exit__(self, exc_type, exc_val, exc_tb):
if self.with_sync and torch.cuda.is_available():
torch.cuda.synchronize()
elapsed = time.perf_counter() - self.__start_time
self.print_time(elapsed)
def print_time(self, elapsed: Union[int, float]) -> None:
"""print times per count."""
if self.__count >= self.warmup_interval:
self.__pure_inf_time += elapsed
if self.__count % self.log_interval == 0:
times_per_count = 1000 * self.__pure_inf_time / (
self.__count - self.warmup_interval + 1)
print_log(
f'[{self.tag}]-time per run averaged in the past '
f'{self.__count} runs: {times_per_count:.1f} ms',
self.logger)
|