NEOX / megatron /devutil.py
akswelh's picture
Upload 251 files
d90b3a8 verified
raw
history blame
1.28 kB
import torch.cuda
class Metric:
"""
Dumb utility to collect and report average wall-time metrics.
"""
def __init__(self, label):
self.label = label
self.measurements = []
def collect(self, measurement):
self.measurements.append(measurement)
def get_measurements(self):
return self.measurements[:]
def report(self):
print(
self.label,
torch.quantile(torch.tensor(self.measurements), torch.arange(10) / 10.0),
)
def monitor_method_cuda_wall_times(metric, obj, methodname):
"""
Measure timings for a method on an object or class.
For instance:
>>> metric = Metric('!LNORM')
>>> monitor_method_wall_times(metric, LayerNorm, 'forward')
"""
oldmeth = getattr(obj, methodname)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
def newmeth(*args, **kw):
start_event.record()
try:
return oldmeth(*args, **kw)
finally:
end_event.record()
torch.cuda.synchronize()
elapsed = start_event.elapsed_time(end_event)
metric.collect(elapsed)
metric.report()
setattr(obj, methodname, newmeth)