import torch
import torch.nn as nn
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel


biases = False


class Pool2BN(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        self.bn = torch.nn.BatchNorm1d(num_channels * 2)

    def forward(self, x):
        avgp = torch.nn.functional.adaptive_avg_pool1d(x, 1)[:, :, 0]
        maxp = torch.nn.functional.adaptive_max_pool1d(x, 1)[:, :, 0]
        x = torch.cat((avgp, maxp), axis=1)
        x = self.bn(x)
        return x

class MLP(torch.nn.Module):
    def __init__(self, layer_sizes, biases=False, sigmoid=False, dropout=None):
        super().__init__()
        layers = []
        prev_size = layer_sizes[0]
        for i, s in enumerate(layer_sizes[1:]):
            if i != 0 and dropout is not None:
                layers.append(torch.nn.Dropout(dropout))

            layers.append(torch.nn.Linear(in_features=prev_size, out_features=s, bias=biases))
            if i != len(layer_sizes) - 2:
                if sigmoid:
                    # layers.append(torch.nn.Sigmoid())
                    layers.append(torch.nn.Tanh())
                else:
                    layers.append(torch.nn.ReLU())

                layers.append(torch.nn.BatchNorm1d(s))

            prev_size = s

        self.mlp = torch.nn.Sequential(*layers)

    def forward(self, x):

        return self.mlp(x)


class SimpleCNN(torch.nn.Module):
    def __init__(self, k, num_filters, sigmoid=False, additional_layer=False):
        super(SimpleCNN, self).__init__()
        self.sigmoid = sigmoid
        self.cnn = torch.nn.Conv1d(in_channels=4, out_channels=num_filters, kernel_size=k, bias=biases)

        self.additional_layer = additional_layer
        if additional_layer:
            self.bn = nn.BatchNorm1d(num_filters)
            # self.do = nn.Dropout(0.5)
            self.cnn2 = nn.Conv1d(in_channels=num_filters, out_channels=num_filters, kernel_size=1, bias=biases)

        self.post = Pool2BN(num_filters)

    def forward(self, x):
        x = self.cnn(x)
        x = (torch.tanh if self.sigmoid else torch.relu)(x)

        if self.additional_layer:
            x = self.bn(x)
            # x = self.do(x)
            x = self.cnn2(x)
            x = (torch.tanh if self.sigmoid else torch.relu)(x)

        x = self.post(x)
        #print(f'x shape at CNN output: {x.shape}')
        return x


class ResNet1dBlock(torch.nn.Module):
    def __init__(self, num_filters, k1, internal_filters, k2, dropout=None, dilation=None):
        super().__init__()

        self.init_do = torch.nn.Dropout(dropout) if dropout is not None else None
        self.bn1 = torch.nn.BatchNorm1d(num_filters)
        if dilation is None:
            dilation = 1

        self.cnn1 = torch.nn.Conv1d(in_channels=num_filters, out_channels=internal_filters, kernel_size=k1, bias=biases,
                                    dilation=dilation,
                                    padding=(k1 // 2) * dilation)

        self.bn2 = torch.nn.BatchNorm1d(internal_filters)
        self.cnn2 = torch.nn.Conv1d(in_channels=internal_filters, out_channels=num_filters, kernel_size=k2, bias=biases,
                                    padding=k2 // 2)

    def forward(self, x):
        x_orig = x

        x = self.bn1(x)
        x = torch.relu(x)
        if self.init_do is not None:
            x = self.init_do(x)

        x = self.cnn1(x)

        x = self.bn2(x)
        x = torch.relu(x)
        x = self.cnn2(x)

        return x + x_orig


class ResNet1d(torch.nn.Module):
    def __init__(self, num_filters, block_spec, dropout=None, dilation=None):
        super().__init__()

        blocks = [ResNet1dBlock(num_filters, *spec, dropout=dropout, dilation=dilation) for spec in block_spec]
        self.blocks = torch.nn.Sequential(*blocks)

    def forward(self, x):
        return self.blocks(x)



class LogisticRegressionTorch(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super(LogisticRegressionTorch, self).__init__()
        self.batch_norm = nn.BatchNorm1d(num_features=input_dim)
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.batch_norm(x)
        out = self.linear(x)
        return out

class BertClassifier(nn.Module):
    def __init__(self, bert_model: AutoModel, classifier: LogisticRegressionTorch, num_labels: int):
        super(BertClassifier, self).__init__()
        self.bert = bert_model
        self.classifier = classifier
        self.num_labels = num_labels

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None):
        outputs = self.bert(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        pooled_output = outputs.hidden_states[-1][:, 0, :]
        logits = self.classifier(pooled_output)
        return logits