File size: 1,276 Bytes
d90b3a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import torch.cuda
class Metric:
"""
Dumb utility to collect and report average wall-time metrics.
"""
def __init__(self, label):
self.label = label
self.measurements = []
def collect(self, measurement):
self.measurements.append(measurement)
def get_measurements(self):
return self.measurements[:]
def report(self):
print(
self.label,
torch.quantile(torch.tensor(self.measurements), torch.arange(10) / 10.0),
)
def monitor_method_cuda_wall_times(metric, obj, methodname):
"""
Measure timings for a method on an object or class.
For instance:
>>> metric = Metric('!LNORM')
>>> monitor_method_wall_times(metric, LayerNorm, 'forward')
"""
oldmeth = getattr(obj, methodname)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
def newmeth(*args, **kw):
start_event.record()
try:
return oldmeth(*args, **kw)
finally:
end_event.record()
torch.cuda.synchronize()
elapsed = start_event.elapsed_time(end_event)
metric.collect(elapsed)
metric.report()
setattr(obj, methodname, newmeth)
|