from typing import Mapping
import torch
import math
from speechbrain.inference.interfaces import Pretrained


class AttentionMLP(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(AttentionMLP, self).__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, 1, bias=False),
        )

    def forward(self, x):
        x = self.layers(x)
        att_w = torch.nn.functional.softmax(x, dim=2)
        return att_w


class Discrete_EmbeddingLayer(torch.nn.Module):
    """This class handles embedding layers  for discrete tokens.

    Arguments
    ---------
    num_codebooks: int ,
        number of codebooks of the tokenizer.
    vocab_size : int,
        size of the dictionary of embeddings
    emb_dim: int ,
        the size of each embedding vector
    pad_index: int (default: 0),
        If specified, the entries at padding_idx do not contribute to the gradient.
    init: boolean (default: False):
        If set to True, init the embedding with the tokenizer embedding otherwise init randomly.
    freeze: boolean (default: False)
        If True, the embedding is frozen. If False, the model will be trained
        alongside with the rest of the pipeline.
    chunk_size: int
        The size of lengthwize chunks use when evaluating via
        Gumbel softmax

    Example
    -------
    >>> from speechbrain.lobes.models.huggingface_transformers.encodec import Encodec
    >>> model_hub = "facebook/encodec_24khz"
    >>> save_path = "savedir"
    >>> model = Encodec(model_hub, save_path)
    >>> audio = torch.randn(4, 1000)
    >>> length = torch.tensor([1.0, .5, .75, 1.0])
    >>> tokens, emb = model.encode(audio, length)
    >>> print(tokens.shape)
    torch.Size([4, 4, 2])
    >>> emb= Discrete_EmbeddingLayer(2, 1024, 1024)
    >>> in_emb = emb(tokens)
    >>> print(in_emb.shape)
    torch.Size([4, 4, 2, 1024])
    """

    def __init__(
        self,
        num_codebooks,
        vocab_size,
        emb_dim,
        pad_index=0,
        init=False,
        freeze=False,
        available_layers=None,
        layers=None,
        chunk_size=100,
    ):
        super(Discrete_EmbeddingLayer, self).__init__()
        self.vocab_size = vocab_size
        self.num_codebooks = num_codebooks
        self.freeze = freeze
        self.embedding = torch.nn.Embedding(
            num_codebooks * vocab_size, emb_dim
        ).requires_grad_(not self.freeze)
        self.init = init
        self.layers = layers
        self.available_layers = available_layers
        self.register_buffer("offsets", self.build_offsets())
        self.register_buffer("layer_embs", self.compute_layer_embs())
        self.chunk_size = chunk_size

    def init_embedding(self, weights):
        with torch.no_grad():
            self.embedding.weight = torch.nn.Parameter(weights)

    def build_offsets(self):
        offsets = torch.arange(
            0,
            self.num_codebooks * self.vocab_size,
            self.vocab_size,
        )
        if self.layers:
            selected_layers = set(self.layers)
            indexes = [
                idx for idx, layer in enumerate(self.available_layers)
                if layer in selected_layers
            ]
            offsets = offsets[indexes]
        return offsets

    def forward(self, in_tokens):
        """Computes the embedding for discrete tokens.
        a sample.

        Arguments
        ---------
        in_tokens : torch.Tensor
            A (Batch x Time x num_codebooks)
            audio sample
        Returns
        -------
        in_embs : torch.Tensor
        """
        with torch.set_grad_enabled(not self.freeze):
            #  Add unique token IDs across diffrent codebooks by adding num_codebooks * vocab_size
            in_tokens_offset = in_tokens + self.offsets.to(in_tokens.device)
            # Forward Pass to embedding and
            in_embs = self.embedding(in_tokens_offset.int())
            return in_embs

    def compute_layer_embs(self):
        weight = self.embedding.weight

        # Compute offsets
        layer_idx_map = {
            layer: idx
            for idx, layer in enumerate(self.available_layers)
        }
        layer_idx = [
            layer_idx_map[layer]
            for layer in self.layers
        ]

        offsets = [
            idx * self.vocab_size
            for idx in layer_idx
        ]

        layer_embs = torch.stack([
            weight[offset:offset + self.vocab_size]
            for offset in offsets
        ])

        # To (Batch x Length x Emb)
        layer_embs = layer_embs.unsqueeze(0).unsqueeze(0)
        return layer_embs

    def encode_logits(self, logits, length=None):
        """Computes waveforms from a batch of discrete units
        Arguments
        ---------
        units: torch.tensor
            Batch of discrete unit logits [batch, length, head, token]
            or tokens [batch, length, head]
        spk: torch.tensor
            Batch of speaker embeddings [batch, spk_dim]
        Returns
        -------
        waveforms: torch.tensor
            Batch of mel-waveforms [batch, 1, time]
        """

        # Convert logits to one-hot representations
        # without losing the gradient
        units_gumbel = torch.nn.functional.gumbel_softmax(
            logits,
            hard=False,
            dim=-1
        )

        # Straight-through trick
        _, argmax_idx = logits.max(dim=-1, keepdim=True)
        units_ref = torch.zeros_like(logits).scatter_(
            dim=-1, index=argmax_idx, src=torch.ones_like(logits)
        )
        units_hard = units_ref - units_gumbel.detach() + units_gumbel

        # Sum over embeddings for each layer
        units_hard_chunked = units_hard.chunk(
            math.ceil(units_hard.size(1) / self.chunk_size),
            dim=1
        )
        emb = torch.cat(
            [
                (self.layer_embs * units_hard_chunk.unsqueeze(-1)).sum(-2)
                for units_hard_chunk in units_hard_chunked
            ],
            dim=1
        )
        return emb

    def load_state_dict(self, state_dict, strict=True):
        result = super().load_state_dict(state_dict, strict)
        self.layer_embs = self.compute_layer_embs()
        return result


class DiscreteSpkEmb(Pretrained):
    """A ready-to-use class for utterance-level classification (e.g, speaker-id,
    language-id, emotion recognition, keyword spotting, etc).
    The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model
    are defined in the yaml file. If you want to
    convert the predicted index into a corresponding text label, please
    provide the path of the label_encoder in a variable called 'lab_encoder_file'
    within the yaml.
    The class can be used either to run only the encoder (encode_batch()) to
    extract embeddings or to run a classification step (classify_batch()).
    ```
    Example
    -------
    >>> import torchaudio
    >>> from speechbrain.pretrained import EncoderClassifier
    >>> # Model is downloaded from the speechbrain HuggingFace repo
    >>> tmpdir = getfixture("tmpdir")
    >>> classifier = EncoderClassifier.from_hparams(
    ...     source="speechbrain/spkrec-ecapa-voxceleb",
    ...     savedir=tmpdir,
    ... )
    >>> # Compute embeddings
    >>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav")
    >>> embeddings =  classifier.encode_batch(signal)
    >>> # Classification
    >>> prediction =  classifier .classify_batch(signal)
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def encode_batch(self, audio, length=None):
        """Encodes the input audio into a single vector embedding.
        The waveforms should already be in the model's desired format.
        Arguments
        ---------
        audio : torch.tensor
            Batch of tokenized audio [batch, time, heads] 
        length : torch.tensor
            Lengths of the waveforms relative to the longest one in the
            batch, tensor of shape [batch]. The longest one should have
            relative length 1.0 and others len(waveform) / max_length.
            Used for ignoring padding.

        Returns
        -------
        torch.tensor
            The encoded batch
        """
        # Manage single waveforms in input
        embeddings = self.mods.discrete_embedding_layer(audio)
        att_w = self.mods.attention_mlp(embeddings)
        feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2)
        embeddings = self.mods.embedding_model(feats, length)
        return embeddings.squeeze(1)
    
    def encode_logits(self, logits, length=None):
        """Encodes the input audio logits into a single vector embedding.

        Arguments
        ---------
        audio : torch.tensor
            Batch of tokenized audio [batch, time, heads] 
        length : torch.tensor
            Lengths of the waveforms relative to the longest one in the
            batch, tensor of shape [batch]. The longest one should have
            relative length 1.0 and others len(waveform) / max_length.
            Used for ignoring padding.

        Returns
        -------
        torch.tensor
            The encoded batch
        """        
        embeddings = self.mods.discrete_embedding_layer.encode_logits(logits)
        att_w = self.mods.attention_mlp(embeddings)
        feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2)
        embeddings = self.mods.embedding_model(feats, length)
        return embeddings.squeeze(1)

    def forward(self, audio, length=None):
        """Encodes the input audio into a single vector embedding.
        The waveforms should already be in the model's desired format.
        Arguments
        ---------
        audio : torch.tensor
            Batch of tokenized audio [batch, time, heads] 
            or logits [batch, time, heads, tokens]
        length : torch.tensor
            Lengths of the waveforms relative to the longest one in the
            batch, tensor of shape [batch]. The longest one should have
            relative length 1.0 and others len(waveform) / max_length.
            Used for ignoring padding.

        Returns
        -------
        torch.tensor
            The encoded batch
        """        
        audio_dim = audio.dim()
        if audio_dim == 3:
            embeddings = self.encode_batch(audio, length)
        elif audio_dim == 4:
            embeddings = self.encode_logits(audio, length)
        else:
            raise ValueError("Unsupported audio shape {audio.shape}")
        return embeddings