import os
from datetime import datetime
from pathlib import Path

import torch
import typer
from accelerate import Accelerator
from accelerate.utils import LoggerType
from torch import Tensor
from torch.optim import AdamW

# from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from tqdm import tqdm

from data import MusdbDataset
from splitter import Splitter

DISABLE_TQDM = os.environ.get("DISABLE_TQDM", False)
app = typer.Typer(pretty_exceptions_show_locals=False)


def spectrogram_loss(masked_target: Tensor, original: Tensor) -> Tensor:
    """
    masked_target (Tensor): a masked STFT generated by applying a net's
        estimated mask for source S to the ground truth STFT for source S
    original (Tensor): an original input mixture
    """
    square_difference = torch.square(masked_target - original)
    loss_value = torch.mean(square_difference)
    return loss_value


@app.command()
def train(
    dataset: str = "data/musdb18-wav",
    output_dir: str = None,
    fp16: bool = False,
    cpu: bool = True,
    max_steps: int = 100,
    num_train_epochs: int = 1,
    per_device_train_batch_size: int = 1,
    effective_batch_size: int = 4,
    max_grad_norm: float = 0.0,
) -> None:
    if not output_dir:
        now_str = datetime.now().strftime("%Y%m%d-%H%M%S")
        output_dir = f"experiments/{now_str}"
    output_dir = Path(output_dir)
    logging_dir = output_dir / "tracker_logs"
    accelerator = Accelerator(
        fp16=fp16,
        cpu=cpu,
        logging_dir=logging_dir,
        log_with=[LoggerType.TENSORBOARD],
    )
    accelerator.init_trackers(logging_dir / "run")

    train_dataset = MusdbDataset(root=dataset, is_train=True)
    train_dataloader = DataLoader(
        train_dataset,
        shuffle=True,
        batch_size=per_device_train_batch_size,
    )

    model = Splitter(stem_names=[s for s in train_dataset.targets])
    optimizer = AdamW(
        model.parameters(),
        lr=1e-3,
        eps=1e-8,
    )
    model, optimizer, train_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader
    )

    num_train_steps = (
        max_steps if max_steps > 0 else len(train_dataloader) * num_train_epochs
    )
    accelerator.print(f"Num train steps: {num_train_steps}")

    step_batch_size = per_device_train_batch_size * accelerator.num_processes
    gradient_accumulation_steps = max(
        1,
        effective_batch_size // step_batch_size,
    )

    accelerator.print(
        f"Gradient Accumulation Steps: {gradient_accumulation_steps}\nEffective Batch Size: {gradient_accumulation_steps * step_batch_size}"
    )
    global_step = 0
    while global_step < num_train_steps:
        accelerator.wait_for_everyone()
        # accelerator.print(f"global step: {global_step}")
        # accelerator.print("running train...")
        model.train()
        batch_iterator = tqdm(
            train_dataloader,
            desc="Batch",
            disable=((not accelerator.is_local_main_process) or DISABLE_TQDM),
        )
        for batch_idx, batch in enumerate(batch_iterator):
            assert per_device_train_batch_size == 1, "For now limit to 1."
            x_wav, y_target_wavs = batch
            predictions = model(x_wav)
            stem_losses = []
            for name, masked_stft in predictions.items():
                target_stft, _ = model.compute_stft(y_target_wavs[name].squeeze())
                loss = spectrogram_loss(
                    masked_target=masked_stft,
                    original=target_stft,
                )
                stem_losses.append(loss)
                accelerator.log({f"train-loss-{name}": 1.0 * loss}, step=global_step)

            total_loss = (
                torch.sum(torch.stack(stem_losses)) / gradient_accumulation_steps
            )
            accelerator.print(f"global step: {global_step}\tloss: {total_loss:.4f}")
            accelerator.log({f"train-loss": 1.0 * total_loss}, step=global_step)
            accelerator.backward(total_loss)
            if (batch_idx + 1) % gradient_accumulation_steps == 0:
                if max_grad_norm > 0:
                    accelerator.clip_grad_norm_(model.parameters(), max_grad_norm)
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

    accelerator.wait_for_everyone()
    accelerator.end_training()
    accelerator.print(f"Saving model to {output_dir}...")
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(
        output_dir,
        save_function=accelerator.save,
        state_dict=accelerator.get_state_dict(model),
    )

    accelerator.wait_for_everyone()
    accelerator.print("DONE!")


if __name__ == "__main__":
    app()