import os
from glob import glob
from logging import getLogger
from typing import Literal, Optional, Tuple
from pathlib import Path
from threading import Thread
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from accelerate import Accelerator
from datasets import Dataset
from .pretrained import pretrained_checkpoints
from .constants import *
from torch.utils.tensorboard import SummaryWriter
import time
from tqdm.auto import tqdm
from huggingface_hub import HfApi, upload_folder

from .synthesizer import commons
from .synthesizer.models import (
    SynthesizerTrnMs768NSFsid,
    MultiPeriodDiscriminator,
)

from .utils.losses import (
    discriminator_loss,
    feature_loss,
    generator_loss,
    kl_loss,
)
from .utils.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from .utils.data_utils import TextAudioCollateMultiNSFsid

logger = getLogger(__name__)


class TrainingCheckpoint:
    def __init__(
        self,
        epoch: int,
        G: SynthesizerTrnMs768NSFsid,
        D: MultiPeriodDiscriminator,
        optimizer_G: torch.optim.AdamW,
        optimizer_D: torch.optim.AdamW,
        scheduler_G: torch.optim.lr_scheduler.ExponentialLR,
        scheduler_D: torch.optim.lr_scheduler.ExponentialLR,
        loss_gen: float,
        loss_fm: float,
        loss_mel: float,
        loss_kl: float,
        loss_gen_all: float,
        loss_disc: float,
    ):
        self.epoch = epoch
        self.G = G
        self.D = D
        self.optimizer_G = optimizer_G
        self.optimizer_D = optimizer_D
        self.scheduler_G = scheduler_G
        self.scheduler_D = scheduler_D
        self.loss_gen = loss_gen
        self.loss_fm = loss_fm
        self.loss_mel = loss_mel
        self.loss_kl = loss_kl
        self.loss_gen_all = loss_gen_all
        self.loss_disc = loss_disc

    def save(
        self,
        exp_dir="./",
        g_checkpoint: str | None = None,
        d_checkpoint: str | None = None,
    ):
        g_path = g_checkpoint if g_checkpoint is not None else f"G_latest.pth"
        d_path = d_checkpoint if d_checkpoint is not None else f"D_latest.pth"
        torch.save(
            {
                "epoch": self.epoch,
                "model": self.G.state_dict(),
                "optimizer": self.optimizer_G.state_dict(),
                "scheduler": self.scheduler_G.state_dict(),
                "loss_gen": self.loss_gen,
                "loss_fm": self.loss_fm,
                "loss_mel": self.loss_mel,
                "loss_kl": self.loss_kl,
                "loss_gen_all": self.loss_gen_all,
                "loss_disc": self.loss_disc,
            },
            os.path.join(exp_dir, g_path),
        )
        torch.save(
            {
                "epoch": self.epoch,
                "model": self.D.state_dict(),
                "optimizer": self.optimizer_D.state_dict(),
                "scheduler": self.scheduler_D.state_dict(),
            },
            os.path.join(exp_dir, d_path),
        )


def latest_checkpoint_file(files: list[str]) -> str:
    try:
        return max(files, key=lambda x: int(Path(x).stem.split("_")[1]))
    except:
        return max(files, key=os.path.getctime)


class RVCTrainer:
    def __init__(
        self,
        exp_dir: str,
        dataset_train: Dataset,
        dataset_test: Optional[Dataset] = None,
        sr: int = SR_48K,
    ):
        self.exp_dir = exp_dir
        self.dataset_train = dataset_train
        self.dataset_test = dataset_test
        self.sr = sr
        self.writer = SummaryWriter(
            os.path.join(exp_dir, "logs", time.strftime("%Y%m%d-%H%M%S"))
        )

    def latest_checkpoint(self, fallback_to_pretrained: bool = True):
        files_g = glob(os.path.join(self.exp_dir, "G_*.pth"))
        if not files_g:
            return pretrained_checkpoints() if fallback_to_pretrained else None
        latest_g = latest_checkpoint_file(files_g)

        files_d = glob(os.path.join(self.exp_dir, "D_*.pth"))
        if not files_d:
            return pretrained_checkpoints() if fallback_to_pretrained else None
        latest_d = latest_checkpoint_file(files_d)

        return latest_g, latest_d

    def setup_models(
        self,
        resume_from: Tuple[str, str] | None = None,
        accelerator: Accelerator | None = None,
        lr=1e-4,
        lr_decay=0.999875,
        betas: Tuple[float, float] = (0.8, 0.99),
        eps=1e-9,
        use_spectral_norm=False,
        segment_size=17280,
        filter_length=N_FFT,
        hop_length=HOP_LENGTH,
        inter_channels=192,
        hidden_channels=192,
        filter_channels=768,
        n_heads=2,
        n_layers=6,
        kernel_size=3,
        p_dropout=0.0,
        resblock: Literal["1", "2"] = "1",
        resblock_kernel_sizes: list[int] = [3, 7, 11],
        resblock_dilation_sizes: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
        upsample_initial_channel=512,
        upsample_rates: list[int] = [12, 10, 2, 2],
        upsample_kernel_sizes: list[int] = [24, 20, 4, 4],
        spk_embed_dim=109,
        gin_channels=256,
    ) -> Tuple[
        SynthesizerTrnMs768NSFsid,
        MultiPeriodDiscriminator,
        torch.optim.AdamW,
        torch.optim.AdamW,
        torch.optim.lr_scheduler.ExponentialLR,
        torch.optim.lr_scheduler.ExponentialLR,
        int,
    ]:
        if accelerator is None:
            accelerator = Accelerator()

        G = SynthesizerTrnMs768NSFsid(
            spec_channels=filter_length // 2 + 1,
            segment_size=segment_size // hop_length,
            inter_channels=inter_channels,
            hidden_channels=hidden_channels,
            filter_channels=filter_channels,
            n_heads=n_heads,
            n_layers=n_layers,
            kernel_size=kernel_size,
            p_dropout=p_dropout,
            resblock=resblock,
            resblock_kernel_sizes=resblock_kernel_sizes,
            resblock_dilation_sizes=resblock_dilation_sizes,
            upsample_initial_channel=upsample_initial_channel,
            upsample_rates=upsample_rates,
            upsample_kernel_sizes=upsample_kernel_sizes,
            spk_embed_dim=spk_embed_dim,
            gin_channels=gin_channels,
            sr=self.sr,
        ).to(accelerator.device)
        D = MultiPeriodDiscriminator(use_spectral_norm=use_spectral_norm).to(
            accelerator.device
        )

        optimizer_G = torch.optim.AdamW(
            G.parameters(),
            lr,
            betas=betas,
            eps=eps,
        )
        optimizer_D = torch.optim.AdamW(
            D.parameters(),
            lr,
            betas=betas,
            eps=eps,
        )

        if resume_from is not None:
            g_checkpoint, d_checkpoint = resume_from
            logger.info(f"Resuming from {g_checkpoint} and {d_checkpoint}")

            G_checkpoint = torch.load(
                g_checkpoint, map_location=accelerator.device, weights_only=True
            )
            D_checkpoint = torch.load(
                d_checkpoint, map_location=accelerator.device, weights_only=True
            )

            if "epoch" in G_checkpoint:
                finished_epoch = int(G_checkpoint["epoch"])
            try:
                finished_epoch = int(Path(g_checkpoint).stem.split("_")[1])
            except:
                finished_epoch = 0

            scheduler_G = torch.optim.lr_scheduler.ExponentialLR(
                optimizer_G, gamma=lr_decay, last_epoch=finished_epoch - 1
            )
            scheduler_D = torch.optim.lr_scheduler.ExponentialLR(
                optimizer_D, gamma=lr_decay, last_epoch=finished_epoch - 1
            )

            G.load_state_dict(G_checkpoint["model"])
            if "optimizer" in G_checkpoint:
                optimizer_G.load_state_dict(G_checkpoint["optimizer"])
            if "scheduler" in G_checkpoint:
                scheduler_G.load_state_dict(G_checkpoint["scheduler"])

            D.load_state_dict(D_checkpoint["model"])
            if "optimizer" in D_checkpoint:
                optimizer_D.load_state_dict(D_checkpoint["optimizer"])
            if "scheduler" in D_checkpoint:
                scheduler_D.load_state_dict(D_checkpoint["scheduler"])
        else:
            finished_epoch = 0
            scheduler_G = torch.optim.lr_scheduler.ExponentialLR(
                optimizer_G, gamma=lr_decay, last_epoch=-1
            )
            scheduler_D = torch.optim.lr_scheduler.ExponentialLR(
                optimizer_D, gamma=lr_decay, last_epoch=-1
            )

        G, D, optimizer_G, optimizer_D = accelerator.prepare(
            G, D, optimizer_G, optimizer_D
        )

        return G, D, optimizer_G, optimizer_D, scheduler_G, scheduler_D, finished_epoch

    def setup_dataloader(
        self,
        dataset: Dataset,
        batch_size=1,
        shuffle=True,
        accelerator: Accelerator | None = None,
    ):
        if accelerator is None:
            accelerator = Accelerator()

        dataset = dataset.with_format("torch", device=accelerator.device)
        loader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=TextAudioCollateMultiNSFsid(),
        )
        loader = accelerator.prepare(loader)
        return loader

    def run(
        self,
        G,
        D,
        optimizer_G,
        optimizer_D,
        scheduler_G,
        scheduler_D,
        finished_epoch,
        loader_train,
        loader_test,
        accelerator: Accelerator | None = None,
        epochs=100,
        segment_size=17280,
        filter_length=N_FFT,
        hop_length=HOP_LENGTH,
        n_mel_channels=N_MELS,
        win_length=WIN_LENGTH,
        mel_fmin=0.0,
        mel_fmax: float | None = None,
        c_mel=45,
        c_kl=1.0,
        upload_to_hub: str | None = None,
        upload_window_minutes=5,
    ):
        if accelerator is None:
            accelerator = Accelerator()

        if accelerator.is_main_process:
            logger.info("Start training")

        upload_state_last = 0.0

        prev_loss_gen = -1.0
        prev_loss_fm = -1.0
        prev_loss_mel = -1.0
        prev_loss_kl = -1.0
        prev_loss_disc = -1.0
        prev_loss_gen_all = -1.0

        with accelerator.autocast():
            epoch_iterator = tqdm(
                range(1, epochs + 1),
                desc="Training",
                disable=not accelerator.is_main_process,
            )
            for epoch in epoch_iterator:
                if epoch <= finished_epoch:
                    continue

                G.train()
                D.train()

                epoch_loss_gen = 0.0
                epoch_loss_fm = 0.0
                epoch_loss_mel = 0.0
                epoch_loss_kl = 0.0
                epoch_loss_disc = 0.0
                epoch_loss_gen_all = 0.0
                num_batches = 0

                batch_iterator = tqdm(
                    loader_train,
                    desc=f"Epoch {epoch}",
                    leave=False,
                    disable=not accelerator.is_main_process,
                )
                for batch in batch_iterator:
                    (
                        phone,
                        phone_lengths,
                        pitch,
                        pitchf,
                        spec,
                        spec_lengths,
                        wave,
                        wave_lengths,
                        sid,
                    ) = batch

                    # Generator
                    optimizer_G.zero_grad()
                    (
                        y_hat,
                        ids_slice,
                        x_mask,
                        z_mask,
                        (z, z_p, m_p, logs_p, m_q, logs_q),
                    ) = G(
                        phone,
                        phone_lengths,
                        pitch,
                        pitchf,
                        spec,
                        spec_lengths,
                        sid,
                    )
                    mel = spec_to_mel_torch(
                        spec,
                        filter_length,
                        n_mel_channels,
                        self.sr,
                        mel_fmin,
                        mel_fmax,
                    )
                    y_mel = commons.slice_segments(
                        mel, ids_slice, segment_size // hop_length
                    )
                    y_hat_mel = mel_spectrogram_torch(
                        y_hat.squeeze(1),
                        filter_length,
                        n_mel_channels,
                        self.sr,
                        hop_length,
                        win_length,
                        mel_fmin,
                        mel_fmax,
                    )
                    wave = commons.slice_segments(
                        wave, ids_slice * hop_length, segment_size
                    )

                    # Discriminator
                    optimizer_D.zero_grad()
                    y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = D(wave, y_hat.detach())

                    # Update Discriminator
                    loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
                        y_d_hat_r, y_d_hat_g
                    )
                    accelerator.backward(loss_disc)
                    optimizer_D.step()

                    # Re-compute discriminator output (since we just got a "better" discriminator)
                    y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = D(wave, y_hat)

                    # Update Generator
                    loss_gen, losses_gen = generator_loss(y_d_hat_g)
                    loss_mel = F.l1_loss(y_mel, y_hat_mel) * c_mel
                    loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * c_kl
                    loss_fm = feature_loss(fmap_r, fmap_g)
                    loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
                    accelerator.backward(loss_gen_all)
                    optimizer_G.step()

                    prev_loss_gen = loss_gen.item()
                    prev_loss_fm = loss_fm.item()
                    prev_loss_mel = loss_mel.item()
                    prev_loss_kl = loss_kl.item()
                    prev_loss_disc = loss_disc.item()
                    prev_loss_gen_all = loss_gen_all.item()

                    # Update progress bar with current losses
                    if accelerator.is_main_process:
                        batch_iterator.set_postfix(
                            {
                                "g_loss": f"{prev_loss_gen:.4f}",
                                "d_loss": f"{prev_loss_disc:.4f}",
                                "mel_loss": f"{prev_loss_mel:.4f}",
                                "total": f"{prev_loss_gen_all:.4f}",
                            }
                        )

                    epoch_loss_gen += prev_loss_gen
                    epoch_loss_fm += prev_loss_fm
                    epoch_loss_mel += prev_loss_mel
                    epoch_loss_kl += prev_loss_kl
                    epoch_loss_disc += prev_loss_disc
                    epoch_loss_gen_all += prev_loss_gen_all
                    num_batches += 1

                scheduler_G.step()
                scheduler_D.step()

                if accelerator.is_main_process and num_batches > 0:
                    avg_gen = epoch_loss_gen / num_batches
                    avg_disc = epoch_loss_disc / num_batches
                    avg_fm = epoch_loss_fm / num_batches
                    avg_mel = epoch_loss_mel / num_batches
                    avg_kl = epoch_loss_kl / num_batches
                    avg_total = epoch_loss_gen_all / num_batches

                    logger.info(
                        f"Epoch {epoch} | "
                        f"Generator Loss: {avg_gen:.4f} | "
                        f"Discriminator Loss: {avg_disc:.4f} | "
                        f"Mel Loss: {avg_mel:.4f} | "
                        f"Total Loss: {avg_total:.4f}"
                    )

                    # Update epoch progress bar
                    epoch_iterator.set_postfix(
                        {
                            "g_loss": f"{avg_gen:.4f}",
                            "d_loss": f"{avg_disc:.4f}",
                            "total": f"{avg_total:.4f}",
                        }
                    )

                    self.writer.add_scalar("Loss/Generator", avg_gen, epoch)
                    self.writer.add_scalar("Loss/Feature_Matching", avg_fm, epoch)
                    self.writer.add_scalar("Loss/Mel", avg_mel, epoch)
                    self.writer.add_scalar("Loss/KL", avg_kl, epoch)
                    self.writer.add_scalar("Loss/Discriminator", avg_disc, epoch)
                    self.writer.add_scalar("Loss/Generator_Total", avg_total, epoch)
                    self.writer.add_scalar(
                        "Learning_Rate/Generator",
                        scheduler_G.get_last_lr()[0],
                        epoch,
                    )
                    self.writer.add_scalar(
                        "Learning_Rate/Discriminator",
                        scheduler_D.get_last_lr()[0],
                        epoch,
                    )

                if loader_test is not None:
                    with torch.no_grad():
                        sample_idx = 0
                        test_iterator = tqdm(
                            loader_test,
                            desc=f"Testing epoch {epoch}",
                            leave=False,
                            disable=not accelerator.is_main_process,
                        )
                        for batch_idx, (
                            phone,
                            phone_lengths,
                            pitch,
                            pitchf,
                            spec,
                            spec_lengths,
                            wave,
                            wave_lengths,
                            sid,
                        ) in enumerate(test_iterator):
                            # Generate audio for each sample in the batch
                            audio_segments = G.infer(
                                phone, phone_lengths, pitch, pitchf, sid
                            )[0]

                            # Log each audio sample in the batch
                            for i, audio in enumerate(audio_segments):
                                audio_numpy = audio[0].data.cpu().float().numpy()
                                self.writer.add_audio(
                                    f"Audio/{sample_idx}",
                                    audio_numpy,
                                    epoch,
                                    sample_rate=self.sr,
                                )
                                sample_idx += 1

                res = TrainingCheckpoint(
                    epoch,
                    G,
                    D,
                    optimizer_G,
                    optimizer_D,
                    scheduler_G,
                    scheduler_D,
                    prev_loss_gen,
                    prev_loss_fm,
                    prev_loss_mel,
                    prev_loss_kl,
                    prev_loss_gen_all,
                    prev_loss_disc,
                )

                res.save(self.exp_dir)
                G.save_pretrained(self.exp_dir)

                if upload_to_hub is not None:
                    if (
                        time.time() - upload_state_last > 60 * upload_window_minutes
                        or epoch == epochs
                    ):
                        try:
                            self.push_to_hub(upload_to_hub)
                            upload_state_last = time.time()
                        except Exception:
                            logger.error(f"Failed to upload to Hub.", exc_info=1)
                    else:
                        next_upload = 60 * upload_window_minutes - (
                            time.time() - upload_state_last
                        )
                        logger.info(
                            f"Skipping upload to Hub (next upload in {next_upload:.0f} seconds)"
                        )

    def train(
        self,
        resume_from: Tuple[str, str] | None = None,
        accelerator: Accelerator | None = None,
        batch_size=1,
        epochs=100,
        lr=1e-4,
        lr_decay=0.999875,
        betas: Tuple[float, float] = (0.8, 0.99),
        eps=1e-9,
        use_spectral_norm=False,
        segment_size=17280,
        filter_length=N_FFT,
        hop_length=HOP_LENGTH,
        inter_channels=192,
        hidden_channels=192,
        filter_channels=768,
        n_heads=2,
        n_layers=6,
        kernel_size=3,
        p_dropout=0.0,
        resblock: Literal["1", "2"] = "1",
        resblock_kernel_sizes: list[int] = [3, 7, 11],
        resblock_dilation_sizes: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
        upsample_initial_channel=512,
        upsample_rates: list[int] = [12, 10, 2, 2],
        upsample_kernel_sizes: list[int] = [24, 20, 4, 4],
        spk_embed_dim=109,
        gin_channels=256,
        n_mel_channels=N_MELS,
        win_length=WIN_LENGTH,
        mel_fmin=0.0,
        mel_fmax: float | None = None,
        c_mel=45,
        c_kl=1.0,
        upload_to_hub: str | None = None,
    ):
        if not os.path.exists(self.exp_dir):
            os.makedirs(self.exp_dir)

        if accelerator is None:
            accelerator = Accelerator()

        (
            G,
            D,
            optimizer_G,
            optimizer_D,
            scheduler_G,
            scheduler_D,
            finished_epoch,
        ) = self.setup_models(
            resume_from=resume_from or self.latest_checkpoint(),
            accelerator=accelerator,
            lr=lr,
            lr_decay=lr_decay,
            betas=betas,
            eps=eps,
            use_spectral_norm=use_spectral_norm,
            segment_size=segment_size,
            filter_length=filter_length,
            hop_length=hop_length,
            inter_channels=inter_channels,
            hidden_channels=hidden_channels,
            filter_channels=filter_channels,
            n_heads=n_heads,
            n_layers=n_layers,
            kernel_size=kernel_size,
            p_dropout=p_dropout,
            resblock=resblock,
            resblock_kernel_sizes=resblock_kernel_sizes,
            resblock_dilation_sizes=resblock_dilation_sizes,
            upsample_initial_channel=upsample_initial_channel,
            upsample_rates=upsample_rates,
            upsample_kernel_sizes=upsample_kernel_sizes,
            spk_embed_dim=spk_embed_dim,
            gin_channels=gin_channels,
        )

        loader_train = self.setup_dataloader(
            self.dataset_train,
            batch_size=batch_size,
            accelerator=accelerator,
        )

        loader_test = (
            self.setup_dataloader(
                self.dataset_test,
                batch_size=batch_size,
                accelerator=accelerator,
                shuffle=False,
            )
            if self.dataset_test is not None
            else None
        )

        return self.run(
            G,
            D,
            optimizer_G,
            optimizer_D,
            scheduler_G,
            scheduler_D,
            finished_epoch,
            loader_train,
            loader_test,
            accelerator,
            epochs=epochs,
            segment_size=segment_size,
            filter_length=filter_length,
            hop_length=hop_length,
            n_mel_channels=n_mel_channels,
            win_length=win_length,
            mel_fmin=mel_fmin,
            mel_fmax=mel_fmax,
            c_mel=c_mel,
            c_kl=c_kl,
            upload_to_hub=upload_to_hub,
        )

    def push_to_hub(self, repo: str, private: bool = True):
        if not os.path.exists(self.exp_dir):
            raise FileNotFoundError("exp_dir not found")

        api = HfApi()
        repo_id = api.create_repo(repo_id=repo, private=private, exist_ok=True).repo_id

        return upload_folder(
            repo_id=repo_id,
            folder_path=self.exp_dir,
            commit_message="Upload via ZeroRVC",
        )

    def __del__(self):
        if hasattr(self, "writer"):
            self.writer.close()