# from transformers.models.led.modeling_led import LEDEncoder
from transformers import LEDConfig, LEDModel, LEDPreTrainedModel
import torch.nn as nn

# NEED TO REPLACE nn.Module with PreTrainedModel
class CustomLEDForQAModel(LEDPreTrainedModel):
    config_class = LEDConfig
    
    def __init__(self, config: LEDConfig, checkpoint): 

        super().__init__(config)
        config.num_labels = 2
        self.num_labels = config.num_labels

        if (checkpoint):
            self.led = LEDModel.from_pretrained(checkpoint, config=config).get_encoder()
        else:
            self.led = LEDModel(config).get_encoder()
            
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, input_ids=None, attention_mask=None, global_attention_mask=None, start_positions=None, end_positions=None):

        outputs = self.led(input_ids=input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)
        
        logits = self.qa_outputs(outputs.last_hidden_state)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()

        total_loss = None

        if start_positions is not None and end_positions is not None:
            
            loss_fct = nn.CrossEntropyLoss()

            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)

            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            # start_loss = loss_fct(start_logits[index], start_positions[index][0])
            # end_loss = loss_fct(end_logits[index], end_positions[index][0])
            total_loss = (start_loss + end_loss) / 2
            
            

    
        return {
            'loss': total_loss,
            'start_logits': start_logits,
            'end_logits': end_logits,
        }