import base64
import collections
import os
import random
from typing import Dict, List, Union

import numpy as np
import torch
import tqdm
from torch.utils.data import Dataset

from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy

# to prevent too many open files error as suggested here
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
torch.multiprocessing.set_sharing_strategy("file_system")


def _parse_sample(item):
    language_name = None
    attn_file = None
    if len(item) == 5:
        text, wav_file, speaker_name, language_name, attn_file = item
    elif len(item) == 4:
        text, wav_file, speaker_name, language_name = item
    elif len(item) == 3:
        text, wav_file, speaker_name = item
    else:
        raise ValueError(" [!] Dataset cannot parse the sample.")
    return text, wav_file, speaker_name, language_name, attn_file


def noise_augment_audio(wav):
    return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape)


def string2filename(string):
    # generate a safe and reversible filename based on a string
    filename = base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore")
    return filename


class TTSDataset(Dataset):
    def __init__(
        self,
        outputs_per_step: int = 1,
        compute_linear_spec: bool = False,
        ap: AudioProcessor = None,
        samples: List[Dict] = None,
        tokenizer: "TTSTokenizer" = None,
        compute_f0: bool = False,
        compute_energy: bool = False,
        f0_cache_path: str = None,
        energy_cache_path: str = None,
        return_wav: bool = False,
        batch_group_size: int = 0,
        min_text_len: int = 0,
        max_text_len: int = float("inf"),
        min_audio_len: int = 0,
        max_audio_len: int = float("inf"),
        phoneme_cache_path: str = None,
        precompute_num_workers: int = 0,
        speaker_id_mapping: Dict = None,
        d_vector_mapping: Dict = None,
        language_id_mapping: Dict = None,
        use_noise_augment: bool = False,
        start_by_longest: bool = False,
        verbose: bool = False,
    ):
        """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.

        If you need something different, you can subclass and override.

        Args:
            outputs_per_step (int): Number of time frames predicted per step.

            compute_linear_spec (bool): compute linear spectrogram if True.

            ap (TTS.tts.utils.AudioProcessor): Audio processor object.

            samples (list): List of dataset samples.

            tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else
                use the given. Defaults to None.

            compute_f0 (bool): compute f0 if True. Defaults to False.

            compute_energy (bool): compute energy if True. Defaults to False.

            f0_cache_path (str): Path to store f0 cache. Defaults to None.

            energy_cache_path (str): Path to store energy cache. Defaults to None.

            return_wav (bool): Return the waveform of the sample. Defaults to False.

            batch_group_size (int): Range of batch randomization after sorting
                sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
                batch. Set 0 to disable. Defaults to 0.

            min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored.
                Defaults to 0.

            max_text_len (int): Maximum length of input text to be used. All longer samples will be ignored.
                Defaults to float("inf").

            min_audio_len (int): Minimum length of input audio to be used. All shorter samples will be ignored.
                Defaults to 0.

            max_audio_len (int): Maximum length of input audio to be used. All longer samples will be ignored.
                The maximum length in the dataset defines the VRAM used in the training. Hence, pay attention to
                this value if you encounter an OOM error in training. Defaults to float("inf").

            phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a
                separate file. Defaults to None.

            precompute_num_workers (int): Number of workers to precompute features. Defaults to 0.

            speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the
                embedding layer. Defaults to None.

            d_vector_mapping (dict): Mapping of wav files to computed d-vectors. Defaults to None.

            use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.

            start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False.

            verbose (bool): Print diagnostic information. Defaults to false.
        """
        super().__init__()
        self.batch_group_size = batch_group_size
        self._samples = samples
        self.outputs_per_step = outputs_per_step
        self.compute_linear_spec = compute_linear_spec
        self.return_wav = return_wav
        self.compute_f0 = compute_f0
        self.compute_energy = compute_energy
        self.f0_cache_path = f0_cache_path
        self.energy_cache_path = energy_cache_path
        self.min_audio_len = min_audio_len
        self.max_audio_len = max_audio_len
        self.min_text_len = min_text_len
        self.max_text_len = max_text_len
        self.ap = ap
        self.phoneme_cache_path = phoneme_cache_path
        self.speaker_id_mapping = speaker_id_mapping
        self.d_vector_mapping = d_vector_mapping
        self.language_id_mapping = language_id_mapping
        self.use_noise_augment = use_noise_augment
        self.start_by_longest = start_by_longest

        self.verbose = verbose
        self.rescue_item_idx = 1
        self.pitch_computed = False
        self.tokenizer = tokenizer

        if self.tokenizer.use_phonemes:
            self.phoneme_dataset = PhonemeDataset(
                self.samples, self.tokenizer, phoneme_cache_path, precompute_num_workers=precompute_num_workers
            )

        if compute_f0:
            self.f0_dataset = F0Dataset(
                self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
            )
        if compute_energy:
            self.energy_dataset = EnergyDataset(
                self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers
            )
        if self.verbose:
            self.print_logs()

    @property
    def lengths(self):
        lens = []
        for item in self.samples:
            _, wav_file, *_ = _parse_sample(item)
            audio_len = os.path.getsize(wav_file) / 16 * 8  # assuming 16bit audio
            lens.append(audio_len)
        return lens

    @property
    def samples(self):
        return self._samples

    @samples.setter
    def samples(self, new_samples):
        self._samples = new_samples
        if hasattr(self, "f0_dataset"):
            self.f0_dataset.samples = new_samples
        if hasattr(self, "energy_dataset"):
            self.energy_dataset.samples = new_samples
        if hasattr(self, "phoneme_dataset"):
            self.phoneme_dataset.samples = new_samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.load_data(idx)

    def print_logs(self, level: int = 0) -> None:
        indent = "\t" * level
        print("\n")
        print(f"{indent}> DataLoader initialization")
        print(f"{indent}| > Tokenizer:")
        self.tokenizer.print_logs(level + 1)
        print(f"{indent}| > Number of instances : {len(self.samples)}")

    def load_wav(self, filename):
        waveform = self.ap.load_wav(filename)
        assert waveform.size > 0
        return waveform

    def get_phonemes(self, idx, text):
        out_dict = self.phoneme_dataset[idx]
        assert text == out_dict["text"], f"{text} != {out_dict['text']}"
        assert len(out_dict["token_ids"]) > 0
        return out_dict

    def get_f0(self, idx):
        out_dict = self.f0_dataset[idx]
        item = self.samples[idx]
        assert item["audio_unique_name"] == out_dict["audio_unique_name"]
        return out_dict

    def get_energy(self, idx):
        out_dict = self.energy_dataset[idx]
        item = self.samples[idx]
        assert item["audio_unique_name"] == out_dict["audio_unique_name"]
        return out_dict

    @staticmethod
    def get_attn_mask(attn_file):
        return np.load(attn_file)

    def get_token_ids(self, idx, text):
        if self.tokenizer.use_phonemes:
            token_ids = self.get_phonemes(idx, text)["token_ids"]
        else:
            token_ids = self.tokenizer.text_to_ids(text)
        return np.array(token_ids, dtype=np.int32)

    def load_data(self, idx):
        item = self.samples[idx]

        raw_text = item["text"]

        wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32)

        # apply noise for augmentation
        if self.use_noise_augment:
            wav = noise_augment_audio(wav)

        # get token ids
        token_ids = self.get_token_ids(idx, item["text"])

        # get pre-computed attention maps
        attn = None
        if "alignment_file" in item:
            attn = self.get_attn_mask(item["alignment_file"])

        # after phonemization the text length may change
        # this is a shareful 🤭 hack to prevent longer phonemes
        # TODO: find a better fix
        if len(token_ids) > self.max_text_len or len(wav) < self.min_audio_len:
            self.rescue_item_idx += 1
            return self.load_data(self.rescue_item_idx)

        # get f0 values
        f0 = None
        if self.compute_f0:
            f0 = self.get_f0(idx)["f0"]
        energy = None
        if self.compute_energy:
            energy = self.get_energy(idx)["energy"]

        sample = {
            "raw_text": raw_text,
            "token_ids": token_ids,
            "wav": wav,
            "pitch": f0,
            "energy": energy,
            "attn": attn,
            "item_idx": item["audio_file"],
            "speaker_name": item["speaker_name"],
            "language_name": item["language"],
            "wav_file_name": os.path.basename(item["audio_file"]),
            "audio_unique_name": item["audio_unique_name"],
        }
        return sample

    @staticmethod
    def _compute_lengths(samples):
        new_samples = []
        for item in samples:
            audio_length = os.path.getsize(item["audio_file"]) / 16 * 8  # assuming 16bit audio
            text_lenght = len(item["text"])
            item["audio_length"] = audio_length
            item["text_length"] = text_lenght
            new_samples += [item]
        return new_samples

    @staticmethod
    def filter_by_length(lengths: List[int], min_len: int, max_len: int):
        idxs = np.argsort(lengths)  # ascending order
        ignore_idx = []
        keep_idx = []
        for idx in idxs:
            length = lengths[idx]
            if length < min_len or length > max_len:
                ignore_idx.append(idx)
            else:
                keep_idx.append(idx)
        return ignore_idx, keep_idx

    @staticmethod
    def sort_by_length(samples: List[List]):
        audio_lengths = [s["audio_length"] for s in samples]
        idxs = np.argsort(audio_lengths)  # ascending order
        return idxs

    @staticmethod
    def create_buckets(samples, batch_group_size: int):
        assert batch_group_size > 0
        for i in range(len(samples) // batch_group_size):
            offset = i * batch_group_size
            end_offset = offset + batch_group_size
            temp_items = samples[offset:end_offset]
            random.shuffle(temp_items)
            samples[offset:end_offset] = temp_items
        return samples

    @staticmethod
    def _select_samples_by_idx(idxs, samples):
        samples_new = []
        for idx in idxs:
            samples_new.append(samples[idx])
        return samples_new

    def preprocess_samples(self):
        r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length
        range.
        """
        samples = self._compute_lengths(self.samples)

        # sort items based on the sequence length in ascending order
        text_lengths = [i["text_length"] for i in samples]
        audio_lengths = [i["audio_length"] for i in samples]
        text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len)
        audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len)
        keep_idx = list(set(audio_keep_idx) & set(text_keep_idx))
        ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx))

        samples = self._select_samples_by_idx(keep_idx, samples)

        sorted_idxs = self.sort_by_length(samples)

        if self.start_by_longest:
            longest_idxs = sorted_idxs[-1]
            sorted_idxs[-1] = sorted_idxs[0]
            sorted_idxs[0] = longest_idxs

        samples = self._select_samples_by_idx(sorted_idxs, samples)

        if len(samples) == 0:
            raise RuntimeError(" [!] No samples left")

        # shuffle batch groups
        # create batches with similar length items
        # the larger the `batch_group_size`, the higher the length variety in a batch.
        if self.batch_group_size > 0:
            samples = self.create_buckets(samples, self.batch_group_size)

        # update items to the new sorted items
        audio_lengths = [s["audio_length"] for s in samples]
        text_lengths = [s["text_length"] for s in samples]
        self.samples = samples

        if self.verbose:
            print(" | > Preprocessing samples")
            print(" | > Max text length: {}".format(np.max(text_lengths)))
            print(" | > Min text length: {}".format(np.min(text_lengths)))
            print(" | > Avg text length: {}".format(np.mean(text_lengths)))
            print(" | ")
            print(" | > Max audio length: {}".format(np.max(audio_lengths)))
            print(" | > Min audio length: {}".format(np.min(audio_lengths)))
            print(" | > Avg audio length: {}".format(np.mean(audio_lengths)))
            print(f" | > Num. instances discarded samples: {len(ignore_idx)}")
            print(" | > Batch group size: {}.".format(self.batch_group_size))

    @staticmethod
    def _sort_batch(batch, text_lengths):
        """Sort the batch by the input text length for RNN efficiency.

        Args:
            batch (Dict): Batch returned by `__getitem__`.
            text_lengths (List[int]): Lengths of the input character sequences.
        """
        text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True)
        batch = [batch[idx] for idx in ids_sorted_decreasing]
        return batch, text_lengths, ids_sorted_decreasing

    def collate_fn(self, batch):
        r"""
        Perform preprocessing and create a final data batch:
        1. Sort batch instances by text-length
        2. Convert Audio signal to features.
        3. PAD sequences wrt r.
        4. Load to Torch.
        """

        # Puts each data field into a tensor with outer dimension batch size
        if isinstance(batch[0], collections.abc.Mapping):
            token_ids_lengths = np.array([len(d["token_ids"]) for d in batch])

            # sort items with text input length for RNN efficiency
            batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths)

            # convert list of dicts to dict of lists
            batch = {k: [dic[k] for dic in batch] for k in batch[0]}

            # get language ids from language names
            if self.language_id_mapping is not None:
                language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]]
            else:
                language_ids = None
            # get pre-computed d-vectors
            if self.d_vector_mapping is not None:
                embedding_keys = list(batch["audio_unique_name"])
                d_vectors = [self.d_vector_mapping[w]["embedding"] for w in embedding_keys]
            else:
                d_vectors = None

            # get numerical speaker ids from speaker names
            if self.speaker_id_mapping:
                speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]]
            else:
                speaker_ids = None
            # compute features
            mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]]

            mel_lengths = [m.shape[1] for m in mel]

            # lengths adjusted by the reduction factor
            mel_lengths_adjusted = [
                m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step))
                if m.shape[1] % self.outputs_per_step
                else m.shape[1]
                for m in mel
            ]

            # compute 'stop token' targets
            stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths]

            # PAD stop targets
            stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)

            # PAD sequences with longest instance in the batch
            token_ids = prepare_data(batch["token_ids"]).astype(np.int32)

            # PAD features with longest instance
            mel = prepare_tensor(mel, self.outputs_per_step)

            # B x D x T --> B x T x D
            mel = mel.transpose(0, 2, 1)

            # convert things to pytorch
            token_ids_lengths = torch.LongTensor(token_ids_lengths)
            token_ids = torch.LongTensor(token_ids)
            mel = torch.FloatTensor(mel).contiguous()
            mel_lengths = torch.LongTensor(mel_lengths)
            stop_targets = torch.FloatTensor(stop_targets)

            # speaker vectors
            if d_vectors is not None:
                d_vectors = torch.FloatTensor(d_vectors)

            if speaker_ids is not None:
                speaker_ids = torch.LongTensor(speaker_ids)

            if language_ids is not None:
                language_ids = torch.LongTensor(language_ids)

            # compute linear spectrogram
            linear = None
            if self.compute_linear_spec:
                linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]]
                linear = prepare_tensor(linear, self.outputs_per_step)
                linear = linear.transpose(0, 2, 1)
                assert mel.shape[1] == linear.shape[1]
                linear = torch.FloatTensor(linear).contiguous()

            # format waveforms
            wav_padded = None
            if self.return_wav:
                wav_lengths = [w.shape[0] for w in batch["wav"]]
                max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length
                wav_lengths = torch.LongTensor(wav_lengths)
                wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len)
                for i, w in enumerate(batch["wav"]):
                    mel_length = mel_lengths_adjusted[i]
                    w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge")
                    w = w[: mel_length * self.ap.hop_length]
                    wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w)
                wav_padded.transpose_(1, 2)

            # format F0
            if self.compute_f0:
                pitch = prepare_data(batch["pitch"])
                assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}"
                pitch = torch.FloatTensor(pitch)[:, None, :].contiguous()  # B x 1 xT
            else:
                pitch = None
            # format energy
            if self.compute_energy:
                energy = prepare_data(batch["energy"])
                assert mel.shape[1] == energy.shape[1], f"[!] {mel.shape} vs {energy.shape}"
                energy = torch.FloatTensor(energy)[:, None, :].contiguous()  # B x 1 xT
            else:
                energy = None
            # format attention masks
            attns = None
            if batch["attn"][0] is not None:
                attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing]
                for idx, attn in enumerate(attns):
                    pad2 = mel.shape[1] - attn.shape[1]
                    pad1 = token_ids.shape[1] - attn.shape[0]
                    assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}"
                    attn = np.pad(attn, [[0, pad1], [0, pad2]])
                    attns[idx] = attn
                attns = prepare_tensor(attns, self.outputs_per_step)
                attns = torch.FloatTensor(attns).unsqueeze(1)

            return {
                "token_id": token_ids,
                "token_id_lengths": token_ids_lengths,
                "speaker_names": batch["speaker_name"],
                "linear": linear,
                "mel": mel,
                "mel_lengths": mel_lengths,
                "stop_targets": stop_targets,
                "item_idxs": batch["item_idx"],
                "d_vectors": d_vectors,
                "speaker_ids": speaker_ids,
                "attns": attns,
                "waveform": wav_padded,
                "raw_text": batch["raw_text"],
                "pitch": pitch,
                "energy": energy,
                "language_ids": language_ids,
                "audio_unique_names": batch["audio_unique_name"],
            }

        raise TypeError(
            (
                "batch must contain tensors, numbers, dicts or lists;\
                         found {}".format(
                    type(batch[0])
                )
            )
        )


class PhonemeDataset(Dataset):
    """Phoneme Dataset for converting input text to phonemes and then token IDs

    At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data
    loading latency. If `cache_path` is already present, it skips the pre-computation.

    Args:
        samples (Union[List[List], List[Dict]]):
            List of samples. Each sample is a list or a dict.

        tokenizer (TTSTokenizer):
            Tokenizer to convert input text to phonemes.

        cache_path (str):
            Path to cache phonemes. If `cache_path` is already present or None, it skips the pre-computation.

        precompute_num_workers (int):
            Number of workers used for pre-computing the phonemes. Defaults to 0.
    """

    def __init__(
        self,
        samples: Union[List[Dict], List[List]],
        tokenizer: "TTSTokenizer",
        cache_path: str,
        precompute_num_workers=0,
    ):
        self.samples = samples
        self.tokenizer = tokenizer
        self.cache_path = cache_path
        if cache_path is not None and not os.path.exists(cache_path):
            os.makedirs(cache_path)
            self.precompute(precompute_num_workers)

    def __getitem__(self, index):
        item = self.samples[index]
        ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"], item["language"])
        ph_hat = self.tokenizer.ids_to_text(ids)
        return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)}

    def __len__(self):
        return len(self.samples)

    def compute_or_load(self, file_name, text, language):
        """Compute phonemes for the given text.

        If the phonemes are already cached, load them from cache.
        """
        file_ext = "_phoneme.npy"
        cache_path = os.path.join(self.cache_path, file_name + file_ext)
        try:
            ids = np.load(cache_path)
        except FileNotFoundError:
            ids = self.tokenizer.text_to_ids(text, language=language)
            np.save(cache_path, ids)
        return ids

    def get_pad_id(self):
        """Get pad token ID for sequence padding"""
        return self.tokenizer.pad_id

    def precompute(self, num_workers=1):
        """Precompute phonemes for all samples.

        We use pytorch dataloader because we are lazy.
        """
        print("[*] Pre-computing phonemes...")
        with tqdm.tqdm(total=len(self)) as pbar:
            batch_size = num_workers if num_workers > 0 else 1
            dataloder = torch.utils.data.DataLoader(
                batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
            )
            for _ in dataloder:
                pbar.update(batch_size)

    def collate_fn(self, batch):
        ids = [item["token_ids"] for item in batch]
        ids_lens = [item["token_ids_len"] for item in batch]
        texts = [item["text"] for item in batch]
        texts_hat = [item["ph_hat"] for item in batch]
        ids_lens_max = max(ids_lens)
        ids_torch = torch.LongTensor(len(ids), ids_lens_max).fill_(self.get_pad_id())
        for i, ids_len in enumerate(ids_lens):
            ids_torch[i, :ids_len] = torch.LongTensor(ids[i])
        return {"text": texts, "ph_hat": texts_hat, "token_ids": ids_torch}

    def print_logs(self, level: int = 0) -> None:
        indent = "\t" * level
        print("\n")
        print(f"{indent}> PhonemeDataset ")
        print(f"{indent}| > Tokenizer:")
        self.tokenizer.print_logs(level + 1)
        print(f"{indent}| > Number of instances : {len(self.samples)}")


class F0Dataset:
    """F0 Dataset for computing F0 from wav files in CPU

    Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It
    also computes the mean and std of F0 values if `normalize_f0` is True.

    Args:
        samples (Union[List[List], List[Dict]]):
            List of samples. Each sample is a list or a dict.

        ap (AudioProcessor):
            AudioProcessor to compute F0 from wav files.

        cache_path (str):
            Path to cache F0 values. If `cache_path` is already present or None, it skips the pre-computation.
            Defaults to None.

        precompute_num_workers (int):
            Number of workers used for pre-computing the F0 values. Defaults to 0.

        normalize_f0 (bool):
            Whether to normalize F0 values by mean and std. Defaults to True.
    """

    def __init__(
        self,
        samples: Union[List[List], List[Dict]],
        ap: "AudioProcessor",
        audio_config=None,  # pylint: disable=unused-argument
        verbose=False,
        cache_path: str = None,
        precompute_num_workers=0,
        normalize_f0=True,
    ):
        self.samples = samples
        self.ap = ap
        self.verbose = verbose
        self.cache_path = cache_path
        self.normalize_f0 = normalize_f0
        self.pad_id = 0.0
        self.mean = None
        self.std = None
        if cache_path is not None and not os.path.exists(cache_path):
            os.makedirs(cache_path)
            self.precompute(precompute_num_workers)
        if normalize_f0:
            self.load_stats(cache_path)

    def __getitem__(self, idx):
        item = self.samples[idx]
        f0 = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"]))
        if self.normalize_f0:
            assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
            f0 = self.normalize(f0)
        return {"audio_unique_name": item["audio_unique_name"], "f0": f0}

    def __len__(self):
        return len(self.samples)

    def precompute(self, num_workers=0):
        print("[*] Pre-computing F0s...")
        with tqdm.tqdm(total=len(self)) as pbar:
            batch_size = num_workers if num_workers > 0 else 1
            # we do not normalize at preproessing
            normalize_f0 = self.normalize_f0
            self.normalize_f0 = False
            dataloder = torch.utils.data.DataLoader(
                batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
            )
            computed_data = []
            for batch in dataloder:
                f0 = batch["f0"]
                computed_data.append(f for f in f0)
                pbar.update(batch_size)
            self.normalize_f0 = normalize_f0

        if self.normalize_f0:
            computed_data = [tensor for batch in computed_data for tensor in batch]  # flatten
            pitch_mean, pitch_std = self.compute_pitch_stats(computed_data)
            pitch_stats = {"mean": pitch_mean, "std": pitch_std}
            np.save(os.path.join(self.cache_path, "pitch_stats"), pitch_stats, allow_pickle=True)

    def get_pad_id(self):
        return self.pad_id

    @staticmethod
    def create_pitch_file_path(file_name, cache_path):
        pitch_file = os.path.join(cache_path, file_name + "_pitch.npy")
        return pitch_file

    @staticmethod
    def _compute_and_save_pitch(ap, wav_file, pitch_file=None):
        wav = ap.load_wav(wav_file)
        pitch = ap.compute_f0(wav)
        if pitch_file:
            np.save(pitch_file, pitch)
        return pitch

    @staticmethod
    def compute_pitch_stats(pitch_vecs):
        nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs])
        mean, std = np.mean(nonzeros), np.std(nonzeros)
        return mean, std

    def load_stats(self, cache_path):
        stats_path = os.path.join(cache_path, "pitch_stats.npy")
        stats = np.load(stats_path, allow_pickle=True).item()
        self.mean = stats["mean"].astype(np.float32)
        self.std = stats["std"].astype(np.float32)

    def normalize(self, pitch):
        zero_idxs = np.where(pitch == 0.0)[0]
        pitch = pitch - self.mean
        pitch = pitch / self.std
        pitch[zero_idxs] = 0.0
        return pitch

    def denormalize(self, pitch):
        zero_idxs = np.where(pitch == 0.0)[0]
        pitch *= self.std
        pitch += self.mean
        pitch[zero_idxs] = 0.0
        return pitch

    def compute_or_load(self, wav_file, audio_unique_name):
        """
        compute pitch and return a numpy array of pitch values
        """
        pitch_file = self.create_pitch_file_path(audio_unique_name, self.cache_path)
        if not os.path.exists(pitch_file):
            pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file)
        else:
            pitch = np.load(pitch_file)
        return pitch.astype(np.float32)

    def collate_fn(self, batch):
        audio_unique_name = [item["audio_unique_name"] for item in batch]
        f0s = [item["f0"] for item in batch]
        f0_lens = [len(item["f0"]) for item in batch]
        f0_lens_max = max(f0_lens)
        f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id())
        for i, f0_len in enumerate(f0_lens):
            f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i])
        return {"audio_unique_name": audio_unique_name, "f0": f0s_torch, "f0_lens": f0_lens}

    def print_logs(self, level: int = 0) -> None:
        indent = "\t" * level
        print("\n")
        print(f"{indent}> F0Dataset ")
        print(f"{indent}| > Number of instances : {len(self.samples)}")


class EnergyDataset:
    """Energy Dataset for computing Energy from wav files in CPU

    Pre-compute Energy values for all the samples at initialization if `cache_path` is not None or already present. It
    also computes the mean and std of Energy values if `normalize_Energy` is True.

    Args:
        samples (Union[List[List], List[Dict]]):
            List of samples. Each sample is a list or a dict.

        ap (AudioProcessor):
            AudioProcessor to compute Energy from wav files.

        cache_path (str):
            Path to cache Energy values. If `cache_path` is already present or None, it skips the pre-computation.
            Defaults to None.

        precompute_num_workers (int):
            Number of workers used for pre-computing the Energy values. Defaults to 0.

        normalize_Energy (bool):
            Whether to normalize Energy values by mean and std. Defaults to True.
    """

    def __init__(
        self,
        samples: Union[List[List], List[Dict]],
        ap: "AudioProcessor",
        verbose=False,
        cache_path: str = None,
        precompute_num_workers=0,
        normalize_energy=True,
    ):
        self.samples = samples
        self.ap = ap
        self.verbose = verbose
        self.cache_path = cache_path
        self.normalize_energy = normalize_energy
        self.pad_id = 0.0
        self.mean = None
        self.std = None
        if cache_path is not None and not os.path.exists(cache_path):
            os.makedirs(cache_path)
            self.precompute(precompute_num_workers)
        if normalize_energy:
            self.load_stats(cache_path)

    def __getitem__(self, idx):
        item = self.samples[idx]
        energy = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"]))
        if self.normalize_energy:
            assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
            energy = self.normalize(energy)
        return {"audio_unique_name": item["audio_unique_name"], "energy": energy}

    def __len__(self):
        return len(self.samples)

    def precompute(self, num_workers=0):
        print("[*] Pre-computing energys...")
        with tqdm.tqdm(total=len(self)) as pbar:
            batch_size = num_workers if num_workers > 0 else 1
            # we do not normalize at preproessing
            normalize_energy = self.normalize_energy
            self.normalize_energy = False
            dataloder = torch.utils.data.DataLoader(
                batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
            )
            computed_data = []
            for batch in dataloder:
                energy = batch["energy"]
                computed_data.append(e for e in energy)
                pbar.update(batch_size)
            self.normalize_energy = normalize_energy

        if self.normalize_energy:
            computed_data = [tensor for batch in computed_data for tensor in batch]  # flatten
            energy_mean, energy_std = self.compute_energy_stats(computed_data)
            energy_stats = {"mean": energy_mean, "std": energy_std}
            np.save(os.path.join(self.cache_path, "energy_stats"), energy_stats, allow_pickle=True)

    def get_pad_id(self):
        return self.pad_id

    @staticmethod
    def create_energy_file_path(wav_file, cache_path):
        file_name = os.path.splitext(os.path.basename(wav_file))[0]
        energy_file = os.path.join(cache_path, file_name + "_energy.npy")
        return energy_file

    @staticmethod
    def _compute_and_save_energy(ap, wav_file, energy_file=None):
        wav = ap.load_wav(wav_file)
        energy = calculate_energy(wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length)
        if energy_file:
            np.save(energy_file, energy)
        return energy

    @staticmethod
    def compute_energy_stats(energy_vecs):
        nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in energy_vecs])
        mean, std = np.mean(nonzeros), np.std(nonzeros)
        return mean, std

    def load_stats(self, cache_path):
        stats_path = os.path.join(cache_path, "energy_stats.npy")
        stats = np.load(stats_path, allow_pickle=True).item()
        self.mean = stats["mean"].astype(np.float32)
        self.std = stats["std"].astype(np.float32)

    def normalize(self, energy):
        zero_idxs = np.where(energy == 0.0)[0]
        energy = energy - self.mean
        energy = energy / self.std
        energy[zero_idxs] = 0.0
        return energy

    def denormalize(self, energy):
        zero_idxs = np.where(energy == 0.0)[0]
        energy *= self.std
        energy += self.mean
        energy[zero_idxs] = 0.0
        return energy

    def compute_or_load(self, wav_file, audio_unique_name):
        """
        compute energy and return a numpy array of energy values
        """
        energy_file = self.create_energy_file_path(audio_unique_name, self.cache_path)
        if not os.path.exists(energy_file):
            energy = self._compute_and_save_energy(self.ap, wav_file, energy_file)
        else:
            energy = np.load(energy_file)
        return energy.astype(np.float32)

    def collate_fn(self, batch):
        audio_unique_name = [item["audio_unique_name"] for item in batch]
        energys = [item["energy"] for item in batch]
        energy_lens = [len(item["energy"]) for item in batch]
        energy_lens_max = max(energy_lens)
        energys_torch = torch.LongTensor(len(energys), energy_lens_max).fill_(self.get_pad_id())
        for i, energy_len in enumerate(energy_lens):
            energys_torch[i, :energy_len] = torch.LongTensor(energys[i])
        return {"audio_unique_name": audio_unique_name, "energy": energys_torch, "energy_lens": energy_lens}

    def print_logs(self, level: int = 0) -> None:
        indent = "\t" * level
        print("\n")
        print(f"{indent}> energyDataset ")
        print(f"{indent}| > Number of instances : {len(self.samples)}")