# from transformers.models.led.modeling_led import LEDEncoder from transformers import LEDConfig, LEDModel, LEDPreTrainedModel import torch.nn as nn # NEED TO REPLACE nn.Module with PreTrainedModel class CustomLEDForQAModel(LEDPreTrainedModel): config_class = LEDConfig def __init__(self, config: LEDConfig, checkpoint): super().__init__(config) config.num_labels = 2 self.num_labels = config.num_labels if (checkpoint): self.led = LEDModel.from_pretrained(checkpoint, config=config).get_encoder() else: self.led = LEDModel(config).get_encoder() self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) def forward(self, input_ids=None, attention_mask=None, global_attention_mask=None, start_positions=None, end_positions=None): outputs = self.led(input_ids=input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask) logits = self.qa_outputs(outputs.last_hidden_state) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1).contiguous() end_logits = end_logits.squeeze(-1).contiguous() total_loss = None if start_positions is not None and end_positions is not None: loss_fct = nn.CrossEntropyLoss() if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) # start_loss = loss_fct(start_logits[index], start_positions[index][0]) # end_loss = loss_fct(end_logits[index], end_positions[index][0]) total_loss = (start_loss + end_loss) / 2 return { 'loss': total_loss, 'start_logits': start_logits, 'end_logits': end_logits, }