File size: 1,709 Bytes
db26c81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import time
import torch

act_dict = {}
act_dict['none'] = lambda x: x
act_dict['relu'] = torch.relu
act_dict['tanh'] = torch.tanh
act_dict['sigmoid'] = torch.sigmoid


def get_act(act):
    return act_dict[act]


def to_list(var):
    if isinstance(var, dict):
        return {k: to_list(v) for k, v in var.items()}
    elif isinstance(var, list):
        return [to_list(v) for v in var]
    elif isinstance(var, tuple):
        return (to_list(v) for v in var)
    elif isinstance(var, torch.Tensor):
        return var.tolist()
    else:
        return var


def repeat(tensor, size, dim=0):
    return tensor.repeat_interleave(size, dim)


def get_default_device():
    if not torch.cuda.is_available():
        return torch.device("cpu")

    cmd = 'nvidia-smi -q -d Memory | grep -A4 GPU | grep Free'
    with os.popen(cmd) as result:
        max_free_mem = 0
        max_cuda_index = -1
        for i, line in enumerate(result):
            free_mem = int(line.strip().split()[2])
            if free_mem > max_free_mem:
                max_free_mem = free_mem
                max_cuda_index = i

    return torch.device("cuda:{}".format(max_cuda_index))


def cumem_stats(device, msg):
    torch.cuda.empty_cache()
    print("{}, device:{}, memory_allocated: {:.3f}G".format(msg, device,
                                                            torch.cuda.memory_allocated(device) / (1024 * 1024 * 1024)))


cutime_stats_time = None


def cutime_stats(device, msg=''):
    global cutime_stats_time
    torch.cuda.synchronize(device)
    if cutime_stats_time is not None:
        print("{} time: {:.6f}s".format(msg, time.time() - cutime_stats_time))

    cutime_stats_time = time.time()