from torch import nn
from transformers import AutoConfig, AutoModel, AutoTokenizer
import torch


def weight_init_normal(module, model):
    if isinstance(module, nn.Linear):
        module.weight.data.normal_(mean=0.0, std=model.config.initializer_range)
        if module.bias is not None:
            module.bias.data.zero_()
    elif isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=model.config.initializer_range)
        if module.padding_idx is not None:
            module.weight.data[module.padding_idx].zero_()
    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)



class MeanPooling(nn.Module):
    def __init__(self):
        super(MeanPooling, self).__init__()
        
    def forward(self, last_hidden_state, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        mean_embeddings = sum_embeddings / sum_mask
        return mean_embeddings


class MeanPoolingLayer(nn.Module):
    def __init__(self, 
        hidden_size,
        target_size,
        dropout = 0,
    ):
        super(MeanPoolingLayer, self).__init__()
        self.pool = MeanPooling()
        self.fc = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, target_size),
            nn.Sigmoid()
        )
        
    def forward(self, inputs, mask):
        last_hidden_states = inputs[0]
        feature = self.pool(last_hidden_states, mask)
        outputs = self.fc(feature)
        return outputs



class HSLanguageModel(nn.Module):
    def __init__(self,
        backbone = 'microsoft/deberta-v3-small',
        target_size = 1,
        head_dropout = 0,
        reinit_nlayers = 0,
        freeze_nlayers = 0,
        reinit_head = True,
        grad_checkpointing = False,
    ):
        super(HSLanguageModel, self).__init__()
        
        self.config = AutoConfig.from_pretrained(backbone, output_hidden_states=True)
        self.model = AutoModel.from_pretrained(backbone, config=self.config)
        self.head = MeanPoolingLayer(
            self.config.hidden_size,
            target_size,
            head_dropout
        )
        self.tokenizer = AutoTokenizer.from_pretrained(backbone);
        
        
        if grad_checkpointing == True:
            print('Gradient ckpt enabled')
            self.model.gradient_checkpointing_enable()
            
        if reinit_nlayers > 0:
            # Reinit last n encoder layers
            # [TODO] Check if it is autoencoding model: Bert, Roberta, DistilBert, Albert, XLMRoberta, BertModel
            for layer in self.model.encoder.layer[-reinit_nlayers:]: 
                self._init_weights(layer)
        
        if freeze_nlayers > 0:
            self.model.embeddings.requires_grad_(False)
            self.model.encoder.layer[:freeze_nlayers].requires_grad_(False)
        
        if reinit_head:
            # Reinit layers in head
            self._init_weights(self.head)
        
        
    def _init_weights(self, layer):
        for module in layer.modules():
            init_fn = weight_init_normal
            init_fn(module, self)
    

    def forward(self, inputs):
        outputs = self.model(**inputs)
        outputs = self.head(outputs, inputs['attention_mask'])
        return outputs


if __name__ == '__main__':
    
    model = HSLanguageModel()