File size: 1,199 Bytes
404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import torch
from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler
from contextlib import contextmanager
from pytorch_lightning.utilities import rank_zero_only
class InferenceProfiler(SimpleProfiler):
"""
This profiler records duration of actions with cuda.synchronize()
Use this in test time.
"""
def __init__(self):
super().__init__()
self.start = rank_zero_only(self.start)
self.stop = rank_zero_only(self.stop)
self.summary = rank_zero_only(self.summary)
@contextmanager
def profile(self, action_name: str) -> None:
try:
torch.cuda.synchronize()
self.start(action_name)
yield action_name
finally:
torch.cuda.synchronize()
self.stop(action_name)
def build_profiler(name):
if name == "inference":
return InferenceProfiler()
elif name == "pytorch":
from pytorch_lightning.profiler import PyTorchProfiler
return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100)
elif name is None:
return PassThroughProfiler()
else:
raise ValueError(f"Invalid profiler: {name}")
|