Praise2112's picture
Upload modelling.py
7f80132 verified
raw
history blame
4.97 kB
from typing import Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import ModernBertModel, ModernBertPreTrainedModel, ModernBertConfig
from transformers.modeling_outputs import QuestionAnsweringModelOutput
from transformers.models.modernbert.modeling_modernbert import ModernBertPredictionHead
class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
def __init__(self, config: ModernBertConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.model = ModernBertModel(config)
self.head = ModernBertPredictionHead(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2) # 2 for start/end position logits
self.qa_outputs.weight.data.normal_(mean=0.0, std=0.02)
self.qa_outputs.bias.data.zero_()
self.drop = torch.nn.Dropout(config.classifier_dropout)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
start_positions: Optional[torch.Tensor] = None,
end_positions: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self._maybe_set_compile()
outputs = self.model(
input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=batch_size,
seq_len=seq_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
sequence_output = outputs[0]
sequence_output = self.drop(self.head(sequence_output))
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)