"""
This file contains some useful functions for train / val.
"""
import os
import numpy as np
import torch


#################
## image utils ##
#################
def convert_image(input_tensor, axis):
    """Convert single channel images to 3-channel images."""
    image_lst = [input_tensor for _ in range(3)]
    outputs = np.concatenate(image_lst, axis)
    return outputs


######################
## checkpoint utils ##
######################
def get_latest_checkpoint(
    checkpoint_root, checkpoint_name, device=torch.device("cuda")
):
    """Get the latest checkpoint or by filename."""
    # Load specific checkpoint
    if checkpoint_name is not None:
        checkpoint = torch.load(
            os.path.join(checkpoint_root, checkpoint_name), map_location=device
        )
    # Load the latest checkpoint
    else:
        lastest_checkpoint = sorted(os.listdir(os.path.join(checkpoint_root, "*.tar")))[
            -1
        ]
        checkpoint = torch.load(
            os.path.join(checkpoint_root, lastest_checkpoint), map_location=device
        )
    return checkpoint


def remove_old_checkpoints(checkpoint_root, max_ckpt=15):
    """Remove the outdated checkpoints."""
    # Get sorted list of checkpoints
    checkpoint_list = sorted(
        [_ for _ in os.listdir(os.path.join(checkpoint_root)) if _.endswith(".tar")]
    )

    # Get the checkpoints to be removed
    if len(checkpoint_list) > max_ckpt:
        remove_list = checkpoint_list[:-max_ckpt]
        for _ in remove_list:
            full_name = os.path.join(checkpoint_root, _)
            os.remove(full_name)
            print("[Debug] Remove outdated checkpoint %s" % (full_name))


def adapt_checkpoint(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith("module."):
            new_state_dict[k[7:]] = v
        else:
            new_state_dict[k] = v
    return new_state_dict


################
## HDF5 utils ##
################
def parse_h5_data(h5_data):
    """Parse h5 dataset."""
    output_data = {}
    for key in h5_data.keys():
        output_data[key] = np.array(h5_data[key])

    return output_data