|
import numpy as np |
|
from collections import deque |
|
from ditk import logging |
|
from time import time |
|
|
|
from ding.framework import task |
|
from typing import TYPE_CHECKING |
|
if TYPE_CHECKING: |
|
from ding.framework.context import Context |
|
|
|
|
|
def epoch_timer(print_per: int = 1, smooth_window: int = 10): |
|
""" |
|
Overview: |
|
Print time cost of each epoch. |
|
Arguments: |
|
- print_per (:obj:`int`): Print each N epoch. |
|
- smooth_window (:obj:`int`): The window size to smooth the mean. |
|
""" |
|
records = deque(maxlen=print_per * smooth_window) |
|
|
|
def _epoch_timer(ctx: "Context"): |
|
start = time() |
|
yield |
|
time_cost = time() - start |
|
records.append(time_cost) |
|
if ctx.total_step % print_per == 0: |
|
logging.info( |
|
"[Epoch Timer][Node:{:>2}]: Cost: {:.2f}ms, Mean: {:.2f}ms".format( |
|
task.router.node_id or 0, time_cost * 1000, |
|
np.mean(records) * 1000 |
|
) |
|
) |
|
|
|
return _epoch_timer |
|
|