"""
Utility functions
"""
import pickle
from pathlib import Path

import pax
import toml
import yaml

from tacotron import Tacotron


def load_tacotron_config(config_file=Path("tacotron.toml")):
    """
    Load the project configurations
    """
    return toml.load(config_file)["tacotron"]


def load_tacotron_ckpt(net: pax.Module, optim: pax.Module, path):
    """
    load checkpoint from disk
    """
    with open(path, "rb") as f:
        dic = pickle.load(f)
    if net is not None:
        net = net.load_state_dict(dic["model_state_dict"])
    if optim is not None:
        optim = optim.load_state_dict(dic["optim_state_dict"])
    return dic["step"], net, optim


def create_tacotron_model(config):
    """
    return a random initialized Tacotron model
    """
    return Tacotron(
        mel_dim=config["MEL_DIM"],
        attn_bias=config["ATTN_BIAS"],
        rr=config["RR"],
        max_rr=config["MAX_RR"],
        mel_min=config["MEL_MIN"],
        sigmoid_noise=config["SIGMOID_NOISE"],
        pad_token=config["PAD_TOKEN"],
        prenet_dim=config["PRENET_DIM"],
        attn_hidden_dim=config["ATTN_HIDDEN_DIM"],
        attn_rnn_dim=config["ATTN_RNN_DIM"],
        rnn_dim=config["RNN_DIM"],
        postnet_dim=config["POSTNET_DIM"],
        text_dim=config["TEXT_DIM"],
    )


def load_wavegru_config(config_file):
    """
    Load project configurations
    """
    with open(config_file, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)


def load_wavegru_ckpt(net, optim, ckpt_file):
    """
    load training checkpoint from file
    """
    with open(ckpt_file, "rb") as f:
        dic = pickle.load(f)

    if net is not None:
        net = net.load_state_dict(dic["net_state_dict"])
    if optim is not None:
        optim = optim.load_state_dict(dic["optim_state_dict"])
    return dic["step"], net, optim