import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class CNNPrenet(torch.nn.Module):
    def __init__(self):
        super(CNNPrenet, self).__init__()

        # Define the layers using Sequential container
        self.conv_layers = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.1),

            nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.1),

            nn.Conv1d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.1)
        )

    def forward(self, x):

        # Add a new dimension for the channel
        x = x.unsqueeze(1)

        # Pass input through convolutional layers
        x = self.conv_layers(x)

        # Remove the channel dimension
        x = x.squeeze(1)

        # Scale the output to the range [-1, 1]
        x = torch.tanh(x)

        return x



class CNNDecoderPrenet(nn.Module):
    def __init__(self, input_dim=80, hidden_dim=256, output_dim=256, final_dim=512, dropout_rate=0.5):
        super(CNNDecoderPrenet, self).__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, output_dim)
        self.linear_projection = nn.Linear(output_dim, final_dim) # Added linear projection
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):

      # Transpose the input tensor to have the feature dimension as the last dimension
      x = x.transpose(1, 2)
      # Apply the linear layers
      x = F.relu(self.layer1(x))
      x = self.dropout(x)
      x = F.relu(self.layer2(x))
      x = self.dropout(x)
      # Apply the linear projection
      x = self.linear_projection(x)
      x = x.transpose(1, 2)

      return x




class CNNPostNet(torch.nn.Module):
    """
    Conv Postnet
    Arguments
    ---------
    n_mel_channels: int
       input feature dimension for convolution layers
    postnet_embedding_dim: int
       output feature dimension for convolution layers
    postnet_kernel_size: int
       postnet convolution kernal size
    postnet_n_convolutions: int
       number of convolution layers
    postnet_dropout: float
        dropout probability fot postnet
    """

    def __init__(
        self,
        n_mel_channels=80,
        postnet_embedding_dim=512,
        postnet_kernel_size=5,
        postnet_n_convolutions=5,
        postnet_dropout=0.1,
    ):
        super(CNNPostNet, self).__init__()

        self.conv_pre = nn.Conv1d(
            in_channels=n_mel_channels,
            out_channels=postnet_embedding_dim,
            kernel_size=postnet_kernel_size,
            padding="same",
        )

        self.convs_intermedite = nn.ModuleList()
        for i in range(1, postnet_n_convolutions - 1):
            self.convs_intermedite.append(
                nn.Conv1d(
                    in_channels=postnet_embedding_dim,
                    out_channels=postnet_embedding_dim,
                    kernel_size=postnet_kernel_size,
                    padding="same",
                ),
            )

        self.conv_post = nn.Conv1d(
            in_channels=postnet_embedding_dim,
            out_channels=n_mel_channels,
            kernel_size=postnet_kernel_size,
            padding="same",
        )

        self.tanh = nn.Tanh()
        self.ln1 = nn.LayerNorm(postnet_embedding_dim)
        self.ln2 = nn.LayerNorm(postnet_embedding_dim)
        self.ln3 = nn.LayerNorm(n_mel_channels)
        self.dropout1 = nn.Dropout(postnet_dropout)
        self.dropout2 = nn.Dropout(postnet_dropout)
        self.dropout3 = nn.Dropout(postnet_dropout)


    def forward(self, x):
        """Computes the forward pass
        Arguments
        ---------
        x: torch.Tensor
            a (batch, time_steps, features) input tensor
        Returns
        -------
        output: torch.Tensor (the spectrogram predicted)
        """
        x = self.conv_pre(x)
        x = self.ln1(x.permute(0, 2, 1)).permute(0, 2, 1)  # Transpose to [batch_size, feature_dim, sequence_length]
        x = self.tanh(x)
        x = self.dropout1(x)

        for i in range(len(self.convs_intermedite)):
            x = self.convs_intermedite[i](x)
        x = self.ln2(x.permute(0, 2, 1)).permute(0, 2, 1)  # Transpose to [batch_size, feature_dim, sequence_length]
        x = self.tanh(x)
        x = self.dropout2(x)

        x = self.conv_post(x)
        x = self.ln3(x.permute(0, 2, 1)).permute(0, 2, 1)  # Transpose to [batch_size, feature_dim, sequence_length]
        x = self.dropout3(x)

        return x


class ScaledPositionalEncoding(nn.Module):
    """
    This class implements the absolute sinusoidal positional encoding function
    with an adaptive weight parameter alpha.

    PE(pos, 2i)   = sin(pos/(10000^(2i/dmodel)))
    PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))

    Arguments
    ---------
    input_size: int
        Embedding dimension.
    max_len : int, optional
        Max length of the input sequences (default 2500).
    Example
    -------
    >>> a = torch.rand((8, 120, 512))
    >>> enc = PositionalEncoding(input_size=a.shape[-1])
    >>> b = enc(a)
    >>> b.shape
    torch.Size([1, 120, 512])
    """

    def __init__(self, input_size, max_len=2500):
        super().__init__()
        if input_size % 2 != 0:
            raise ValueError(
                f"Cannot use sin/cos positional encoding with odd channels (got channels={input_size})"
            )
        self.max_len = max_len
        self.alpha = nn.Parameter(torch.ones(1))  # Define alpha as a trainable parameter
        pe = torch.zeros(self.max_len, input_size, requires_grad=False)
        positions = torch.arange(0, self.max_len).unsqueeze(1).float()
        denominator = torch.exp(
            torch.arange(0, input_size, 2).float()
            * -(math.log(10000.0) / input_size)
        )

        pe[:, 0::2] = torch.sin(positions * denominator)
        pe[:, 1::2] = torch.cos(positions * denominator)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        """
        Arguments
        ---------
        x : tensor
            Input feature shape (batch, time, fea)
        """
        pe_scaled = self.pe[:, :x.size(1)].clone().detach() * self.alpha  # Scale positional encoding by alpha
        return pe_scaled