File size: 2,389 Bytes
e303d64
196ff11
e303d64
 
 
2414673
e303d64
 
196ff11
 
 
 
 
 
 
 
 
 
 
 
 
03e5907
eaaeefc
 
03e5907
 
 
196ff11
 
 
 
 
 
 
 
 
 
7b55fe6
 
 
 
196ff11
7b55fe6
 
 
 
 
 
 
fac2d98
 
 
 
 
 
196ff11
7b55fe6
e303d64
 
 
 
2414673
 
 
 
 
 
 
e303d64
 
 
fac2d98
 
 
 
7b55fe6
 
 
 
 
e303d64
7b55fe6
e303d64
7b55fe6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""Benchmarking and measurement utilities"""
import functools

import pynvml
import torch
from pynvml.nvml import NVMLError


def check_cuda_device(default_value):
    """
    wraps a function and returns the default value instead of running the
    wrapped function if cuda isn't available or the device is auto
    :param default_value:
    :return:
    """

    def deco(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            device = kwargs.get("device", args[0] if args else None)

            if (
                device is None
                or not torch.cuda.is_available()
                or device == "auto"
                or torch.device(device).type == "cpu"
            ):
                return default_value

            return func(*args, **kwargs)

        return wrapper

    return deco


@check_cuda_device(0.0)
def gpu_memory_usage(device=0):
    return torch.cuda.memory_allocated(device) / 1024.0**3


@check_cuda_device((0.0, 0.0, 0.0))
def gpu_memory_usage_all(device=0):
    usage = torch.cuda.memory_allocated(device) / 1024.0**3
    reserved = torch.cuda.memory_reserved(device) / 1024.0**3
    smi = gpu_memory_usage_smi(device)
    return usage, reserved - usage, max(0, smi - reserved)


def mps_memory_usage_all():
    usage = torch.mps.current_allocated_memory() / 1024.0**3
    reserved = torch.mps.driver_allocated_memory() / 1024.0**3
    return usage, reserved - usage, 0


@check_cuda_device(0.0)
def gpu_memory_usage_smi(device=0):
    if isinstance(device, torch.device):
        device = device.index
    if isinstance(device, str) and device.startswith("cuda:"):
        device = int(device[5:])
    try:
        pynvml.nvmlInit()
        handle = pynvml.nvmlDeviceGetHandleByIndex(device)
        info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        return info.used / 1024.0**3
    except NVMLError:
        return 0.0


def log_gpu_memory_usage(log, msg, device):
    if torch.backends.mps.is_available():
        usage, cache, misc = mps_memory_usage_all()
    else:
        usage, cache, misc = gpu_memory_usage_all(device)
    extras = []
    if cache > 0:
        extras.append(f"+{cache:.03f}GB cache")
    if misc > 0:
        extras.append(f"+{misc:.03f}GB misc")
    log.info(
        f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
    )
    return usage, cache, misc