|
import gc |
|
import datetime |
|
import inspect |
|
|
|
import torch |
|
import numpy as np |
|
|
|
dtype_memory_size_dict = { |
|
torch.float64: 64 / 8, |
|
torch.double: 64 / 8, |
|
torch.float32: 32 / 8, |
|
torch.float: 32 / 8, |
|
torch.float16: 16 / 8, |
|
torch.half: 16 / 8, |
|
torch.int64: 64 / 8, |
|
torch.long: 64 / 8, |
|
torch.int32: 32 / 8, |
|
torch.int: 32 / 8, |
|
torch.int16: 16 / 8, |
|
torch.short: 16 / 6, |
|
torch.uint8: 8 / 8, |
|
torch.int8: 8 / 8, |
|
} |
|
|
|
if getattr(torch, "bfloat16", None) is not None: |
|
dtype_memory_size_dict[torch.bfloat16] = 16 / 8 |
|
if getattr(torch, "bool", None) is not None: |
|
dtype_memory_size_dict[ |
|
torch.bool] = 8 / 8 |
|
|
|
|
|
def get_mem_space(x): |
|
try: |
|
ret = dtype_memory_size_dict[x] |
|
except KeyError: |
|
print(f"dtype {x} is not supported!") |
|
return ret |
|
|
|
|
|
import contextlib, sys |
|
|
|
@contextlib.contextmanager |
|
def file_writer(file_name = None): |
|
|
|
writer = open(file_name, "aw") if file_name is not None else sys.stdout |
|
|
|
yield writer |
|
|
|
if file_name != None: writer.close() |
|
|
|
|
|
class MemTracker(object): |
|
""" |
|
Class used to track pytorch memory usage |
|
Arguments: |
|
detail(bool, default True): whether the function shows the detail gpu memory usage |
|
path(str): where to save log file |
|
verbose(bool, default False): whether show the trivial exception |
|
device(int): GPU number, default is 0 |
|
""" |
|
|
|
def __init__(self, detail=True, path='', verbose=False, device=0, log_to_disk=False): |
|
self.print_detail = detail |
|
self.last_tensor_sizes = set() |
|
self.gpu_profile_fn = path + f'{datetime.datetime.now():%d-%b-%y-%H:%M:%S}-gpu_mem_track.txt' |
|
self.verbose = verbose |
|
self.begin = True |
|
self.device = device |
|
self.log_to_disk = log_to_disk |
|
|
|
def get_tensors(self): |
|
for obj in gc.get_objects(): |
|
try: |
|
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): |
|
tensor = obj |
|
else: |
|
continue |
|
if tensor.is_cuda: |
|
yield tensor |
|
except Exception as e: |
|
if self.verbose: |
|
print('A trivial exception occurred: {}'.format(e)) |
|
|
|
def get_tensor_usage(self): |
|
sizes = [np.prod(np.array(tensor.size())) * get_mem_space(tensor.dtype) for tensor in self.get_tensors()] |
|
return np.sum(sizes) / 1024 ** 2 |
|
|
|
def get_allocate_usage(self): |
|
return torch.cuda.memory_allocated() / 1024 ** 2 |
|
|
|
def clear_cache(self): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
def print_all_gpu_tensor(self, file=None): |
|
for x in self.get_tensors(): |
|
print(x.size(), x.dtype, np.prod(np.array(x.size())) * get_mem_space(x.dtype) / 1024 ** 2, file=file) |
|
|
|
def track(self): |
|
""" |
|
Track the GPU memory usage |
|
""" |
|
frameinfo = inspect.stack()[1] |
|
where_str = frameinfo.filename + ' line ' + str(frameinfo.lineno) + ': ' + frameinfo.function |
|
|
|
if self.log_to_disk: |
|
file_name = self.gpu_profile_fn |
|
else: |
|
file_name = None |
|
|
|
with file_writer(file_name) as f: |
|
|
|
if self.begin: |
|
f.write(f"GPU Memory Track | {datetime.datetime.now():%d-%b-%y-%H:%M:%S} |" |
|
f" Total Tensor Used Memory:{self.get_tensor_usage():<7.1f}Mb" |
|
f" Total Allocated Memory:{self.get_allocate_usage():<7.1f}Mb\n\n") |
|
self.begin = False |
|
|
|
if self.print_detail is True: |
|
ts_list = [(tensor.size(), tensor.dtype) for tensor in self.get_tensors()] |
|
new_tensor_sizes = {(type(x), |
|
tuple(x.size()), |
|
ts_list.count((x.size(), x.dtype)), |
|
np.prod(np.array(x.size())) * get_mem_space(x.dtype) / 1024 ** 2, |
|
x.dtype) for x in self.get_tensors()} |
|
for t, s, n, m, data_type in new_tensor_sizes - self.last_tensor_sizes: |
|
f.write( |
|
f'+ | {str(n)} * Size:{str(s):<20} | Memory: {str(m * n)[:6]} M | {str(t):<20} | {data_type}\n') |
|
for t, s, n, m, data_type in self.last_tensor_sizes - new_tensor_sizes: |
|
f.write( |
|
f'- | {str(n)} * Size:{str(s):<20} | Memory: {str(m * n)[:6]} M | {str(t):<20} | {data_type}\n') |
|
|
|
self.last_tensor_sizes = new_tensor_sizes |
|
|
|
f.write(f"\nAt {where_str:<50}" |
|
f" Total Tensor Used Memory:{self.get_tensor_usage():<7.1f}Mb" |
|
f" Total Allocated Memory:{self.get_allocate_usage():<7.1f}Mb\n\n") |
|
|