rawalkhirodkar's picture
Add initial commit
28c256d
raw
history blame
4.56 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import time
from typing import Optional, Union
import torch
from mmengine.dist.utils import master_only
from mmengine.logging import MMLogger, print_log
class TimeCounter:
"""A tool that counts the average running time of a function or a method.
Users can use it as a decorator or context manager to calculate the average
running time of code blocks.
Args:
log_interval (int): The interval of logging. Defaults to 1.
warmup_interval (int): The interval of warmup. Defaults to 1.
with_sync (bool): Whether to synchronize cuda. Defaults to True.
tag (str, optional): Function tag. Used to distinguish between
different functions or methods being called. Defaults to None.
logger (MMLogger, optional): Formatted logger used to record messages.
Defaults to None.
Examples:
>>> import time
>>> from mmengine.utils.dl_utils import TimeCounter
>>> @TimeCounter()
... def fun1():
... time.sleep(0.1)
... fun1()
[fun1]-time per run averaged in the past 1 runs: 100.0 ms
>>> @@TimeCounter(log_interval=2, tag='fun')
... def fun2():
... time.sleep(0.2)
>>> for _ in range(3):
... fun2()
[fun]-time per run averaged in the past 2 runs: 200.0 ms
>>> with TimeCounter(tag='fun3'):
... time.sleep(0.3)
[fun3]-time per run averaged in the past 1 runs: 300.0 ms
"""
instance_dict: dict = dict()
log_interval: int
warmup_interval: int
logger: Optional[MMLogger]
__count: int
__pure_inf_time: float
def __new__(cls,
log_interval: int = 1,
warmup_interval: int = 1,
with_sync: bool = True,
tag: Optional[str] = None,
logger: Optional[MMLogger] = None):
assert warmup_interval >= 1
if tag is not None and tag in cls.instance_dict:
return cls.instance_dict[tag]
instance = super().__new__(cls)
cls.instance_dict[tag] = instance
instance.log_interval = log_interval
instance.warmup_interval = warmup_interval
instance.with_sync = with_sync
instance.tag = tag
instance.logger = logger
instance.__count = 0
instance.__pure_inf_time = 0.
instance.__start_time = 0.
return instance
@master_only
def __call__(self, fn):
if self.tag is None:
self.tag = fn.__name__
def wrapper(*args, **kwargs):
self.__count += 1
if self.with_sync and torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.perf_counter()
result = fn(*args, **kwargs)
if self.with_sync and torch.cuda.is_available():
torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time
self.print_time(elapsed)
return result
return wrapper
@master_only
def __enter__(self):
assert self.tag is not None, 'In order to clearly distinguish ' \
'printing information in different ' \
'contexts, please specify the ' \
'tag parameter'
self.__count += 1
if self.with_sync and torch.cuda.is_available():
torch.cuda.synchronize()
self.__start_time = time.perf_counter()
@master_only
def __exit__(self, exc_type, exc_val, exc_tb):
if self.with_sync and torch.cuda.is_available():
torch.cuda.synchronize()
elapsed = time.perf_counter() - self.__start_time
self.print_time(elapsed)
def print_time(self, elapsed: Union[int, float]) -> None:
"""print times per count."""
if self.__count >= self.warmup_interval:
self.__pure_inf_time += elapsed
if self.__count % self.log_interval == 0:
times_per_count = 1000 * self.__pure_inf_time / (
self.__count - self.warmup_interval + 1)
print_log(
f'[{self.tag}]-time per run averaged in the past '
f'{self.__count} runs: {times_per_count:.1f} ms',
self.logger)