File size: 2,181 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
This file contains some useful functions for train / val.
"""
import os
import numpy as np
import torch


#################
## image utils ##
#################
def convert_image(input_tensor, axis):
    """ Convert single channel images to 3-channel images. """
    image_lst = [input_tensor for _ in range(3)]
    outputs = np.concatenate(image_lst, axis)
    return outputs


######################
## checkpoint utils ##
######################
def get_latest_checkpoint(checkpoint_root, checkpoint_name,
                          device=torch.device("cuda")):
    """ Get the latest checkpoint or by filename. """
    # Load specific checkpoint
    if checkpoint_name is not None:
        checkpoint = torch.load(
            os.path.join(checkpoint_root, checkpoint_name),
            map_location=device)
    # Load the latest checkpoint
    else:
        lastest_checkpoint = sorted(os.listdir(os.path.join(
            checkpoint_root, "*.tar")))[-1]
        checkpoint = torch.load(os.path.join(
            checkpoint_root, lastest_checkpoint), map_location=device)
    return checkpoint


def remove_old_checkpoints(checkpoint_root, max_ckpt=15):
    """ Remove the outdated checkpoints. """
    # Get sorted list of checkpoints
    checkpoint_list = sorted(
        [_ for _ in os.listdir(os.path.join(checkpoint_root))
         if _.endswith(".tar")])

    # Get the checkpoints to be removed
    if len(checkpoint_list) > max_ckpt:
        remove_list = checkpoint_list[:-max_ckpt]
        for _ in remove_list:
            full_name = os.path.join(checkpoint_root, _)
            os.remove(full_name)
            print("[Debug] Remove outdated checkpoint %s" % (full_name))


def adapt_checkpoint(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[7:]] = v
        else:
            new_state_dict[k] = v
    return new_state_dict


################
## HDF5 utils ##
################
def parse_h5_data(h5_data):
    """ Parse h5 dataset. """
    output_data = {}
    for key in h5_data.keys():
        output_data[key] = np.array(h5_data[key])
        
    return output_data