# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch
import torch.nn as nn
from torchaudio.models import Conformer
from models.svc.transformer.transformer import PositionalEncoding

from utils.f0 import f0_to_coarse


class ContentEncoder(nn.Module):
    def __init__(self, cfg, input_dim, output_dim):
        super().__init__()
        self.cfg = cfg

        assert input_dim != 0
        self.nn = nn.Linear(input_dim, output_dim)

        # Introduce conformer or not
        if (
            "use_conformer_for_content_features" in cfg
            and cfg.use_conformer_for_content_features
        ):
            self.pos_encoder = PositionalEncoding(input_dim)
            self.conformer = Conformer(
                input_dim=input_dim,
                num_heads=2,
                ffn_dim=256,
                num_layers=6,
                depthwise_conv_kernel_size=3,
            )
        else:
            self.conformer = None

    def forward(self, x, length=None):
        # x: (N, seq_len, input_dim) -> (N, seq_len, output_dim)
        if self.conformer:
            x = self.pos_encoder(x)
            x, _ = self.conformer(x, length)
        return self.nn(x)


class MelodyEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.input_dim = self.cfg.input_melody_dim
        self.output_dim = self.cfg.output_melody_dim
        self.n_bins = self.cfg.n_bins_melody
        self.pitch_min = self.cfg.pitch_min
        self.pitch_max = self.cfg.pitch_max

        if self.input_dim != 0:
            if self.n_bins == 0:
                # Not use quantization
                self.nn = nn.Linear(self.input_dim, self.output_dim)
            else:
                self.f0_min = cfg.f0_min
                self.f0_max = cfg.f0_max

                self.nn = nn.Embedding(
                    num_embeddings=self.n_bins,
                    embedding_dim=self.output_dim,
                    padding_idx=None,
                )
                self.uv_embedding = nn.Embedding(2, self.output_dim)
                # self.conformer = Conformer(
                #     input_dim=self.output_dim,
                #     num_heads=4,
                #     ffn_dim=128,
                #     num_layers=4,
                #     depthwise_conv_kernel_size=3,
                # )

    def forward(self, x, uv=None, length=None):
        # x: (N, frame_len)
        # print(x.shape)
        if self.n_bins == 0:
            x = x.unsqueeze(-1)
        else:
            x = f0_to_coarse(x, self.n_bins, self.f0_min, self.f0_max)
            x = self.nn(x)
            if uv is not None:
                uv = self.uv_embedding(uv)
                x = x + uv
            # x, _ = self.conformer(x, length)
        return x


class LoudnessEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.input_dim = self.cfg.input_loudness_dim
        self.output_dim = self.cfg.output_loudness_dim
        self.n_bins = self.cfg.n_bins_loudness

        if self.input_dim != 0:
            if self.n_bins == 0:
                # Not use quantization
                self.nn = nn.Linear(self.input_dim, self.output_dim)
            else:
                # TODO: set trivially now
                self.loudness_min = 1e-30
                self.loudness_max = 1.5

                if cfg.use_log_loudness:
                    self.energy_bins = nn.Parameter(
                        torch.exp(
                            torch.linspace(
                                np.log(self.loudness_min),
                                np.log(self.loudness_max),
                                self.n_bins - 1,
                            )
                        ),
                        requires_grad=False,
                    )

                self.nn = nn.Embedding(
                    num_embeddings=self.n_bins,
                    embedding_dim=self.output_dim,
                    padding_idx=None,
                )

    def forward(self, x):
        # x: (N, frame_len)
        if self.n_bins == 0:
            x = x.unsqueeze(-1)
        else:
            x = torch.bucketize(x, self.energy_bins)
        return self.nn(x)


class SingerEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.input_dim = 1
        self.output_dim = self.cfg.output_singer_dim

        self.nn = nn.Embedding(
            num_embeddings=cfg.singer_table_size,
            embedding_dim=self.output_dim,
            padding_idx=None,
        )

    def forward(self, x):
        # x: (N, 1) -> (N, 1, output_dim)
        return self.nn(x)


class ConditionEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.merge_mode = cfg.merge_mode

        if cfg.use_whisper:
            self.whisper_encoder = ContentEncoder(
                self.cfg, self.cfg.whisper_dim, self.cfg.content_encoder_dim
            )

        if cfg.use_contentvec:
            self.contentvec_encoder = ContentEncoder(
                self.cfg, self.cfg.contentvec_dim, self.cfg.content_encoder_dim
            )

        if cfg.use_mert:
            self.mert_encoder = ContentEncoder(
                self.cfg, self.cfg.mert_dim, self.cfg.content_encoder_dim
            )

        if cfg.use_wenet:
            self.wenet_encoder = ContentEncoder(
                self.cfg, self.cfg.wenet_dim, self.cfg.content_encoder_dim
            )

        self.melody_encoder = MelodyEncoder(self.cfg)
        self.loudness_encoder = LoudnessEncoder(self.cfg)
        if cfg.use_spkid:
            self.singer_encoder = SingerEncoder(self.cfg)

    def forward(self, x):
        outputs = []

        if "frame_pitch" in x.keys():
            if "frame_uv" not in x.keys():
                x["frame_uv"] = None
            pitch_enc_out = self.melody_encoder(
                x["frame_pitch"], uv=x["frame_uv"], length=x["target_len"]
            )
            outputs.append(pitch_enc_out)

        if "frame_energy" in x.keys():
            loudness_enc_out = self.loudness_encoder(x["frame_energy"])
            outputs.append(loudness_enc_out)

        if "whisper_feat" in x.keys():
            # whisper_feat: [b, T, 1024]
            whiser_enc_out = self.whisper_encoder(
                x["whisper_feat"], length=x["target_len"]
            )
            outputs.append(whiser_enc_out)
            seq_len = whiser_enc_out.shape[1]

        if "contentvec_feat" in x.keys():
            contentvec_enc_out = self.contentvec_encoder(
                x["contentvec_feat"], length=x["target_len"]
            )
            outputs.append(contentvec_enc_out)
            seq_len = contentvec_enc_out.shape[1]

        if "mert_feat" in x.keys():
            mert_enc_out = self.mert_encoder(x["mert_feat"], length=x["target_len"])
            outputs.append(mert_enc_out)
            seq_len = mert_enc_out.shape[1]

        if "wenet_feat" in x.keys():
            wenet_enc_out = self.wenet_encoder(x["wenet_feat"], length=x["target_len"])
            outputs.append(wenet_enc_out)
            seq_len = wenet_enc_out.shape[1]

        if "spk_id" in x.keys():
            speaker_enc_out = self.singer_encoder(x["spk_id"])  # [b, 1, 384]
            assert (
                "whisper_feat" in x.keys()
                or "contentvec_feat" in x.keys()
                or "mert_feat" in x.keys()
                or "wenet_feat" in x.keys()
            )
            singer_info = speaker_enc_out.expand(-1, seq_len, -1)
            outputs.append(singer_info)

        encoder_output = None
        if self.merge_mode == "concat":
            encoder_output = torch.cat(outputs, dim=-1)
        if self.merge_mode == "add":
            # (#modules, N, seq_len, output_dim)
            outputs = torch.cat([out[None, :, :, :] for out in outputs], dim=0)
            # (N, seq_len, output_dim)
            encoder_output = torch.sum(outputs, dim=0)

        return encoder_output