Spaces:
Running
Running
File size: 2,181 Bytes
a80d6bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
"""
This file contains some useful functions for train / val.
"""
import os
import numpy as np
import torch
#################
## image utils ##
#################
def convert_image(input_tensor, axis):
""" Convert single channel images to 3-channel images. """
image_lst = [input_tensor for _ in range(3)]
outputs = np.concatenate(image_lst, axis)
return outputs
######################
## checkpoint utils ##
######################
def get_latest_checkpoint(checkpoint_root, checkpoint_name,
device=torch.device("cuda")):
""" Get the latest checkpoint or by filename. """
# Load specific checkpoint
if checkpoint_name is not None:
checkpoint = torch.load(
os.path.join(checkpoint_root, checkpoint_name),
map_location=device)
# Load the latest checkpoint
else:
lastest_checkpoint = sorted(os.listdir(os.path.join(
checkpoint_root, "*.tar")))[-1]
checkpoint = torch.load(os.path.join(
checkpoint_root, lastest_checkpoint), map_location=device)
return checkpoint
def remove_old_checkpoints(checkpoint_root, max_ckpt=15):
""" Remove the outdated checkpoints. """
# Get sorted list of checkpoints
checkpoint_list = sorted(
[_ for _ in os.listdir(os.path.join(checkpoint_root))
if _.endswith(".tar")])
# Get the checkpoints to be removed
if len(checkpoint_list) > max_ckpt:
remove_list = checkpoint_list[:-max_ckpt]
for _ in remove_list:
full_name = os.path.join(checkpoint_root, _)
os.remove(full_name)
print("[Debug] Remove outdated checkpoint %s" % (full_name))
def adapt_checkpoint(state_dict):
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
return new_state_dict
################
## HDF5 utils ##
################
def parse_h5_data(h5_data):
""" Parse h5 dataset. """
output_data = {}
for key in h5_data.keys():
output_data[key] = np.array(h5_data[key])
return output_data
|