Vincentqyw
update: features and matchers
a80d6bb
raw
history blame
2.18 kB
"""
This file contains some useful functions for train / val.
"""
import os
import numpy as np
import torch
#################
## image utils ##
#################
def convert_image(input_tensor, axis):
""" Convert single channel images to 3-channel images. """
image_lst = [input_tensor for _ in range(3)]
outputs = np.concatenate(image_lst, axis)
return outputs
######################
## checkpoint utils ##
######################
def get_latest_checkpoint(checkpoint_root, checkpoint_name,
device=torch.device("cuda")):
""" Get the latest checkpoint or by filename. """
# Load specific checkpoint
if checkpoint_name is not None:
checkpoint = torch.load(
os.path.join(checkpoint_root, checkpoint_name),
map_location=device)
# Load the latest checkpoint
else:
lastest_checkpoint = sorted(os.listdir(os.path.join(
checkpoint_root, "*.tar")))[-1]
checkpoint = torch.load(os.path.join(
checkpoint_root, lastest_checkpoint), map_location=device)
return checkpoint
def remove_old_checkpoints(checkpoint_root, max_ckpt=15):
""" Remove the outdated checkpoints. """
# Get sorted list of checkpoints
checkpoint_list = sorted(
[_ for _ in os.listdir(os.path.join(checkpoint_root))
if _.endswith(".tar")])
# Get the checkpoints to be removed
if len(checkpoint_list) > max_ckpt:
remove_list = checkpoint_list[:-max_ckpt]
for _ in remove_list:
full_name = os.path.join(checkpoint_root, _)
os.remove(full_name)
print("[Debug] Remove outdated checkpoint %s" % (full_name))
def adapt_checkpoint(state_dict):
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
return new_state_dict
################
## HDF5 utils ##
################
def parse_h5_data(h5_data):
""" Parse h5 dataset. """
output_data = {}
for key in h5_data.keys():
output_data[key] = np.array(h5_data[key])
return output_data