File size: 4,561 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import time
from typing import Optional, Union

import torch

from mmengine.dist.utils import master_only
from mmengine.logging import MMLogger, print_log


class TimeCounter:
    """A tool that counts the average running time of a function or a method.
    Users can use it as a decorator or context manager to calculate the average
    running time of code blocks.

    Args:
        log_interval (int): The interval of logging. Defaults to 1.
        warmup_interval (int): The interval of warmup. Defaults to 1.
        with_sync (bool): Whether to synchronize cuda. Defaults to True.
        tag (str, optional): Function tag. Used to distinguish between
            different functions or methods being called. Defaults to None.
        logger (MMLogger, optional): Formatted logger used to record messages.
                Defaults to None.

    Examples:
        >>> import time
        >>> from mmengine.utils.dl_utils import TimeCounter
        >>> @TimeCounter()
        ... def fun1():
        ...     time.sleep(0.1)
        ... fun1()
        [fun1]-time per run averaged in the past 1 runs: 100.0 ms

        >>> @@TimeCounter(log_interval=2, tag='fun')
        ... def fun2():
        ...    time.sleep(0.2)
        >>> for _ in range(3):
        ...    fun2()
        [fun]-time per run averaged in the past 2 runs: 200.0 ms

        >>> with TimeCounter(tag='fun3'):
        ...      time.sleep(0.3)
        [fun3]-time per run averaged in the past 1 runs: 300.0 ms
    """

    instance_dict: dict = dict()

    log_interval: int
    warmup_interval: int
    logger: Optional[MMLogger]
    __count: int
    __pure_inf_time: float

    def __new__(cls,
                log_interval: int = 1,
                warmup_interval: int = 1,
                with_sync: bool = True,
                tag: Optional[str] = None,
                logger: Optional[MMLogger] = None):
        assert warmup_interval >= 1
        if tag is not None and tag in cls.instance_dict:
            return cls.instance_dict[tag]

        instance = super().__new__(cls)
        cls.instance_dict[tag] = instance

        instance.log_interval = log_interval
        instance.warmup_interval = warmup_interval
        instance.with_sync = with_sync
        instance.tag = tag
        instance.logger = logger

        instance.__count = 0
        instance.__pure_inf_time = 0.
        instance.__start_time = 0.

        return instance

    @master_only
    def __call__(self, fn):
        if self.tag is None:
            self.tag = fn.__name__

        def wrapper(*args, **kwargs):
            self.__count += 1

            if self.with_sync and torch.cuda.is_available():
                torch.cuda.synchronize()
            start_time = time.perf_counter()

            result = fn(*args, **kwargs)

            if self.with_sync and torch.cuda.is_available():
                torch.cuda.synchronize()

            elapsed = time.perf_counter() - start_time
            self.print_time(elapsed)

            return result

        return wrapper

    @master_only
    def __enter__(self):
        assert self.tag is not None, 'In order to clearly distinguish ' \
                                     'printing information in different ' \
                                     'contexts, please specify the ' \
                                     'tag parameter'

        self.__count += 1

        if self.with_sync and torch.cuda.is_available():
            torch.cuda.synchronize()
        self.__start_time = time.perf_counter()

    @master_only
    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.with_sync and torch.cuda.is_available():
            torch.cuda.synchronize()
        elapsed = time.perf_counter() - self.__start_time
        self.print_time(elapsed)

    def print_time(self, elapsed: Union[int, float]) -> None:
        """print times per count."""
        if self.__count >= self.warmup_interval:
            self.__pure_inf_time += elapsed

            if self.__count % self.log_interval == 0:
                times_per_count = 1000 * self.__pure_inf_time / (
                    self.__count - self.warmup_interval + 1)
                print_log(
                    f'[{self.tag}]-time per run averaged in the past '
                    f'{self.__count} runs: {times_per_count:.1f} ms',
                    self.logger)