Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import time | |
from typing import Optional, Union | |
import torch | |
from mmengine.dist.utils import master_only | |
from mmengine.logging import MMLogger, print_log | |
class TimeCounter: | |
"""A tool that counts the average running time of a function or a method. | |
Users can use it as a decorator or context manager to calculate the average | |
running time of code blocks. | |
Args: | |
log_interval (int): The interval of logging. Defaults to 1. | |
warmup_interval (int): The interval of warmup. Defaults to 1. | |
with_sync (bool): Whether to synchronize cuda. Defaults to True. | |
tag (str, optional): Function tag. Used to distinguish between | |
different functions or methods being called. Defaults to None. | |
logger (MMLogger, optional): Formatted logger used to record messages. | |
Defaults to None. | |
Examples: | |
>>> import time | |
>>> from mmengine.utils.dl_utils import TimeCounter | |
>>> @TimeCounter() | |
... def fun1(): | |
... time.sleep(0.1) | |
... fun1() | |
[fun1]-time per run averaged in the past 1 runs: 100.0 ms | |
>>> @@TimeCounter(log_interval=2, tag='fun') | |
... def fun2(): | |
... time.sleep(0.2) | |
>>> for _ in range(3): | |
... fun2() | |
[fun]-time per run averaged in the past 2 runs: 200.0 ms | |
>>> with TimeCounter(tag='fun3'): | |
... time.sleep(0.3) | |
[fun3]-time per run averaged in the past 1 runs: 300.0 ms | |
""" | |
instance_dict: dict = dict() | |
log_interval: int | |
warmup_interval: int | |
logger: Optional[MMLogger] | |
__count: int | |
__pure_inf_time: float | |
def __new__(cls, | |
log_interval: int = 1, | |
warmup_interval: int = 1, | |
with_sync: bool = True, | |
tag: Optional[str] = None, | |
logger: Optional[MMLogger] = None): | |
assert warmup_interval >= 1 | |
if tag is not None and tag in cls.instance_dict: | |
return cls.instance_dict[tag] | |
instance = super().__new__(cls) | |
cls.instance_dict[tag] = instance | |
instance.log_interval = log_interval | |
instance.warmup_interval = warmup_interval | |
instance.with_sync = with_sync | |
instance.tag = tag | |
instance.logger = logger | |
instance.__count = 0 | |
instance.__pure_inf_time = 0. | |
instance.__start_time = 0. | |
return instance | |
def __call__(self, fn): | |
if self.tag is None: | |
self.tag = fn.__name__ | |
def wrapper(*args, **kwargs): | |
self.__count += 1 | |
if self.with_sync and torch.cuda.is_available(): | |
torch.cuda.synchronize() | |
start_time = time.perf_counter() | |
result = fn(*args, **kwargs) | |
if self.with_sync and torch.cuda.is_available(): | |
torch.cuda.synchronize() | |
elapsed = time.perf_counter() - start_time | |
self.print_time(elapsed) | |
return result | |
return wrapper | |
def __enter__(self): | |
assert self.tag is not None, 'In order to clearly distinguish ' \ | |
'printing information in different ' \ | |
'contexts, please specify the ' \ | |
'tag parameter' | |
self.__count += 1 | |
if self.with_sync and torch.cuda.is_available(): | |
torch.cuda.synchronize() | |
self.__start_time = time.perf_counter() | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
if self.with_sync and torch.cuda.is_available(): | |
torch.cuda.synchronize() | |
elapsed = time.perf_counter() - self.__start_time | |
self.print_time(elapsed) | |
def print_time(self, elapsed: Union[int, float]) -> None: | |
"""print times per count.""" | |
if self.__count >= self.warmup_interval: | |
self.__pure_inf_time += elapsed | |
if self.__count % self.log_interval == 0: | |
times_per_count = 1000 * self.__pure_inf_time / ( | |
self.__count - self.warmup_interval + 1) | |
print_log( | |
f'[{self.tag}]-time per run averaged in the past ' | |
f'{self.__count} runs: {times_per_count:.1f} ms', | |
self.logger) | |