import os
import time
import random
from tqdm import tqdm
import argparse

import torch
import torchaudio
from accelerate import Accelerator
from einops import rearrange
from ema_pytorch import EMA
from vocos import Vocos

from model import CFM, UNetT, DiT
from model.utils import (
    get_tokenizer, 
    get_seedtts_testset_metainfo, 
    get_librispeech_test_clean_metainfo, 
    get_inference_prompt,
)

accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"


# --------------------- Dataset Settings -------------------- #

target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
target_rms = 0.1

tokenizer = "pinyin"


# ---------------------- infer setting ---------------------- #

parser = argparse.ArgumentParser(description="batch inference")

parser.add_argument('-s', '--seed', default=None, type=int)
parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
parser.add_argument('-n', '--expname', required=True)
parser.add_argument('-c', '--ckptstep', default=1200000, type=int)

parser.add_argument('-nfe', '--nfestep', default=32, type=int)
parser.add_argument('-o', '--odemethod', default="euler")
parser.add_argument('-ss', '--swaysampling', default=-1, type=float)

parser.add_argument('-t', '--testset', required=True)

args = parser.parse_args()


seed = args.seed
dataset_name = args.dataset
exp_name = args.expname
ckpt_step = args.ckptstep
checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)

nfe_step = args.nfestep
ode_method = args.odemethod
sway_sampling_coef = args.swaysampling

testset = args.testset


infer_batch_size = 1  # max frames. 1 for ddp single inference (recommended)
cfg_strength = 2.
speed = 1.
use_truth_duration = False
no_ref_audio = False


if exp_name == "F5TTS_Base":
    model_cls = DiT
    model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)

elif exp_name == "E2TTS_Base":
    model_cls = UNetT
    model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)


if testset == "ls_pc_test_clean":
    metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
    librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean"  # test-clean path
    metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
    
elif testset == "seedtts_test_zh":
    metalst = "data/seedtts_testset/zh/meta.lst"
    metainfo = get_seedtts_testset_metainfo(metalst)

elif testset == "seedtts_test_en":
    metalst = "data/seedtts_testset/en/meta.lst"
    metainfo = get_seedtts_testset_metainfo(metalst)


# path to save genereted wavs
if seed is None: seed = random.randint(-10000, 10000)
output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
    f"seed{seed}_{ode_method}_nfe{nfe_step}" \
    f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
    f"_cfg{cfg_strength}_speed{speed}" \
    f"{'_gt-dur' if use_truth_duration else ''}" \
    f"{'_no-ref-audio' if no_ref_audio else ''}"


# -------------------------------------------------#

use_ema = True

prompts_all = get_inference_prompt(
    metainfo, 
    speed = speed, 
    tokenizer = tokenizer, 
    target_sample_rate = target_sample_rate, 
    n_mel_channels = n_mel_channels,
    hop_length = hop_length,
    target_rms = target_rms,
    use_truth_duration = use_truth_duration,
    infer_batch_size = infer_batch_size,
)

# Vocoder model
local = False
if local:
    vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
    vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
    state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
    vocos.load_state_dict(state_dict)
    vocos.eval()
else:
    vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")

# Tokenizer
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)

# Model
model = CFM(
    transformer = model_cls(
        **model_cfg,
        text_num_embeds = vocab_size, 
        mel_dim = n_mel_channels
    ),
    mel_spec_kwargs = dict(
        target_sample_rate = target_sample_rate, 
        n_mel_channels = n_mel_channels,
        hop_length = hop_length,
    ),
    odeint_kwargs = dict(
        method = ode_method,
    ),
    vocab_char_map = vocab_char_map,
).to(device)

if use_ema == True:
    ema_model = EMA(model, include_online_model = False).to(device)
    ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
    ema_model.copy_params_from_ema_to_model()
else:
    model.load_state_dict(checkpoint['model_state_dict'])

if not os.path.exists(output_dir) and accelerator.is_main_process:
    os.makedirs(output_dir)

# start batch inference
accelerator.wait_for_everyone()
start = time.time()

with accelerator.split_between_processes(prompts_all) as prompts:

    for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
        utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
        ref_mels = ref_mels.to(device)
        ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
        total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
        
        # Inference
        with torch.inference_mode():
            generated, _ = model.sample(
                cond = ref_mels,
                text = final_text_list,
                duration = total_mel_lens,
                lens = ref_mel_lens,
                steps = nfe_step,
                cfg_strength = cfg_strength,
                sway_sampling_coef = sway_sampling_coef,
                no_ref_audio = no_ref_audio,
                seed = seed,
            )
        # Final result
        for i, gen in enumerate(generated):
            gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
            gen_mel_spec = rearrange(gen, '1 n d -> 1 d n')
            generated_wave = vocos.decode(gen_mel_spec.cpu())
            if ref_rms_list[i] < target_rms:
                generated_wave = generated_wave * ref_rms_list[i] / target_rms
            torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)

accelerator.wait_for_everyone()
if accelerator.is_main_process:
    timediff = time.time() - start
    print(f"Done batch inference in {timediff / 60 :.2f} minutes.")