# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
from tqdm import tqdm
import glob
import json
import torchaudio

from utils.util import has_existed
from utils.io import save_audio


def get_splitted_utterances(
    raw_wav_dir, trimed_wav_dir, n_utterance_splits, overlapping
):
    res = []
    raw_song_files = glob.glob(
        os.path.join(raw_wav_dir, "**/pjs*_song.wav"), recursive=True
    )
    trimed_song_files = glob.glob(
        os.path.join(trimed_wav_dir, "**/*.wav"), recursive=True
    )

    if len(raw_song_files) * n_utterance_splits == len(trimed_song_files):
        print("Splitted done...")
        for wav_file in tqdm(trimed_song_files):
            uid = wav_file.split("/")[-1].split(".")[0]
            utt = {"Dataset": "pjs", "Singer": "male1", "Uid": uid, "Path": wav_file}

            waveform, sample_rate = torchaudio.load(wav_file)
            duration = waveform.size(-1) / sample_rate
            utt["Duration"] = duration

            res.append(utt)

    else:
        for wav_file in tqdm(raw_song_files):
            song_id = wav_file.split("/")[-1].split(".")[0]

            waveform, sample_rate = torchaudio.load(wav_file)
            trimed_waveform = torchaudio.functional.vad(waveform, sample_rate)
            trimed_waveform = torchaudio.functional.vad(
                trimed_waveform.flip(dims=[1]), sample_rate
            ).flip(dims=[1])

            audio_len = trimed_waveform.size(-1)
            lapping_len = overlapping * sample_rate

            for i in range(n_utterance_splits):
                start = i * audio_len // 3
                end = start + audio_len // 3 + lapping_len
                splitted_waveform = trimed_waveform[:, start:end]

                utt = {
                    "Dataset": "pjs",
                    "Singer": "male1",
                    "Uid": "{}_{}".format(song_id, i),
                }

                # Duration
                duration = splitted_waveform.size(-1) / sample_rate
                utt["Duration"] = duration

                # Save trimed wav
                splitted_waveform_file = os.path.join(
                    trimed_wav_dir, "{}.wav".format(utt["Uid"])
                )
                save_audio(splitted_waveform_file, splitted_waveform, sample_rate)

                # Path
                utt["Path"] = splitted_waveform_file

                res.append(utt)

    res = sorted(res, key=lambda x: x["Uid"])
    return res


def main(output_path, dataset_path, n_utterance_splits=3, overlapping=1):
    """
    1. Split one raw utterance to three splits (since some samples are too long)
    2. Overlapping of ajacent splits is 1 s
    """
    print("-" * 10)
    print("Preparing training dataset for PJS...")

    save_dir = os.path.join(output_path, "pjs")
    raw_wav_dir = os.path.join(dataset_path, "PJS_corpus_ver1.1")

    # Trim for silence
    trimed_wav_dir = os.path.join(dataset_path, "trim")
    os.makedirs(trimed_wav_dir, exist_ok=True)

    # Total utterances
    utterances = get_splitted_utterances(
        raw_wav_dir, trimed_wav_dir, n_utterance_splits, overlapping
    )
    total_uids = [utt["Uid"] for utt in utterances]

    # Test uids
    n_test_songs = 3
    test_uids = []
    for i in range(1, n_test_songs + 1):
        test_uids += [
            "pjs00{}_song_{}".format(i, split_id)
            for split_id in range(n_utterance_splits)
        ]

    # Train uids
    train_uids = [uid for uid in total_uids if uid not in test_uids]

    for dataset_type in ["train", "test"]:
        output_file = os.path.join(save_dir, "{}.json".format(dataset_type))
        if has_existed(output_file):
            continue

        uids = eval("{}_uids".format(dataset_type))
        res = [utt for utt in utterances if utt["Uid"] in uids]
        for i in range(len(res)):
            res[i]["index"] = i

        time = sum([utt["Duration"] for utt in res])
        print(
            "{}, Total size: {}, Total Duraions = {} s = {:.2f} hour\n".format(
                dataset_type, len(res), time, time / 3600
            )
        )

        # Save
        os.makedirs(save_dir, exist_ok=True)
        with open(output_file, "w") as f:
            json.dump(res, f, indent=4, ensure_ascii=False)